Skip to content

Commit 9601828

Browse files
authored
Add torchinfo package to summarize a pytorch model (#1220)
http://b/257562539
1 parent 5e58ee9 commit 9601828

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

Dockerfile.tmpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,8 @@ RUN pip install pytorch-ignite \
559559
tables \
560560
openpyxl \
561561
timm \
562-
pycolmap && \
562+
pycolmap \
563+
torchinfo && \
563564
/tmp/clean-layer.sh
564565

565566
# Download base easyocr models.

tests/test_torchinfo.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import unittest
2+
3+
import torch.nn as tnn
4+
5+
from torchinfo import summary
6+
7+
class TestTorchinfo(unittest.TestCase):
8+
def test_info(self):
9+
model = tnn.Linear(5,3)
10+
s = summary(model)
11+
self.assertEqual(1, len(s.summary_list))

0 commit comments

Comments
 (0)