Skip to content

Commit ef97199

Browse files
committed
scatter reduce decomposition
1 parent 0d4af77 commit ef97199

File tree

2 files changed

+479
-0
lines changed

2 files changed

+479
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum, auto
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -243,6 +244,84 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
243244
)
244245

245246

247+
# enum class for reduce operation of scatter_reduce
248+
class reduceOperation(Enum):
249+
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
250+
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
251+
MEAN = ("Mean reduce operation", lambda x, y: torch.div(torch.add(x, y), 2))
252+
AMAX = ("Amax reduce operation", lambda x, y: torch.amax(x, y))
253+
AMIN = ("Amin reduce operation", lambda x, y: torch.amin(x, y))
254+
255+
def __new__(cls, description, func):
256+
obj = object.__new__(cls)
257+
obj._value_ = auto() # Assign a unique value based on the number of members
258+
obj.description = description
259+
obj.func = func
260+
return obj
261+
262+
def reduce_operation_with_scatter(
263+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
264+
):
265+
scatter_tensor = None
266+
if self == reduceOperation.SUM or self == reduceOperation.MEAN:
267+
scatter_tensor = torch.zeros_like(initial_tensor)
268+
elif self == reduceOperation.PROD:
269+
scatter_tensor = torch.ones_like(initial_tensor)
270+
elif self.name == reduceOperation.AMIN or self.name == reduceOperation.AMAX:
271+
scatter_tensor = initial_tensor
272+
else:
273+
# This case would not be encountered from torch itself
274+
print("Invalid Operation for Reduce op!!")
275+
276+
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
277+
device = to_torch_device(default_device())
278+
operation_lhs = operation_lhs.to(device)
279+
operation_rhs = operation_rhs.to(device)
280+
return self.func(operation_lhs, operation_rhs)
281+
282+
283+
@register_torch_trt_decomposition(
284+
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
285+
)
286+
def scatter_reduce_decomposition(
287+
input_tensor: torch.Tensor,
288+
dim: int,
289+
index: torch.Tensor,
290+
src_tensor: torch.Tensor,
291+
reduce: str,
292+
) -> torch.Tensor:
293+
scatter_loop_tensor = input_tensor
294+
src_shape = list(src_tensor.shape)
295+
src_dim = src_shape[dim]
296+
297+
for i in range(0, src_dim):
298+
src_slice = torch.select(src_tensor, dim, i)
299+
index_slice = torch.select(index, dim, i)
300+
# unsqueeze src and index in dim
301+
src_slice = torch.unsqueeze(src_slice, dim)
302+
index_slice = torch.unsqueeze(index_slice, dim)
303+
device = to_torch_device(default_device())
304+
305+
# moving tensor to default device
306+
scatter_loop_tensor = scatter_loop_tensor.to(device)
307+
index_slice = index_slice.to(device)
308+
src_slice = src_slice.to(device)
309+
if reduce == "sum":
310+
reduceOp = reduceOperation.SUM
311+
elif reduce == "prod":
312+
reduceOp = reduceOperation.PROD
313+
elif reduce == "mean":
314+
reduceOp = reduceOperation.MEAN
315+
elif reduce == "amax":
316+
reduceOp = reduceOperation.AMAX
317+
elif reduce == "amin":
318+
reduceOp = reduceOperation.AMIN
319+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
320+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
321+
)
322+
return scatter_loop_tensor
323+
324+
246325
def get_decompositions(
247326
enable_experimental_decompositions: bool = False,
248327
) -> Dict[OpOverload, Callable[[Any], Any]]:

0 commit comments

Comments
 (0)