@@ -98,9 +98,24 @@ public static void createCocoDirectoryStructure(String baseDir, String type) thr
98
98
Files .createDirectories (Paths .get (baseDir + "annotations" ));
99
99
Files .createDirectories (Paths .get (baseDir + "images" ));
100
100
101
+ Files .createDirectories (Paths .get (baseDir + "train" ));
102
+ Files .createDirectories (Paths .get (baseDir + "val" ));
103
+ Files .createDirectories (Paths .get (baseDir + "test" ));
104
+
105
+ Files .createDirectories (Paths .get (baseDir , "annotations" , "train" ));
106
+ Files .createDirectories (Paths .get (baseDir , "annotations" , "val" ));
107
+ Files .createDirectories (Paths .get (baseDir , "annotations" , "test" ));
108
+
109
+ Files .createDirectories (Paths .get (baseDir , "images" , "train" ));
110
+ Files .createDirectories (Paths .get (baseDir , "images" , "val" ));
111
+ Files .createDirectories (Paths .get (baseDir , "images" , "test" ));
112
+
101
113
// 根据类型创建特定目录 detection, classification, segmentation, keypoints, face_keypoints 使用标准结构
102
114
if (TaskType .OCR .getType ().equals (type ) || TaskType .ROTATED_DETECTION .getType ().equals (type )) {
103
- Files .createDirectories (Paths .get (baseDir + "labels" ));
115
+ Files .createDirectories (Paths .get (baseDir , "labels" ));
116
+ Files .createDirectories (Paths .get (baseDir , "labels" , "train" ));
117
+ Files .createDirectories (Paths .get (baseDir , "labels" , "val" ));
118
+ Files .createDirectories (Paths .get (baseDir , "labels" , "test" ));
104
119
}
105
120
}
106
121
@@ -526,14 +541,32 @@ public static void generate(List<JSONObject> data, Set<TaskType> tasks, String o
526
541
527
542
// --- 1. 初始化构建器和通用信息 ---
528
543
DatasetBuilder builder = new DatasetBuilder ()
529
- .withInfo ("Dataset from JSONObject" , "1.0" , "2025" )
530
- .withCategory (1 , "person" , "person" )
531
- .withCategory (2 , "car" , "vehicle" )
532
- .withCategory (3 , "dog" , "animal" )
533
- .withKeypointCategory (1 , "person" , "person" ,
534
- Arrays .asList ("nose" , "left_eye" , "right_eye" ),
535
- Arrays .asList (Arrays .asList (1 , 2 ), Arrays .asList (1 , 3 ))
536
- );
544
+ .withInfo ("Dataset from JSONObject" , "1.0" , "2025" );
545
+
546
+ // 从数据中提取实际的categories
547
+ JSONArray extractedCategories = extractCategoriesFromApiJson (data );
548
+
549
+ // 动态添加categories到builder
550
+ if (extractedCategories != null ) {
551
+ boolean hasPose = tasks .contains (TaskType .POSE_KEYPOINTS );
552
+
553
+ for (int i = 0 ; i < extractedCategories .size (); i ++) {
554
+ JSONObject categoryObj = extractedCategories .getJSONObject (i );
555
+ int id = categoryObj .getIntValue ("id" );
556
+ String name = categoryObj .getString ("name" );
557
+ String supercategory = categoryObj .getString ("supercategory" );
558
+
559
+ // 如果是关键点任务且包含person类别,添加关键点信息
560
+ if (hasPose && ("person" .equals (name ) || name .contains ("人" ))) {
561
+ builder .withKeypointCategory (id , name , supercategory ,
562
+ Arrays .asList ("nose" , "left_eye" , "right_eye" ),
563
+ Arrays .asList (Arrays .asList (1 , 2 ), Arrays .asList (1 , 3 ))
564
+ );
565
+ } else {
566
+ builder .withCategory (id , name , supercategory );
567
+ }
568
+ }
569
+ }
537
570
538
571
// 用于跟踪图片ID映射
539
572
Map <String , Integer > imgNameIdMap = new HashMap <>();
0 commit comments