27
27
from sagemaker .jumpstart import artifacts
28
28
from sagemaker .workflow import is_pipeline_variable
29
29
from sagemaker .workflow .utilities import override_pipeline_parameter_var
30
- from sagemaker .fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY , GRAVITON_ALLOWED_FRAMEWORKS
30
+ from sagemaker .fw_utils import (
31
+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY ,
32
+ GRAVITON_ALLOWED_FRAMEWORKS ,
33
+ )
31
34
32
35
logger = logging .getLogger (__name__ )
33
36
@@ -164,13 +167,20 @@ def retrieve(
164
167
)
165
168
else :
166
169
_framework = framework
167
- if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS :
170
+ if (
171
+ framework == HUGGING_FACE_FRAMEWORK
172
+ or framework in TRAINIUM_ALLOWED_FRAMEWORKS
173
+ ):
168
174
inference_tool = _get_inference_tool (inference_tool , instance_type )
169
175
if inference_tool in ["neuron" , "neuronx" ]:
170
176
_framework = f"{ framework } -{ inference_tool } "
171
- final_image_scope = _get_final_image_scope (framework , instance_type , image_scope )
177
+ final_image_scope = _get_final_image_scope (
178
+ framework , instance_type , image_scope
179
+ )
172
180
_validate_for_suppported_frameworks_and_instance_type (framework , instance_type )
173
- config = _config_for_framework_and_scope (_framework , final_image_scope , accelerator_type )
181
+ config = _config_for_framework_and_scope (
182
+ _framework , final_image_scope , accelerator_type
183
+ )
174
184
175
185
original_version = version
176
186
version = _validate_version_and_set_if_needed (version , config , framework )
@@ -181,10 +191,14 @@ def retrieve(
181
191
full_base_framework_version = version_config ["version_aliases" ].get (
182
192
base_framework_version , base_framework_version
183
193
)
184
- _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
194
+ _validate_arg (
195
+ full_base_framework_version , list (version_config .keys ()), "base framework"
196
+ )
185
197
version_config = version_config .get (full_base_framework_version )
186
198
187
- py_version = _validate_py_version_and_set_if_needed (py_version , version_config , framework )
199
+ py_version = _validate_py_version_and_set_if_needed (
200
+ py_version , version_config , framework
201
+ )
188
202
version_config = version_config .get (py_version ) or version_config
189
203
registry = _registry_from_region (region , version_config ["registries" ])
190
204
endpoint_data = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )
@@ -212,7 +226,9 @@ def retrieve(
212
226
213
227
if framework == HUGGING_FACE_FRAMEWORK :
214
228
pt_or_tf_version = (
215
- re .compile ("^(pytorch|tensorflow)(.*)$" ).match (base_framework_version ).group (2 )
229
+ re .compile ("^(pytorch|tensorflow)(.*)$" )
230
+ .match (base_framework_version )
231
+ .group (2 )
216
232
)
217
233
_version = original_version
218
234
@@ -236,11 +252,13 @@ def retrieve(
236
252
.get ("version_aliases" , {})
237
253
.get (base_framework_version , {})
238
254
):
239
- _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
240
- base_framework_version
241
- ]
255
+ _base_framework_version = config .get ("versions" )[_version ][
256
+ "version_aliases"
257
+ ][ base_framework_version ]
242
258
pt_or_tf_version = (
243
- re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
259
+ re .compile ("^(pytorch|tensorflow)(.*)$" )
260
+ .match (_base_framework_version )
261
+ .group (2 )
244
262
)
245
263
246
264
tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -267,7 +285,9 @@ def retrieve(
267
285
if tag :
268
286
repo += ":{}" .format (tag )
269
287
270
- return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
288
+ return ECR_URI_TEMPLATE .format (
289
+ registry = registry , hostname = hostname , repository = repo
290
+ )
271
291
272
292
273
293
def _get_image_tag (
@@ -306,9 +326,13 @@ def _get_image_tag(
306
326
}
307
327
tag = version_to_arm64_tag_mapping [framework ][version ]
308
328
else :
309
- tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
329
+ tag = _format_tag (
330
+ tag_prefix , processor , py_version , container_version , inference_tool
331
+ )
310
332
else :
311
- tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
333
+ tag = _format_tag (
334
+ tag_prefix , processor , py_version , container_version , inference_tool
335
+ )
312
336
313
337
if instance_type is not None and _should_auto_select_container_version (
314
338
instance_type , distribution
@@ -343,7 +367,8 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
343
367
344
368
if image_scope not in ("eia" , "inference" ):
345
369
logger .warning (
346
- "Elastic inference is for inference only. Ignoring image scope: %s." , image_scope
370
+ "Elastic inference is for inference only. Ignoring image scope: %s." ,
371
+ image_scope ,
347
372
)
348
373
image_scope = "eia"
349
374
@@ -358,7 +383,11 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
358
383
)
359
384
image_scope = available_scopes [0 ]
360
385
361
- if not image_scope and "scope" in config and set (available_scopes ) == {"training" , "inference" }:
386
+ if (
387
+ not image_scope
388
+ and "scope" in config
389
+ and set (available_scopes ) == {"training" , "inference" }
390
+ ):
362
391
logger .info (
363
392
"Same images used for training and inference. Defaulting to image scope: %s." ,
364
393
available_scopes [0 ],
@@ -390,20 +419,27 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
390
419
and "trn" in instance_type
391
420
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
392
421
):
393
- _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
422
+ _validate_framework (
423
+ framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium"
424
+ )
394
425
395
426
# Validate for Graviton allowed frameowrks
396
427
if (
397
428
instance_type is not None
398
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
429
+ and utils .get_instance_type_family (instance_type )
430
+ in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
399
431
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
400
432
):
401
- _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
433
+ _validate_framework (
434
+ framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton"
435
+ )
402
436
403
437
404
438
def config_for_framework (framework ):
405
439
"""Loads the JSON config for the given framework."""
406
- fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
440
+ fname = os .path .join (
441
+ os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework )
442
+ )
407
443
with open (fname ) as f :
408
444
return json .load (f )
409
445
@@ -412,7 +448,8 @@ def _get_final_image_scope(framework, instance_type, image_scope):
412
448
"""Return final image scope based on provided framework and instance type."""
413
449
if (
414
450
framework in GRAVITON_ALLOWED_FRAMEWORKS
415
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
451
+ and utils .get_instance_type_family (instance_type )
452
+ in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
416
453
):
417
454
return INFERENCE_GRAVITON
418
455
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -428,7 +465,9 @@ def _get_inference_tool(inference_tool, instance_type):
428
465
"""Extract the inference tool name from instance type."""
429
466
if not inference_tool :
430
467
instance_type_family = utils .get_instance_type_family (instance_type )
431
- if instance_type_family .startswith ("inf" ) or instance_type_family .startswith ("trn" ):
468
+ if instance_type_family .startswith ("inf" ) or instance_type_family .startswith (
469
+ "trn"
470
+ ):
432
471
return "neuron"
433
472
return inference_tool
434
473
@@ -440,10 +479,15 @@ def _get_latest_versions(list_of_versions):
440
479
441
480
def _validate_accelerator_type (accelerator_type ):
442
481
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
443
- if not accelerator_type .startswith ("ml.eia" ) and accelerator_type != "local_sagemaker_notebook" :
482
+ if (
483
+ not accelerator_type .startswith ("ml.eia" )
484
+ and accelerator_type != "local_sagemaker_notebook"
485
+ ):
444
486
raise ValueError (
445
487
"Invalid SageMaker Elastic Inference accelerator type: {}. "
446
- "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (accelerator_type )
488
+ "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (
489
+ accelerator_type
490
+ )
447
491
)
448
492
449
493
@@ -453,11 +497,15 @@ def _validate_version_and_set_if_needed(version, config, framework):
453
497
aliased_versions = list (config .get ("version_aliases" , {}).keys ())
454
498
455
499
if len (available_versions ) == 1 and version not in aliased_versions :
456
- log_message = "Defaulting to the only supported framework/algorithm version: {}." .format (
457
- available_versions [0 ]
500
+ log_message = (
501
+ "Defaulting to the only supported framework/algorithm version: {}." .format (
502
+ available_versions [0 ]
503
+ )
458
504
)
459
505
if version and version != available_versions [0 ]:
460
- logger .warning ("%s Ignoring framework/algorithm version: %s." , log_message , version )
506
+ logger .warning (
507
+ "%s Ignoring framework/algorithm version: %s." , log_message , version
508
+ )
461
509
elif not version :
462
510
logger .info (log_message )
463
511
@@ -470,7 +518,9 @@ def _validate_version_and_set_if_needed(version, config, framework):
470
518
]:
471
519
version = _get_latest_versions (available_versions )
472
520
473
- _validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
521
+ _validate_arg (
522
+ version , available_versions + aliased_versions , "{} version" .format (framework )
523
+ )
474
524
return version
475
525
476
526
@@ -496,7 +546,9 @@ def _processor(instance_type, available_processors, serverless_inference_config=
496
546
return None
497
547
498
548
if len (available_processors ) == 1 and not instance_type :
499
- logger .info ("Defaulting to only supported image scope: %s." , available_processors [0 ])
549
+ logger .info (
550
+ "Defaulting to only supported image scope: %s." , available_processors [0 ]
551
+ )
500
552
return available_processors [0 ]
501
553
502
554
if serverless_inference_config is not None :
@@ -533,7 +585,9 @@ def _processor(instance_type, available_processors, serverless_inference_config=
533
585
else :
534
586
raise ValueError (
535
587
"Invalid SageMaker instance type: {}. For options, see: "
536
- "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (instance_type )
588
+ "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (
589
+ instance_type
590
+ )
537
591
)
538
592
539
593
_validate_arg (processor , available_processors , "processor" )
@@ -572,7 +626,9 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
572
626
return None
573
627
574
628
if py_version is None and len (available_versions ) == 1 :
575
- logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
629
+ logger .info (
630
+ "Defaulting to only available Python version: %s" , available_versions [0 ]
631
+ )
576
632
return available_versions [0 ]
577
633
578
634
_validate_arg (py_version , available_versions , "Python version" )
@@ -585,7 +641,9 @@ def _validate_arg(arg, available_options, arg_name):
585
641
raise ValueError (
586
642
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
587
643
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
588
- "{options}." .format (arg_name = arg_name , arg = arg , options = ", " .join (available_options ))
644
+ "{options}." .format (
645
+ arg_name = arg_name , arg = arg , options = ", " .join (available_options )
646
+ )
589
647
)
590
648
591
649
@@ -598,11 +656,17 @@ def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
598
656
)
599
657
600
658
601
- def _format_tag (tag_prefix , processor , py_version , container_version , inference_tool = None ):
659
+ def _format_tag (
660
+ tag_prefix , processor , py_version , container_version , inference_tool = None
661
+ ):
602
662
"""Creates a tag for the image URI."""
603
663
if inference_tool :
604
- return "-" .join (x for x in (tag_prefix , inference_tool , py_version , container_version ) if x )
605
- return "-" .join (x for x in (tag_prefix , processor , py_version , container_version ) if x )
664
+ return "-" .join (
665
+ x for x in (tag_prefix , inference_tool , py_version , container_version ) if x
666
+ )
667
+ return "-" .join (
668
+ x for x in (tag_prefix , processor , py_version , container_version ) if x
669
+ )
606
670
607
671
608
672
@override_pipeline_parameter_var
@@ -670,7 +734,7 @@ def get_training_image_uri(
670
734
container_version = "cu121"
671
735
else :
672
736
container_version = "cu118"
673
-
737
+
674
738
return retrieve (
675
739
framework ,
676
740
region ,
@@ -711,4 +775,6 @@ def get_base_python_image_uri(region, py_version="310") -> str:
711
775
repo = version_config ["repository" ] + "-" + py_version
712
776
repo_and_tag = repo + ":" + version
713
777
714
- return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
778
+ return ECR_URI_TEMPLATE .format (
779
+ registry = registry , hostname = hostname , repository = repo_and_tag
780
+ )
0 commit comments