Skip to content

Commit 46551f9

Browse files
committed
renamed the test cases
1 parent 81ed8e8 commit 46551f9

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# type: ignore
2+
import os
3+
import tempfile
4+
import unittest
5+
6+
import pytest
7+
import timm
8+
import torch
9+
import torch.nn.functional as F
10+
import torch_tensorrt as torchtrt
11+
import torchvision.models as models
12+
from torch import nn
13+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
14+
from transformers import BertModel
15+
from transformers.utils.fx import symbolic_trace as transformers_trace
16+
17+
assertions = unittest.TestCase()
18+
19+
20+
@pytest.mark.unit
21+
def test_custom_model():
22+
class net(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
26+
self.bn = nn.BatchNorm2d(12)
27+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
28+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
29+
30+
def forward(self, x, b=5, c=None, d=None):
31+
x = self.conv1(x)
32+
x = F.relu(x)
33+
x = self.bn(x)
34+
x = F.max_pool2d(x, (2, 2))
35+
x = self.conv2(x)
36+
x = F.relu(x)
37+
x = F.max_pool2d(x, (2, 2))
38+
x = torch.flatten(x, 1)
39+
x = x + b
40+
if c is not None:
41+
x = x * c
42+
if d is not None:
43+
x = x - d["value"]
44+
return self.fc1(x)
45+
46+
model = net().eval().to("cuda")
47+
args = [torch.rand((1, 3, 224, 224)).to("cuda")]
48+
kwargs = {
49+
"b": torch.tensor(6).to("cuda"),
50+
"d": {"value": torch.tensor(8).to("cuda")},
51+
}
52+
53+
compile_spec = {
54+
"inputs": args,
55+
"kwarg_inputs": kwargs,
56+
"device": torchtrt.Device("cuda:0"),
57+
"enabled_precisions": {torch.float},
58+
"pass_through_build_failures": True,
59+
"optimization_level": 1,
60+
"min_block_size": 1,
61+
"ir": "dynamo",
62+
}
63+
# TODO: Support torchtrt.compile
64+
# trt_mod = torchtrt.compile(model, **compile_spec)
65+
66+
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
67+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
68+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
69+
assertions.assertTrue(
70+
cos_sim > COSINE_THRESHOLD,
71+
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
72+
)
73+
74+
# Save the module
75+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
76+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
77+
# Clean up model env
78+
torch._dynamo.reset()

0 commit comments

Comments
 (0)