Skip to content

Commit 95fc522

Browse files
committed
update remap_mscoco_category
1 parent 5f0f43c commit 95fc522

File tree

5 files changed

+14
-181
lines changed

5 files changed

+14
-181
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ This is the official implementation of the paper "[DETRs Beat YOLOs on Real-time
3636

3737
## Updates!!!
3838
---
39+
- \[2023.11.05\] upgrade the logic of `remap_mscoco_category` to facilitate training of custom datasets, see detils in [Train custom data](./rtdetr_pytorch/) part.
3940
- \[2023.10.23\] Add [*discussion for deployments*](https://github.com/lyuwenyu/RT-DETR/issues/95), supported onnxruntime, TensorRT, openVINO
4041
- \[2023.10.12\] Add tuning code for pytorch version, now you can tuning rtdetr based on pretrained weights
4142
- \[2023.09.19\] Upload [*pytorch weights*](https://github.com/lyuwenyu/RT-DETR/issues/42) convert from paddle version

rtdetr_pytorch/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ python tools/export_onnx.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -r path/t
7979
<details open>
8080
<summary>Train custom data</summary>
8181

82-
1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset.
82+
1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset. If you want to use `remap_mscoco_category` logic on your dataset, please modify variable [`mscoco_category2name`](https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/data/coco/coco_dataset.py) based on your dataset.
8383

8484
2. add `-t path/to/checkpoint` (optinal) to tuning rtdetr based on pretrained checkpoint. see [training script details](./tools/README.md).
8585
</details>
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
2-
from .coco_dataset import *
1+
from .coco_dataset import (
2+
CocoDetection,
3+
mscoco_category2label,
4+
mscoco_label2category,
5+
mscoco_category2name,
6+
)
37
from .coco_eval import *
48

59
from .coco_utils import get_coco_api_from_dataset

rtdetr_pytorch/src/data/coco/coco_dataset.py

Lines changed: 4 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __call__(self, image, target):
104104
boxes[:, 1::2].clamp_(min=0, max=h)
105105

106106
if self.remap_mscoco_category:
107-
classes = [category2label[obj["category_id"]] - 1 for obj in anno]
107+
classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
108108
else:
109109
classes = [obj["category_id"] for obj in anno]
110110

@@ -151,10 +151,7 @@ def __call__(self, image, target):
151151
return image, target
152152

153153

154-
155-
156-
names = {
157-
0: 'background',
154+
mscoco_category2name = {
158155
1: 'person',
159156
2: 'bicycle',
160157
3: 'car',
@@ -237,88 +234,5 @@ def __call__(self, image, target):
237234
90: 'toothbrush'
238235
}
239236

240-
241-
label2category = {
242-
1: 1,
243-
2: 2,
244-
3: 3,
245-
4: 4,
246-
5: 5,
247-
6: 6,
248-
7: 7,
249-
8: 8,
250-
9: 9,
251-
10: 10,
252-
11: 11,
253-
12: 13,
254-
13: 14,
255-
14: 15,
256-
15: 16,
257-
16: 17,
258-
17: 18,
259-
18: 19,
260-
19: 20,
261-
20: 21,
262-
21: 22,
263-
22: 23,
264-
23: 24,
265-
24: 25,
266-
25: 27,
267-
26: 28,
268-
27: 31,
269-
28: 32,
270-
29: 33,
271-
30: 34,
272-
31: 35,
273-
32: 36,
274-
33: 37,
275-
34: 38,
276-
35: 39,
277-
36: 40,
278-
37: 41,
279-
38: 42,
280-
39: 43,
281-
40: 44,
282-
41: 46,
283-
42: 47,
284-
43: 48,
285-
44: 49,
286-
45: 50,
287-
46: 51,
288-
47: 52,
289-
48: 53,
290-
49: 54,
291-
50: 55,
292-
51: 56,
293-
52: 57,
294-
53: 58,
295-
54: 59,
296-
55: 60,
297-
56: 61,
298-
57: 62,
299-
58: 63,
300-
59: 64,
301-
60: 65,
302-
61: 67,
303-
62: 70,
304-
63: 72,
305-
64: 73,
306-
65: 74,
307-
66: 75,
308-
67: 76,
309-
68: 77,
310-
69: 78,
311-
70: 79,
312-
71: 80,
313-
72: 81,
314-
73: 82,
315-
74: 84,
316-
75: 85,
317-
76: 86,
318-
77: 87,
319-
78: 88,
320-
79: 89,
321-
80: 90
322-
}
323-
324-
category2label = {v: k for k, v in label2category.items()}
237+
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
238+
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}

rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def forward(self, outputs, orig_target_sizes):
5858

5959
# TODO
6060
if self.remap_mscoco_category:
61-
labels = torch.tensor([self.mscoco_label_category_map[int(x.item()) + 1] for x in labels.flatten()])\
61+
from ...data.coco import mscoco_label2category
62+
labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
6263
.to(boxes.device).reshape(labels.shape)
6364

6465
results = []
@@ -77,90 +78,3 @@ def deploy(self, ):
7778
@property
7879
def iou_types(self, ):
7980
return ('bbox', )
80-
81-
82-
@property
83-
def mscoco_label_category_map(self, ):
84-
return {
85-
1: 1,
86-
2: 2,
87-
3: 3,
88-
4: 4,
89-
5: 5,
90-
6: 6,
91-
7: 7,
92-
8: 8,
93-
9: 9,
94-
10: 10,
95-
11: 11,
96-
12: 13,
97-
13: 14,
98-
14: 15,
99-
15: 16,
100-
16: 17,
101-
17: 18,
102-
18: 19,
103-
19: 20,
104-
20: 21,
105-
21: 22,
106-
22: 23,
107-
23: 24,
108-
24: 25,
109-
25: 27,
110-
26: 28,
111-
27: 31,
112-
28: 32,
113-
29: 33,
114-
30: 34,
115-
31: 35,
116-
32: 36,
117-
33: 37,
118-
34: 38,
119-
35: 39,
120-
36: 40,
121-
37: 41,
122-
38: 42,
123-
39: 43,
124-
40: 44,
125-
41: 46,
126-
42: 47,
127-
43: 48,
128-
44: 49,
129-
45: 50,
130-
46: 51,
131-
47: 52,
132-
48: 53,
133-
49: 54,
134-
50: 55,
135-
51: 56,
136-
52: 57,
137-
53: 58,
138-
54: 59,
139-
55: 60,
140-
56: 61,
141-
57: 62,
142-
58: 63,
143-
59: 64,
144-
60: 65,
145-
61: 67,
146-
62: 70,
147-
63: 72,
148-
64: 73,
149-
65: 74,
150-
66: 75,
151-
67: 76,
152-
68: 77,
153-
69: 78,
154-
70: 79,
155-
71: 80,
156-
72: 81,
157-
73: 82,
158-
74: 84,
159-
75: 85,
160-
76: 86,
161-
77: 87,
162-
78: 88,
163-
79: 89,
164-
80: 90
165-
}
166-

0 commit comments

Comments
 (0)