19
19
import os
20
20
import re
21
21
import copy
22
+ from typing import List , Dict
22
23
23
24
import sagemaker
24
25
from sagemaker import (
38
39
from sagemaker .async_inference import AsyncInferenceConfig
39
40
from sagemaker .predictor_async import AsyncPredictor
40
41
from sagemaker .workflow import is_pipeline_variable
42
+ from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
41
43
42
44
LOGGER = logging .getLogger ("sagemaker" )
43
45
@@ -289,6 +291,7 @@ def __init__(
289
291
self .uploaded_code = None
290
292
self .repacked_model_data = None
291
293
294
+ @runnable_by_pipeline
292
295
def register (
293
296
self ,
294
297
content_types ,
@@ -336,6 +339,7 @@ def register(
336
339
Returns:
337
340
A `sagemaker.model.ModelPackage` instance.
338
341
"""
342
+
339
343
if self .model_data is None :
340
344
raise ValueError ("SageMaker Model Package cannot be created without model data." )
341
345
if image_uri is not None :
@@ -361,15 +365,52 @@ def register(
361
365
drift_check_baselines = drift_check_baselines ,
362
366
customer_metadata_properties = customer_metadata_properties ,
363
367
)
368
+
364
369
model_package = self .sagemaker_session .create_model_package_from_containers (
365
370
** model_pkg_args
366
371
)
372
+ if isinstance (self .sagemaker_session , PipelineSession ):
373
+ return None
367
374
return ModelPackage (
368
375
role = self .role ,
369
376
model_data = self .model_data ,
370
377
model_package_arn = model_package .get ("ModelPackageArn" ),
371
378
)
372
379
380
+ @runnable_by_pipeline
381
+ def create (
382
+ self ,
383
+ instance_type : str = None ,
384
+ accelerator_type : str = None ,
385
+ serverless_inference_config : ServerlessInferenceConfig = None ,
386
+ tags : List [Dict [str , str ]] = None ,
387
+ ):
388
+ """Create a Model Entity ???
389
+
390
+ Args:
391
+ instance_type (str): The EC2 instance type that this Model will be
392
+ used for, this is only used to determine if the image needs GPU
393
+ support or not.
394
+ accelerator_type (str): Type of Elastic Inference accelerator to
395
+ attach to an endpoint for model loading and inference, for
396
+ example, 'ml.eia1.medium'. If not specified, no Elastic
397
+ Inference accelerator will be attached to the endpoint.
398
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399
+ Specifies configuration related to serverless endpoint. Instance type is
400
+ not provided in serverless inference. So this is used to find image URIs.
401
+ tags (List[Dict[str, str]]): Optional. The list of tags to add to
402
+ the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
403
+ 'tagvalue'}] For more information about tags, see
404
+ https://boto3.amazonaws.com/v1/documentation
405
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
406
+ """
407
+ self ._create_sagemaker_model (
408
+ instance_type = instance_type ,
409
+ accelerator_type = accelerator_type ,
410
+ tags = tags ,
411
+ serverless_inference_config = serverless_inference_config ,
412
+ )
413
+
373
414
def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
374
415
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
375
416
@@ -452,6 +493,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
452
493
if repack and self .model_data is not None and self .entry_point is not None :
453
494
if is_pipeline_variable (self .model_data ):
454
495
# model is not yet there, defer repacking to later during pipeline execution
496
+ if isinstance (self .sagemaker_session , PipelineSession ):
497
+ self .sagemaker_session .mark_runtime_repack (id (self ))
455
498
return
456
499
457
500
bucket = self .bucket or self .sagemaker_session .default_bucket ()
@@ -466,7 +509,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
466
509
sagemaker_session = self .sagemaker_session ,
467
510
kms_key = self .model_kms_key ,
468
511
)
469
-
470
512
self .repacked_model_data = repacked_model_data
471
513
self .uploaded_code = fw_utils .UploadedCode (
472
514
s3_prefix = self .repacked_model_data , script_name = os .path .basename (self .entry_point )
@@ -533,23 +575,32 @@ def _create_sagemaker_model(
533
575
serverless_inference_config = serverless_inference_config ,
534
576
)
535
577
536
- self ._ensure_base_name_if_needed (
537
- image_uri = container_def ["Image" ], script_uri = self .source_dir , model_uri = self .model_data
538
- )
539
- self ._set_model_name_if_needed ()
578
+ if not isinstance (self .sagemaker_session , PipelineSession ):
579
+ # _base_name, model_name are not needed under PipelineSession.
580
+ # the model_data, source_dir may be Pipeline variable
581
+ # which may break the _base_name generation
582
+ self ._ensure_base_name_if_needed (
583
+ image_uri = container_def ["Image" ],
584
+ script_uri = self .source_dir ,
585
+ model_uri = self .model_data ,
586
+ )
587
+ self ._set_model_name_if_needed ()
540
588
541
589
enable_network_isolation = self .enable_network_isolation ()
542
590
543
591
self ._init_sagemaker_session_if_does_not_exist (instance_type )
544
- self .sagemaker_session .create_model (
545
- self .name ,
546
- self .role ,
547
- container_def ,
592
+
593
+ create_model_args = dict (
594
+ name = self .name ,
595
+ role = self .role ,
596
+ container_defs = container_def ,
548
597
vpc_config = self .vpc_config ,
549
598
enable_network_isolation = enable_network_isolation ,
550
599
tags = tags ,
551
600
)
552
601
602
+ self .sagemaker_session .create_model (** create_model_args )
603
+
553
604
def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
554
605
"""Create a base name from the image URI if there is no model name provided.
555
606
0 commit comments