Skip to content

Commit 132f132

Browse files
committed
fix: Add generic evaluator function
1 parent b7d9d5a commit 132f132

File tree

8 files changed

+67
-86
lines changed

8 files changed

+67
-86
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch_tensorrt.fx.converters import acc_ops_converters
1515
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1616

17-
from .converter_registry import dynamo_tensorrt_converter
17+
from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter
1818

1919
_LOGGER: logging.Logger = logging.getLogger(__name__)
2020

@@ -421,22 +421,26 @@ def aten_ops_to_copy_dtype(
421421
)
422422

423423

424-
@dynamo_tensorrt_converter(operator.getitem)
425-
def operator_getitem(
424+
def getitem_validator(getitem_node: Node) -> bool:
425+
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS
426+
427+
# Getitem nodes can only be converted if their parent node also can
428+
return getitem_node.args[0] in DYNAMO_CONVERTERS
429+
430+
431+
# TODO: Subsequent evaluators should be registered here with their own validators
432+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
433+
def generic_evaluator(
426434
network: TRTNetwork,
427435
target: Target,
428436
args: Tuple[Argument, ...],
429437
kwargs: Dict[str, Argument],
430438
name: str,
431439
) -> Union[TRTTensor, Sequence[TRTTensor]]:
432-
return impl.evaluators.getitem(
433-
network,
434-
target,
435-
SourceIR.ATEN,
436-
name,
437-
args[0],
438-
args[1],
440+
_LOGGER.debug(
441+
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
439442
)
443+
return target(*args, **kwargs)
440444

441445

442446
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def dynamo_tensorrt_converter(
6464
enabled: bool = True,
6565
capability_validator: Optional[Callable[[Node], bool]] = None,
6666
priority: ConverterPriority = ConverterPriority.STANDARD,
67-
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
67+
) -> Callable[[Any], TRTTensor | Sequence[TRTTensor]]:
6868
"""Decorator for Dynamo TensorRT Converter
6969
7070
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
@@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]:
347347
"""Returns the set of unique converter targets stored across all registries"""
348348
return set.union(*[set(registry.keys()) for registry in self.registries])
349349

350-
# TODO: Make this a static method since it does not need state
351-
def qualified_name_or_str(self, target: Target) -> str:
350+
@staticmethod
351+
def qualified_name_or_str(target: Target) -> str:
352352
"""Returns string representation of an FX Node target"""
353353
if isinstance(target, str):
354354
return target

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
condition,
77
elementwise,
88
embedding,
9-
evaluators,
109
matmul,
1110
normalization,
1211
permutation,

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import logging
12
from typing import Optional
23

34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
67
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
78

9+
LOGGER: logging.Logger = logging.getLogger(__name__)
10+
811

912
def to_copy(
1013
network: TRTNetwork,
@@ -21,3 +24,20 @@ def to_copy(
2124

2225
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
2326
return casted_tensor
27+
28+
29+
def clone(
30+
network: TRTNetwork,
31+
target: Target,
32+
source_ir: Optional[SourceIR],
33+
name: str,
34+
input: TRTTensor,
35+
) -> TRTTensor:
36+
if not isinstance(input, TRTTensor):
37+
raise RuntimeError(
38+
f"clone received input {input} that is not a TensorRT ITensor"
39+
)
40+
41+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42+
43+
return input

py/torch_tensorrt/dynamo/conversion/impl/evaluators.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/py/dynamo/converters/test_casts.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,36 @@
55
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
66

77

8+
class TestCloneConverter(DispatchTestCase):
9+
def test_clone_contiguous(self):
10+
class Clone(nn.Module):
11+
def forward(self, x):
12+
y = torch.clone(x, memory_format=torch.contiguous_format)
13+
return y + 1
14+
15+
inputs = [torch.randn((1, 3, 10))]
16+
self.run_test(
17+
Clone(),
18+
inputs,
19+
expected_ops={torch.ops.aten.clone.default},
20+
disable_passes=True,
21+
)
22+
23+
def test_clone_regular(self):
24+
class Clone(nn.Module):
25+
def forward(self, x):
26+
y = torch.clone(x)
27+
return y + 1
28+
29+
inputs = [torch.randn((8, 2, 10))]
30+
self.run_test(
31+
Clone(),
32+
inputs,
33+
expected_ops={torch.ops.aten.clone.default},
34+
disable_passes=True,
35+
)
36+
37+
838
class TestToCopyConverter(DispatchTestCase):
939
def test_to_copy_half(self):
1040
class ToCopyHalf(nn.Module):

tests/py/dynamo/converters/test_evaluators.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,6 @@
77
from torch.testing._internal.common_utils import run_tests
88

99

10-
class TestCloneConverter(DispatchTestCase):
11-
def test_clone_contiguous(self):
12-
class Clone(nn.Module):
13-
def forward(self, x):
14-
y = torch.clone(x, memory_format=torch.contiguous_format)
15-
return y + 1
16-
17-
inputs = [torch.randn((1, 3, 10))]
18-
self.run_test(
19-
Clone(),
20-
inputs,
21-
expected_ops={torch.ops.aten.clone.default},
22-
disable_passes=True,
23-
)
24-
25-
def test_clone_regular(self):
26-
class Clone(nn.Module):
27-
def forward(self, x):
28-
y = torch.clone(x)
29-
return y + 1
30-
31-
inputs = [torch.randn((8, 2, 10))]
32-
self.run_test(
33-
Clone(),
34-
inputs,
35-
expected_ops={torch.ops.aten.clone.default},
36-
disable_passes=True,
37-
)
38-
39-
4010
# TODO: Switch this test back to self.run_test once an implementation exists
4111
# for a converter that returns a list, such as aten.split
4212
@unittest.skip("Pending aten.split converter. Currently tested by E2E")

tests/py/dynamo/models/test_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def test_resnet18(ir):
2727
"ir": ir,
2828
"pass_through_build_failures": True,
2929
"optimization_level": 1,
30-
"min_block_size": 10,
3130
"ir": "torch_compile",
3231
}
3332

@@ -176,7 +175,6 @@ def test_resnet18_half(ir):
176175
"ir": ir,
177176
"pass_through_build_failures": True,
178177
"optimization_level": 1,
179-
"min_block_size": 10,
180178
"ir": "torch_compile",
181179
}
182180

0 commit comments

Comments
 (0)