Skip to content

Commit ce3fa67

Browse files
apbosegs-olive
authored andcommitted
Converter reorg fmod
1 parent 7158ca5 commit ce3fa67

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2929
from torch_tensorrt.fx.converters.impl import activation
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
31+
from torch_tensorrt.fx.converters.impl.elementwise import fmod
3132
from torch_tensorrt.fx.converters.impl.unary import sign
3233
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3334
convert_binary_elementwise,
@@ -2091,34 +2092,14 @@ def acc_ops_fmod(
20912092
kwargs: Dict[str, Argument],
20922093
name: str,
20932094
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2094-
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
2095-
trunc_div_value = trunc_div(
2095+
return fmod(
20962096
network,
20972097
target,
20982098
SourceIR.ACC,
2099-
name + "_trunc_div",
2099+
name,
21002100
kwargs["input"],
21012101
kwargs["other"],
21022102
)
2103-
prod_value = convert_binary_elementwise(
2104-
network,
2105-
target,
2106-
SourceIR.ACC,
2107-
name + "_prod",
2108-
trt.ElementWiseOperation.PROD,
2109-
trunc_div_value,
2110-
kwargs["other"],
2111-
)
2112-
sub_value = convert_binary_elementwise(
2113-
network,
2114-
target,
2115-
SourceIR.ACC,
2116-
name + "_sub",
2117-
trt.ElementWiseOperation.SUB,
2118-
kwargs["input"],
2119-
prod_value,
2120-
)
2121-
return sub_value
21222103

21232104

21242105
# T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch_tensorrt.fx.converters.impl import activation
2424
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2525
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
26+
from torch_tensorrt.fx.converters.impl.elementwise import fmod
2627

2728
_LOGGER: logging.Logger = logging.getLogger(__name__)
2829

@@ -219,11 +220,7 @@ def aten_ops_fmod(
219220
kwargs: Dict[str, Argument],
220221
name: str,
221222
) -> Union[TRTTensor, Sequence[TRTTensor]]:
222-
kwargs_new = {
223-
"input": args[0],
224-
"other": args[1],
225-
}
226-
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
223+
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
227224

228225

229226
@tensorrt_converter(torch.ops.aten.hardtanh.default)

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,41 @@ def rsqrt(
139139
)
140140

141141
return output
142+
143+
144+
def fmod(
145+
network: TRTNetwork,
146+
target: Target,
147+
source_ir: Optional[SourceIR],
148+
name: str,
149+
input: TRTTensor,
150+
other: TRTTensor,
151+
) -> TRTTensor:
152+
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
153+
trunc_div_value = trunc_div(
154+
network,
155+
target,
156+
source_ir,
157+
name + "_trunc_div",
158+
input,
159+
other,
160+
)
161+
prod_value = convert_binary_elementwise(
162+
network,
163+
target,
164+
source_ir,
165+
name + "_prod",
166+
trt.ElementWiseOperation.PROD,
167+
trunc_div_value,
168+
other,
169+
)
170+
sub_value = convert_binary_elementwise(
171+
network,
172+
target,
173+
SourceIR.ACC,
174+
name + "_sub",
175+
trt.ElementWiseOperation.SUB,
176+
input,
177+
prod_value,
178+
)
179+
return sub_value

0 commit comments

Comments
 (0)