Skip to content

Commit 3e41f59

Browse files
committed
scatter reduce decomposition
1 parent 39f8255 commit 3e41f59

File tree

3 files changed

+508
-0
lines changed

3 files changed

+508
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Union
22

33
import numpy as np
4+
import tensorrt as trt
45
import torch
56
import torch_tensorrt.dynamo.conversion.impl as impl
67
from torch.fx.node import Target
@@ -17,6 +18,7 @@
1718
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1819
convert_binary_elementwise,
1920
)
21+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
2022
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
2123
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
2224
from torch_tensorrt.fx.converters.converter_utils import broadcast
@@ -67,6 +69,11 @@ def trunc_div(
6769
prod_output,
6870
)
6971

72+
# cast the sign_output back to int32 for trunc div
73+
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
74+
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
75+
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)
76+
7077
# Convert constant input into ITensor for UnaryOperation
7178
if not isinstance(input, trt.tensorrt.ITensor):
7279
input = get_trt_tensor(ctx, input, f"{name}_input")

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum, auto
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -285,6 +286,99 @@ def scatter_add_decomposition(
285286
return scatter_add_tensor
286287

287288

289+
# enum class for reduce operation of scatter_reduce
290+
class ReduceOperation(Enum):
291+
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
292+
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
293+
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
294+
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
295+
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))
296+
297+
def __new__(cls, description, func):
298+
obj = object.__new__(cls)
299+
obj._value_ = auto()
300+
obj.description = description
301+
obj.func = func
302+
return obj
303+
304+
def reduce_operation_with_scatter(
305+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
306+
):
307+
scatter_tensor = None
308+
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
309+
scatter_tensor = torch.zeros_like(initial_tensor)
310+
elif self == ReduceOperation.PROD:
311+
scatter_tensor = torch.ones_like(initial_tensor)
312+
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
313+
scatter_tensor = initial_tensor
314+
else:
315+
# This case would not be encountered from torch itself
316+
print("Invalid Operation for Reduce op!!")
317+
318+
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
319+
device = to_torch_device(default_device())
320+
operation_lhs = operation_lhs.to(device)
321+
operation_rhs = operation_rhs.to(device)
322+
return self.func(operation_lhs, operation_rhs)
323+
324+
325+
@register_torch_trt_decomposition(
326+
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
327+
)
328+
def scatter_reduce_decomposition(
329+
input_tensor: torch.Tensor,
330+
dim: int,
331+
index: torch.Tensor,
332+
src_tensor: torch.Tensor,
333+
reduce: str,
334+
) -> torch.Tensor:
335+
scatter_loop_tensor = input_tensor
336+
# required for mean reduce operation
337+
scatter_count_tensor = torch.zeros_like(input_tensor)
338+
src_shape = list(src_tensor.shape)
339+
src_dim = src_shape[dim]
340+
341+
for i in range(0, src_dim):
342+
src_slice = torch.select(src_tensor, dim, i)
343+
index_slice = torch.select(index, dim, i)
344+
# unsqueeze src and index in dim
345+
src_slice = torch.unsqueeze(src_slice, dim)
346+
index_slice = torch.unsqueeze(index_slice, dim)
347+
device = to_torch_device(default_device())
348+
349+
# moving tensor to default device
350+
scatter_loop_tensor = scatter_loop_tensor.to(device)
351+
index_slice = index_slice.to(device)
352+
src_slice = src_slice.to(device)
353+
if reduce == "sum":
354+
reduceOp = ReduceOperation.SUM
355+
elif reduce == "prod":
356+
reduceOp = ReduceOperation.PROD
357+
elif reduce == "mean":
358+
reduceOp = ReduceOperation.MEAN
359+
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
360+
scatter_count_tensor,
361+
input_tensor,
362+
dim,
363+
index_slice,
364+
torch.ones_like(src_slice),
365+
)
366+
elif reduce == "amax":
367+
reduceOp = ReduceOperation.AMAX
368+
elif reduce == "amin":
369+
reduceOp = ReduceOperation.AMIN
370+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
371+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
372+
)
373+
if reduce == "mean":
374+
scatter_loop_tensor = torch.div(
375+
scatter_loop_tensor,
376+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
377+
rounding_mode="trunc",
378+
)
379+
return scatter_loop_tensor
380+
381+
288382
def get_decompositions(
289383
enable_experimental_decompositions: bool = False,
290384
) -> Dict[OpOverload, Callable[[Any], Any]]:

0 commit comments

Comments
 (0)