Skip to content

Commit 0676981

Browse files
committed
Converter reorg fmod
1 parent 9611d67 commit 0676981

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2929
from torch_tensorrt.fx.converters.impl import activation
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
31+
from torch_tensorrt.fx.converters.impl.elementwise import fmod
3132
from torch_tensorrt.fx.converters.impl.unary import sign
3233
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3334
convert_binary_elementwise,
@@ -2098,35 +2099,15 @@ def acc_ops_fmod(
20982099
kwargs: Dict[str, Argument],
20992100
name: str,
21002101
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2101-
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
2102-
trunc_div_value = trunc_div(
2102+
return fmod(
21032103
network,
21042104
target,
21052105
SourceIR.ACC,
2106-
name + "_trunc_div",
2106+
name,
21072107
kwargs["input"],
21082108
kwargs["other"],
21092109
)
2110-
prod_value = convert_binary_elementwise(
2111-
network,
2112-
target,
2113-
SourceIR.ACC,
2114-
name + "_prod",
2115-
trt.ElementWiseOperation.PROD,
2116-
trunc_div_value,
2117-
kwargs["other"],
2118-
)
2119-
sub_value = convert_binary_elementwise(
2120-
network,
2121-
target,
2122-
SourceIR.ACC,
2123-
name + "_sub",
2124-
trt.ElementWiseOperation.SUB,
2125-
kwargs["input"],
2126-
prod_value,
2127-
)
2128-
return sub_value
2129-
2110+
21302111

21312112
# T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64.
21322113
# if we cast to int32, it will create accuracy issues. We'd better leave it to future implementation.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2525
from torch_tensorrt.fx.converters.impl import activation
2626
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
27+
from torch_tensorrt.fx.converters.impl.elementwise import fmod
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -193,11 +194,7 @@ def aten_ops_fmod(
193194
kwargs: Dict[str, Argument],
194195
name: str,
195196
) -> Union[TRTTensor, Sequence[TRTTensor]]:
196-
kwargs_new = {
197-
"input": args[0],
198-
"other": args[1],
199-
}
200-
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
197+
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
201198

202199

203200
@tensorrt_converter(torch.ops.aten.linear)

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,41 @@ def trunc_div(
109109
)
110110

111111
return output
112+
113+
114+
def fmod(
115+
network: TRTNetwork,
116+
target: Target,
117+
source_ir: Optional[SourceIR],
118+
name: str,
119+
input: TRTTensor,
120+
other: TRTTensor,
121+
) -> TRTTensor:
122+
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
123+
trunc_div_value = trunc_div(
124+
network,
125+
target,
126+
source_ir,
127+
name + "_trunc_div",
128+
input,
129+
other,
130+
)
131+
prod_value = convert_binary_elementwise(
132+
network,
133+
target,
134+
source_ir,
135+
name + "_prod",
136+
trt.ElementWiseOperation.PROD,
137+
trunc_div_value,
138+
other,
139+
)
140+
sub_value = convert_binary_elementwise(
141+
network,
142+
target,
143+
SourceIR.ACC,
144+
name + "_sub",
145+
trt.ElementWiseOperation.SUB,
146+
input,
147+
prod_value,
148+
)
149+
return sub_value

0 commit comments

Comments
 (0)