Skip to content

fix: Correctly interpolate Callback output parameters #2467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 22, 2021
13 changes: 7 additions & 6 deletions src/sagemaker/workflow/callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,20 @@ def to_request(self) -> RequestType:
"OutputType": self.output_type.value,
}

@property
def expr(self) -> Dict[str, str]:
"""The 'Get' expression dict for a `Parameter`."""
return CallbackOutput._expr(self.output_name)
def expr(self, step_name) -> Dict[str, str]:
"""The 'Get' expression dict for a `CallbackOutput`."""
return CallbackOutput._expr(self.output_name, step_name)

@classmethod
def _expr(cls, name):
def _expr(cls, name, step_name):
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.

Args:
name (str): The name of the callback output.
step_name (str): The name of the step the callback step associated
with this output belongs to.
"""
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}


class CallbackStep(Step):
Expand Down
45 changes: 36 additions & 9 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from sagemaker._studio import _append_project_tags
from sagemaker.session import Session
from sagemaker.workflow.callback_step import CallbackOutput
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
from sagemaker.workflow.entities import (
Entity,
Expression,
Expand Down Expand Up @@ -240,9 +240,12 @@ def definition(self) -> str:
"""Converts a request structure to string representation for workflow service calls."""
request_dict = self.to_request()
request_dict["PipelineExperimentConfig"] = interpolate(
request_dict["PipelineExperimentConfig"]
request_dict["PipelineExperimentConfig"], {}
)
callback_output_to_step_map = _map_callback_outputs(self.steps)
request_dict["Steps"] = interpolate(
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
)
request_dict["Steps"] = interpolate(request_dict["Steps"])

return json.dumps(request_dict)

Expand All @@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]


def interpolate(request_obj: RequestType) -> RequestType:
def interpolate(
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
) -> RequestType:
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.

Args:
request_obj (RequestType): The request dict.
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.

Returns:
RequestType: The request dict with Parameter values replaced by their expression.
"""
request_obj_copy = deepcopy(request_obj)
return _interpolate(request_obj_copy)
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)


def _interpolate(obj: Union[RequestType, Any]):
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.

Args:
obj (Union[RequestType, Any]): The request dict.
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
"""
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
if isinstance(obj, (Expression, Parameter, Properties)):
return obj.expr
if isinstance(obj, CallbackOutput):
step_name = callback_output_to_step_map[obj.output_name]
return obj.expr(step_name)
if isinstance(obj, dict):
new = obj.__class__()
for key, value in obj.items():
new[key] = interpolate(value)
new[key] = interpolate(value, callback_output_to_step_map)
elif isinstance(obj, (list, set, tuple)):
new = obj.__class__(interpolate(value) for value in obj)
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
else:
return obj
return new


def _map_callback_outputs(steps: List[Step]):
"""Iterate over the provided steps, building a map of callback output parameters to step names.

Args:
step (List[Step]): The steps list.
"""

callback_output_map = {}
for step in steps:
if isinstance(step, CallbackStep):
if step.outputs:
for output in step.outputs:
callback_output_map[output.output_name] = step.name

return callback_output_map


def update_args(args: Dict[str, Any], **kwargs):
"""Updates the request arguments dict with a value, if populated.

Expand Down
41 changes: 41 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,47 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
pass


def test_two_step_callback_pipeline_with_output_reference(
sagemaker_session, role, pipeline_name, region_name
):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
step_callback1 = CallbackStep(
name="callback-step1",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": "foo"},
outputs=[outputParam1],
)

step_callback2 = CallbackStep(
name="callback-step2",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": outputParam1},
outputs=[],
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=[step_callback1, step_callback2],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_conditional_pytorch_training_model_registration(
sagemaker_session,
role,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/workflow/test_callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_pipeline_interpolates_callback_outputs():
name="MyCallbackStep2",
depends_on=["TestStep"],
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": cb_step1.properties.Outputs["output1"]},
inputs={"arg1": outputParam1},
outputs=[outputParam2],
)

Expand Down