Skip to content

Commit 95bf7fa

Browse files
change: add tflite to Neo-supported frameworks (#1364)
1 parent 4505229 commit 95bf7fa

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/sagemaker/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
LOGGER = logging.getLogger("sagemaker")
2626

27-
NEO_ALLOWED_FRAMEWORKS = set(["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost"])
27+
NEO_ALLOWED_FRAMEWORKS = set(
28+
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
29+
)
2830

2931
NEO_IMAGE_ACCOUNT = {
3032
"us-west-1": "710691900526",

tests/unit/test_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,22 @@ def test_compile_model_for_edge_device(sagemaker_session, tmpdir):
560560
assert model._is_compiled_model is False
561561

562562

563+
def test_compile_model_for_edge_device_tflite(sagemaker_session, tmpdir):
564+
sagemaker_session.wait_for_compilation_job = Mock(
565+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
566+
)
567+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
568+
model.compile(
569+
target_instance_family="deeplens",
570+
input_shape={"data": [1, 3, 1024, 1024]},
571+
output_path="s3://output",
572+
role="role",
573+
framework="tflite",
574+
job_name="tflite-compile-model",
575+
)
576+
assert model._is_compiled_model is False
577+
578+
563579
def test_compile_model_for_cloud(sagemaker_session, tmpdir):
564580
sagemaker_session.wait_for_compilation_job = Mock(
565581
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
@@ -576,6 +592,22 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
576592
assert model._is_compiled_model is True
577593

578594

595+
def test_compile_model_for_cloud_tflite(sagemaker_session, tmpdir):
596+
sagemaker_session.wait_for_compilation_job = Mock(
597+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
598+
)
599+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
600+
model.compile(
601+
target_instance_family="ml_c4",
602+
input_shape={"data": [1, 3, 1024, 1024]},
603+
output_path="s3://output",
604+
role="role",
605+
framework="tflite",
606+
job_name="tflite-compile-model",
607+
)
608+
assert model._is_compiled_model is True
609+
610+
579611
@patch("sagemaker.session.Session")
580612
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
581613
def test_compile_creates_session(session):

0 commit comments

Comments
 (0)