Skip to content

Commit 9b39f6a

Browse files
committed
fix: override register method in xgboost & chainer model class
1 parent c2cea54 commit 9b39f6a

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

src/sagemaker/chainer/model.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,88 @@ def __init__(
140140

141141
self.model_server_workers = model_server_workers
142142

143+
def register(
144+
self,
145+
content_types,
146+
response_types,
147+
inference_instances,
148+
transform_instances,
149+
model_package_name=None,
150+
model_package_group_name=None,
151+
image_uri=None,
152+
model_metrics=None,
153+
marketplace_cert=False,
154+
approval_status=None,
155+
description=None,
156+
drift_check_baselines=None,
157+
customer_metadata_properties=None,
158+
domain=None,
159+
sample_payload_url=None,
160+
task=None,
161+
framework=None,
162+
framework_version=None,
163+
nearest_model_name=None,
164+
data_input_configuration=None,
165+
):
166+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
167+
168+
Args:
169+
content_types (list): The supported MIME types for the input data.
170+
response_types (list): The supported MIME types for the output data.
171+
inference_instances (list): A list of the instance types that are used to
172+
generate inferences in real-time.
173+
transform_instances (list): A list of the instance types on which a transformation
174+
job can be run or on which an endpoint can be deployed.
175+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
176+
using `model_package_name` makes the Model Package un-versioned (default: None).
177+
model_package_group_name (str): Model Package Group name, exclusive to
178+
`model_package_name`, using `model_package_group_name` makes the Model Package
179+
versioned (default: None).
180+
image_uri (str): Inference image uri for the container. Model class' self.image will
181+
be used if it is None (default: None).
182+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
183+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
184+
for AWS Marketplace (default: False).
185+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
186+
or "PendingManualApproval" (default: "PendingManualApproval").
187+
description (str): Model Package description (default: None).
188+
189+
Returns:
190+
str: A string of SageMaker Model Package ARN.
191+
"""
192+
instance_type = inference_instances[0]
193+
self._init_sagemaker_session_if_does_not_exist(instance_type)
194+
195+
if image_uri:
196+
self.image_uri = image_uri
197+
if not self.image_uri:
198+
self.image_uri = self.serving_image_uri(
199+
region_name=self.sagemaker_session.boto_session.region_name,
200+
instance_type=instance_type,
201+
)
202+
return super(ChainerModel, self).register(
203+
content_types,
204+
response_types,
205+
inference_instances,
206+
transform_instances,
207+
model_package_name,
208+
model_package_group_name,
209+
image_uri,
210+
model_metrics,
211+
marketplace_cert,
212+
approval_status,
213+
description,
214+
drift_check_baselines=drift_check_baselines,
215+
customer_metadata_properties=customer_metadata_properties,
216+
domain=domain,
217+
sample_payload_url=sample_payload_url,
218+
task=task,
219+
framework=framework or self._framework_name,
220+
framework_version=framework_version or self.framework_version,
221+
nearest_model_name=nearest_model_name,
222+
data_input_configuration=data_input_configuration,
223+
)
224+
143225
def prepare_container_def(
144226
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
145227
):

src/sagemaker/xgboost/model.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,88 @@ def __init__(
124124
validate_py_version(py_version)
125125
validate_framework_version(framework_version)
126126

127+
def register(
128+
self,
129+
content_types,
130+
response_types,
131+
inference_instances,
132+
transform_instances,
133+
model_package_name=None,
134+
model_package_group_name=None,
135+
image_uri=None,
136+
model_metrics=None,
137+
marketplace_cert=False,
138+
approval_status=None,
139+
description=None,
140+
drift_check_baselines=None,
141+
customer_metadata_properties=None,
142+
domain=None,
143+
sample_payload_url=None,
144+
task=None,
145+
framework=None,
146+
framework_version=None,
147+
nearest_model_name=None,
148+
data_input_configuration=None,
149+
):
150+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
151+
152+
Args:
153+
content_types (list): The supported MIME types for the input data.
154+
response_types (list): The supported MIME types for the output data.
155+
inference_instances (list): A list of the instance types that are used to
156+
generate inferences in real-time.
157+
transform_instances (list): A list of the instance types on which a transformation
158+
job can be run or on which an endpoint can be deployed.
159+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
160+
using `model_package_name` makes the Model Package un-versioned (default: None).
161+
model_package_group_name (str): Model Package Group name, exclusive to
162+
`model_package_name`, using `model_package_group_name` makes the Model Package
163+
versioned (default: None).
164+
image_uri (str): Inference image uri for the container. Model class' self.image will
165+
be used if it is None (default: None).
166+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
167+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
168+
for AWS Marketplace (default: False).
169+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
170+
or "PendingManualApproval" (default: "PendingManualApproval").
171+
description (str): Model Package description (default: None).
172+
173+
Returns:
174+
str: A string of SageMaker Model Package ARN.
175+
"""
176+
instance_type = inference_instances[0]
177+
self._init_sagemaker_session_if_does_not_exist(instance_type)
178+
179+
if image_uri:
180+
self.image_uri = image_uri
181+
if not self.image_uri:
182+
self.image_uri = self.serving_image_uri(
183+
region_name=self.sagemaker_session.boto_session.region_name,
184+
instance_type=instance_type,
185+
)
186+
return super(XGBoostModel, self).register(
187+
content_types,
188+
response_types,
189+
inference_instances,
190+
transform_instances,
191+
model_package_name,
192+
model_package_group_name,
193+
image_uri,
194+
model_metrics,
195+
marketplace_cert,
196+
approval_status,
197+
description,
198+
drift_check_baselines=drift_check_baselines,
199+
customer_metadata_properties=customer_metadata_properties,
200+
domain=domain,
201+
sample_payload_url=sample_payload_url,
202+
task=task,
203+
framework=framework or self._framework_name,
204+
framework_version=framework_version or self.framework_version,
205+
nearest_model_name=nearest_model_name,
206+
data_input_configuration=data_input_configuration,
207+
)
208+
127209
def prepare_container_def(
128210
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
129211
):

0 commit comments

Comments
 (0)