Skip to content

Commit 93a61ee

Browse files
Pull request pytorch#71: Modify MLPerf Tiny models to accept datasets with labels
Merge in AITEC/executorch from feature/nxg11066/EIEX-184-Add-labels-to-all-MLPerf-Tiny-calibration-data to main-nxp * commit '1ffc1f7c47740633cdf7c7a6aff65b5cc46d7d79': Modify MLPerf Tiny models to accept datasets with labels
2 parents a484fb2 + 1ffc1f7 commit 93a61ee

File tree

6 files changed

+16
-50
lines changed

6 files changed

+16
-50
lines changed

examples/nxp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ We support several example models. To run these examples, we will use `aot_neutr
4848

4949
Run the script with following arguments: `--model_name` - specify model from a list of available models, `--so_library`
5050
- specify path to `libquantized_ops_aot_lib.so` library from step 1. Quantization and delegation are controlled by
51-
`--quantize` and `--quantize` flags, turned off by default.
51+
`--quantize` and `--delegate` flags, turned off by default.
5252

5353
Supported models:
5454
- cifar10

examples/nxp/models/mlperf_tiny/anomaly_detection/anomaly_detection.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import os
7-
from typing import Iterator
87

98
import torch
109

@@ -27,14 +26,5 @@ def _dts_path(self):
2726
def _input_shape(self):
2827
return self.__input_shape
2928

30-
@staticmethod
31-
def _collate_fn(data: torch.Tensor, **kwargs):
32-
return (torch.stack(list(data)),)
33-
34-
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
35-
self._batch_size = batch_size
36-
data_loader = self._get_data_loader()
37-
return iter(data_loader)
38-
3929
def get_eager_model(self) -> torch.nn.Module:
4030
return self._model_manager.get_model("anomaly_detection", input_dimension=self._input_dim)

examples/nxp/models/mlperf_tiny/image_classification/image_classification.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import os
7-
from typing import Iterator
87

98
import torch
109

@@ -26,14 +25,5 @@ def _dts_path(self):
2625
def _input_shape(self):
2726
return self.__input_shape
2827

29-
@staticmethod
30-
def _collate_fn(data: torch.Tensor, **kwargs):
31-
return (torch.stack(list(data)),)
32-
33-
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
34-
self._batch_size = batch_size
35-
data_loader = self._get_data_loader()
36-
return iter(data_loader)
37-
3828
def get_eager_model(self) -> torch.nn.Module:
3929
return self._model_manager.get_model("image_classification")

examples/nxp/models/mlperf_tiny/keyword_spotting/keyword_spotting.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import itertools
76
import os
8-
from typing import Iterator
97

108
import torch
119

@@ -27,17 +25,6 @@ def _dts_path(self):
2725
def _input_shape(self):
2826
return self.__input_shape
2927

30-
@staticmethod
31-
def _collate_fn(data: torch.Tensor, **kwargs):
32-
data, labels = zip(*data)
33-
return torch.stack(list(data)).to(memory_format=torch.channels_last), torch.tensor(list(labels))
34-
35-
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
36-
self._batch_size = batch_size
37-
data_loader = self._get_data_loader()
38-
get_first = lambda a, b: (a,)
39-
return itertools.starmap(get_first, iter(data_loader))
40-
4128
def get_eager_model(self) -> torch.nn.Module:
4229
example_input_shape = self._input_shape[1:]
4330
return self._model_manager.get_model("keyword_spotting", input_shape=example_input_shape, num_classes=12)

examples/nxp/models/mlperf_tiny/mlperf_tiny_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import itertools
67
import os
78
from abc import abstractmethod
9+
from typing import Iterator
810

911
import torch
1012
from torch.utils.data import DataLoader
@@ -33,9 +35,16 @@ def _dts_path(self):
3335
def _input_shape(self):
3436
pass
3537

36-
@abstractmethod
37-
def _collate_fn(self, data: torch.Tensor):
38-
pass
38+
@staticmethod
39+
def _collate_fn(data: list[tuple]):
40+
data, labels = zip(*data)
41+
return torch.stack(list(data)), torch.tensor(list(labels))
42+
43+
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
44+
self._batch_size = batch_size
45+
data_loader = self._get_data_loader()
46+
get_first = lambda a, b: (a,)
47+
return itertools.starmap(get_first, iter(data_loader))
3948

4049
@abstractmethod
4150
def get_eager_model(self) -> torch.nn.Module:
@@ -53,4 +62,7 @@ def _get_data_loader(self):
5362

5463
def _init_dataset(self):
5564
if self._dataset is None:
65+
os.makedirs(os.path.dirname(self._dts_path), exist_ok=True)
66+
if not os.path.exists(self._dts_path):
67+
raise FileNotFoundError("Calibration data file not found! For more info, follow README.md.")
5668
self._dataset = CalibrationDataset(self._dts_path)

examples/nxp/models/mlperf_tiny/visual_wake_words/visual_wake_words.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import itertools
76
import os
8-
from typing import Iterator
97

108
import torch
119

@@ -27,16 +25,5 @@ def _dts_path(self):
2725
def _input_shape(self):
2826
return self.__input_shape
2927

30-
@staticmethod
31-
def _collate_fn(data: torch.Tensor, **kwargs):
32-
images, labels = zip(*data)
33-
return torch.stack(list(images)).to(memory_format=torch.channels_last), torch.tensor(list(labels))
34-
35-
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
36-
self._batch_size = batch_size
37-
data_loader = self._get_data_loader()
38-
get_first = lambda a, b: (a,)
39-
return itertools.starmap(get_first, iter(data_loader))
40-
4128
def get_eager_model(self) -> torch.nn.Module:
4229
return self._model_manager.get_model("visual_wake_words")

0 commit comments

Comments
 (0)