Skip to content

Commit a3fce08

Browse files
zhaoqizqwangbenieric
authored andcommitted
Fix tests and codestyle (#1619)
Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 80f44f8 commit a3fce08

28 files changed

+469
-434
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def from_recipe(
893893
sagemaker_session = Session()
894894
logger.warning("SageMaker session not provided. Using default Session.")
895895
if role is None:
896-
role = get_execution_role(sagemaker_session=session)
896+
role = get_execution_role(sagemaker_session=sagemaker_session)
897897
logger.warning(f"Role not provided. Using default role:\n{role}")
898898

899899
# The training recipe is used to prepare the following args:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,14 +838,17 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
838838

839839
@_capture_telemetry("ModelBuilder.build_training_job")
840840
def _collect_training_job_model_telemetry(self):
841+
"""Dummy method to collect telemetry for training job handshake"""
841842
return
842843

843844
@_capture_telemetry("ModelBuilder.build_model_trainer")
844845
def _collect_model_trainer_model_telemetry(self):
846+
"""Dummy method to collect telemetry for model trainer handshake"""
845847
return
846848

847849
@_capture_telemetry("ModelBuilder.build_estimator")
848850
def _collect_estimator_model_telemetry(self):
851+
"""Dummy method to collect telemetry for estimator handshake"""
849852
return
850853

851854
# Model Builder is a class to build the model for deployment.

src/sagemaker/serve/utils/telemetry_logger.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,12 @@ def wrapper(self, *args, **kwargs):
122122
extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}"
123123

124124
if self.image_uri:
125-
image_uri_tail = self.image_uri.split("/")[1]
126125
image_uri_option = _get_image_uri_option(
127126
self.image_uri, getattr(self, "_is_custom_image_uri", False)
128127
)
129-
130-
if self.image_uri:
131-
extra += f"&x-imageTag={image_uri_tail}"
128+
split_image_uri = self.image_uri.split("/")
129+
if len(split_image_uri) > 1:
130+
extra += f"&x-imageTag={split_image_uri[1]}"
132131

133132
extra += f"&x-sdkVersion={SDK_VERSION}"
134133

src/sagemaker/telemetry/telemetry_logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@
6363

6464

6565
def _telemetry_emitter(feature: str, func_name: str):
66-
"""
66+
"""Telemetry Emitter
67+
6768
Decorator to emit telemetry logs for SageMaker Python SDK functions. This class needs
6869
sagemaker_session object as a member. Default session object is a pysdk v2 Session object
6970
in this repo. When collecting telemetry for classes using sagemaker-core Session object,

tests/integ/sagemaker/modules/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ def modules_sagemaker_session():
3636

3737
yield sagemaker_session
3838

39-
if region_manual_set:
39+
if region_manual_set and "AWS_DEFAULT_REGION" in os.environ:
4040
del os.environ["AWS_DEFAULT_REGION"]

tests/integ/sagemaker/modules/train/test_model_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
4444
)
4545

4646
model_trainer = ModelTrainer(
47-
session=modules_sagemaker_session,
47+
sagemaker_session=modules_sagemaker_session,
4848
training_image=DEFAULT_CPU_IMAGE,
4949
hyperparameters=EXPECTED_HYPERPARAMETERS,
5050
source_code=source_code,
@@ -60,7 +60,7 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
6060
entry_script="train.sh",
6161
)
6262
model_trainer = ModelTrainer(
63-
session=modules_sagemaker_session,
63+
sagemaker_session=modules_sagemaker_session,
6464
training_image=DEFAULT_CPU_IMAGE,
6565
hyperparameters=EXPECTED_HYPERPARAMETERS,
6666
source_code=source_code,
@@ -77,7 +77,7 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
7777
)
7878
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
7979
model_trainer = ModelTrainer(
80-
session=modules_sagemaker_session,
80+
sagemaker_session=modules_sagemaker_session,
8181
training_image=DEFAULT_CPU_IMAGE,
8282
compute=compute,
8383
hyperparameters=EXPECTED_HYPERPARAMETERS,
@@ -96,7 +96,7 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
9696
)
9797
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
9898
model_trainer = ModelTrainer(
99-
session=modules_sagemaker_session,
99+
sagemaker_session=modules_sagemaker_session,
100100
training_image=DEFAULT_CPU_IMAGE,
101101
compute=compute,
102102
hyperparameters=EXPECTED_HYPERPARAMETERS,

tests/integ/sagemaker/serve/conftest.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,48 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
# from __future__ import absolute_import
13+
from __future__ import absolute_import
1414

15-
# import os
16-
# import pytest
17-
# import platform
18-
# import collections
19-
# from numpy import loadtxt
20-
# from sagemaker.serve.spec.inference_spec import InferenceSpec
15+
import pytest
16+
import os
17+
import boto3
18+
import sagemaker
19+
import sagemaker_core.helper.session_helper as core_session
2120

22-
# if platform.python_version_tuple()[1] == "8":
23-
# from xgboost import XGBClassifier
24-
# from sklearn.model_selection import train_test_split
21+
DEFAULT_REGION = "us-west-2"
2522

26-
# from tests.integ.sagemaker.serve.constants import XGB_RESOURCE_DIR
2723

24+
@pytest.fixture(scope="module")
25+
def mb_sagemaker_session():
26+
region = os.environ.get("AWS_DEFAULT_REGION")
27+
if not region:
28+
os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION
29+
region_manual_set = True
30+
else:
31+
region_manual_set = True
2832

29-
# XgbTestSplit = collections.namedtuple("XgbTrainTestSplit", "x_test y_test")
33+
boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"])
34+
sagemaker_session = sagemaker.Session(boto_session=boto_session)
3035

36+
yield sagemaker_session
3137

32-
# @pytest.fixture(scope="session")
33-
# def loaded_xgb_model():
34-
# model = XGBClassifier()
35-
# model.load_model(XGB_RESOURCE_DIR + "/model.xgb")
36-
# return model
38+
if region_manual_set and "AWS_DEFAULT_REGION" in os.environ:
39+
del os.environ["AWS_DEFAULT_REGION"]
3740

3841

39-
# @pytest.fixture(scope="session")
40-
# def xgb_inference_spec():
41-
# class MyXGBoostModel(InferenceSpec):
42-
# def load(self, model_dir: str):
43-
# model = XGBClassifier()
44-
# model.load_model(model_dir + "/model.xgb")
45-
# return model
42+
@pytest.fixture(scope="module")
43+
def mb_sagemaker_core_session():
44+
region = os.environ.get("AWS_DEFAULT_REGION")
45+
if not region:
46+
os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION
47+
region_manual_set = True
48+
else:
49+
region_manual_set = True
4650

47-
# def invoke(
48-
# self,
49-
# input: object,
50-
# model: object,
51-
# ):
52-
# y_pred = model.predict(input)
53-
# predictions = [round(value) for value in y_pred]
54-
# return predictions
51+
boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"])
52+
sagemaker_session = core_session.Session(boto_session=boto_session)
5553

56-
# return MyXGBoostModel()
54+
yield sagemaker_session
5755

58-
59-
# @pytest.fixture(scope="session")
60-
# def xgb_test_sets():
61-
# dataset = loadtxt(
62-
# os.path.join(XGB_RESOURCE_DIR, "classification_training_data.data.csv"), delimiter=","
63-
# )
64-
65-
# X = dataset[:, 0:8]
66-
# Y = dataset[:, 8]
67-
68-
# seed = 7
69-
# test_size = 0.33
70-
71-
# _, x_test, _, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed)
72-
73-
# return XgbTestSplit(x_test, y_test)
56+
if region_manual_set and "AWS_DEFAULT_REGION" in os.environ:
57+
del os.environ["AWS_DEFAULT_REGION"]
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker import get_execution_role
18+
from sklearn.datasets import load_iris
19+
from sklearn.model_selection import train_test_split
20+
21+
import os
22+
23+
from sagemaker_core.main.shapes import (
24+
AlgorithmSpecification,
25+
Channel,
26+
DataSource,
27+
S3DataSource,
28+
OutputDataConfig,
29+
ResourceConfig,
30+
StoppingCondition,
31+
)
32+
import uuid
33+
from sagemaker.serve.builder.model_builder import ModelBuilder
34+
import pandas as pd
35+
import numpy as np
36+
from sagemaker.serve import InferenceSpec, SchemaBuilder
37+
from sagemaker_core.main.resources import TrainingJob
38+
from xgboost import XGBClassifier
39+
40+
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
41+
42+
from sagemaker.s3_utils import s3_path_join
43+
from sagemaker.async_inference import AsyncInferenceConfig
44+
from tests.integ.utils import cleanup_model_resources
45+
46+
47+
@pytest.fixture(scope="module")
48+
def xgboost_model_builder(mb_sagemaker_session):
49+
sagemaker_session = mb_sagemaker_session
50+
role = get_execution_role(sagemaker_session=sagemaker_session)
51+
bucket = sagemaker_session.default_bucket()
52+
53+
# Get IRIS Data
54+
iris = load_iris()
55+
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
56+
iris_df["target"] = iris.target
57+
58+
# Prepare Data
59+
os.makedirs("data", exist_ok=True)
60+
61+
iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]]
62+
63+
train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42)
64+
65+
train_data.to_csv("data/train.csv", index=False, header=False)
66+
test_data.to_csv("data/test.csv", index=False, header=False)
67+
68+
# Remove the target column from the testing data. We will use this to call invoke_endpoint later
69+
test_data.drop("target", axis=1)
70+
71+
prefix = "DEMO-scikit-iris"
72+
TRAIN_DATA = "train.csv"
73+
DATA_DIRECTORY = "data"
74+
75+
sagemaker_session.upload_data(
76+
DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY)
77+
)
78+
79+
s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA)
80+
s3_output_path = "s3://{}/{}/output".format(bucket, prefix)
81+
82+
print(s3_input_path)
83+
print(s3_output_path)
84+
85+
image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
86+
87+
class XGBoostSpec(InferenceSpec):
88+
def load(self, model_dir: str):
89+
print(model_dir)
90+
model = XGBClassifier()
91+
model.load_model(model_dir + "/xgboost-model")
92+
return model
93+
94+
def invoke(self, input_object: object, model: object):
95+
prediction_probabilities = model.predict_proba(input_object)
96+
predictions = np.argmax(prediction_probabilities, axis=1)
97+
return predictions
98+
99+
data = {"Name": ["Alice", "Bob", "Charlie"]}
100+
df = pd.DataFrame(data)
101+
training_job_name = str(uuid.uuid4())
102+
schema_builder = SchemaBuilder(sample_input=df, sample_output=df)
103+
104+
training_job = TrainingJob.create(
105+
training_job_name=training_job_name,
106+
hyper_parameters={
107+
"objective": "multi:softmax",
108+
"num_class": "3",
109+
"num_round": "10",
110+
"eval_metric": "merror",
111+
},
112+
algorithm_specification=AlgorithmSpecification(
113+
training_image=image, training_input_mode="File"
114+
),
115+
role_arn=role,
116+
input_data_config=[
117+
Channel(
118+
channel_name="train",
119+
content_type="csv",
120+
compression_type="None",
121+
record_wrapper_type="None",
122+
data_source=DataSource(
123+
s3_data_source=S3DataSource(
124+
s3_data_type="S3Prefix",
125+
s3_uri=s3_input_path,
126+
s3_data_distribution_type="FullyReplicated",
127+
)
128+
),
129+
)
130+
],
131+
output_data_config=OutputDataConfig(s3_output_path=s3_output_path),
132+
resource_config=ResourceConfig(
133+
instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30
134+
),
135+
stopping_condition=StoppingCondition(max_runtime_in_seconds=600),
136+
)
137+
training_job.wait()
138+
139+
xgboost_model_builder = ModelBuilder(
140+
name="ModelBuilderTest",
141+
model_path=training_job.model_artifacts.s3_model_artifacts,
142+
role_arn=role,
143+
inference_spec=XGBoostSpec(),
144+
image_uri=image,
145+
schema_builder=schema_builder,
146+
instance_type="ml.c6i.xlarge",
147+
)
148+
xgboost_model_builder.build()
149+
return xgboost_model_builder
150+
151+
152+
def test_real_time_deployment(xgboost_model_builder):
153+
real_time_predictor = xgboost_model_builder.deploy(
154+
endpoint_name="test", initial_instance_count=1
155+
)
156+
157+
assert real_time_predictor is not None
158+
cleanup_model_resources(
159+
sagemaker_session=xgboost_model_builder.sagemaker_session,
160+
model_name=xgboost_model_builder.built_model.name,
161+
endpoint_name=xgboost_model_builder.built_model.endpoint_name,
162+
)
163+
164+
165+
def test_serverless_deployment(xgboost_model_builder):
166+
serverless_predictor = xgboost_model_builder.deploy(
167+
endpoint_name="test1", inference_config=ServerlessInferenceConfig()
168+
)
169+
170+
assert serverless_predictor is not None
171+
cleanup_model_resources(
172+
sagemaker_session=xgboost_model_builder.sagemaker_session,
173+
model_name=xgboost_model_builder.built_model.name,
174+
endpoint_name=xgboost_model_builder.built_model.endpoint_name,
175+
)
176+
177+
178+
def test_async_deployment(xgboost_model_builder, mb_sagemaker_session):
179+
async_predictor = xgboost_model_builder.deploy(
180+
endpoint_name="test2",
181+
inference_config=AsyncInferenceConfig(
182+
output_path=s3_path_join(
183+
"s3://", mb_sagemaker_session.default_bucket(), "async_inference/output"
184+
)
185+
),
186+
)
187+
188+
assert async_predictor is not None
189+
cleanup_model_resources(
190+
sagemaker_session=xgboost_model_builder.sagemaker_session,
191+
model_name=xgboost_model_builder.built_model.name,
192+
endpoint_name=xgboost_model_builder.built_model.endpoint_name,
193+
)

0 commit comments

Comments
 (0)