Skip to content

Commit 08e5a8a

Browse files
authored
Add tests for torchtext and torchvision (#1046)
1 parent 0638abf commit 08e5a8a

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

tests/test_torchtext.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import unittest
2+
3+
from torchtext.data.metrics import bleu_score
4+
5+
6+
class TestTorchtext(unittest.TestCase):
7+
def test_bleu_score(self):
8+
candidate = [['I', 'love', 'Kaggle', 'Notebooks']]
9+
refs = [[['Completely', 'Different']]]
10+
11+
self.assertEqual(0, bleu_score(candidate, refs))
12+

tests/test_torchvision.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
3+
import torch
4+
import torchvision.transforms as transforms
5+
import torchvision.transforms.functional as F
6+
7+
8+
class TestTorchvision(unittest.TestCase):
9+
def test_float_to_float(self):
10+
input_dtype=torch.float32
11+
output_dtype=torch.float64
12+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
13+
transform = transforms.ConvertImageDtype(output_dtype)
14+
transform_script = torch.jit.script(F.convert_image_dtype)
15+
16+
output_image = transform(input_image)
17+
output_image_script = transform_script(input_image, output_dtype)
18+
19+
# TODO(b/181966788) Uncomment after upgrade to pytorch 1.9.0 is done.
20+
# torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
21+
22+
actual_min, actual_max = output_image.tolist()
23+
24+
self.assertAlmostEqual(0, actual_min)
25+
self.assertAlmostEqual(1, actual_max)

0 commit comments

Comments
 (0)