Skip to content

Commit 5b60d27

Browse files
committed
delete the old file
1 parent de86e07 commit 5b60d27

File tree

2 files changed

+48
-82
lines changed

2 files changed

+48
-82
lines changed

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import torch_tensorrt as torchtrt
1111
import torchvision.models as models
1212
from torch import nn
13+
from torch_tensorrt.dynamo._compiler import convert_module_to_trt_engine
1314
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
14-
from transformers import BertModel
15-
from transformers.utils.fx import symbolic_trace as transformers_trace
1615

1716
assertions = unittest.TestCase()
1817

@@ -60,8 +59,6 @@ def forward(self, x, b=5, c=None, d=None):
6059
"min_block_size": 1,
6160
"ir": "dynamo",
6261
}
63-
# TODO: Support torchtrt.compile
64-
# trt_mod = torchtrt.compile(model, **compile_spec)
6562

6663
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
6764
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
@@ -76,3 +73,50 @@ def forward(self, x, b=5, c=None, d=None):
7673
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
7774
# Clean up model env
7875
torch._dynamo.reset()
76+
77+
78+
def test_custom_model_compile_engine():
79+
class net(nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
83+
self.bn = nn.BatchNorm2d(12)
84+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
85+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
86+
87+
def forward(self, x, b=5, c=None, d=None):
88+
x = self.conv1(x)
89+
x = F.relu(x)
90+
x = self.bn(x)
91+
x = F.max_pool2d(x, (2, 2))
92+
x = self.conv2(x)
93+
x = F.relu(x)
94+
x = F.max_pool2d(x, (2, 2))
95+
x = torch.flatten(x, 1)
96+
x = x + b
97+
if c is not None:
98+
x = x * c
99+
if d is not None:
100+
x = x - d["value"]
101+
return self.fc1(x)
102+
103+
model = net().eval().to("cuda")
104+
args = [torch.rand((1, 3, 224, 224)).to("cuda")]
105+
kwargs = {
106+
"b": torch.tensor(6).to("cuda"),
107+
"d": {"value": torch.tensor(8).to("cuda")},
108+
}
109+
110+
compile_spec = {
111+
"inputs": args,
112+
"kwarg_inputs": kwargs,
113+
"device": torchtrt.Device("cuda:0"),
114+
"enabled_precisions": {torch.float},
115+
"pass_through_build_failures": True,
116+
"optimization_level": 1,
117+
"min_block_size": 1,
118+
"ir": "dynamo",
119+
}
120+
121+
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
122+
engine = convert_module_to_trt_engine(exp_program, **compile_spec)

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)