Skip to content

Commit cdd56e6

Browse files
authored
change: eliminate dependency on mnist dataset website (#986)
This commit will modify the folder structure to avoid needing to download the datasets and relying on the dataset website from being available. This caused a canary to fail. The dependency has existed since Pytorch 1.1 upgrade, as the directory structure was modified in pytorch/vision#601
1 parent fa14b32 commit cdd56e6

File tree

4 files changed

+5
-3
lines changed

4 files changed

+5
-3
lines changed

tests/data/pytorch_mnist/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
4747
transform=transforms.Compose(
4848
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
4949
),
50-
download=True,
50+
download=False, # True sets a dependency on an external site for our canaries.
5151
)
5252
train_sampler = (
5353
torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
@@ -71,7 +71,7 @@ def _get_test_data_loader(training_dir, **kwargs):
7171
transform=transforms.Compose(
7272
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
7373
),
74-
download=True,
74+
download=False, # True sets a dependency on an external site for our canaries.
7575
),
7676
batch_size=1000,
7777
shuffle=True,

tests/integ/test_git.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from sagemaker.sklearn.model import SKLearnModel
2929
from tests.integ import DATA_DIR, PYTHON_VERSION
3030

31+
MNIST_FOLDER_NAME = "MNIST"
32+
3133
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
3234
BRANCH = "test-branch-git-config"
3335
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
@@ -69,7 +71,7 @@ def test_git_support_with_pytorch(sagemaker_local_session):
6971
git_config=git_config,
7072
)
7173

72-
pytorch.fit({"training": "file://" + os.path.join(data_path, "training")})
74+
pytorch.fit({"training": "file://" + os.path.join(data_path, "training", MNIST_FOLDER_NAME)})
7375

7476
with lock.lock(LOCK_PATH):
7577
try:

0 commit comments

Comments
 (0)