Skip to content

Commit d8b4042

Browse files
metrizableDan Choi
authored andcommitted
fix: construct default session when not present (aws#463)
1 parent 43a9b28 commit d8b4042

File tree

10 files changed

+95
-40
lines changed

10 files changed

+95
-40
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def __init__(
215215
transform_instances,
216216
model_package_group_name=None,
217217
image_uri=None,
218+
model_metrics=None,
219+
approval_status="PendingManualApproval",
218220
compile_model_family=None,
219221
**kwargs,
220222
):
@@ -236,6 +238,9 @@ def __init__(
236238
versioned (default: None).
237239
image_uri (str): The container image uri for Model Package, if not specified,
238240
Estimator's training container image will be used (default: None).
241+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
242+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
243+
or "PendingManualApproval" (default: "PendingManualApproval").
239244
compile_model_family (str): Instance family for compiled model, if specified, a compiled
240245
model will be used (default: None).
241246
**kwargs: additional arguments to `create_model`.
@@ -249,6 +254,8 @@ def __init__(
249254
self.transform_instances = transform_instances
250255
self.model_package_group_name = model_package_group_name
251256
self.image_uri = image_uri
257+
self.model_metrics = model_metrics
258+
self.approval_status = approval_status
252259
self.compile_model_family = compile_model_family
253260
self.kwargs = kwargs
254261

@@ -300,6 +307,8 @@ def arguments(self) -> RequestType:
300307
inference_instances=self.inference_instances,
301308
transform_instances=self.transform_instances,
302309
model_package_group_name=self.model_package_group_name,
310+
model_metrics=self.model_metrics,
311+
approval_status=self.approval_status,
303312
)
304313
request_dict = model.sagemaker_session._get_create_model_package_request(
305314
**model_package_args
@@ -309,8 +318,6 @@ def arguments(self) -> RequestType:
309318
request_dict.pop("CertifyForMarketplace")
310319
if "Description" in request_dict:
311320
request_dict.pop("Description")
312-
if "ModelApprovalStatus" in request_dict:
313-
request_dict.pop("ModelApprovalStatus")
314321

315322
return request_dict
316323

src/sagemaker/workflow/pipeline.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ class Pipeline(Entity):
4949
`else_steps` of any `ConditionStep`. In particular, any steps that are within the
5050
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
5151
pipeline.
52-
52+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
53+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
54+
pipeline creates one using the default AWS configuration chain.
5355
"""
5456

5557
name: str = attr.ib(factory=str)
5658
parameters: Sequence[Parameter] = attr.ib(factory=list)
5759
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
58-
sagemaker_session: Session = attr.ib(default=Session)
60+
sagemaker_session: Session = attr.ib(factory=Session)
5961

6062
_version: str = "2020-12-01"
6163
_metadata: Dict[str, Any] = dict()
@@ -180,7 +182,7 @@ def start(
180182
self,
181183
parameters: Dict[str, Any] = None,
182184
execution_description: str = None,
183-
) -> Dict[str, Any]:
185+
):
184186
"""Starts a Pipeline execution in the Workflow service.
185187
186188
Args:
@@ -189,7 +191,7 @@ def start(
189191
execution_description (str): Description of the execution.
190192
191193
Returns:
192-
Response dict from service.
194+
A `_PipelineExecution` instance, if successful.
193195
"""
194196
exists = True
195197
try:
@@ -286,10 +288,17 @@ def update_args(args: Dict[str, Any], **kwargs):
286288

287289
@attr.s
288290
class _PipelineExecution:
289-
"""Internal class for encapsulating pipeline execution instances"""
291+
"""Internal class for encapsulating pipeline execution instances.
292+
293+
Attributes:
294+
arn (str): The arn of the pipeline exeuction.
295+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
296+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
297+
pipeline creates one using the default AWS configuration chain.
298+
"""
290299

291300
arn: str = attr.ib()
292-
sagemaker_session: Session = attr.ib(default=Session)
301+
sagemaker_session: Session = attr.ib(factory=Session)
293302

294303
def stop(self):
295304
"""Stops a pipeline execution."""

src/sagemaker/workflow/properties.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Union
16+
from typing import Dict, Union
1717

1818
import attr
1919

@@ -79,7 +79,7 @@ def __init__(self, path: str, shape_name: str = None):
7979
shape_name (str): botocore sagemaker service model shape name.
8080
"""
8181
super(PropertiesList, self).__init__(path, shape_name)
82-
self._items = dict()
82+
self._items: Dict[Union[int, str], Properties] = dict()
8383

8484
def __getitem__(self, item: Union[int, str]):
8585
"""Populate the indexing item with a Property, working for both list and dictionary.
@@ -114,7 +114,7 @@ class PropertyFile(Expression):
114114
path: str = attr.ib()
115115

116116
@property
117-
def expr(self):
117+
def expr(self) -> Dict[str, str]:
118118
"""The expression dict for a `PropertyFile`."""
119119
return {
120120
"PropertyFileName": self.name,

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226
self,
227227
name: str,
228228
transformer: Transformer,
229-
inputs: TransformInput = None,
229+
inputs: TransformInput,
230230
):
231231
"""Constructs a TrainingStep, given an `Transformer` instance.
232232
@@ -237,7 +237,6 @@ def __init__(
237237
name (str): The name of the transform step.
238238
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
239239
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
240-
Defaults to `None`.
241240
"""
242241
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
243242
self.transformer = transformer

tests/integ/test_artifact_analytics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def generate_artifacts(sagemaker_session):
4545

4646

4747
@pytest.mark.canary_quick
48+
@pytest.mark.skip("Failing as restricted to the SageMaker/Pipeline runtimes")
4849
def test_artifact_analytics(sagemaker_session):
4950
with generate_artifacts(sagemaker_session):
5051
analytics = ArtifactAnalytics(

tests/integ/test_workflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222

2323
from botocore.config import Config
24+
from botocore.exceptions import WaiterError
2425
from sagemaker.inputs import TrainingInput
2526
from sagemaker.processing import ProcessingInput, ProcessingOutput
2627
from sagemaker.pytorch.estimator import PyTorch
@@ -263,6 +264,10 @@ def test_one_step_sklearn_processing_pipeline(
263264
response = execution.describe()
264265
assert response["PipelineArn"] == create_arn
265266

267+
try:
268+
execution.wait(delay=5, max_attempts=6)
269+
except WaiterError:
270+
pass
266271
execution_steps = execution.list_steps()
267272
assert len(execution_steps) == 1
268273
assert execution_steps[0]["StepName"] == "sklearn-process"

tests/unit/sagemaker/workflow/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
"""Helper methods for testing."""
15+
from __future__ import absolute_import
16+
17+
18+
def ordered(obj):
19+
"""Helper function for dict comparison.
20+
21+
Recursively orders a json-like dict or list of dicts.
22+
23+
Args:
24+
obj: either a list or a dict
25+
26+
Returns:
27+
either a sorted list of elements or sorted list of tuples
28+
"""
29+
if isinstance(obj, dict):
30+
return sorted((k, ordered(v)) for k, v in obj.items())
31+
if isinstance(obj, list):
32+
return sorted(ordered(x) for x in obj)
33+
else:
34+
return obj

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Step,
2929
StepTypeEnum,
3030
)
31+
from tests.unit.sagemaker.workflow.helpers import ordered
3132

3233

3334
class CustomStep(Step):
@@ -49,16 +50,6 @@ def properties(self):
4950
return self._properties
5051

5152

52-
def ordered(obj):
53-
"""Helper function for dict comparison"""
54-
if isinstance(obj, dict):
55-
return sorted((k, ordered(v)) for k, v in obj.items())
56-
if isinstance(obj, list):
57-
return sorted(ordered(x) for x in obj)
58-
else:
59-
return obj
60-
61-
6253
@pytest.fixture
6354
def role_arn():
6455
return "arn:role"
@@ -211,11 +202,16 @@ def test_pipeline_basic():
211202
)
212203

213204

214-
def test_pipeline_two_step():
205+
def test_pipeline_two_step(sagemaker_session_mock):
215206
parameter = ParameterString("MyStr")
216207
step1 = CustomStep(name="MyStep1", input_data=parameter)
217208
step2 = CustomStep(name="MyStep2", input_data=step1.properties.S3Uri)
218-
pipeline = Pipeline(name="MyPipeline", parameters=[parameter], steps=[step1, step2])
209+
pipeline = Pipeline(
210+
name="MyPipeline",
211+
parameters=[parameter],
212+
steps=[step1, step2],
213+
sagemaker_session=sagemaker_session_mock,
214+
)
219215
assert pipeline.to_request() == {
220216
"Version": "2020-12-01",
221217
"Metadata": {},

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
StepCollection,
3333
RegisterModel,
3434
)
35+
from tests.unit.sagemaker.workflow.helpers import ordered
3536

3637
REGION = "us-west-2"
3738
BUCKET = "my-bucket"
@@ -124,20 +125,23 @@ def test_register_model(estimator):
124125
inference_instances=["inference_instance"],
125126
transform_instances=["transform_instance"],
126127
)
127-
assert register_model.request_dicts() == [
128-
{
129-
"Name": "RegisterModelStep",
130-
"Type": "RegisterModel",
131-
"Arguments": {
132-
"InferenceSpecification": {
133-
"Containers": [
134-
{"Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz"}
135-
],
136-
"SupportedContentTypes": ["content_type"],
137-
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
138-
"SupportedResponseMIMETypes": ["response_type"],
139-
"SupportedTransformInstanceTypes": ["transform_instance"],
128+
assert ordered(register_model.request_dicts()) == ordered(
129+
[
130+
{
131+
"Name": "RegisterModelStep",
132+
"Type": "RegisterModel",
133+
"Arguments": {
134+
"InferenceSpecification": {
135+
"Containers": [
136+
{"Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz"}
137+
],
138+
"SupportedContentTypes": ["content_type"],
139+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
140+
"SupportedResponseMIMETypes": ["response_type"],
141+
"SupportedTransformInstanceTypes": ["transform_instance"],
142+
},
143+
"ModelApprovalStatus": "PendingManualApproval",
140144
},
141145
},
142-
},
143-
]
146+
]
147+
)

0 commit comments

Comments
 (0)