Skip to content

fix: use workflow parameters in training hyperparameters (#2114) (#2115) #2227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
_region_supports_profiler,
get_mp_parameters,
)
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
from sagemaker.inputs import TrainingInput
from sagemaker.job import _Job
from sagemaker.local import LocalSession
Expand Down Expand Up @@ -1456,7 +1459,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config):

current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()}
hyperparameters = {
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
for (k, v) in current_hyperparameters.items()
}

train_args = config.copy()
train_args["input_mode"] = estimator.input_mode
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sagemaker.session import Session
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
Expand Down Expand Up @@ -292,7 +293,9 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
normalized_inputs.append(file_input)
continue

if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
normalized_inputs.append(file_input)
continue
# If the source is a local path, upload it to S3
# and save the S3 uri in the ProcessingInput source.
parse_result = urlparse(file_input.s3_input.s3_uri)
Expand Down Expand Up @@ -340,8 +343,7 @@ def _normalize_outputs(self, outputs=None):
# Generate a name for the ProcessingOutput if it doesn't have one.
if output.output_name is None:
output.output_name = "output-{}".format(count)
# if the output's destination is a workflow expression, do no normalization
if isinstance(output.destination, Expression):
if isinstance(output.destination, (Parameter, Expression, Properties)):
normalized_outputs.append(output)
continue
# If the output's destination is not an s3_uri, create one.
Expand Down Expand Up @@ -1099,7 +1101,7 @@ def _create_s3_input(self):
self.s3_data_type = self.s3_input.s3_data_type
self.s3_input_mode = self.s3_input.s3_input_mode
self.s3_data_distribution_type = self.s3_input.s3_data_distribution_type
elif self.source and self.destination:
elif self.source is not None and self.destination is not None:
self.s3_input = S3Input(
s3_uri=self.source,
local_path=self.destination,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create(

Args:
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
pipeline_description (str): A description of the pipeline.
description (str): A description of the pipeline.
experiment_name (str): The name of the experiment.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.
Expand Down
45 changes: 33 additions & 12 deletions tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sagemaker.network import NetworkConfig
from sagemaker.transformer import Transformer
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.workflow.steps import (
ProcessingStep,
Step,
Expand Down Expand Up @@ -112,16 +113,27 @@ def test_custom_step():


def test_training_step(sagemaker_session):
instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge")
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
data_source_uri_parameter = ParameterString(
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
)
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
estimator = Estimator(
image_uri=IMAGE_URI,
role=ROLE,
instance_count=1,
instance_type="c4.4xlarge",
instance_count=instance_count_parameter,
instance_type=instance_type_parameter,
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
hyperparameters={
"batch-size": training_batch_size_parameter,
"epochs": training_epochs_parameter,
},
rules=[],
sagemaker_session=sagemaker_session,
)
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
inputs = TrainingInput(s3_data=data_source_uri_parameter)
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
step = TrainingStep(
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
Expand All @@ -131,22 +143,26 @@ def test_training_step(sagemaker_session):
"Type": "Training",
"Arguments": {
"AlgorithmSpecification": {"TrainingImage": IMAGE_URI, "TrainingInputMode": "File"},
"HyperParameters": {
"batch-size": training_batch_size_parameter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't expect to see objects in this assertion. Does it work if asserted on {"Get": "Parameters.TrainingBatchSize"}?

Copy link
Member Author

@nmadan nmadan Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that won't work. pipeline.to_request() won't interpolate parameters like that but pipeline.definition() would. Here's the code that does that https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/workflow/pipeline.py#L248-L258

We have older unit tests that perform similar assertions https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/sagemaker/workflow/test_pipeline.py#L187.

Lmk if you think I should change pipeline.to_request() logic to return parameter expressions instead of objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, I just wanted to make sure we have coverage on the conversion of those Parameter objects to the JSON code, looks like it's in test_pipeline.py

"epochs": training_epochs_parameter,
},
"InputDataConfig": [
{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{BUCKET}/train_manifest",
"S3Uri": data_source_uri_parameter,
}
},
}
],
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "c4.4xlarge",
"InstanceCount": instance_count_parameter,
"InstanceType": instance_type_parameter,
"VolumeSizeInGB": 30,
},
"RoleArn": ROLE,
Expand All @@ -162,16 +178,21 @@ def test_training_step(sagemaker_session):


def test_processing_step(sagemaker_session):
processing_input_data_uri_parameter = ParameterString(
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"
)
instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.m4.4xlarge")
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
processor = Processor(
image_uri=IMAGE_URI,
role=ROLE,
instance_count=1,
instance_type="ml.m4.4xlarge",
instance_count=instance_count_parameter,
instance_type=instance_type_parameter,
sagemaker_session=sagemaker_session,
)
inputs = [
ProcessingInput(
source=f"s3://{BUCKET}/processing_manifest",
source=processing_input_data_uri_parameter,
destination="processing_manifest",
)
]
Expand All @@ -198,14 +219,14 @@ def test_processing_step(sagemaker_session):
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3Uri": "s3://my-bucket/processing_manifest",
"S3Uri": processing_input_data_uri_parameter,
},
}
],
"ProcessingResources": {
"ClusterConfig": {
"InstanceCount": 1,
"InstanceType": "ml.m4.4xlarge",
"InstanceCount": instance_count_parameter,
"InstanceType": instance_type_parameter,
"VolumeSizeInGB": 30,
}
},
Expand Down