Skip to content

Commit 5133eb8

Browse files
melty-chocolateHui Wang
andauthored
Upgrade Pytorch version to 1.2.0 in Neo notebook (#1098)
* set pytorch version for neo sagemaker notebook * Upgrade pytorch version to 1.2.0 for neo notebook Co-authored-by: Hui Wang <[email protected]>
1 parent bc8d9d5 commit 5133eb8

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

sagemaker_neo_compilation_jobs/pytorch_torchvision/pytorch_torchvision_neo.ipynb

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
{
1818
"cell_type": "code",
1919
"execution_count": null,
20-
"metadata": {},
20+
"metadata": {
21+
"scrolled": true
22+
},
2123
"outputs": [],
2224
"source": [
23-
"!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.0"
25+
"!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.2.0"
2426
]
2527
},
2628
{
@@ -48,7 +50,9 @@
4850
"import tarfile\n",
4951
"\n",
5052
"resnet18 = models.resnet18(pretrained=True)\n",
51-
"torch.save(resnet18, 'model.pth')\n",
53+
"input_shape = [1,3,224,224]\n",
54+
"trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())\n",
55+
"trace.save('model.pth')\n",
5256
"\n",
5357
"with tarfile.open('model.tar.gz', 'w:gz') as f:\n",
5458
" f.add('model.pth')"
@@ -94,7 +98,7 @@
9498
"data_shape = '{\"input0\":[1,3,224,224]}'\n",
9599
"target_device = 'ml_c5'\n",
96100
"framework = 'PYTORCH'\n",
97-
"framework_version = '0.4.0'\n",
101+
"framework_version = '1.2.0'\n",
98102
"compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)"
99103
]
100104
},

0 commit comments

Comments
 (0)