Skip to content

Commit fbc3a7e

Browse files
committed
fix: Add generic evaluator function
1 parent 26d1051 commit fbc3a7e

File tree

10 files changed

+89
-100
lines changed

10 files changed

+89
-100
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._TRTInterpreter import * # noqa: F403
22
from .aten_ops_converters import * # noqa: F403
33
from .conversion import * # noqa: F403
4+
from .op_evaluators import * # noqa: F403
45
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import operator
32
from typing import Any, Dict, Optional, Sequence, Tuple, Union
43

54
import tensorrt as trt
@@ -406,24 +405,6 @@ def aten_ops_to_copy_dtype(
406405
)
407406

408407

409-
@dynamo_tensorrt_converter(operator.getitem)
410-
def operator_getitem(
411-
network: TRTNetwork,
412-
target: Target,
413-
args: Tuple[Argument, ...],
414-
kwargs: Dict[str, Argument],
415-
name: str,
416-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
417-
return impl.evaluators.getitem(
418-
network,
419-
target,
420-
SourceIR.ATEN,
421-
name,
422-
args[0],
423-
args[1],
424-
)
425-
426-
427408
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
428409
def aten_ops_clone(
429410
network: TRTNetwork,
@@ -432,7 +413,7 @@ def aten_ops_clone(
432413
kwargs: Dict[str, Argument],
433414
name: str,
434415
) -> Union[TRTTensor, Sequence[TRTTensor]]:
435-
return impl.evaluators.clone(
416+
return impl.cast.clone(
436417
network,
437418
target,
438419
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import tensorrt as trt
66
import torch
7-
from torch.fx.node import Target, _get_qualified_name
7+
from torch.fx.node import Target
88
from torch_tensorrt.fx.converters.converter_utils import (
99
Frameworks,
1010
unified_dtype_converter,
1111
)
1212
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
1313

1414
from .._SourceIR import SourceIR
15+
from .converter_registry import ConverterRegistry
1516

1617
_LOGGER: logging.Logger = logging.getLogger(__name__)
1718

@@ -94,15 +95,12 @@ def cast_trt_tensor(
9495

9596
if input_val.dtype != trt_dtype:
9697
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
97-
target_name = (
98-
f"{source_ir}_ops{'.' + target if target else ''}"
99-
if (isinstance(target, str))
100-
else f"{source_ir}_ops.{_get_qualified_name(target)}"
101-
)
98+
target_str = ConverterRegistry.qualified_name_or_str(target)
99+
target_name = f"{source_ir}_ops{'.' + target_str if target_str else ''}"
102100

103101
identity_layer = network.add_identity(input_val)
104102
identity_layer.set_output_type(0, trt_dtype)
105-
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
103+
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
106104
return identity_layer.get_output(0)
107105
else:
108106
return input_val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
condition,
77
elementwise,
88
embedding,
9-
evaluators,
109
matmul,
1110
normalization,
1211
permutation,

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import logging
12
from typing import Optional
23

34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
67
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
78

9+
LOGGER: logging.Logger = logging.getLogger(__name__)
10+
811

912
def to_copy(
1013
network: TRTNetwork,
@@ -21,3 +24,20 @@ def to_copy(
2124

2225
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
2326
return casted_tensor
27+
28+
29+
def clone(
30+
network: TRTNetwork,
31+
target: Target,
32+
source_ir: Optional[SourceIR],
33+
name: str,
34+
input: TRTTensor,
35+
) -> TRTTensor:
36+
if not isinstance(input, TRTTensor):
37+
raise RuntimeError(
38+
f"clone received input {input} that is not a TensorRT ITensor"
39+
)
40+
41+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42+
43+
return input

py/torch_tensorrt/dynamo/conversion/impl/evaluators.py

Lines changed: 0 additions & 40 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
import operator
3+
from typing import Dict, Sequence, Tuple, Union
4+
5+
from torch.fx.node import Argument, Node, Target
6+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
7+
8+
from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter
9+
10+
_LOGGER: logging.Logger = logging.getLogger(__name__)
11+
12+
13+
def getitem_validator(getitem_node: Node) -> bool:
14+
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS
15+
16+
# Getitem nodes can only be converted if their parent node also can
17+
return getitem_node.args[0] in DYNAMO_CONVERTERS
18+
19+
20+
# TODO: Subsequent evaluators should be registered here with their own validators
21+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
def generic_evaluator(
23+
network: TRTNetwork,
24+
target: Target,
25+
args: Tuple[Argument, ...],
26+
kwargs: Dict[str, Argument],
27+
name: str,
28+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29+
_LOGGER.debug(
30+
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
31+
)
32+
return target(*args)

tests/py/dynamo/converters/test_casts.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,36 @@
55
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
66

77

8+
class TestCloneConverter(DispatchTestCase):
9+
def test_clone_contiguous(self):
10+
class Clone(nn.Module):
11+
def forward(self, x):
12+
y = torch.clone(x, memory_format=torch.contiguous_format)
13+
return y + 1
14+
15+
inputs = [torch.randn((1, 3, 10))]
16+
self.run_test(
17+
Clone(),
18+
inputs,
19+
expected_ops={torch.ops.aten.clone.default},
20+
disable_passes=True,
21+
)
22+
23+
def test_clone_regular(self):
24+
class Clone(nn.Module):
25+
def forward(self, x):
26+
y = torch.clone(x)
27+
return y + 1
28+
29+
inputs = [torch.randn((8, 2, 10))]
30+
self.run_test(
31+
Clone(),
32+
inputs,
33+
expected_ops={torch.ops.aten.clone.default},
34+
disable_passes=True,
35+
)
36+
37+
838
class TestToCopyConverter(DispatchTestCase):
939
def test_to_copy_half(self):
1040
class ToCopyHalf(nn.Module):

tests/py/dynamo/converters/test_evaluators.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,6 @@
77
from torch.testing._internal.common_utils import run_tests
88

99

10-
class TestCloneConverter(DispatchTestCase):
11-
def test_clone_contiguous(self):
12-
class Clone(nn.Module):
13-
def forward(self, x):
14-
y = torch.clone(x, memory_format=torch.contiguous_format)
15-
return y + 1
16-
17-
inputs = [torch.randn((1, 3, 10))]
18-
self.run_test(
19-
Clone(),
20-
inputs,
21-
expected_ops={torch.ops.aten.clone.default},
22-
disable_passes=True,
23-
)
24-
25-
def test_clone_regular(self):
26-
class Clone(nn.Module):
27-
def forward(self, x):
28-
y = torch.clone(x)
29-
return y + 1
30-
31-
inputs = [torch.randn((8, 2, 10))]
32-
self.run_test(
33-
Clone(),
34-
inputs,
35-
expected_ops={torch.ops.aten.clone.default},
36-
disable_passes=True,
37-
)
38-
39-
4010
# TODO: Switch this test back to self.run_test once an implementation exists
4111
# for a converter that returns a list, such as aten.split
4212
@unittest.skip("Pending aten.split converter. Currently tested by E2E")

tests/py/dynamo/models/test_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def test_resnet18(ir):
2727
"ir": ir,
2828
"pass_through_build_failures": True,
2929
"optimization_level": 1,
30-
"min_block_size": 10,
3130
"ir": "torch_compile",
3231
}
3332

@@ -176,7 +175,6 @@ def test_resnet18_half(ir):
176175
"ir": ir,
177176
"pass_through_build_failures": True,
178177
"optimization_level": 1,
179-
"min_block_size": 10,
180178
"ir": "torch_compile",
181179
}
182180

0 commit comments

Comments
 (0)