Skip to content

Commit fb4edbb

Browse files
committed
Added kwarg support for conver_module_to_engine
1 parent 0ca7d42 commit fb4edbb

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import os
3+
import tempfile
24
import unittest
35

46
import pytest
@@ -62,12 +64,15 @@ def forward(self, x, b=5, c=None, d=None):
6264
# trt_mod = torchtrt.compile(model, **compile_spec)
6365

6466
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
65-
trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)
66-
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
67+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
68+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
6769
assertions.assertTrue(
6870
cos_sim > COSINE_THRESHOLD,
6971
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
7072
)
7173

74+
# Save the module
75+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
76+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
7277
# Clean up model env
7378
torch._dynamo.reset()

0 commit comments

Comments
 (0)