Skip to content

feature: Support model pipelines in CreateModelStep #2845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TransformInput,
)
from sagemaker.model import Model
from sagemaker.pipeline import PipelineModel
from sagemaker.processing import (
ProcessingInput,
ProcessingJob,
Expand Down Expand Up @@ -319,7 +320,7 @@ class CreateModelStep(ConfigurableRetryStep):
def __init__(
self,
name: str,
model: Model,
model: Union[Model, PipelineModel],
inputs: CreateModelInput,
depends_on: Union[List[str], List[Step]] = None,
retry_policies: List[RetryPolicy] = None,
Expand All @@ -333,7 +334,8 @@ def __init__(

Args:
name (str): The name of the CreateModel step.
model (Model): A `sagemaker.model.Model` instance.
model (Model or PipelineModel): A `sagemaker.model.Model`
or `sagemaker.pipeline.PipelineModel` instance.
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
Defaults to `None`.
depends_on (List[str] or List[Step]): A list of step names or step instances
Expand All @@ -358,16 +360,25 @@ def arguments(self) -> RequestType:
ModelName cannot be included in the arguments.
"""

request_dict = self.model.sagemaker_session._create_model_request(
name="",
role=self.model.role,
container_defs=self.model.prepare_container_def(
instance_type=self.inputs.instance_type,
accelerator_type=self.inputs.accelerator_type,
),
vpc_config=self.model.vpc_config,
enable_network_isolation=self.model.enable_network_isolation(),
)
if isinstance(self.model, PipelineModel):
request_dict = self.model.sagemaker_session._create_model_request(
name="",
role=self.model.role,
container_defs=self.model.pipeline_container_def(self.inputs.instance_type),
vpc_config=self.model.vpc_config,
enable_network_isolation=self.model.enable_network_isolation,
)
else:
request_dict = self.model.sagemaker_session._create_model_request(
name="",
role=self.model.role,
container_defs=self.model.prepare_container_def(
instance_type=self.inputs.instance_type,
accelerator_type=self.inputs.accelerator_type,
),
vpc_config=self.model.vpc_config,
enable_network_isolation=self.model.enable_network_isolation(),
)
request_dict.pop("ModelName")

return request_dict
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
CreateModelStep,
CacheConfig,
)
from sagemaker.pipeline import PipelineModel
from sagemaker.sparkml import SparkMLModel
from sagemaker.predictor import Predictor
from sagemaker.model import FrameworkModel
from tests.unit import DATA_DIR

DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
Expand Down Expand Up @@ -89,6 +93,21 @@ def properties(self):
return self._properties


class DummyFrameworkModel(FrameworkModel):
def __init__(self, sagemaker_session, **kwargs):
super(DummyFrameworkModel, self).__init__(
"s3://bucket/model_1.tar.gz",
"mi-1",
ROLE,
os.path.join(DATA_DIR, "dummy_script.py"),
sagemaker_session=sagemaker_session,
**kwargs,
)

def create_predictor(self, endpoint_name):
return Predictor(endpoint_name, self.sagemaker_session)


@pytest.fixture
def boto_session():
role_mock = Mock()
Expand Down Expand Up @@ -704,6 +723,63 @@ def test_create_model_step(sagemaker_session):
assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"}


@patch("tarfile.open")
@patch("time.strftime", return_value="2017-10-10-14-14-15")
def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session):
framework_model = DummyFrameworkModel(sagemaker_session)
sparkml_model = SparkMLModel(
model_data="s3://bucket/model_2.tar.gz",
role=ROLE,
sagemaker_session=sagemaker_session,
env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
)
model = PipelineModel(
models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session
)
inputs = CreateModelInput(
instance_type="c4.4xlarge",
accelerator_type="ml.eia1.medium",
)
step = CreateModelStep(
name="MyCreateModelStep",
depends_on=["TestStep"],
display_name="MyCreateModelStep",
description="TestDescription",
model=model,
inputs=inputs,
)
step.add_depends_on(["SecondTestStep"])

assert step.to_request() == {
"Name": "MyCreateModelStep",
"Type": "Model",
"Description": "TestDescription",
"DisplayName": "MyCreateModelStep",
"DependsOn": ["TestStep", "SecondTestStep"],
"Arguments": {
"Containers": [
{
"Environment": {
"SAGEMAKER_PROGRAM": "dummy_script.py",
"SAGEMAKER_SUBMIT_DIRECTORY": "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
"SAGEMAKER_REGION": "us-west-2",
},
"Image": "mi-1",
"ModelDataUrl": "s3://bucket/model_1.tar.gz",
},
{
"Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
"Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4",
"ModelDataUrl": "s3://bucket/model_2.tar.gz",
},
],
"ExecutionRoleArn": "DummyRole",
},
}
assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"}


def test_transform_step(sagemaker_session):
transformer = Transformer(
model_name=MODEL_NAME,
Expand Down