Skip to content

Commit 70b78d2

Browse files
author
Hongshan Li
committed
helper fn to download mnist from public s3
1 parent 8ce42b1 commit 70b78d2

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

utils/datasets.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import boto3
2+
3+
PUBLIC_BUCKET = "sagemaker-sample-files"
4+
5+
def download_mnist(data_dir='/tmp/data', train=True):
6+
"""Download MNIST dataset from a public S3 bucket
7+
8+
Args:
9+
data_dir (str): directory to save the data
10+
train (bool): download training set
11+
12+
Returns:
13+
None
14+
"""
15+
16+
if not os.path.exists(data_dir):
17+
os.makedirs(data_dir)
18+
19+
if train:
20+
images_file = "train-images-idx3-ubyte.gz"
21+
labels_file = "train-labels-idx1-ubyte.gz"
22+
else:
23+
images_file = "t10k-images-idx3-ubyte.gz"
24+
labels_file = "t10k-labels-idx1-ubyte.gz"
25+
26+
# download objects
27+
s3 = boto3.client('s3')
28+
for obj in [images_file, labels_file]:
29+
key = os.path.join("datasets/image/MNIST", obj)
30+
dest = os.path.join(data_dir, obj)
31+
if not os.path.exists(dest):
32+
s3.download_file(PUBLIC_BUCKET, key, dest)
33+
return
34+

0 commit comments

Comments
 (0)