Skip to content

Commit c4df18c

Browse files
author
Dewen Qi
committed
change: Update tests
1 parent 4bc1365 commit c4df18c

File tree

7 files changed

+574
-296
lines changed

7 files changed

+574
-296
lines changed

src/sagemaker/workflow/pipeline_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None
105105
else:
106106
self.context = request
107107

108-
def init_model_step_arguments(self, model):
108+
def init_step_arguments(self, model):
109109
"""Create a `_ModelStepArguments` (if not exist) as pipeline context
110110
111111
Args:
@@ -161,7 +161,7 @@ def wrapper(*args, **kwargs):
161161
UserWarning,
162162
)
163163
if run_func.__name__ in ["register", "create"]:
164-
args[0].sagemaker_session.init_model_step_arguments(args[0])
164+
args[0].sagemaker_session.init_step_arguments(args[0])
165165
run_func(*args, **kwargs)
166166
context = args[0].sagemaker_session.context
167167
args[0].sagemaker_session.context = None

src/sagemaker/workflow/steps.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,15 @@ def __init__(
420420

421421
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelOutput")
422422

423-
if not self.step_args:
424-
warnings.warn(
425-
(
426-
"We are deprecating the instantiation of CreateModelStep using "
427-
"`Model` and a list of `CreateModelInput`. "
428-
"Instead, the new interface simply uses step_args."
429-
),
430-
DeprecationWarning,
431-
)
423+
# TODO: add public document link here once ready
424+
warnings.warn(
425+
(
426+
"We are deprecating the use of CreateModelStep. "
427+
"Instead, please use the ModelStep, which simply takes in the step arguments "
428+
"generated by model.create()."
429+
),
430+
DeprecationWarning,
431+
)
432432

433433
@property
434434
def arguments(self) -> RequestType:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
import os
14+
import pickle as pkl
15+
16+
import numpy as np
17+
import sagemaker_xgboost_container.encoder as xgb_encoders
18+
19+
20+
def model_fn(model_dir):
21+
"""
22+
Deserialize and return fitted model.
23+
"""
24+
model_file = "xgboost-model"
25+
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
26+
return booster
27+
28+
29+
def input_fn(request_body, request_content_type):
30+
"""
31+
The SageMaker XGBoost model server receives the request data body and the content type,
32+
and invokes the `input_fn`.
33+
Return a DMatrix (an object that can be passed to predict_fn).
34+
"""
35+
if request_content_type == "text/libsvm":
36+
return xgb_encoders.libsvm_to_dmatrix(request_body)
37+
else:
38+
raise ValueError("Content type {} is not supported.".format(request_content_type))
39+
40+
41+
def predict_fn(input_data, model):
42+
"""
43+
SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.
44+
Return a two-dimensional NumPy array where the first columns are predictions
45+
and the remaining columns are the feature contributions (SHAP values) for that prediction.
46+
"""
47+
prediction = model.predict(input_data)
48+
feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False)
49+
output = np.hstack((prediction[:, np.newaxis], feature_contribs))
50+
return output
51+
52+
53+
def output_fn(predictions, content_type):
54+
"""
55+
After invoking predict_fn, the model server invokes `output_fn`.
56+
"""
57+
if content_type == "text/csv":
58+
return ",".join(str(x) for x in predictions[0])
59+
else:
60+
raise ValueError("Content type {} is not supported.".format(content_type))
35.1 KB
Binary file not shown.

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
from botocore.exceptions import WaiterError
2727

2828
import tests
29-
from sagemaker.parameter import IntegerParameter
3029
from sagemaker.tensorflow import TensorFlow, TensorFlowModel
31-
from sagemaker.tuner import HyperparameterTuner
3230
from tests.integ.retry import retries
3331
from sagemaker.drift_check_baselines import DriftCheckBaselines
3432
from sagemaker import (
@@ -50,7 +48,7 @@
5048
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
5149
from sagemaker.workflow.pipeline import Pipeline
5250
from sagemaker.workflow.step_collections import RegisterModel
53-
from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, TuningStep
51+
from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep
5452
from sagemaker.xgboost import XGBoostModel
5553
from sagemaker.xgboost import XGBoost
5654
from sagemaker.workflow.conditions import (
@@ -848,124 +846,3 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
848846
pipeline.delete()
849847
except Exception as error:
850848
logging.error(error)
851-
852-
853-
def test_tuning_single_algo_with_create_model(
854-
sagemaker_session,
855-
role,
856-
cpu_instance_type,
857-
pipeline_name,
858-
region_name,
859-
):
860-
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
861-
entry_point = os.path.join(base_dir, "mnist.py")
862-
input_path = sagemaker_session.upload_data(
863-
path=os.path.join(base_dir, "training"),
864-
key_prefix="integ-test-data/pytorch_mnist/training",
865-
)
866-
inputs = TrainingInput(s3_data=input_path)
867-
868-
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
869-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
870-
871-
pytorch_estimator = PyTorch(
872-
entry_point=entry_point,
873-
role=role,
874-
framework_version="1.5.0",
875-
py_version="py3",
876-
instance_count=instance_count,
877-
instance_type=instance_type,
878-
sagemaker_session=sagemaker_session,
879-
enable_sagemaker_metrics=True,
880-
max_retry_attempts=3,
881-
)
882-
883-
min_batch_size = ParameterInteger(name="MinBatchSize", default_value=64)
884-
max_batch_size = ParameterInteger(name="MaxBatchSize", default_value=128)
885-
hyperparameter_ranges = {
886-
"batch-size": IntegerParameter(min_batch_size, max_batch_size),
887-
}
888-
tuner = HyperparameterTuner(
889-
estimator=pytorch_estimator,
890-
objective_metric_name="test:acc",
891-
objective_type="Maximize",
892-
hyperparameter_ranges=hyperparameter_ranges,
893-
metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
894-
max_jobs=2,
895-
max_parallel_jobs=2,
896-
)
897-
step_tune = TuningStep(
898-
name="my-tuning-step",
899-
tuner=tuner,
900-
inputs=inputs,
901-
)
902-
best_model = Model(
903-
image_uri=pytorch_estimator.training_image_uri(),
904-
model_data=step_tune.get_top_model_s3_uri(
905-
top_k=0,
906-
s3_bucket=sagemaker_session.default_bucket(),
907-
),
908-
sagemaker_session=sagemaker_session,
909-
role=role,
910-
)
911-
model_inputs = CreateModelInput(
912-
instance_type="ml.m5.large",
913-
accelerator_type="ml.eia1.medium",
914-
)
915-
step_best_model = CreateModelStep(
916-
name="1st-model",
917-
model=best_model,
918-
inputs=model_inputs,
919-
)
920-
921-
second_best_model = Model(
922-
image_uri=pytorch_estimator.training_image_uri(),
923-
model_data=step_tune.get_top_model_s3_uri(
924-
top_k=1,
925-
s3_bucket=sagemaker_session.default_bucket(),
926-
),
927-
sagemaker_session=sagemaker_session,
928-
role=role,
929-
entry_point=entry_point,
930-
source_dir=base_dir,
931-
)
932-
step_second_best_model = CreateModelStep(
933-
name="2nd-best-model",
934-
model=second_best_model,
935-
inputs=model_inputs,
936-
)
937-
pipeline = Pipeline(
938-
name=pipeline_name,
939-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
940-
steps=[step_tune, step_best_model, step_second_best_model],
941-
sagemaker_session=sagemaker_session,
942-
)
943-
944-
try:
945-
response = pipeline.create(role)
946-
create_arn = response["PipelineArn"]
947-
assert re.match(
948-
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
949-
create_arn,
950-
)
951-
952-
execution = pipeline.start(parameters={})
953-
assert re.match(
954-
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
955-
execution.arn,
956-
)
957-
try:
958-
execution.wait(delay=30, max_attempts=60)
959-
except WaiterError:
960-
pass
961-
execution_steps = execution.list_steps()
962-
963-
for step in execution_steps:
964-
assert not step.get("FailureReason", None)
965-
assert step["StepStatus"] == "Succeeded"
966-
assert len(execution_steps) == 3
967-
finally:
968-
try:
969-
pipeline.delete()
970-
except Exception:
971-
pass

0 commit comments

Comments
 (0)