Skip to content

Commit aa54102

Browse files
committed
update data parallel integ test to use local mnist dataset
1 parent caf6580 commit aa54102

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tests/data/smdistributed_dataparallel/mnist_pt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import print_function
1414

1515
import argparse
16+
import os
1617
import time
1718
import torch
1819
import torch.nn as nn
@@ -150,8 +151,8 @@ def main():
150151
parser.add_argument(
151152
"--data-path",
152153
type=str,
153-
default="/tmp/data",
154-
help="Path for downloading " "the MNIST dataset",
154+
default=os.environ["SM_CHANNEL_TRAINING"],
155+
help="Path for downloading the MNIST dataset",
155156
)
156157

157158
args = parser.parse_args()
@@ -186,7 +187,7 @@ def main():
186187
train_dataset = datasets.MNIST(
187188
data_path,
188189
train=True,
189-
download=True,
190+
download=False, # True sets a dependency on an external site for our tests.
190191
transform=transforms.Compose(
191192
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
192193
),

tests/integ/test_smdataparallel_pt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from sagemaker.pytorch import PyTorch
2323
from tests.integ import timeout
24-
24+
from tests.integ.test_pytorch import _upload_training_data
2525

2626
smdataparallel_dir = os.path.join(
2727
os.path.dirname(__file__), "..", "data", "smdistributed_dataparallel"
@@ -51,4 +51,4 @@ def test_smdataparallel_pt_mnist(
5151
)
5252

5353
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
54-
estimator.fit(job_name=job_name)
54+
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)

0 commit comments

Comments
 (0)