Skip to content

Commit 4283f43

Browse files
committed
fix: Add generic evaluator function
1 parent b7d9d5a commit 4283f43

File tree

10 files changed

+86
-95
lines changed

10 files changed

+86
-95
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._TRTInterpreter import * # noqa: F403
22
from .aten_ops_converters import * # noqa: F403
33
from .conversion import * # noqa: F403
4+
from .op_evaluators import * # noqa: F403
45
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import operator
32
from typing import Any, Dict, Optional, Sequence, Tuple, Union
43

54
import tensorrt as trt
@@ -421,24 +420,6 @@ def aten_ops_to_copy_dtype(
421420
)
422421

423422

424-
@dynamo_tensorrt_converter(operator.getitem)
425-
def operator_getitem(
426-
network: TRTNetwork,
427-
target: Target,
428-
args: Tuple[Argument, ...],
429-
kwargs: Dict[str, Argument],
430-
name: str,
431-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
432-
return impl.evaluators.getitem(
433-
network,
434-
target,
435-
SourceIR.ATEN,
436-
name,
437-
args[0],
438-
args[1],
439-
)
440-
441-
442423
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
443424
def aten_ops_clone(
444425
network: TRTNetwork,
@@ -447,7 +428,7 @@ def aten_ops_clone(
447428
kwargs: Dict[str, Argument],
448429
name: str,
449430
) -> Union[TRTTensor, Sequence[TRTTensor]]:
450-
return impl.evaluators.clone(
431+
return impl.cast.clone(
451432
network,
452433
target,
453434
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]:
347347
"""Returns the set of unique converter targets stored across all registries"""
348348
return set.union(*[set(registry.keys()) for registry in self.registries])
349349

350-
# TODO: Make this a static method since it does not need state
351-
def qualified_name_or_str(self, target: Target) -> str:
350+
@staticmethod
351+
def qualified_name_or_str(target: Target) -> str:
352352
"""Returns string representation of an FX Node target"""
353353
if isinstance(target, str):
354354
return target

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
condition,
77
elementwise,
88
embedding,
9-
evaluators,
109
matmul,
1110
normalization,
1211
permutation,

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import logging
12
from typing import Optional
23

34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
67
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
78

9+
LOGGER: logging.Logger = logging.getLogger(__name__)
10+
811

912
def to_copy(
1013
network: TRTNetwork,
@@ -21,3 +24,20 @@ def to_copy(
2124

2225
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
2326
return casted_tensor
27+
28+
29+
def clone(
30+
network: TRTNetwork,
31+
target: Target,
32+
source_ir: Optional[SourceIR],
33+
name: str,
34+
input: TRTTensor,
35+
) -> TRTTensor:
36+
if not isinstance(input, TRTTensor):
37+
raise RuntimeError(
38+
f"clone received input {input} that is not a TensorRT ITensor"
39+
)
40+
41+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42+
43+
return input

py/torch_tensorrt/dynamo/conversion/impl/evaluators.py

Lines changed: 0 additions & 40 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
import operator
3+
from typing import Dict, Sequence, Tuple, Union
4+
5+
from torch.fx.node import Argument, Node, Target
6+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
7+
8+
from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter
9+
10+
_LOGGER: logging.Logger = logging.getLogger(__name__)
11+
12+
13+
def getitem_validator(getitem_node: Node) -> bool:
14+
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS
15+
16+
# Getitem nodes can only be converted if their parent node also can
17+
return getitem_node.args[0] in DYNAMO_CONVERTERS
18+
19+
20+
# TODO: Subsequent evaluators should be registered here with their own validators
21+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
def generic_evaluator(
23+
network: TRTNetwork,
24+
target: Target,
25+
args: Tuple[Argument, ...],
26+
kwargs: Dict[str, Argument],
27+
name: str,
28+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29+
_LOGGER.debug(
30+
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
31+
)
32+
return target(*args, **kwargs)

tests/py/dynamo/converters/test_casts.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,36 @@
55
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
66

77

8+
class TestCloneConverter(DispatchTestCase):
9+
def test_clone_contiguous(self):
10+
class Clone(nn.Module):
11+
def forward(self, x):
12+
y = torch.clone(x, memory_format=torch.contiguous_format)
13+
return y + 1
14+
15+
inputs = [torch.randn((1, 3, 10))]
16+
self.run_test(
17+
Clone(),
18+
inputs,
19+
expected_ops={torch.ops.aten.clone.default},
20+
disable_passes=True,
21+
)
22+
23+
def test_clone_regular(self):
24+
class Clone(nn.Module):
25+
def forward(self, x):
26+
y = torch.clone(x)
27+
return y + 1
28+
29+
inputs = [torch.randn((8, 2, 10))]
30+
self.run_test(
31+
Clone(),
32+
inputs,
33+
expected_ops={torch.ops.aten.clone.default},
34+
disable_passes=True,
35+
)
36+
37+
838
class TestToCopyConverter(DispatchTestCase):
939
def test_to_copy_half(self):
1040
class ToCopyHalf(nn.Module):

tests/py/dynamo/converters/test_evaluators.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,6 @@
77
from torch.testing._internal.common_utils import run_tests
88

99

10-
class TestCloneConverter(DispatchTestCase):
11-
def test_clone_contiguous(self):
12-
class Clone(nn.Module):
13-
def forward(self, x):
14-
y = torch.clone(x, memory_format=torch.contiguous_format)
15-
return y + 1
16-
17-
inputs = [torch.randn((1, 3, 10))]
18-
self.run_test(
19-
Clone(),
20-
inputs,
21-
expected_ops={torch.ops.aten.clone.default},
22-
disable_passes=True,
23-
)
24-
25-
def test_clone_regular(self):
26-
class Clone(nn.Module):
27-
def forward(self, x):
28-
y = torch.clone(x)
29-
return y + 1
30-
31-
inputs = [torch.randn((8, 2, 10))]
32-
self.run_test(
33-
Clone(),
34-
inputs,
35-
expected_ops={torch.ops.aten.clone.default},
36-
disable_passes=True,
37-
)
38-
39-
4010
# TODO: Switch this test back to self.run_test once an implementation exists
4111
# for a converter that returns a list, such as aten.split
4212
@unittest.skip("Pending aten.split converter. Currently tested by E2E")

tests/py/dynamo/models/test_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def test_resnet18(ir):
2727
"ir": ir,
2828
"pass_through_build_failures": True,
2929
"optimization_level": 1,
30-
"min_block_size": 10,
3130
"ir": "torch_compile",
3231
}
3332

@@ -176,7 +175,6 @@ def test_resnet18_half(ir):
176175
"ir": ir,
177176
"pass_through_build_failures": True,
178177
"optimization_level": 1,
179-
"min_block_size": 10,
180178
"ir": "torch_compile",
181179
}
182180

0 commit comments

Comments
 (0)