Skip to content

Commit 12d669a

Browse files
authored
Merge branch 'master' into framework-versioning
2 parents 2f35471 + 467e157 commit 12d669a

File tree

2 files changed

+75
-142
lines changed

2 files changed

+75
-142
lines changed

doc/frameworks/mxnet/using_mxnet.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ It loads the model parameters from a ``model.params`` file in the SageMaker mode
377377
return net
378378
379379
MXNet on Amazon SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`__, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance.
380-
In order to load and serve your MXNet model through Amazon Elastic Inference, the MXNet context passed to your MXNet Symbol or Module object within your ``model_fn`` needs to be set to ``eia``, as shown `here <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html#ei-mxnet>`__.
380+
In order to load and serve your MXNet model through Amazon Elastic Inference, import the ``eimx`` Python package and make one change in the code to partition your model and optimize it for the ``EIA`` back end, as shown `here <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html#ei-mxnet>`__.
381381

382382
Based on the example above, the following code-snippet shows an example custom ``model_fn`` implementation, which enables loading and serving our MXNet model through Amazon Elastic Inference.
383383

@@ -392,11 +392,12 @@ Based on the example above, the following code-snippet shows an example custom `
392392
Returns:
393393
mxnet.gluon.nn.Block: a Gluon network (for this example)
394394
"""
395-
net = models.get_model('resnet34_v2', ctx=mx.eia(), pretrained=False, classes=10)
396-
net.load_params('%s/model.params' % model_dir, ctx=mx.eia())
395+
net = models.get_model('resnet34_v2', ctx=mx.cpu(), pretrained=False, classes=10)
396+
net.load_params('%s/model.params' % model_dir, ctx=mx.cpu())
397+
net.hybridize(backend='EIA', static_alloc=True, static_shape=True)
397398
return net
398399
399-
The `default_model_fn <https://github.com/aws/sagemaker-mxnet-container/pull/55/files#diff-aabf018d906ed282a3c738377d19a8deR71>`__ loads and serve your model through Elastic Inference, if applicable, within the Amazon SageMaker MXNet containers.
400+
If you are using MXNet 1.5.1 and earlier, the `default_model_fn <https://github.com/aws/sagemaker-mxnet-container/pull/55/files#diff-aabf018d906ed282a3c738377d19a8deR71>`__ loads and serve your model through Elastic Inference, if applicable, within the Amazon SageMaker MXNet containers.
400401

401402
For more information on how to enable MXNet to interact with Amazon Elastic Inference, see `Use Elastic Inference with MXNet <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html>`__.
402403

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 70 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
4949
"TrialComponentArn": "tc-arn",
5050
}
5151

52-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
53-
{
54-
"AssociationSummaries": [
55-
{
56-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
57-
"SourceName": "source-name-1",
58-
"SourceType": "source-type-1",
59-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
60-
"DestinationName": "dest-name-1",
61-
"DestinationType": "dest-type-1",
62-
"AssociationType": "type-1",
63-
}
64-
]
65-
},
66-
{
67-
"AssociationSummaries": [
68-
{
69-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
70-
"SourceName": "source-name-2",
71-
"SourceType": "source-type-2",
72-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
73-
"DestinationName": "dest-name-2",
74-
"DestinationType": "dest-type-2",
75-
"AssociationType": "type-2",
76-
}
77-
]
78-
},
79-
]
52+
get_list_associations_side_effect(sagemaker_session)
8053

8154
df = viz.show(trial_component_name=name)
8255

8356
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
8457
TrialComponentName=name,
8558
)
8659

87-
expected_calls = [
88-
unittest.mock.call(
89-
DestinationArn="tc-arn",
90-
),
91-
unittest.mock.call(
92-
SourceArn="tc-arn",
93-
),
94-
]
95-
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
60+
assert_list_associations_mock_calls(sagemaker_session)
9661

97-
expected_dataframe = pd.DataFrame.from_dict(
98-
OrderedDict(
99-
[
100-
("Name/Source", ["source-name-1", "dest-name-2"]),
101-
("Direction", ["Input", "Output"]),
102-
("Type", ["source-type-1", "dest-type-2"]),
103-
("Association Type", ["type-1", "type-2"]),
104-
("Lineage Type", ["artifact", "artifact"]),
105-
]
106-
)
107-
)
108-
109-
pd.testing.assert_frame_equal(expected_dataframe, df)
62+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
11063

11164

11265
def test_model_package_arn(viz, sagemaker_session):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
11669
"ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}]
11770
}
11871

119-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
120-
{
121-
"AssociationSummaries": [
122-
{
123-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
124-
"SourceName": "source-name-1",
125-
"SourceType": "source-type-1",
126-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
127-
"DestinationName": "dest-name-1",
128-
"DestinationType": "dest-type-1",
129-
"AssociationType": "type-1",
130-
}
131-
]
132-
},
133-
{
134-
"AssociationSummaries": [
135-
{
136-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
137-
"SourceName": "source-name-2",
138-
"SourceType": "source-type-2",
139-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
140-
"DestinationName": "dest-name-2",
141-
"DestinationType": "dest-type-2",
142-
"AssociationType": "type-2",
143-
}
144-
]
145-
},
146-
]
72+
get_list_associations_side_effect(sagemaker_session)
14773

14874
df = viz.show(model_package_arn=name)
14975

@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
16187
]
16288
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
16389

164-
expected_dataframe = pd.DataFrame.from_dict(
165-
OrderedDict(
166-
[
167-
("Name/Source", ["source-name-1", "dest-name-2"]),
168-
("Direction", ["Input", "Output"]),
169-
("Type", ["source-type-1", "dest-type-2"]),
170-
("Association Type", ["type-1", "type-2"]),
171-
("Lineage Type", ["artifact", "artifact"]),
172-
]
173-
)
174-
)
175-
176-
pd.testing.assert_frame_equal(expected_dataframe, df)
90+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
17791

17892

17993
def test_endpoint_arn(viz, sagemaker_session):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
18397
"ContextSummaries": [{"ContextArn": "context-arn"}]
18498
}
18599

186-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
187-
{
188-
"AssociationSummaries": [
189-
{
190-
"SourceArn": "a:b:c:d:e:context/src-arn-1",
191-
"SourceName": "source-name-1",
192-
"SourceType": "source-type-1",
193-
"DestinationArn": "a:b:c:d:e:context/dest-arn-1",
194-
"DestinationName": "dest-name-1",
195-
"DestinationType": "dest-type-1",
196-
"AssociationType": "type-1",
197-
}
198-
]
199-
},
200-
{
201-
"AssociationSummaries": [
202-
{
203-
"SourceArn": "a:b:c:d:e:context/src-arn-2",
204-
"SourceName": "source-name-2",
205-
"SourceType": "source-type-2",
206-
"DestinationArn": "a:b:c:d:e:context/dest-arn-2",
207-
"DestinationName": "dest-name-2",
208-
"DestinationType": "dest-type-2",
209-
"AssociationType": "type-2",
210-
}
211-
]
212-
},
213-
]
100+
get_list_associations_side_effect(sagemaker_session)
214101

215102
df = viz.show(endpoint_arn=name)
216103

@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228115
]
229116
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
230117

231-
expected_dataframe = pd.DataFrame.from_dict(
232-
OrderedDict(
233-
[
234-
("Name/Source", ["source-name-1", "dest-name-2"]),
235-
("Direction", ["Input", "Output"]),
236-
("Type", ["source-type-1", "dest-type-2"]),
237-
("Association Type", ["type-1", "type-2"]),
238-
("Lineage Type", ["context", "context"]),
239-
]
240-
)
118+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
119+
120+
121+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
122+
123+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
124+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
125+
}
126+
127+
get_list_associations_side_effect(sagemaker_session)
128+
129+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
130+
131+
df = viz.show(pipeline_execution_step=step)
132+
133+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
134+
SourceArn="proc-job-arn",
241135
)
242136

243-
pd.testing.assert_frame_equal(expected_dataframe, df)
137+
assert_list_associations_mock_calls(sagemaker_session)
244138

139+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
245140

246-
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
141+
142+
def test_training_job_pipeline_execution_step(viz, sagemaker_session):
247143

248144
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249145
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250146
}
251147

148+
get_list_associations_side_effect(sagemaker_session)
149+
150+
step = {"Metadata": {"TrainingJob": {"Arn": "training-job-arn"}}}
151+
152+
df = viz.show(pipeline_execution_step=step)
153+
154+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
155+
SourceArn="training-job-arn",
156+
)
157+
158+
assert_list_associations_mock_calls(sagemaker_session)
159+
160+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
161+
162+
163+
def test_transform_job_pipeline_execution_step(viz, sagemaker_session):
164+
165+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
166+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
167+
}
168+
169+
get_list_associations_side_effect(sagemaker_session)
170+
171+
step = {"Metadata": {"TransformJob": {"Arn": "transform-job-arn"}}}
172+
173+
df = viz.show(pipeline_execution_step=step)
174+
175+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
176+
SourceArn="transform-job-arn",
177+
)
178+
179+
assert_list_associations_mock_calls(sagemaker_session)
180+
181+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
182+
183+
184+
def get_list_associations_side_effect(sagemaker_session):
185+
252186
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253187
{
254188
"AssociationSummaries": [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278212
},
279213
]
280214

281-
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282-
283-
df = viz.show(pipeline_execution_step=step)
284215

285-
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286-
SourceArn="proc-job-arn",
287-
)
216+
def assert_list_associations_mock_calls(sagemaker_session):
288217

289218
expected_calls = [
290219
unittest.mock.call(
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296225
]
297226
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298227

228+
229+
def get_expected_dataframe():
230+
299231
expected_dataframe = pd.DataFrame.from_dict(
300232
OrderedDict(
301233
[
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308240
)
309241
)
310242

311-
pd.testing.assert_frame_equal(expected_dataframe, df)
243+
return expected_dataframe

0 commit comments

Comments
 (0)