19
19
import os
20
20
import re
21
21
import copy
22
+ from typing import List , Dict
22
23
23
24
import sagemaker
24
25
from sagemaker import (
38
39
from sagemaker .async_inference import AsyncInferenceConfig
39
40
from sagemaker .predictor_async import AsyncPredictor
40
41
from sagemaker .workflow import is_pipeline_variable
42
+ from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
41
43
42
44
LOGGER = logging .getLogger ("sagemaker" )
43
45
@@ -289,6 +291,7 @@ def __init__(
289
291
self .uploaded_code = None
290
292
self .repacked_model_data = None
291
293
294
+ @runnable_by_pipeline
292
295
def register (
293
296
self ,
294
297
content_types ,
@@ -309,12 +312,12 @@ def register(
309
312
"""Creates a model package for creating SageMaker models or listing on Marketplace.
310
313
311
314
Args:
312
- content_types (list): The supported MIME types for the input data (default: None) .
313
- response_types (list): The supported MIME types for the output data (default: None) .
315
+ content_types (list): The supported MIME types for the input data.
316
+ response_types (list): The supported MIME types for the output data.
314
317
inference_instances (list): A list of the instance types that are used to
315
- generate inferences in real-time (default: None) .
318
+ generate inferences in real-time.
316
319
transform_instances (list): A list of the instance types on which a transformation
317
- job can be run or on which an endpoint can be deployed (default: None) .
320
+ job can be run or on which an endpoint can be deployed.
318
321
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
319
322
using `model_package_name` makes the Model Package un-versioned (default: None).
320
323
model_package_group_name (str): Model Package Group name, exclusive to
@@ -364,12 +367,49 @@ def register(
364
367
model_package = self .sagemaker_session .create_model_package_from_containers (
365
368
** model_pkg_args
366
369
)
370
+ if isinstance (self .sagemaker_session , PipelineSession ):
371
+ return None
367
372
return ModelPackage (
368
373
role = self .role ,
369
374
model_data = self .model_data ,
370
375
model_package_arn = model_package .get ("ModelPackageArn" ),
371
376
)
372
377
378
+ @runnable_by_pipeline
379
+ def create (
380
+ self ,
381
+ instance_type : str = None ,
382
+ accelerator_type : str = None ,
383
+ serverless_inference_config : ServerlessInferenceConfig = None ,
384
+ tags : List [Dict [str , str ]] = None ,
385
+ ):
386
+ """Create a SageMaker Model Entity
387
+
388
+ Args:
389
+ instance_type (str): The EC2 instance type that this Model will be
390
+ used for, this is only used to determine if the image needs GPU
391
+ support or not (default: None).
392
+ accelerator_type (str): Type of Elastic Inference accelerator to
393
+ attach to an endpoint for model loading and inference, for
394
+ example, 'ml.eia1.medium'. If not specified, no Elastic
395
+ Inference accelerator will be attached to the endpoint (default: None).
396
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
397
+ Specifies configuration related to serverless endpoint. Instance type is
398
+ not provided in serverless inference. So this is used to find image URIs
399
+ (default: None).
400
+ tags (List[Dict[str, str]]): The list of tags to add to
401
+ the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
402
+ 'tagvalue'}] For more information about tags, see
403
+ https://boto3.amazonaws.com/v1/documentation
404
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
405
+ """
406
+ self ._create_sagemaker_model (
407
+ instance_type = instance_type ,
408
+ accelerator_type = accelerator_type ,
409
+ tags = tags ,
410
+ serverless_inference_config = serverless_inference_config ,
411
+ )
412
+
373
413
def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
374
414
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
375
415
@@ -453,6 +493,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
453
493
if repack and self .model_data is not None and self .entry_point is not None :
454
494
if is_pipeline_variable (self .model_data ):
455
495
# model is not yet there, defer repacking to later during pipeline execution
496
+ if not isinstance (self .sagemaker_session , PipelineSession ):
497
+ # TODO: link the doc in the warning once ready
498
+ logging .warning (
499
+ "The model_data is a Pipeline variable of type %s, "
500
+ "which should be used under `PipelineSession` and "
501
+ "leverage `ModelStep` to create or register model. "
502
+ "Otherwise some functionalities e.g. "
503
+ "runtime repack may be missing" ,
504
+ type (self .model_data ),
505
+ )
506
+ return
507
+ self .sagemaker_session .context .need_runtime_repack .add (id (self ))
508
+ # Add the uploaded_code and repacked_model_data to update the container env
509
+ self .repacked_model_data = self .model_data
510
+ self .uploaded_code = fw_utils .UploadedCode (
511
+ s3_prefix = self .repacked_model_data ,
512
+ script_name = os .path .basename (self .entry_point ),
513
+ )
456
514
return
457
515
458
516
bucket = self .bucket or self .sagemaker_session .default_bucket ()
@@ -534,22 +592,29 @@ def _create_sagemaker_model(
534
592
serverless_inference_config = serverless_inference_config ,
535
593
)
536
594
537
- self ._ensure_base_name_if_needed (
538
- image_uri = container_def ["Image" ], script_uri = self .source_dir , model_uri = self .model_data
539
- )
540
- self ._set_model_name_if_needed ()
595
+ if not isinstance (self .sagemaker_session , PipelineSession ):
596
+ # _base_name, model_name are not needed under PipelineSession.
597
+ # the model_data may be Pipeline variable
598
+ # which may break the _base_name generation
599
+ self ._ensure_base_name_if_needed (
600
+ image_uri = container_def ["Image" ],
601
+ script_uri = self .source_dir ,
602
+ model_uri = self .model_data ,
603
+ )
604
+ self ._set_model_name_if_needed ()
541
605
542
606
enable_network_isolation = self .enable_network_isolation ()
543
607
544
608
self ._init_sagemaker_session_if_does_not_exist (instance_type )
545
- self . sagemaker_session . create_model (
546
- self .name ,
547
- self .role ,
548
- container_def ,
609
+ create_model_args = dict (
610
+ name = self .name ,
611
+ role = self .role ,
612
+ container_defs = container_def ,
549
613
vpc_config = self .vpc_config ,
550
614
enable_network_isolation = enable_network_isolation ,
551
615
tags = tags ,
552
616
)
617
+ self .sagemaker_session .create_model (** create_model_args )
553
618
554
619
def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
555
620
"""Create a base name from the image URI if there is no model name provided.
0 commit comments