Skip to content

Commit 6c25b30

Browse files
author
Dewen Qi
committed
update TM as per latest estimator changes
1 parent bcb4e4c commit 6c25b30

File tree

4 files changed

+59
-13
lines changed

4 files changed

+59
-13
lines changed

src/sagemaker/instance_group.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
"""Defines the InstanceGroup class that configures a heterogeneous cluster."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class InstanceGroup(object):
1822
"""The class to create instance groups for a heterogeneous cluster."""
1923

2024
def __init__(
2125
self,
22-
instance_group_name=None,
23-
instance_type=None,
24-
instance_count=None,
26+
instance_group_name: Optional[Union[str, PipelineVariable]] = None,
27+
instance_type: Optional[Union[str, PipelineVariable]] = None,
28+
instance_count: Optional[Union[int, PipelineVariable]] = None,
2529
):
2630
"""It initializes an ``InstanceGroup`` instance.
2731

tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker import ModelMetrics, MetricsSource, FileSource, Predictor
2121
from sagemaker.drift_check_baselines import DriftCheckBaselines
22+
from sagemaker.instance_group import InstanceGroup
2223
from sagemaker.metadata_properties import MetadataProperties
2324
from sagemaker.model import FrameworkModel
2425
from sagemaker.parameter import IntegerParameter
@@ -233,14 +234,17 @@ def _generate_all_pipeline_vars() -> dict:
233234
)
234235

235236

237+
# TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE
238+
# As currently the `instance_groups` does not work well with some estimator subclasses,
239+
# we temporarily hard code it to False which disables the instance_groups
240+
_IS_TRUE_TMP = False
236241
IS_TRUE = bool(getrandbits(1))
237242
PIPELINE_SESSION = _generate_mock_pipeline_session()
238243
PIPELINE_VARIABLES = _generate_all_pipeline_vars()
239244

240245
# TODO: need to recursively assign with Pipeline Variable in later changes
241246
FIXED_ARGUMENTS = dict(
242247
common=dict(
243-
instance_type=INSTANCE_TYPE,
244248
role=ROLE,
245249
sagemaker_session=PIPELINE_SESSION,
246250
source_dir=f"s3://{BUCKET}/source",
@@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict:
281285
response_types=["application/json"],
282286
),
283287
processor=dict(
288+
instance_type=INSTANCE_TYPE,
284289
estimator_cls=PyTorch,
285290
code=f"s3://{BUCKET}/code",
286291
spark_event_logs_s3_uri=f"s3://{BUCKET}/my-spark-output-path",
@@ -438,13 +443,31 @@ def _generate_all_pipeline_vars() -> dict:
438443
input_mode=ParameterString(name="train_inputs_input_mode"),
439444
attribute_names=[ParameterString(name="train_inputs_attribute_name")],
440445
target_attribute_name=ParameterString(name="train_inputs_target_attr_name"),
446+
instance_groups=[ParameterString(name="train_inputs_instance_groups")],
441447
),
442448
},
449+
instance_groups=[
450+
InstanceGroup(
451+
instance_group_name=ParameterString(name="instance_group_name"),
452+
# hard code the instance_type here because InstanceGroup.instance_type
453+
# would be used to retrieve image_uri if image_uri is not presented
454+
# and currently the test mechanism does not support skip the test case
455+
# relating to bonded parameters in composite variables (i.e. the InstanceGroup)
456+
# TODO: we should support skip testing on bonded parameters in composite vars
457+
instance_type="ml.m5.xlarge",
458+
instance_count=ParameterString(name="instance_group_instance_count"),
459+
),
460+
] if _IS_TRUE_TMP else None,
461+
instance_type="ml.m5.xlarge" if not _IS_TRUE_TMP else None,
462+
instance_count=1 if not _IS_TRUE_TMP else None,
463+
distribution={} if not _IS_TRUE_TMP else None,
443464
),
444465
transformer=dict(
466+
instance_type=INSTANCE_TYPE,
445467
data=f"s3://{BUCKET}/data",
446468
),
447469
tuner=dict(
470+
instance_type=INSTANCE_TYPE,
448471
estimator=TensorFlow(
449472
entry_point=TENSORFLOW_ENTRY_POINT,
450473
role=ROLE,
@@ -475,12 +498,14 @@ def _generate_all_pipeline_vars() -> dict:
475498
include_cls_metadata={"estimator-1": IS_TRUE},
476499
),
477500
model=dict(
501+
instance_type=INSTANCE_TYPE,
478502
serverless_inference_config=ServerlessInferenceConfig(),
479503
framework_version="1.11.0",
480504
py_version="py3",
481505
accelerator_type="ml.eia2.xlarge",
482506
),
483507
pipelinemodel=dict(
508+
instance_type=INSTANCE_TYPE,
484509
models=[
485510
SparkMLModel(
486511
name="MySparkMLModel",
@@ -577,12 +602,17 @@ def _generate_all_pipeline_vars() -> dict:
577602
},
578603
),
579604
)
580-
# A dict to keep the optional arguments which should not be None according to the logic
581-
# specific to the subclass.
605+
# A dict to keep the optional arguments which should not be set to None
606+
# in the test iteration according to the logic specific to the subclass.
582607
PARAMS_SHOULD_NOT_BE_NONE = dict(
583608
estimator=dict(
584609
init=dict(
585-
common={"instance_count", "instance_type"},
610+
# TODO: we should remove the three instance_ parameters here
611+
# For mutually exclusive parameters: instance group
612+
# vs instance count/instance type, if any side is set to None during iteration,
613+
# the other side should get a not None value, instead of listing them here
614+
# and force them to be not None
615+
common={"instance_count", "instance_type", "instance_groups"},
586616
LDA={"mini_batch_size"},
587617
)
588618
),
@@ -692,7 +722,10 @@ def _generate_all_pipeline_vars() -> dict:
692722
),
693723
estimator=dict(
694724
init=dict(
695-
common=dict(),
725+
common=dict(
726+
entry_point={"enable_network_isolation"},
727+
source_dir={"enable_network_isolation"},
728+
),
696729
TensorFlow=dict(
697730
image_uri={"compiler_config"},
698731
compiler_config={"image_uri"},
@@ -701,7 +734,13 @@ def _generate_all_pipeline_vars() -> dict:
701734
image_uri={"compiler_config"},
702735
compiler_config={"image_uri"},
703736
),
704-
)
737+
),
738+
fit=dict(
739+
common=dict(
740+
instance_count={"instance_groups"},
741+
instance_type={"instance_groups"},
742+
),
743+
),
705744
),
706745
)
707746

tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616

1717
from random import getrandbits
18-
from typing import Optional
18+
from typing import Optional, List
1919
from typing_extensions import get_origin
2020

2121
from sagemaker import Model, PipelineModel, AlgorithmEstimator
@@ -368,14 +368,14 @@ def _verify_composite_object_against_pipeline_var(
368368
self,
369369
param_with_none: str,
370370
step_dsl: str,
371-
step_dsl_obj: object,
371+
step_dsl_obj: List[dict],
372372
):
373373
"""verify pipeline definition regarding composite objects against pipeline variables
374374
375375
Args:
376376
param_with_none (str): The name of the parameter with None value.
377377
step_dsl (str): The step definition retrieved from the pipeline definition DSL.
378-
step_dsl_obj (objet): The json load object of the step definition.
378+
step_dsl_obj (List[dict]): The json load object of the step definition.
379379
"""
380380
# TODO: remove the following hard code assertion once recursive assignment is added
381381
if issubclass(self.clazz, Processor):
@@ -398,6 +398,9 @@ def _verify_composite_object_against_pipeline_var(
398398
assert '{"Get": "Parameters.proc_input_s3_data_type"}' in step_dsl
399399
assert '{"Get": "Parameters.proc_input_app_managed"}' in step_dsl
400400
elif issubclass(self.clazz, EstimatorBase):
401+
if param_with_none != "instance_groups" and self.default_args[CLAZZ_ARGS]["instance_groups"]:
402+
assert '{"Get": "Parameters.instance_group_name"}' in step_dsl
403+
assert '{"Get": "Parameters.instance_group_instance_count"}' in step_dsl
401404
if issubclass(self.clazz, AmazonAlgorithmEstimatorBase):
402405
# AmazonAlgorithmEstimatorBase's input is records
403406
if param_with_none != "records":
@@ -415,6 +418,7 @@ def _verify_composite_object_against_pipeline_var(
415418
assert '{"Get": "Parameters.train_inputs_input_mode"}' in step_dsl
416419
assert '{"Get": "Parameters.train_inputs_attribute_name"}' in step_dsl
417420
assert '{"Get": "Parameters.train_inputs_target_attr_name"}' in step_dsl
421+
assert '{"Get": "Parameters.train_inputs_instance_groups"}' in step_dsl
418422
if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch, AlgorithmEstimator)):
419423
# debugger_hook_config may be disabled for these first 3 frameworks
420424
# AlgorithmEstimator ignores the kwargs

tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def test_sklearn_estimator_compatibility():
291291
clazz_args=dict(
292292
py_version="py3",
293293
instance_count=1,
294-
instance_type="ml.m5.xlarge",
295294
framework_version="0.20.0",
296295
),
297296
func_args=dict(),

0 commit comments

Comments
 (0)