|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 |
| -"""Classes to modify serializer and deserializer code to be compatible with |
14 |
| -version 2.0 and later of the SageMaker Python SDK. |
15 |
| -""" |
| 13 | +"""Classes to modify SerDe code to be compatibile with version 2.0 and later.""" |
16 | 14 | from __future__ import absolute_import
|
17 | 15 |
|
18 | 16 | import ast
|
19 | 17 |
|
20 |
| -import pasta |
21 |
| - |
22 | 18 | from sagemaker.cli.compatibility.v2.modifiers import matching
|
23 | 19 | from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
|
24 | 20 |
|
25 | 21 | OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer", "record_deserializer"}
|
26 | 22 | NEW_AMAZON_CLASS_NAMES = {"RecordSerializer", "RecordDeserializer"}
|
| 23 | +OLD_PREDICTOR_CLASS_NAMES = { |
| 24 | + "_CsvSerializer", |
| 25 | + "_JsonSerializer", |
| 26 | + "_NpySerializer", |
| 27 | + "_CsvDeserializer", |
| 28 | + "BytesDeserializer", |
| 29 | + "StringDeserializer", |
| 30 | + "StreamDeserializer", |
| 31 | + "_NumpyDeserializer", |
| 32 | + "_JsonDeserializer", |
| 33 | +} |
27 | 34 |
|
28 | 35 | # The values are tuples so that the object can be passed to matching.matches_any.
|
29 | 36 | OLD_CLASS_NAME_TO_NAMESPACES = {
|
30 |
| - "_CsvSerializer": ("sagemaker.predictor",), |
31 |
| - "_JsonSerializer": ("sagemaker.predictor",), |
32 |
| - "_NpySerializer": ("sagemaker.predictor",), |
33 |
| - "_CsvDeserializer": ("sagemaker.predictor",), |
34 |
| - "BytesDeserializer": ("sagemaker.predictor",), |
35 |
| - "StringDeserializer": ("sagemaker.predictor",), |
36 |
| - "StreamDeserializer": ("sagemaker.predictor",), |
37 |
| - "_NumpyDeserializer": ("sagemaker.predictor",), |
38 |
| - "_JsonDeserializer": ("sagemaker.predictor",), |
| 37 | + class_name: ("sagemaker.predictor",) for class_name in OLD_PREDICTOR_CLASS_NAMES |
39 | 38 | }
|
40 | 39 | OLD_CLASS_NAME_TO_NAMESPACES.update(
|
41 | 40 | {class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES}
|
@@ -205,7 +204,7 @@ def modify_node(self, node):
|
205 | 204 | object_name = node.id if isinstance(node, ast.Name) else node.attr
|
206 | 205 | new_class_name = OLD_OBJECT_NAME_TO_NEW_CLASS_NAME[object_name]
|
207 | 206 | namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0]
|
208 |
| - subpackage_name = namespace_name[namespace_name.find(".") + 1 :] |
| 207 | + subpackage_name = namespace_name.split(".")[1] |
209 | 208 | return ast.Call(
|
210 | 209 | func=ast.Attribute(value=ast.Name(id=subpackage_name), attr=new_class_name),
|
211 | 210 | args=[],
|
@@ -375,7 +374,9 @@ def __init__(self):
|
375 | 374 | for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES
|
376 | 375 | if "Serializer" in class_name
|
377 | 376 | }
|
378 |
| - import_node = pasta.parse("from sagemaker import serializers\n").body[0] |
| 377 | + import_node = ast.ImportFrom( |
| 378 | + module="sagemaker", names=[ast.alias(name="serializers", asname=None)], level=0 |
| 379 | + ) |
379 | 380 | super().__init__(class_names, import_node)
|
380 | 381 |
|
381 | 382 |
|
@@ -403,5 +404,7 @@ def __init__(self):
|
403 | 404 | for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES
|
404 | 405 | if "Deserializer" in class_name
|
405 | 406 | }
|
406 |
| - import_node = pasta.parse("from sagemaker import deserializers\n").body[0] |
| 407 | + import_node = ast.ImportFrom( |
| 408 | + module="sagemaker", names=[ast.alias(name="deserializers", asname=None)], level=0 |
| 409 | + ) |
407 | 410 | super().__init__(class_names, import_node)
|
0 commit comments