@@ -362,59 +362,35 @@ def test_compile_with_tensorflow_neo_in_ml_inf(session):
362
362
)
363
363
364
364
365
- def test_compile_validates_framework_version (sagemaker_session ):
366
- sagemaker_session .wait_for_compilation_job = Mock (
367
- return_value = {
368
- "CompilationJobStatus" : "Completed" ,
369
- "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
370
- "InferenceImage" : None ,
371
- }
372
- )
365
+ @pytest .mark .parametrize (
366
+ "target,framework,fx_version,expected_fx_version" ,
367
+ [
368
+ ("ml_c4" , "pytorch" , "1.6" , "1.6" ),
369
+ ("rasp3b" , "pytorch" , "1.6.1" , "1.6" ),
370
+ ("amba_cv2" , "pytorch" , "1.6.1" , None ),
371
+ ("ml_c4" , "tensorflow" , "1.15.1" , "1.15" ),
372
+ ("ml_c4" , "tensorflow" , "2.15.1" , "2.15" ),
373
+ ("ml_inf1" , "tensorflow" , "2.15.1" , "2.15" ),
374
+ ("ml_inf2" , "pytorch" , "2.0" , "2.0" ),
375
+ ("ml_inf2" , "pytorch" , "2.0.1" , "2.0" ),
376
+ ("ml_trn1" , "pytorch" , "2.0.1" , "2.0" ),
377
+ ("ml_trn1" , "tensorflow" , "2.0.1" , "2.0" ),
378
+ ],
379
+ )
380
+ def test_compile_validates_framework_version (
381
+ sagemaker_session , target , framework , fx_version , expected_fx_version
382
+ ):
373
383
model = _create_model (sagemaker_session )
374
- model .compile (
375
- target_instance_family = "ml_c4" ,
376
- input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
377
- output_path = "s3://output" ,
378
- role = "role" ,
379
- framework = "pytorch" ,
380
- framework_version = "1.6.1" ,
381
- job_name = "compile-model" ,
382
- )
383
-
384
- assert model .image_uri is None
385
-
386
- sagemaker_session .wait_for_compilation_job = Mock (
387
- return_value = {
388
- "CompilationJobStatus" : "Completed" ,
389
- "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
390
- "InferenceImage" : None ,
391
- }
392
- )
393
-
394
- config = model ._compilation_job_config (
395
- "rasp3b" ,
396
- {"data" : [1 , 3 , 1024 , 1024 ]},
397
- "s3://output" ,
398
- "role" ,
399
- 900 ,
400
- "compile-model" ,
401
- "pytorch" ,
402
- None ,
403
- framework_version = "1.6.1" ,
404
- )
405
-
406
- assert config ["input_model_config" ]["FrameworkVersion" ] == "1.6"
407
-
408
384
config = model ._compilation_job_config (
409
- "amba_cv2" ,
385
+ target ,
410
386
{"data" : [1 , 3 , 1024 , 1024 ]},
411
387
"s3://output" ,
412
388
"role" ,
413
389
900 ,
414
390
"compile-model" ,
415
- "pytorch" ,
391
+ framework ,
416
392
None ,
417
- framework_version = "1.6.1" ,
393
+ framework_version = fx_version ,
418
394
)
419
395
420
- assert config ["input_model_config" ].get ("FrameworkVersion" , None ) is None
396
+ assert config ["input_model_config" ].get ("FrameworkVersion" , None ) == expected_fx_version
0 commit comments