Skip to content

Commit db08c55

Browse files
authored
GluonCV YoloV3 Darknet53 example training and inference with Neo (#1266)
1 parent 25b0fad commit db08c55

File tree

9 files changed

+2123
-0
lines changed

9 files changed

+2123
-0
lines changed

sagemaker_neo_compilation_jobs/gluoncv_yolo_darknet/gluoncv_yolo_darknet_neo.ipynb

Lines changed: 607 additions & 0 deletions
Large diffs are not rendered by default.
Loading
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from imdb import Imdb
19+
import random
20+
21+
class ConcatDB(Imdb):
22+
"""
23+
ConcatDB is used to concatenate multiple imdbs to form a larger db.
24+
It is very useful to combine multiple dataset with same classes.
25+
Parameters
26+
----------
27+
imdbs : Imdb or list of Imdb
28+
Imdbs to be concatenated
29+
shuffle : bool
30+
whether to shuffle the initial list
31+
"""
32+
def __init__(self, imdbs, shuffle):
33+
super(ConcatDB, self).__init__('concatdb')
34+
if not isinstance(imdbs, list):
35+
imdbs = [imdbs]
36+
self.imdbs = imdbs
37+
self._check_classes()
38+
self.image_set_index = self._load_image_set_index(shuffle)
39+
40+
def _check_classes(self):
41+
"""
42+
check input imdbs, make sure they have same classes
43+
"""
44+
try:
45+
self.classes = self.imdbs[0].classes
46+
self.num_classes = len(self.classes)
47+
except AttributeError:
48+
# fine, if no classes is provided
49+
pass
50+
51+
if self.num_classes > 0:
52+
for db in self.imdbs:
53+
assert self.classes == db.classes, "Multiple imdb must have same classes"
54+
55+
def _load_image_set_index(self, shuffle):
56+
"""
57+
get total number of images, init indices
58+
59+
Parameters
60+
----------
61+
shuffle : bool
62+
whether to shuffle the initial indices
63+
"""
64+
self.num_images = 0
65+
for db in self.imdbs:
66+
self.num_images += db.num_images
67+
indices = list(range(self.num_images))
68+
if shuffle:
69+
random.shuffle(indices)
70+
return indices
71+
72+
def _locate_index(self, index):
73+
"""
74+
given index, find out sub-db and sub-index
75+
76+
Parameters
77+
----------
78+
index : int
79+
index of a specific image
80+
81+
Returns
82+
----------
83+
a tuple (sub-db, sub-index)
84+
"""
85+
assert index >= 0 and index < self.num_images, "index out of range"
86+
pos = self.image_set_index[index]
87+
for k, v in enumerate(self.imdbs):
88+
if pos >= v.num_images:
89+
pos -= v.num_images
90+
else:
91+
return (k, pos)
92+
93+
def image_path_from_index(self, index):
94+
"""
95+
given image index, find out full path
96+
97+
Parameters
98+
----------
99+
index: int
100+
index of a specific image
101+
102+
Returns
103+
----------
104+
full path of this image
105+
"""
106+
assert self.image_set_index is not None, "Dataset not initialized"
107+
pos = self.image_set_index[index]
108+
n_db, n_index = self._locate_index(index)
109+
return self.imdbs[n_db].image_path_from_index(n_index)
110+
111+
def label_from_index(self, index):
112+
"""
113+
given image index, return preprocessed ground-truth
114+
115+
Parameters
116+
----------
117+
index: int
118+
index of a specific image
119+
120+
Returns
121+
----------
122+
ground-truths of this image
123+
"""
124+
assert self.image_set_index is not None, "Dataset not initialized"
125+
pos = self.image_set_index[index]
126+
n_db, n_index = self._locate_index(index)
127+
return self.imdbs[n_db].label_from_index(n_index)

0 commit comments

Comments
 (0)