19
19
20
20
from sagemaker import ModelMetrics , MetricsSource , FileSource , Predictor
21
21
from sagemaker .drift_check_baselines import DriftCheckBaselines
22
+ from sagemaker .instance_group import InstanceGroup
22
23
from sagemaker .metadata_properties import MetadataProperties
23
24
from sagemaker .model import FrameworkModel
24
25
from sagemaker .parameter import IntegerParameter
@@ -233,14 +234,17 @@ def _generate_all_pipeline_vars() -> dict:
233
234
)
234
235
235
236
237
+ # TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE
238
+ # As currently the `instance_groups` does not work well with some estimator subclasses,
239
+ # we temporarily hard code it to False which disables the instance_groups
240
+ _IS_TRUE_TMP = False
236
241
IS_TRUE = bool (getrandbits (1 ))
237
242
PIPELINE_SESSION = _generate_mock_pipeline_session ()
238
243
PIPELINE_VARIABLES = _generate_all_pipeline_vars ()
239
244
240
245
# TODO: need to recursively assign with Pipeline Variable in later changes
241
246
FIXED_ARGUMENTS = dict (
242
247
common = dict (
243
- instance_type = INSTANCE_TYPE ,
244
248
role = ROLE ,
245
249
sagemaker_session = PIPELINE_SESSION ,
246
250
source_dir = f"s3://{ BUCKET } /source" ,
@@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict:
281
285
response_types = ["application/json" ],
282
286
),
283
287
processor = dict (
288
+ instance_type = INSTANCE_TYPE ,
284
289
estimator_cls = PyTorch ,
285
290
code = f"s3://{ BUCKET } /code" ,
286
291
spark_event_logs_s3_uri = f"s3://{ BUCKET } /my-spark-output-path" ,
@@ -438,13 +443,31 @@ def _generate_all_pipeline_vars() -> dict:
438
443
input_mode = ParameterString (name = "train_inputs_input_mode" ),
439
444
attribute_names = [ParameterString (name = "train_inputs_attribute_name" )],
440
445
target_attribute_name = ParameterString (name = "train_inputs_target_attr_name" ),
446
+ instance_groups = [ParameterString (name = "train_inputs_instance_groups" )],
441
447
),
442
448
},
449
+ instance_groups = [
450
+ InstanceGroup (
451
+ instance_group_name = ParameterString (name = "instance_group_name" ),
452
+ # hard code the instance_type here because InstanceGroup.instance_type
453
+ # would be used to retrieve image_uri if image_uri is not presented
454
+ # and currently the test mechanism does not support skip the test case
455
+ # relating to bonded parameters in composite variables (i.e. the InstanceGroup)
456
+ # TODO: we should support skip testing on bonded parameters in composite vars
457
+ instance_type = "ml.m5.xlarge" ,
458
+ instance_count = ParameterString (name = "instance_group_instance_count" ),
459
+ ),
460
+ ] if _IS_TRUE_TMP else None ,
461
+ instance_type = "ml.m5.xlarge" if not _IS_TRUE_TMP else None ,
462
+ instance_count = 1 if not _IS_TRUE_TMP else None ,
463
+ distribution = {} if not _IS_TRUE_TMP else None ,
443
464
),
444
465
transformer = dict (
466
+ instance_type = INSTANCE_TYPE ,
445
467
data = f"s3://{ BUCKET } /data" ,
446
468
),
447
469
tuner = dict (
470
+ instance_type = INSTANCE_TYPE ,
448
471
estimator = TensorFlow (
449
472
entry_point = TENSORFLOW_ENTRY_POINT ,
450
473
role = ROLE ,
@@ -475,12 +498,14 @@ def _generate_all_pipeline_vars() -> dict:
475
498
include_cls_metadata = {"estimator-1" : IS_TRUE },
476
499
),
477
500
model = dict (
501
+ instance_type = INSTANCE_TYPE ,
478
502
serverless_inference_config = ServerlessInferenceConfig (),
479
503
framework_version = "1.11.0" ,
480
504
py_version = "py3" ,
481
505
accelerator_type = "ml.eia2.xlarge" ,
482
506
),
483
507
pipelinemodel = dict (
508
+ instance_type = INSTANCE_TYPE ,
484
509
models = [
485
510
SparkMLModel (
486
511
name = "MySparkMLModel" ,
@@ -577,12 +602,17 @@ def _generate_all_pipeline_vars() -> dict:
577
602
},
578
603
),
579
604
)
580
- # A dict to keep the optional arguments which should not be None according to the logic
581
- # specific to the subclass.
605
+ # A dict to keep the optional arguments which should not be set to None
606
+ # in the test iteration according to the logic specific to the subclass.
582
607
PARAMS_SHOULD_NOT_BE_NONE = dict (
583
608
estimator = dict (
584
609
init = dict (
585
- common = {"instance_count" , "instance_type" },
610
+ # TODO: we should remove the three instance_ parameters here
611
+ # For mutually exclusive parameters: instance group
612
+ # vs instance count/instance type, if any side is set to None during iteration,
613
+ # the other side should get a not None value, instead of listing them here
614
+ # and force them to be not None
615
+ common = {"instance_count" , "instance_type" , "instance_groups" },
586
616
LDA = {"mini_batch_size" },
587
617
)
588
618
),
@@ -692,7 +722,10 @@ def _generate_all_pipeline_vars() -> dict:
692
722
),
693
723
estimator = dict (
694
724
init = dict (
695
- common = dict (),
725
+ common = dict (
726
+ entry_point = {"enable_network_isolation" },
727
+ source_dir = {"enable_network_isolation" },
728
+ ),
696
729
TensorFlow = dict (
697
730
image_uri = {"compiler_config" },
698
731
compiler_config = {"image_uri" },
@@ -701,7 +734,13 @@ def _generate_all_pipeline_vars() -> dict:
701
734
image_uri = {"compiler_config" },
702
735
compiler_config = {"image_uri" },
703
736
),
704
- )
737
+ ),
738
+ fit = dict (
739
+ common = dict (
740
+ instance_count = {"instance_groups" },
741
+ instance_type = {"instance_groups" },
742
+ ),
743
+ ),
705
744
),
706
745
)
707
746
0 commit comments