Skip to content

Commit 60b3e51

Browse files
committed
chore: update hw_compat
1 parent 4323e36 commit 60b3e51

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This script is used to generate hw_compat.ts file that's used in test_hw_compat.py
2+
# Generate the model on a different hardware compared to the one you're testing on to
3+
# verify HW compatibility feature.
4+
5+
import torch
6+
import torch_tensorrt
7+
8+
9+
class MyModule(torch.nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
13+
self.relu = torch.nn.ReLU()
14+
15+
def forward(self, x):
16+
out = self.conv(x)
17+
out = self.relu(out)
18+
return out
19+
20+
21+
model = MyModule().eval().cuda()
22+
inputs = torch.randn((1, 3, 224, 224)).to("cuda")
23+
24+
trt_gm = torch_tensorrt.compile(
25+
model,
26+
ir="dynamo",
27+
inputs=inputs,
28+
min_block_size=1,
29+
hardware_compatible=True,
30+
version_compatible=True,
31+
)
32+
trt_script_model = torch.jit.trace(trt_gm, inputs)
33+
torch.jit.save(trt_script_model, "hw_compat.ts")

tests/py/dynamo/runtime/test_hw_compat.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,14 @@ def forward(self, x):
5959
"Detected incorrect ABI version, please update this test case",
6060
)
6161
def test_hw_compat_3080_build(self):
62-
inputs = [torch.randn(5, 7).cuda()]
62+
inputs = [torch.randn(1, 3, 224, 224).cuda()]
6363

6464
cwd = os.getcwd()
6565
os.chdir(os.path.dirname(os.path.realpath(__file__)))
6666
model = torch.jit.load("../../ts/models/hw_compat.ts").cuda()
6767
out = model(*inputs)
6868
self.assertTrue(
69-
isinstance(out, tuple)
70-
and len(out) == 1
71-
and isinstance(out[0], torch.Tensor),
69+
len(out) == 1 and isinstance(out, torch.Tensor),
7270
"Invalid output detected",
7371
)
7472
os.chdir(cwd)

tests/py/ts/models/hw_compat.ts

-270 KB
Binary file not shown.

0 commit comments

Comments
 (0)