@@ -98,24 +98,9 @@ public static void createCocoDirectoryStructure(String baseDir, String type) thr
9898 Files .createDirectories (Paths .get (baseDir + "annotations" ));
9999 Files .createDirectories (Paths .get (baseDir + "images" ));
100100
101- Files .createDirectories (Paths .get (baseDir + "train" ));
102- Files .createDirectories (Paths .get (baseDir + "val" ));
103- Files .createDirectories (Paths .get (baseDir + "test" ));
104-
105- Files .createDirectories (Paths .get (baseDir , "annotations" , "train" ));
106- Files .createDirectories (Paths .get (baseDir , "annotations" , "val" ));
107- Files .createDirectories (Paths .get (baseDir , "annotations" , "test" ));
108-
109- Files .createDirectories (Paths .get (baseDir , "images" , "train" ));
110- Files .createDirectories (Paths .get (baseDir , "images" , "val" ));
111- Files .createDirectories (Paths .get (baseDir , "images" , "test" ));
112-
113101 // 根据类型创建特定目录 detection, classification, segmentation, keypoints, face_keypoints 使用标准结构
114102 if (TaskType .OCR .getType ().equals (type ) || TaskType .ROTATED_DETECTION .getType ().equals (type )) {
115- Files .createDirectories (Paths .get (baseDir , "labels" ));
116- Files .createDirectories (Paths .get (baseDir , "labels" , "train" ));
117- Files .createDirectories (Paths .get (baseDir , "labels" , "val" ));
118- Files .createDirectories (Paths .get (baseDir , "labels" , "test" ));
103+ Files .createDirectories (Paths .get (baseDir + "labels" ));
119104 }
120105 }
121106
@@ -346,18 +331,52 @@ public static void writeToFile(CocoDataset dataset, String outputPath) throws IO
346331
347332
348333 /**
349- * 复制图片文件到指定目录,支持URL和base64两种格式
334+ * 复制图片文件到指定目录,按类别分目录存放, 支持URL和base64两种格式
350335 * @param images 图片信息列表
351- * @param imageDir 目标图片目录
336+ * @param imageDir 目标图片根目录
337+ * @param categories 类别列表
338+ * @param annotations 标注列表
352339 * @throws IOException
353340 */
354- public static void copyImagesToDirectory (List <ImageInfo > images , String imageDir ) throws IOException {
355- // 创建图片目录
341+ public static void copyImagesToDirectory (
342+ List <ImageInfo > images , String imageDir , List <Category > categories , List <Annotation > annotations
343+ ) throws IOException {
344+ // 创建图片根目录
356345 Path imageDirPath = Paths .get (imageDir );
357346 if (!Files .exists (imageDirPath )) {
358347 Files .createDirectories (imageDirPath );
359348 }
360349
350+ // 创建类别ID到名称的映射
351+ Map <Integer , String > categoryIdToName = new HashMap <>();
352+ if (categories != null ) {
353+ for (Category category : categories ) {
354+ categoryIdToName .put (category .getId (), category .getName ());
355+ }
356+ }
357+
358+ // 创建类别目录
359+ for (String categoryName : categoryIdToName .values ()) {
360+ Path categoryDir = Paths .get (imageDir , categoryName );
361+ if (!Files .exists (categoryDir )) {
362+ Files .createDirectories (categoryDir );
363+ }
364+ }
365+
366+ // 创建图片ID到标注类别的映射
367+ Map <Integer , Set <String >> imageIdToCategories = new HashMap <>();
368+ if (annotations != null ) {
369+ for (Annotation annotation : annotations ) {
370+ int imageId = annotation .getImage_id ();
371+ int categoryId = annotation .getCategory_id ();
372+ String categoryName = categoryIdToName .get (categoryId );
373+
374+ if (categoryName != null ) {
375+ imageIdToCategories .computeIfAbsent (imageId , k -> new HashSet <>()).add (categoryName );
376+ }
377+ }
378+ }
379+
361380 for (ImageInfo image : images ) {
362381 String imgSource = image .getImg ();
363382 String fileName = image .getFile_name ();
@@ -367,27 +386,54 @@ public static void copyImagesToDirectory(List<ImageInfo> images, String imageDir
367386 continue ;
368387 }
369388
370- Path targetPath = Paths .get (imageDir , fileName );
371-
372- try {
373- if (imgSource .startsWith ("data:image/" )) {
374- // 处理base64格式
375- copyBase64Image (imgSource , targetPath );
376- } else if (imgSource .startsWith ("http://" ) || imgSource .startsWith ("https://" )) {
377- // 处理URL格式
378- copyUrlImage (imgSource , targetPath );
379- } else {
380- // 处理本地文件路径
381- copyLocalImage (imgSource , targetPath );
382- }
389+ // 获取这张图片的所有相关类别
390+ Set <String > imageCategories = imageIdToCategories .get (image .getId ());
391+ if (imageCategories == null || imageCategories .isEmpty ()) {
392+ // 如果没有标注信息,放到一个默认目录
393+ Path targetPath = Paths .get (imageDir , fileName );
394+ Files .createDirectories (targetPath .getParent ());
383395
384- System .out .println ("Successfully copied image: " + fileName );
385- } catch (Exception e ) {
386- System .err .println ("Failed to copy image " + fileName + ": " + e .getMessage ());
396+ try {
397+ copyImageToPath (imgSource , targetPath );
398+ System .out .println ("Successfully copied unlabeled image: " + fileName );
399+ } catch (Exception e ) {
400+ System .err .println ("Failed to copy image " + fileName + ": " + e .getMessage ());
401+ }
402+ } else {
403+ // 将图片复制到所有相关类别的目录
404+ for (String categoryName : imageCategories ) {
405+ Path targetPath = Paths .get (imageDir , categoryName , fileName );
406+
407+ try {
408+ copyImageToPath (imgSource , targetPath );
409+ System .out .println ("Successfully copied image " + fileName + " to category: " + categoryName );
410+ } catch (Exception e ) {
411+ System .err .println ("Failed to copy image " + fileName + " to category " + categoryName + ": " + e .getMessage ());
412+ }
413+ }
387414 }
388415 }
389416 }
390417
418+ /**
419+ * 复制图片到指定路径的辅助方法
420+ */
421+ private static void copyImageToPath (String imgSource , Path targetPath ) throws IOException {
422+ // 确保目标目录存在
423+ Files .createDirectories (targetPath .getParent ());
424+
425+ if (imgSource .startsWith ("data:image/" )) {
426+ // 处理base64格式
427+ copyBase64Image (imgSource , targetPath );
428+ } else if (imgSource .startsWith ("http://" ) || imgSource .startsWith ("https://" )) {
429+ // 处理URL格式
430+ copyUrlImage (imgSource , targetPath );
431+ } else {
432+ // 处理本地文件路径
433+ copyLocalImage (imgSource , targetPath );
434+ }
435+ }
436+
391437 /**
392438 * 从base64数据复制图片
393439 */
@@ -520,7 +566,8 @@ public static void generate(String outputDir, Set<TaskType> tasks) throws IOExce
520566 writeToFile (cocoDataset , outputJsonPath );
521567
522568 // 复制图片文件到指定目录 outputDir/images/
523- copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" );
569+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" , null , null );
570+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" , cocoDataset .getCategories (), cocoDataset .getAnnotations ());
524571
525572 System .out .println ("Successfully generated dataset at: " + outputDir );
526573 }
@@ -647,7 +694,10 @@ public static void generate(List<JSONObject> data, Set<TaskType> tasks, String o
647694 writeToFile (cocoDataset , Paths .get (outputDir , taskName + ".json" ).toString ());
648695
649696 // 复制图片文件到指定目录 outputDir/images/ train/val
650- copyImagesToDirectory (cocoDataset .getImages (), outputDir + taskName + "/" );
697+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + taskName + File .separator , cocoDataset .getCategories (), cocoDataset .getAnnotations ());
698+ //if (Log.DEBUG || tasks.contains(TaskType.CLASSIFICATION)) {
699+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "images" + File .separator + taskName + File .separator , null , null );
700+ //}
651701
652702 System .out .println ("Successfully generated dataset from JSONObject data at: " + outputDir );
653703 }
0 commit comments