@@ -560,6 +560,22 @@ def test_compile_model_for_edge_device(sagemaker_session, tmpdir):
560
560
assert model ._is_compiled_model is False
561
561
562
562
563
+ def test_compile_model_for_edge_device_tflite (sagemaker_session , tmpdir ):
564
+ sagemaker_session .wait_for_compilation_job = Mock (
565
+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
566
+ )
567
+ model = DummyFrameworkModel (sagemaker_session , source_dir = str (tmpdir ))
568
+ model .compile (
569
+ target_instance_family = "deeplens" ,
570
+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
571
+ output_path = "s3://output" ,
572
+ role = "role" ,
573
+ framework = "tflite" ,
574
+ job_name = "tflite-compile-model" ,
575
+ )
576
+ assert model ._is_compiled_model is False
577
+
578
+
563
579
def test_compile_model_for_cloud (sagemaker_session , tmpdir ):
564
580
sagemaker_session .wait_for_compilation_job = Mock (
565
581
return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
@@ -576,6 +592,22 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
576
592
assert model ._is_compiled_model is True
577
593
578
594
595
+ def test_compile_model_for_cloud_tflite (sagemaker_session , tmpdir ):
596
+ sagemaker_session .wait_for_compilation_job = Mock (
597
+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
598
+ )
599
+ model = DummyFrameworkModel (sagemaker_session , source_dir = str (tmpdir ))
600
+ model .compile (
601
+ target_instance_family = "ml_c4" ,
602
+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
603
+ output_path = "s3://output" ,
604
+ role = "role" ,
605
+ framework = "tflite" ,
606
+ job_name = "tflite-compile-model" ,
607
+ )
608
+ assert model ._is_compiled_model is True
609
+
610
+
579
611
@patch ("sagemaker.session.Session" )
580
612
@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
581
613
def test_compile_creates_session (session ):
0 commit comments