@@ -124,6 +124,88 @@ def __init__(
124
124
validate_py_version (py_version )
125
125
validate_framework_version (framework_version )
126
126
127
+ def register (
128
+ self ,
129
+ content_types ,
130
+ response_types ,
131
+ inference_instances ,
132
+ transform_instances ,
133
+ model_package_name = None ,
134
+ model_package_group_name = None ,
135
+ image_uri = None ,
136
+ model_metrics = None ,
137
+ marketplace_cert = False ,
138
+ approval_status = None ,
139
+ description = None ,
140
+ drift_check_baselines = None ,
141
+ customer_metadata_properties = None ,
142
+ domain = None ,
143
+ sample_payload_url = None ,
144
+ task = None ,
145
+ framework = None ,
146
+ framework_version = None ,
147
+ nearest_model_name = None ,
148
+ data_input_configuration = None ,
149
+ ):
150
+ """Creates a model package for creating SageMaker models or listing on Marketplace.
151
+
152
+ Args:
153
+ content_types (list): The supported MIME types for the input data.
154
+ response_types (list): The supported MIME types for the output data.
155
+ inference_instances (list): A list of the instance types that are used to
156
+ generate inferences in real-time.
157
+ transform_instances (list): A list of the instance types on which a transformation
158
+ job can be run or on which an endpoint can be deployed.
159
+ model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
160
+ using `model_package_name` makes the Model Package un-versioned (default: None).
161
+ model_package_group_name (str): Model Package Group name, exclusive to
162
+ `model_package_name`, using `model_package_group_name` makes the Model Package
163
+ versioned (default: None).
164
+ image_uri (str): Inference image uri for the container. Model class' self.image will
165
+ be used if it is None (default: None).
166
+ model_metrics (ModelMetrics): ModelMetrics object (default: None).
167
+ marketplace_cert (bool): A boolean value indicating if the Model Package is certified
168
+ for AWS Marketplace (default: False).
169
+ approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
170
+ or "PendingManualApproval" (default: "PendingManualApproval").
171
+ description (str): Model Package description (default: None).
172
+
173
+ Returns:
174
+ str: A string of SageMaker Model Package ARN.
175
+ """
176
+ instance_type = inference_instances [0 ]
177
+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
178
+
179
+ if image_uri :
180
+ self .image_uri = image_uri
181
+ if not self .image_uri :
182
+ self .image_uri = self .serving_image_uri (
183
+ region_name = self .sagemaker_session .boto_session .region_name ,
184
+ instance_type = instance_type ,
185
+ )
186
+ return super (XGBoostModel , self ).register (
187
+ content_types ,
188
+ response_types ,
189
+ inference_instances ,
190
+ transform_instances ,
191
+ model_package_name ,
192
+ model_package_group_name ,
193
+ image_uri ,
194
+ model_metrics ,
195
+ marketplace_cert ,
196
+ approval_status ,
197
+ description ,
198
+ drift_check_baselines = drift_check_baselines ,
199
+ customer_metadata_properties = customer_metadata_properties ,
200
+ domain = domain ,
201
+ sample_payload_url = sample_payload_url ,
202
+ task = task ,
203
+ framework = framework or self ._framework_name ,
204
+ framework_version = framework_version or self .framework_version ,
205
+ nearest_model_name = nearest_model_name ,
206
+ data_input_configuration = data_input_configuration ,
207
+ )
208
+
127
209
def prepare_container_def (
128
210
self , instance_type = None , accelerator_type = None , serverless_inference_config = None
129
211
):
0 commit comments