Skip to content

Commit bcf7641

Browse files
authored
fix: Register tensorrt backend name (#2311)
1 parent 33c0673 commit bcf7641

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34+
@td.register_backend(name="tensorrt") # type: ignore[misc]
3435
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
3536
def torch_tensorrt_backend(
3637
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any

tests/py/dynamo/backend/test_backend_compiler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,5 +312,13 @@ def forward(self, x, y):
312312
)
313313

314314

315+
class TestRegistration(TestCase):
316+
def test_torch_tensorrt_registration(self):
317+
self.assertIn("torch_tensorrt", torch._dynamo.list_backends())
318+
319+
def test_tensorrt_registration(self):
320+
self.assertIn("tensorrt", torch._dynamo.list_backends())
321+
322+
315323
if __name__ == "__main__":
316324
run_tests()

0 commit comments

Comments
 (0)