File tree Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Original file line number Diff line number Diff line change
1
+ import boto3
2
+
3
+ PUBLIC_BUCKET = "sagemaker-sample-files"
4
+
5
+ def download_mnist (data_dir = '/tmp/data' , train = True ):
6
+ """Download MNIST dataset from a public S3 bucket
7
+
8
+ Args:
9
+ data_dir (str): directory to save the data
10
+ train (bool): download training set
11
+
12
+ Returns:
13
+ None
14
+ """
15
+
16
+ if not os .path .exists (data_dir ):
17
+ os .makedirs (data_dir )
18
+
19
+ if train :
20
+ images_file = "train-images-idx3-ubyte.gz"
21
+ labels_file = "train-labels-idx1-ubyte.gz"
22
+ else :
23
+ images_file = "t10k-images-idx3-ubyte.gz"
24
+ labels_file = "t10k-labels-idx1-ubyte.gz"
25
+
26
+ # download objects
27
+ s3 = boto3 .client ('s3' )
28
+ for obj in [images_file , labels_file ]:
29
+ key = os .path .join ("datasets/image/MNIST" , obj )
30
+ dest = os .path .join (data_dir , obj )
31
+ if not os .path .exists (dest ):
32
+ s3 .download_file (PUBLIC_BUCKET , key , dest )
33
+ return
34
+
You can’t perform that action at this time.
0 commit comments