Skip to content

Commit 8bffacc

Browse files
author
huilgolr
committed
Fix black check
1 parent 3e73f71 commit 8bffacc

File tree

2 files changed

+108
-49
lines changed

2 files changed

+108
-49
lines changed

src/sagemaker/image_uris.py

Lines changed: 103 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
from sagemaker.jumpstart import artifacts
2828
from sagemaker.workflow import is_pipeline_variable
2929
from sagemaker.workflow.utilities import override_pipeline_parameter_var
30-
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
30+
from sagemaker.fw_utils import (
31+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
32+
GRAVITON_ALLOWED_FRAMEWORKS,
33+
)
3134

3235
logger = logging.getLogger(__name__)
3336

@@ -164,13 +167,20 @@ def retrieve(
164167
)
165168
else:
166169
_framework = framework
167-
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
170+
if (
171+
framework == HUGGING_FACE_FRAMEWORK
172+
or framework in TRAINIUM_ALLOWED_FRAMEWORKS
173+
):
168174
inference_tool = _get_inference_tool(inference_tool, instance_type)
169175
if inference_tool in ["neuron", "neuronx"]:
170176
_framework = f"{framework}-{inference_tool}"
171-
final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
177+
final_image_scope = _get_final_image_scope(
178+
framework, instance_type, image_scope
179+
)
172180
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
173-
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
181+
config = _config_for_framework_and_scope(
182+
_framework, final_image_scope, accelerator_type
183+
)
174184

175185
original_version = version
176186
version = _validate_version_and_set_if_needed(version, config, framework)
@@ -181,10 +191,14 @@ def retrieve(
181191
full_base_framework_version = version_config["version_aliases"].get(
182192
base_framework_version, base_framework_version
183193
)
184-
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
194+
_validate_arg(
195+
full_base_framework_version, list(version_config.keys()), "base framework"
196+
)
185197
version_config = version_config.get(full_base_framework_version)
186198

187-
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
199+
py_version = _validate_py_version_and_set_if_needed(
200+
py_version, version_config, framework
201+
)
188202
version_config = version_config.get(py_version) or version_config
189203
registry = _registry_from_region(region, version_config["registries"])
190204
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
@@ -212,7 +226,9 @@ def retrieve(
212226

213227
if framework == HUGGING_FACE_FRAMEWORK:
214228
pt_or_tf_version = (
215-
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
229+
re.compile("^(pytorch|tensorflow)(.*)$")
230+
.match(base_framework_version)
231+
.group(2)
216232
)
217233
_version = original_version
218234

@@ -236,11 +252,13 @@ def retrieve(
236252
.get("version_aliases", {})
237253
.get(base_framework_version, {})
238254
):
239-
_base_framework_version = config.get("versions")[_version]["version_aliases"][
240-
base_framework_version
241-
]
255+
_base_framework_version = config.get("versions")[_version][
256+
"version_aliases"
257+
][base_framework_version]
242258
pt_or_tf_version = (
243-
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
259+
re.compile("^(pytorch|tensorflow)(.*)$")
260+
.match(_base_framework_version)
261+
.group(2)
244262
)
245263

246264
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
@@ -267,7 +285,9 @@ def retrieve(
267285
if tag:
268286
repo += ":{}".format(tag)
269287

270-
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
288+
return ECR_URI_TEMPLATE.format(
289+
registry=registry, hostname=hostname, repository=repo
290+
)
271291

272292

273293
def _get_image_tag(
@@ -306,9 +326,13 @@ def _get_image_tag(
306326
}
307327
tag = version_to_arm64_tag_mapping[framework][version]
308328
else:
309-
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
329+
tag = _format_tag(
330+
tag_prefix, processor, py_version, container_version, inference_tool
331+
)
310332
else:
311-
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
333+
tag = _format_tag(
334+
tag_prefix, processor, py_version, container_version, inference_tool
335+
)
312336

313337
if instance_type is not None and _should_auto_select_container_version(
314338
instance_type, distribution
@@ -343,7 +367,8 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
343367

344368
if image_scope not in ("eia", "inference"):
345369
logger.warning(
346-
"Elastic inference is for inference only. Ignoring image scope: %s.", image_scope
370+
"Elastic inference is for inference only. Ignoring image scope: %s.",
371+
image_scope,
347372
)
348373
image_scope = "eia"
349374

@@ -358,7 +383,11 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
358383
)
359384
image_scope = available_scopes[0]
360385

361-
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
386+
if (
387+
not image_scope
388+
and "scope" in config
389+
and set(available_scopes) == {"training", "inference"}
390+
):
362391
logger.info(
363392
"Same images used for training and inference. Defaulting to image scope: %s.",
364393
available_scopes[0],
@@ -390,20 +419,27 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
390419
and "trn" in instance_type
391420
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
392421
):
393-
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
422+
_validate_framework(
423+
framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium"
424+
)
394425

395426
# Validate for Graviton allowed frameowrks
396427
if (
397428
instance_type is not None
398-
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
429+
and utils.get_instance_type_family(instance_type)
430+
in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
399431
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
400432
):
401-
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
433+
_validate_framework(
434+
framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton"
435+
)
402436

403437

404438
def config_for_framework(framework):
405439
"""Loads the JSON config for the given framework."""
406-
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
440+
fname = os.path.join(
441+
os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework)
442+
)
407443
with open(fname) as f:
408444
return json.load(f)
409445

@@ -412,7 +448,8 @@ def _get_final_image_scope(framework, instance_type, image_scope):
412448
"""Return final image scope based on provided framework and instance type."""
413449
if (
414450
framework in GRAVITON_ALLOWED_FRAMEWORKS
415-
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
451+
and utils.get_instance_type_family(instance_type)
452+
in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
416453
):
417454
return INFERENCE_GRAVITON
418455
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
@@ -428,7 +465,9 @@ def _get_inference_tool(inference_tool, instance_type):
428465
"""Extract the inference tool name from instance type."""
429466
if not inference_tool:
430467
instance_type_family = utils.get_instance_type_family(instance_type)
431-
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
468+
if instance_type_family.startswith("inf") or instance_type_family.startswith(
469+
"trn"
470+
):
432471
return "neuron"
433472
return inference_tool
434473

@@ -440,10 +479,15 @@ def _get_latest_versions(list_of_versions):
440479

441480
def _validate_accelerator_type(accelerator_type):
442481
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
443-
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
482+
if (
483+
not accelerator_type.startswith("ml.eia")
484+
and accelerator_type != "local_sagemaker_notebook"
485+
):
444486
raise ValueError(
445487
"Invalid SageMaker Elastic Inference accelerator type: {}. "
446-
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type)
488+
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(
489+
accelerator_type
490+
)
447491
)
448492

449493

@@ -453,11 +497,15 @@ def _validate_version_and_set_if_needed(version, config, framework):
453497
aliased_versions = list(config.get("version_aliases", {}).keys())
454498

455499
if len(available_versions) == 1 and version not in aliased_versions:
456-
log_message = "Defaulting to the only supported framework/algorithm version: {}.".format(
457-
available_versions[0]
500+
log_message = (
501+
"Defaulting to the only supported framework/algorithm version: {}.".format(
502+
available_versions[0]
503+
)
458504
)
459505
if version and version != available_versions[0]:
460-
logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version)
506+
logger.warning(
507+
"%s Ignoring framework/algorithm version: %s.", log_message, version
508+
)
461509
elif not version:
462510
logger.info(log_message)
463511

@@ -470,7 +518,9 @@ def _validate_version_and_set_if_needed(version, config, framework):
470518
]:
471519
version = _get_latest_versions(available_versions)
472520

473-
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
521+
_validate_arg(
522+
version, available_versions + aliased_versions, "{} version".format(framework)
523+
)
474524
return version
475525

476526

@@ -496,7 +546,9 @@ def _processor(instance_type, available_processors, serverless_inference_config=
496546
return None
497547

498548
if len(available_processors) == 1 and not instance_type:
499-
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
549+
logger.info(
550+
"Defaulting to only supported image scope: %s.", available_processors[0]
551+
)
500552
return available_processors[0]
501553

502554
if serverless_inference_config is not None:
@@ -533,7 +585,9 @@ def _processor(instance_type, available_processors, serverless_inference_config=
533585
else:
534586
raise ValueError(
535587
"Invalid SageMaker instance type: {}. For options, see: "
536-
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
588+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(
589+
instance_type
590+
)
537591
)
538592

539593
_validate_arg(processor, available_processors, "processor")
@@ -572,7 +626,9 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
572626
return None
573627

574628
if py_version is None and len(available_versions) == 1:
575-
logger.info("Defaulting to only available Python version: %s", available_versions[0])
629+
logger.info(
630+
"Defaulting to only available Python version: %s", available_versions[0]
631+
)
576632
return available_versions[0]
577633

578634
_validate_arg(py_version, available_versions, "Python version")
@@ -585,7 +641,9 @@ def _validate_arg(arg, available_options, arg_name):
585641
raise ValueError(
586642
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
587643
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
588-
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
644+
"{options}.".format(
645+
arg_name=arg_name, arg=arg, options=", ".join(available_options)
646+
)
589647
)
590648

591649

@@ -598,11 +656,17 @@ def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
598656
)
599657

600658

601-
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
659+
def _format_tag(
660+
tag_prefix, processor, py_version, container_version, inference_tool=None
661+
):
602662
"""Creates a tag for the image URI."""
603663
if inference_tool:
604-
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
605-
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
664+
return "-".join(
665+
x for x in (tag_prefix, inference_tool, py_version, container_version) if x
666+
)
667+
return "-".join(
668+
x for x in (tag_prefix, processor, py_version, container_version) if x
669+
)
606670

607671

608672
@override_pipeline_parameter_var
@@ -670,7 +734,7 @@ def get_training_image_uri(
670734
container_version = "cu121"
671735
else:
672736
container_version = "cu118"
673-
737+
674738
return retrieve(
675739
framework,
676740
region,
@@ -711,4 +775,6 @@ def get_base_python_image_uri(region, py_version="310") -> str:
711775
repo = version_config["repository"] + "-" + py_version
712776
repo_and_tag = repo + ":" + version
713777

714-
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)
778+
return ECR_URI_TEMPLATE.format(
779+
registry=registry, hostname=hostname, repository=repo_and_tag
780+
)

tests/unit/sagemaker/image_uris/test_smp_v2.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,10 @@
2323
def test_smp_v2(load_config):
2424
VERSIONS = load_config["training"]["versions"]
2525
PROCESSORS = load_config["training"]["processors"]
26-
distribution = {"torch_distributed":
27-
{
28-
"enabled": True
29-
},
30-
"smdistributed": {
31-
"modelparallel":
32-
{
33-
"enabled": True
34-
}
35-
}
36-
}
26+
distribution = {
27+
"torch_distributed": {"enabled": True},
28+
"smdistributed": {"modelparallel": {"enabled": True}},
29+
}
3730
for processor in PROCESSORS:
3831
for version in VERSIONS:
3932
ACCOUNTS = load_config["training"]["versions"][version]["registries"]
@@ -47,7 +40,7 @@ def test_smp_v2(load_config):
4740
framework_version=version,
4841
py_version=py_version,
4942
distribution=distribution,
50-
instance_type=instance_type
43+
instance_type=instance_type,
5144
)
5245
expected = expected_uris.framework_uri(
5346
repo="smdistributed-modelparallel",

0 commit comments

Comments
 (0)