12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
- import pasta
16
15
import pytest
17
16
18
17
from sagemaker .cli .compatibility .v2 .modifiers import framework_version
@@ -36,8 +35,10 @@ def __init__(
36
35
self .py_version = py_version
37
36
self .py_version_for_model = py_version_for_model
38
37
39
- def constructors (self , versions = False , image = False ):
40
- return self ._frameworks (versions , image ) + self ._models (versions , image )
38
+ def constructors (self , fw_version = False , py_version = False , image = False ):
39
+ return self ._frameworks (fw_version , py_version , image ) + self ._models (
40
+ fw_version , py_version , image
41
+ )
41
42
42
43
def _templates (self , model = False ):
43
44
module = self .framework .lower ()
@@ -54,30 +55,38 @@ def _templates(self, model=False):
54
55
for template in templates
55
56
)
56
57
57
- def _frameworks (self , versions = False , image = False ):
58
- keywords = dict ()
59
- if image :
60
- keywords ["image_uri" ] = "my:image"
61
- if versions :
62
- keywords ["framework_version" ] = self .framework_version
63
- keywords ["py_version" ] = self .py_version
58
+ def _frameworks (self , fw_version = False , py_version = False , image = False ):
59
+ keywords = self ._base_keywords (fw_version , image )
60
+ if py_version :
61
+ keywords ["py_version" ] = (
62
+ "py_version" if py_version == "named" else "'{}'" .format (self .py_version )
63
+ )
64
64
return _format_templates (keywords , self ._templates ())
65
65
66
- def _models (self , versions = False , image = False ):
66
+ def _models (self , fw_version = False , py_version = False , image = False ):
67
+ keywords = self ._base_keywords (fw_version , image )
68
+ if py_version and self .py_version_for_model :
69
+ keywords ["py_version" ] = (
70
+ "py_version" if py_version == "named" else "'{}'" .format (self .py_version )
71
+ )
72
+ return _format_templates (keywords , self ._templates (model = True ))
73
+
74
+ def _base_keywords (self , fw_version = False , image = False ):
67
75
keywords = dict ()
68
76
if image :
69
- keywords ["image_uri" ] = "my:image"
70
- if versions :
71
- keywords ["framework_version" ] = self . framework_version
72
- if self .py_version_for_model :
73
- keywords [ "py_version" ] = self . py_version
74
- return _format_templates ( keywords , self . _templates ( model = True ))
77
+ keywords ["image_uri" ] = "' my:image' "
78
+ if fw_version :
79
+ keywords ["framework_version" ] = (
80
+ "fw_version" if fw_version == "named" else "'{}'" . format ( self .framework_version )
81
+ )
82
+ return keywords
75
83
76
84
77
85
def _format_templates (keywords , templates ):
78
86
args = ", " .join (
79
- "{key}=' {value}' " .format (key = key , value = value ) for key , value in keywords .items ()
87
+ "{key}={value}" .format (key = key , value = value ) for key , value in keywords .items ()
80
88
)
89
+
81
90
return [template .format (args ) for template in templates ]
82
91
83
92
@@ -100,8 +109,12 @@ def _format_templates(keywords, templates):
100
109
]
101
110
102
111
103
- def constructors (versions = False , image = False ):
104
- return [ctr for template in TEMPLATES for ctr in template .constructors (versions , image )]
112
+ def constructors (fw_version = False , py_version = False , image = False ):
113
+ return [
114
+ ctr
115
+ for template in TEMPLATES
116
+ for ctr in template .constructors (fw_version , py_version , image )
117
+ ]
105
118
106
119
107
120
@pytest .fixture
@@ -110,18 +123,34 @@ def constructors_empty():
110
123
111
124
112
125
@pytest .fixture
113
- def constructors_with_versions ():
114
- return constructors (versions = True )
126
+ def constructors_with_only_fw_version_that_need_py_version ():
127
+ ctrs = []
128
+ for template in TEMPLATES :
129
+ if template .py_version_for_model :
130
+ ctrs .extend (template .constructors (fw_version = True ))
131
+ else :
132
+ ctrs .extend (template ._frameworks (fw_version = True ))
133
+ return ctrs
115
134
116
135
117
136
@pytest .fixture
118
- def constructors_with_image ():
119
- return constructors (image = True )
137
+ def constructors_with_only_fw_version ():
138
+ return constructors (fw_version = True )
139
+
140
+
141
+ @pytest .fixture
142
+ def constructors_with_only_py_version ():
143
+ return constructors (py_version = True )
120
144
121
145
122
146
@pytest .fixture
123
- def constructors_with_both ():
124
- return constructors (versions = True , image = True )
147
+ def constructors_with_both_versions ():
148
+ return constructors (fw_version = True , py_version = True )
149
+
150
+
151
+ @pytest .fixture
152
+ def constructors_with_image ():
153
+ return constructors (image = True )
125
154
126
155
127
156
def _test_node_should_be_modified (ctrs , should_modify = True ):
@@ -138,8 +167,20 @@ def test_node_should_be_modified_empty(constructors_empty):
138
167
_test_node_should_be_modified (constructors_empty , should_modify = True )
139
168
140
169
141
- def test_node_should_be_modified_with_versions (constructors_with_versions ):
142
- _test_node_should_be_modified (constructors_with_versions , should_modify = False )
170
+ def test_node_should_be_modified_with_only_fw_versions (
171
+ constructors_with_only_fw_version_that_need_py_version ,
172
+ ):
173
+ _test_node_should_be_modified (
174
+ constructors_with_only_fw_version_that_need_py_version , should_modify = True
175
+ )
176
+
177
+
178
+ def test_node_should_be_modified_with_only_py_versions (constructors_with_only_py_version ):
179
+ _test_node_should_be_modified (constructors_with_only_py_version , should_modify = True )
180
+
181
+
182
+ def test_node_should_be_modified_with_versions (constructors_with_both_versions ):
183
+ _test_node_should_be_modified (constructors_with_both_versions , should_modify = False )
143
184
144
185
145
186
def test_node_should_be_modified_with_image (constructors_with_image ):
@@ -155,17 +196,40 @@ def _test_modify_node(ctrs_before, ctrs_expected):
155
196
for before , expected in zip (ctrs_before , ctrs_expected ):
156
197
node = ast_call (before )
157
198
modifier .modify_node (node )
158
- # NOTE: this type of equality with pasta depends on ordering of args...
159
- assert expected == pasta .dump (node )
199
+ _assert_equal_kwargs (ast_call (expected ), node )
200
+
201
+
202
+ def _assert_equal_kwargs (expected , actual ):
203
+ assert _keywords_for_node (expected ) == _keywords_for_node (actual )
160
204
161
205
162
- def test_modify_node_empty ( constructors_empty , constructors_with_versions ):
163
- _test_modify_node ( constructors_empty , constructors_with_versions )
206
+ def _keywords_for_node ( node ):
207
+ return { kw . arg : getattr ( kw . value , kw . value . _fields [ 0 ]) for kw in node . keywords }
164
208
165
209
166
- def test_modify_node_with_versions ( constructors_with_versions ):
167
- _test_modify_node (constructors_with_versions , constructors_with_versions )
210
+ def test_modify_node_empty ( constructors_empty , constructors_with_both_versions ):
211
+ _test_modify_node (constructors_empty , constructors_with_both_versions )
168
212
169
213
170
- def test_modify_node_with_image (constructors_with_image , constructors_with_both ):
171
- _test_modify_node (constructors_with_image , constructors_with_both )
214
+ def test_modify_node_only_fw_version (
215
+ constructors_with_only_fw_version , constructors_with_both_versions
216
+ ):
217
+ _test_modify_node (constructors_with_only_fw_version , constructors_with_both_versions )
218
+
219
+
220
+ def test_modify_node_only_py_version (
221
+ constructors_with_only_py_version , constructors_with_both_versions
222
+ ):
223
+ _test_modify_node (constructors_with_only_py_version , constructors_with_both_versions )
224
+
225
+
226
+ def test_modify_node_only_named_fw_version ():
227
+ _test_modify_node (
228
+ constructors (fw_version = "named" ), constructors (fw_version = "named" , py_version = "literal" )
229
+ )
230
+
231
+
232
+ def test_modify_node_only_named_py_version ():
233
+ _test_modify_node (
234
+ constructors (py_version = "named" ), constructors (fw_version = "literal" , py_version = "named" )
235
+ )
0 commit comments