@@ -62,19 +62,26 @@ def test_algo_uris(algo):
62
62
63
63
64
64
def _test_neo_framework_uris (framework , version ):
65
- framework = "neo-{}" .format (framework )
65
+ framework_in_config = f"neo-{ framework } "
66
+ framework_in_uri = f"neo-{ framework } " if framework == "tensorflow" else f"inference-{ framework } "
66
67
67
68
for region in regions .regions ():
68
69
if region in ACCOUNTS :
69
- uri = image_uris .retrieve (framework , region , instance_type = "ml_c5" , version = version )
70
- assert _expected_framework_uri (framework , version , region = region ) == uri
70
+ uri = image_uris .retrieve (
71
+ framework_in_config , region , instance_type = "ml_c5" , version = version
72
+ )
73
+ assert _expected_framework_uri (framework_in_uri , version , region = region ) == uri
71
74
else :
72
75
with pytest .raises (ValueError ) as e :
73
- image_uris .retrieve (framework , region , instance_type = "ml_c5" , version = version )
76
+ image_uris .retrieve (
77
+ framework_in_config , region , instance_type = "ml_c5" , version = version
78
+ )
74
79
assert "Unsupported region: {}." .format (region ) in str (e .value )
75
80
76
- uri = image_uris .retrieve (framework , "us-west-2" , instance_type = "ml_p2" , version = version )
77
- assert _expected_framework_uri (framework , version , processor = "gpu" ) == uri
81
+ uri = image_uris .retrieve (
82
+ framework_in_config , "us-west-2" , instance_type = "ml_p2" , version = version
83
+ )
84
+ assert _expected_framework_uri (framework_in_uri , version , processor = "gpu" ) == uri
78
85
79
86
80
87
def test_neo_mxnet (neo_mxnet_version ):
0 commit comments