Skip to content

Commit 16b3f0e

Browse files
committed
fix: set sagemaker_connection and image_uri in register method
1 parent e0b6353 commit 16b3f0e

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

src/sagemaker/sklearn/model.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,74 @@ def __init__(
137137
)
138138

139139
self.model_server_workers = model_server_workers
140+
141+
def register(
142+
self,
143+
content_types,
144+
response_types,
145+
inference_instances,
146+
transform_instances,
147+
model_package_name=None,
148+
model_package_group_name=None,
149+
image_uri=None,
150+
model_metrics=None,
151+
metadata_properties=None,
152+
marketplace_cert=False,
153+
approval_status=None,
154+
description=None,
155+
):
156+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
157+
158+
Args:
159+
content_types (list): The supported MIME types for the input data.
160+
response_types (list): The supported MIME types for the output data.
161+
inference_instances (list): A list of the instance types that are used to
162+
generate inferences in real-time.
163+
transform_instances (list): A list of the instance types on which a transformation
164+
job can be run or on which an endpoint can be deployed.
165+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
166+
using `model_package_name` makes the Model Package un-versioned (default: None).
167+
model_package_group_name (str): Model Package Group name, exclusive to
168+
`model_package_name`, using `model_package_group_name` makes the Model Package
169+
versioned (default: None).
170+
image_uri (str): Inference image uri for the container. Model class' self.image will
171+
be used if it is None (default: None).
172+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
173+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
174+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
175+
for AWS Marketplace (default: False).
176+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
177+
or "PendingManualApproval" (default: "PendingManualApproval").
178+
description (str): Model Package description (default: None).
179+
180+
Returns:
181+
A `sagemaker.model.ModelPackage` instance.
182+
"""
183+
instance_type = inference_instances[0]
184+
self._init_sagemaker_session_if_does_not_exist(instance_type)
185+
186+
if image_uri:
187+
self.image_uri = image_uri
188+
if not self.image_uri:
189+
self.image_uri = self.serving_image_uri(
190+
region_name=self.sagemaker_session.boto_session.region_name,
191+
instance_type=instance_type,
192+
)
193+
return super(SKLearnModel, self).register(
194+
content_types,
195+
response_types,
196+
inference_instances,
197+
transform_instances,
198+
model_package_name,
199+
model_package_group_name,
200+
image_uri,
201+
model_metrics,
202+
metadata_properties,
203+
marketplace_cert,
204+
approval_status,
205+
description,
206+
)
207+
140208

141209
def prepare_container_def(self, instance_type=None, accelerator_type=None):
142210
"""Container definition with framework configuration set in model environment variables.

0 commit comments

Comments
 (0)