Skip to content

Commit 086bf92

Browse files
committed
to_json to inherit from JumpStartDataHolderType
1 parent 4799ccb commit 086bf92

File tree

2 files changed

+21
-58
lines changed

2 files changed

+21
-58
lines changed

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -153,30 +153,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
153153
self.hub_content_version: str = json_obj["hub_content_version"]
154154
self.hub_name: str = json_obj["hub_name"]
155155

156-
def to_json(self) -> Dict[str, Any]:
157-
"""Returns json representation of DescribeHubContentsResponse object."""
158-
json_obj = {}
159-
for att in self.__slots__:
160-
if hasattr(self, att):
161-
cur_val = getattr(self, att)
162-
if issubclass(type(cur_val), JumpStartDataHolderType):
163-
json_obj[att] = cur_val.to_json()
164-
elif isinstance(cur_val, list):
165-
json_obj[att] = []
166-
for obj in cur_val:
167-
if issubclass(type(obj), JumpStartDataHolderType):
168-
json_obj[att].append(obj.to_json())
169-
else:
170-
json_obj[att].append(obj)
171-
else:
172-
json_obj[att] = cur_val
173-
return json_obj
174-
175156

176157
class HubS3StorageConfig(JumpStartDataHolderType):
177158
"""Data class for any dependencies related to hub content.
178159
179-
Includes scripts, model artifacts, datasets, or notebooks."""
160+
Includes scripts, model artifacts, datasets, or notebooks.
161+
"""
180162

181163
__slots__ = ["s3_output_path"]
182164

@@ -245,22 +227,3 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
245227
self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(
246228
json_obj["s3_storage_config"]
247229
)
248-
249-
def to_json(self) -> Dict[str, Any]:
250-
"""Returns json representation of DescribeHubContentsResponse object."""
251-
json_obj = {}
252-
for att in self.__slots__:
253-
if hasattr(self, att):
254-
cur_val = getattr(self, att)
255-
if issubclass(type(cur_val), JumpStartDataHolderType):
256-
json_obj[att] = cur_val.to_json()
257-
elif isinstance(cur_val, list):
258-
json_obj[att] = []
259-
for obj in cur_val:
260-
if issubclass(type(obj), JumpStartDataHolderType):
261-
json_obj[att].append(obj.to_json())
262-
else:
263-
json_obj[att].append(obj)
264-
else:
265-
json_obj[att] = cur_val
266-
return json_obj

src/sagemaker/jumpstart/types.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,25 @@ def __repr__(self) -> str:
9999
}
100100
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
101101

102+
def to_json(self) -> Dict[str, Any]:
103+
"""Returns json representation of object."""
104+
json_obj = {}
105+
for att in self.__slots__:
106+
if hasattr(self, att):
107+
cur_val = getattr(self, att)
108+
if issubclass(type(cur_val), JumpStartDataHolderType):
109+
json_obj[att] = cur_val.to_json()
110+
elif isinstance(cur_val, list):
111+
json_obj[att] = []
112+
for obj in cur_val:
113+
if issubclass(type(obj), JumpStartDataHolderType):
114+
json_obj[att].append(obj.to_json())
115+
else:
116+
json_obj[att].append(obj)
117+
else:
118+
json_obj[att] = cur_val
119+
return json_obj
120+
102121

103122
class JumpStartS3FileType(str, Enum):
104123
"""Type of files published in JumpStart S3 distribution buckets."""
@@ -911,25 +930,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
911930
"""
912931
# TODO: Implement
913932

914-
def to_json(self) -> Dict[str, Any]:
915-
"""Returns json representation of JumpStartModelSpecs object."""
916-
json_obj = {}
917-
for att in self.__slots__:
918-
if hasattr(self, att):
919-
cur_val = getattr(self, att)
920-
if issubclass(type(cur_val), JumpStartDataHolderType):
921-
json_obj[att] = cur_val.to_json()
922-
elif isinstance(cur_val, list):
923-
json_obj[att] = []
924-
for obj in cur_val:
925-
if issubclass(type(obj), JumpStartDataHolderType):
926-
json_obj[att].append(obj.to_json())
927-
else:
928-
json_obj[att].append(obj)
929-
else:
930-
json_obj[att] = cur_val
931-
return json_obj
932-
933933
def supports_prepacked_inference(self) -> bool:
934934
"""Returns True if the model has a prepacked inference artifact."""
935935
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

0 commit comments

Comments
 (0)