Skip to content

Commit 155e465

Browse files
authored
infra: add PyTorch + custom model bucket batch transform integ test (#1407)
This converts the MXNet + VPC config + batch transform integ test to use PyTorch instead, and start from a model rather than an estimator. In addition, the model is uploaded to a non-default bucket, and the tests checks that the repacked model is saved in the same (non-default) bucket. The motivation for converting an existing test is to not increase the overall runtime of the integ tests.
1 parent b0ca952 commit 155e465

File tree

6 files changed

+62
-68
lines changed

6 files changed

+62
-68
lines changed

tests/conftest.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tests.integ
2121
from botocore.config import Config
2222

23-
from sagemaker import Session
23+
from sagemaker import Session, utils
2424
from sagemaker.chainer import Chainer
2525
from sagemaker.local import LocalSession
2626
from sagemaker.mxnet import MXNet
@@ -30,6 +30,7 @@
3030
from sagemaker.tensorflow.estimator import TensorFlow
3131

3232
DEFAULT_REGION = "us-west-2"
33+
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
3334

3435
NO_M4_REGIONS = [
3536
"eu-west-3",
@@ -89,16 +90,16 @@ def sagemaker_runtime_config(request):
8990

9091

9192
@pytest.fixture(scope="session")
92-
def boto_config(request):
93+
def boto_session(request):
9394
config = request.config.getoption("--boto-config")
94-
return json.loads(config) if config else None
95+
if config:
96+
return boto3.Session(**json.loads(config))
97+
else:
98+
return boto3.Session(region_name=DEFAULT_REGION)
9599

96100

97101
@pytest.fixture(scope="session")
98-
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_config):
99-
boto_session = (
100-
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
101-
)
102+
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_session):
102103
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
103104
sagemaker_client = (
104105
boto_session.client("sagemaker", **sagemaker_client_config)
@@ -119,14 +120,19 @@ def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_co
119120

120121

121122
@pytest.fixture(scope="session")
122-
def sagemaker_local_session(boto_config):
123-
if boto_config:
124-
boto_session = boto3.Session(**boto_config)
125-
else:
126-
boto_session = boto3.Session(region_name=DEFAULT_REGION)
123+
def sagemaker_local_session(boto_session):
127124
return LocalSession(boto_session=boto_session)
128125

129126

127+
@pytest.fixture(scope="module")
128+
def custom_bucket_name(boto_session):
129+
region = boto_session.region_name
130+
account = boto_session.client(
131+
"sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region)
132+
).get_caller_identity()["Account"]
133+
return "{}-{}-{}".format(CUSTOM_BUCKET_NAME_PREFIX, region, account)
134+
135+
130136
@pytest.fixture(scope="module", params=["4.0", "4.0.0", "4.1", "4.1.0", "5.0", "5.0.0"])
131137
def chainer_version(request):
132138
return request.param

tests/data/pytorch_mnist/model.tar.gz

142 KB
Binary file not shown.
919 KB
Binary file not shown.

tests/integ/test_processing.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import os
1616

17-
import boto3
1817
import pytest
1918
from botocore.config import Config
2019
from sagemaker import Session
@@ -28,22 +27,16 @@
2827
ProcessingJob,
2928
)
3029
from sagemaker.sklearn.processing import SKLearnProcessor
31-
from sagemaker.utils import sts_regional_endpoint
3230
from tests.integ import DATA_DIR
3331
from tests.integ.kms_utils import get_or_create_kms_key
3432

3533
ROLE = "SageMakerRole"
36-
DEFAULT_REGION = "us-west-2"
37-
CUSTOM_BUCKET_PATH_PREFIX = "sagemaker-custom-bucket"
3834

3935

4036
@pytest.fixture(scope="module")
4137
def sagemaker_session_with_custom_bucket(
42-
boto_config, sagemaker_client_config, sagemaker_runtime_config
38+
boto_session, sagemaker_client_config, sagemaker_runtime_config, custom_bucket_name
4339
):
44-
boto_session = (
45-
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
46-
)
4740
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
4841
sagemaker_client = (
4942
boto_session.client("sagemaker", **sagemaker_client_config)
@@ -56,17 +49,11 @@ def sagemaker_session_with_custom_bucket(
5649
else None
5750
)
5851

59-
region = boto_session.region_name
60-
account = boto_session.client(
61-
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
62-
).get_caller_identity()["Account"]
63-
custom_default_bucket = "{}-{}-{}".format(CUSTOM_BUCKET_PATH_PREFIX, region, account)
64-
6552
return Session(
6653
boto_session=boto_session,
6754
sagemaker_client=sagemaker_client,
6855
sagemaker_runtime_client=runtime_client,
69-
default_bucket=custom_default_bucket,
56+
default_bucket=custom_bucket_name,
7057
)
7158

7259

@@ -221,6 +208,7 @@ def test_sklearn_with_customizations(
221208

222209
def test_sklearn_with_custom_default_bucket(
223210
sagemaker_session_with_custom_bucket,
211+
custom_bucket_name,
224212
image_uri,
225213
sklearn_full_version,
226214
cpu_instance_type,
@@ -272,10 +260,10 @@ def test_sklearn_with_custom_default_bucket(
272260
job_description = sklearn_processor.latest_job.describe()
273261

274262
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
275-
assert CUSTOM_BUCKET_PATH_PREFIX in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
263+
assert custom_bucket_name in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
276264

277265
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
278-
assert CUSTOM_BUCKET_PATH_PREFIX in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
266+
assert custom_bucket_name in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
279267

280268
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
281269

@@ -583,7 +571,11 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k
583571

584572

585573
def test_processor_with_custom_bucket(
586-
sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key
574+
sagemaker_session_with_custom_bucket,
575+
custom_bucket_name,
576+
image_uri,
577+
cpu_instance_type,
578+
output_kms_key,
587579
):
588580
script_path = os.path.join(DATA_DIR, "dummy_script.py")
589581

@@ -624,7 +616,7 @@ def test_processor_with_custom_bucket(
624616
job_description = processor.latest_job.describe()
625617

626618
assert job_description["ProcessingInputs"][0]["InputName"] == "code"
627-
assert CUSTOM_BUCKET_PATH_PREFIX in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
619+
assert custom_bucket_name in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
628620

629621
assert job_description["ProcessingJobName"].startswith("test-processor")
630622

tests/integ/test_session.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,12 @@
1717

1818
from sagemaker import Session
1919

20-
DEFAULT_REGION = "us-west-2"
2120
CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"
2221

2322

2423
def test_sagemaker_session_does_not_create_bucket_on_init(
25-
sagemaker_client_config, sagemaker_runtime_config, boto_config
24+
sagemaker_client_config, sagemaker_runtime_config, boto_session
2625
):
27-
boto_session = (
28-
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
29-
)
3026
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
3127
sagemaker_client = (
3228
boto_session.client("sagemaker", **sagemaker_client_config)
@@ -46,5 +42,5 @@ def test_sagemaker_session_does_not_create_bucket_on_init(
4642
default_bucket=CUSTOM_BUCKET_NAME,
4743
)
4844

49-
s3 = boto3.resource("s3", region_name=DEFAULT_REGION)
45+
s3 = boto3.resource("s3", region_name=boto_session.region_name)
5046
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None

tests/integ/test_transformer.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020

2121
import pytest
2222

23-
from sagemaker import KMeans
23+
from sagemaker import KMeans, s3
2424
from sagemaker.mxnet import MXNet
25+
from sagemaker.pytorch import PyTorchModel
2526
from sagemaker.transformer import Transformer
2627
from sagemaker.estimator import Estimator
2728
from sagemaker.utils import unique_name_from_base
2829
from tests.integ import (
2930
DATA_DIR,
31+
PYTHON_VERSION,
3032
TRAINING_DEFAULT_TIMEOUT_MINUTES,
3133
TRANSFORM_DEFAULT_TIMEOUT_MINUTES,
3234
)
@@ -144,48 +146,43 @@ def test_attach_transform_kmeans(sagemaker_session, cpu_instance_type):
144146
attached_transformer.wait()
145147

146148

147-
def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version, cpu_instance_type):
148-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
149-
script_path = os.path.join(data_path, "mnist.py")
149+
def test_transform_pytorch_vpc_custom_model_bucket(
150+
sagemaker_session, pytorch_full_version, cpu_instance_type, custom_bucket_name
151+
):
152+
data_dir = os.path.join(DATA_DIR, "pytorch_mnist")
150153

151154
ec2_client = sagemaker_session.boto_session.client("ec2")
152155
subnet_ids, security_group_id = get_or_create_vpc_resources(ec2_client)
153156

154-
mx = MXNet(
155-
entry_point=script_path,
156-
role="SageMakerRole",
157-
train_instance_count=1,
158-
train_instance_type=cpu_instance_type,
159-
sagemaker_session=sagemaker_session,
160-
framework_version=mxnet_full_version,
161-
subnets=subnet_ids,
162-
security_group_ids=[security_group_id],
157+
model_data = sagemaker_session.upload_data(
158+
path=os.path.join(data_dir, "model.tar.gz"),
159+
bucket=custom_bucket_name,
160+
key_prefix="integ-test-data/pytorch_mnist/model",
163161
)
164162

165-
train_input = mx.sagemaker_session.upload_data(
166-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
167-
)
168-
test_input = mx.sagemaker_session.upload_data(
169-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
163+
model = PyTorchModel(
164+
model_data=model_data,
165+
entry_point=os.path.join(data_dir, "mnist.py"),
166+
role="SageMakerRole",
167+
framework_version=pytorch_full_version,
168+
py_version=PYTHON_VERSION,
169+
sagemaker_session=sagemaker_session,
170+
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
171+
code_location="s3://{}".format(custom_bucket_name),
170172
)
171-
job_name = unique_name_from_base("test-mxnet-vpc")
172173

173-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
174-
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
175-
176-
job_desc = sagemaker_session.sagemaker_client.describe_training_job(
177-
TrainingJobName=mx.latest_training_job.name
174+
transform_input = sagemaker_session.upload_data(
175+
path=os.path.join(data_dir, "transform", "data.npy"),
176+
key_prefix="integ-test-data/pytorch_mnist/transform",
178177
)
179-
assert set(subnet_ids) == set(job_desc["VpcConfig"]["Subnets"])
180-
assert [security_group_id] == job_desc["VpcConfig"]["SecurityGroupIds"]
181178

182-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
183-
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
184-
transform_input = mx.sagemaker_session.upload_data(
185-
path=transform_input_path, key_prefix=transform_input_key_prefix
179+
transformer = model.transformer(1, cpu_instance_type)
180+
transformer.transform(
181+
transform_input,
182+
content_type="application/x-npy",
183+
job_name=unique_name_from_base("test-transform-vpc"),
186184
)
187185

188-
transformer = _create_transformer_and_transform_job(mx, transform_input, cpu_instance_type)
189186
with timeout_and_delete_model_with_transformer(
190187
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
191188
):
@@ -196,6 +193,9 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version, cpu_instance
196193
assert set(subnet_ids) == set(model_desc["VpcConfig"]["Subnets"])
197194
assert [security_group_id] == model_desc["VpcConfig"]["SecurityGroupIds"]
198195

196+
model_bucket, _ = s3.parse_s3_url(model_desc["PrimaryContainer"]["ModelDataUrl"])
197+
assert custom_bucket_name == model_bucket
198+
199199

200200
def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version, cpu_instance_type):
201201
data_path = os.path.join(DATA_DIR, "mxnet_mnist")

0 commit comments

Comments
 (0)