Skip to content

Commit 9b92c02

Browse files
committed
fix: use py3 when using pytorch_full_version
1 parent 3c737fe commit 9b92c02

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

tests/integ/test_git.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from tests.integ import lock as lock
2323
from sagemaker.mxnet.estimator import MXNet
24+
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2425
from sagemaker.pytorch.estimator import PyTorch
2526
from sagemaker.sklearn.estimator import SKLearn
2627
from sagemaker.sklearn.model import SKLearnModel
@@ -51,15 +52,18 @@
5152

5253

5354
@pytest.mark.local_mode
54-
def test_github(sagemaker_local_session, pytorch_full_version):
55+
def test_github(sagemaker_local_session):
5556
script_path = "mnist.py"
5657
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
5758
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
59+
60+
# TODO: fails for newer pytorch versions when using MNIST from torchvision due to missing dataset
61+
# "algo-1-v767u_1 | RuntimeError: Dataset not found. You can use download=True to download it"
5862
pytorch = PyTorch(
5963
entry_point=script_path,
6064
role="SageMakerRole",
6165
source_dir="pytorch",
62-
framework_version=pytorch_full_version,
66+
framework_version=PYTORCH_VERSION,
6367
py_version=PYTHON_VERSION,
6468
train_instance_count=1,
6569
train_instance_type="local",

tests/integ/test_pytorch_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _get_pytorch_estimator(
191191
entry_point=entry_point,
192192
role="SageMakerRole",
193193
framework_version=pytorch_full_version,
194-
py_version=PYTHON_VERSION,
194+
py_version="py3",
195195
train_instance_count=1,
196196
train_instance_type=instance_type,
197197
sagemaker_session=sagemaker_session,

tests/integ/test_source_dirs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,26 @@
1919
import tests.integ.lock as lock
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

22+
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2223
from sagemaker.pytorch.estimator import PyTorch
2324

2425

2526
@pytest.mark.local_mode
26-
def test_source_dirs(tmpdir, sagemaker_local_session, pytorch_full_version):
27+
def test_source_dirs(tmpdir, sagemaker_local_session):
2728
source_dir = os.path.join(DATA_DIR, "pytorch_source_dirs")
2829
lib = os.path.join(str(tmpdir), "alexa.py")
2930

3031
with open(lib, "w") as f:
3132
f.write("def question(to_anything): return 42")
3233

34+
# TODO: fails on newer versions of pytorch in call to np.load(BytesIO(stream.read()))
35+
# "ValueError: Cannot load file containing pickled data when allow_pickle=False"
3336
estimator = PyTorch(
3437
entry_point="train.py",
3538
role="SageMakerRole",
3639
source_dir=source_dir,
3740
dependencies=[lib],
38-
framework_version=pytorch_full_version,
41+
framework_version=PYTORCH_VERSION,
3942
py_version=PYTHON_VERSION,
4043
train_instance_count=1,
4144
train_instance_type="local",

tests/integ/test_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_transform_pytorch_vpc_custom_model_bucket(
178178
entry_point=os.path.join(data_dir, "mnist.py"),
179179
role="SageMakerRole",
180180
framework_version=pytorch_full_version,
181-
py_version=PYTHON_VERSION,
181+
py_version="py3",
182182
sagemaker_session=sagemaker_session,
183183
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
184184
code_location="s3://{}".format(custom_bucket_name),

tests/integ/test_tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type, pytorch_ful
828828
role="SageMakerRole",
829829
train_instance_count=1,
830830
framework_version=pytorch_full_version,
831-
py_version=PYTHON_VERSION,
831+
py_version="py3",
832832
train_instance_type=cpu_instance_type,
833833
sagemaker_session=sagemaker_session,
834834
)

0 commit comments

Comments
 (0)