Skip to content

scatter reduce decomposition #3008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
Expand All @@ -17,6 +18,7 @@
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import broadcast
Expand Down Expand Up @@ -67,6 +69,11 @@ def trunc_div(
prod_output,
)

# cast the sign_output back to int32 for trunc div
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)

# Convert constant input into ITensor for UnaryOperation
if not isinstance(input, trt.tensorrt.ITensor):
input = get_trt_tensor(ctx, input, f"{name}_input")
Expand Down
96 changes: 96 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -287,6 +288,101 @@ def scatter_add_decomposition(
return scatter_add_tensor


# enum class for reduce operation of scatter_reduce
class ReduceOperation(Enum):
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))

def __new__(cls, description, func):
obj = object.__new__(cls)
obj._value_ = auto()
obj.description = description
obj.func = func
return obj

def reduce_operation_with_scatter(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
):
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == ReduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
scatter_tensor = initial_tensor
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(scatter_tensor.device)
operation_lhs = operation_lhs.to(device)
operation_rhs = operation_rhs.to(device)
return self.func(operation_lhs, operation_rhs)


@register_torch_trt_decomposition(
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
)
def scatter_reduce_decomposition(
input_tensor: torch.Tensor,
dim: int,
index: torch.Tensor,
src_tensor: torch.Tensor,
reduce: str,
include_self: bool = True,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a kwarg include_self in https://github.com/pytorch/pytorch/blob/bc1b8f094d24de27432f4c29f0729e85a6b5ba63/aten/src/ATen/native/native_functions.yaml#L8237. Is it intentionally not handled in our decomposition?

Copy link
Collaborator Author

@apbose apbose Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Most of the cases which I have seen is with include_self = True. Here we have the implementation with the default case. No particular reason, I could add cases with include_self = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add include_self=True in the function arguments. And raise an error saying we don't support the case when user sets it False

scatter_loop_tensor = input_tensor
device_input_tensor = input_tensor.device
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
if include_self == False:
raise AssertionError("include_self False for scatter reduce not yet supported")
for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
# unsqueeze src and index in dim
src_slice = torch.unsqueeze(src_slice, dim)
index_slice = torch.unsqueeze(index_slice, dim)

# moving tensor to default device
scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor)
index_slice = index_slice.to(device_input_tensor)
src_slice = src_slice.to(device_input_tensor)
if reduce == "sum":
reduceOp = ReduceOperation.SUM
elif reduce == "prod":
reduceOp = ReduceOperation.PROD
elif reduce == "mean":
reduceOp = ReduceOperation.MEAN
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
scatter_count_tensor,
input_tensor,
dim,
index_slice,
torch.ones_like(src_slice),
)
elif reduce == "amax":
reduceOp = ReduceOperation.AMAX
elif reduce == "amin":
reduceOp = ReduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
rounding_mode="trunc",
)
return scatter_loop_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
12 changes: 8 additions & 4 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

import tensorrt as trt
from packaging import version

from .types import TRTDataType
Expand Down Expand Up @@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
device = None
for parameter in list(module.parameters()):
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
device = parameter.device
break
return parameter.device

for buffer in list(module.buffers()):
Comment on lines +190 to +192
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer device overrides the parameter device here which shouldn't be the case. Check device of parameters first, if not found, use buffers.
Also consider adding break once the device is found.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm

if isinstance(buffer, (torch.Tensor)):
return buffer.device

if device is None:
device = torch.device("cpu")
device = to_torch_device(default_device())
logger.warning(
"Could not detect the device on which the model exists. Assuming the model is on CPU"
)
Expand Down
Loading
Loading