Skip to content

add dynamic shape support for scaled_dot_product_attention, logical_or/xor #2975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 13, 2024
Merged
25 changes: 23 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def aten_ops_embedding_bag(
)


@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor, supports_dynamic_shapes=True)
def aten_ops_fmod(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -3435,3 +3435,24 @@ def aten_ops_prelu(
args[0],
args[1],
)


@dynamo_tensorrt_converter(
torch.ops.aten.arange.start_step, supports_dynamic_shapes=True
)
def aten_ops_arange_start_step(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.arange.arange(
ctx,
target,
SourceIR.ATEN,
name,
start=args[0],
end=args[1],
step=args_bounds_check(args, 2, 1),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import (
activation,
addmm,
arange,
attention,
cast,
cat,
Expand Down
58 changes: 58 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/arange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def arange(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
start: Union[int, TRTTensor],
end: Union[int, TRTTensor],
step: Union[int, TRTTensor],
) -> TRTTensor:

if any(isinstance(tensor, TRTTensor) for tensor in (start, end, step)):
start_rank_0 = get_trt_tensor(ctx, start, name + "_start_rank_0", min_rank=0)
start_rank_1 = get_trt_tensor(ctx, start, name + "_start_rank_1", min_rank=1)
end = get_trt_tensor(ctx, end, name + "_end", min_rank=1)
step = get_trt_tensor(ctx, step, name + "_step", min_rank=1)
# Calculate shape = (end-start) / step
shape = impl.elementwise.sub(
ctx,
target,
source_ir,
name + "_sub",
end,
start_rank_1,
)
shape = impl.elementwise.trunc_div(
ctx,
target,
source_ir,
name + "_shape",
shape,
step,
)
shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted")
fill_layer = ctx.net.add_fill(
shape.shape, trt.FillOperation.LINSPACE, shape.dtype
)
fill_layer.set_input(0, shape)
# Set start index
fill_layer.set_input(1, start_rank_0)
# Set delta/step
fill_layer.set_input(2, step)
return fill_layer.get_output(0)
return np.arange(start, end, step)
104 changes: 96 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,44 @@
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def tril(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
# the lower triangle of the tensor means the rows greater than and equal to the cols
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
)
# get the rows
row_tensor = impl.elementwise.trunc_div(
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
)
# get the cols
col_tensor = impl.elementwise.fmod(
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
)
cond = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
)
return impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", cond, [row, col]
)


def scaled_dot_product_attention(
ctx: ConversionContext,
target: Union[Target, str],
Expand All @@ -22,8 +56,7 @@ def scaled_dot_product_attention(
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
L, S = query.shape[-2], key.shape[-2]

# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
mm = impl.matmul.matrix_multiply(
ctx,
target,
Expand All @@ -34,13 +67,21 @@ def scaled_dot_product_attention(
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
if scale is None:
scale = query.shape[-1]
if scale < 0:
# dynamic shape
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
else:
# static shape
sqrt_scaled = math.sqrt(scale)
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
sqrt_scaled,
)
else:
scaled = impl.elementwise.mul(
Expand All @@ -53,10 +94,57 @@ def scaled_dot_product_attention(
)

if is_causal:
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
L, S = query.shape[-2], key.shape[-2]
if L >= 0 and S >= 0:
# static shape
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
else:
# if any of the L or S is dynamic shape
if L < 0:
L = impl.shape.shape(
ctx, target, source_ir, name + "_shape_0", query, -2
)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)

LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)

# this is to generate a tensor which has shape (L, S), type is int32
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
)
shape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
)

# since we want our attn_bias to be in float32, so cast it to float32
shape_tensor = cast_trt_tensor(
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
)

# initialize the attn_bias as the zeros tensor
attn_bias = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)
inf_tensor = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
)
cond = impl.elementwise.eq(
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
)
# mask out the certain part of the attn_bias
attn_bias = impl.condition.select(
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
)

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
Expand Down
56 changes: 2 additions & 54 deletions py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
# mypy: disallow-untyped-decorators=False

import logging
import operator
from typing import Dict, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
ConverterRegistry,
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import sub, trunc_div
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -50,53 +45,6 @@ def generic_evaluator(
return target(*args)


@dynamo_tensorrt_converter(
torch.ops.aten.arange.start_step, supports_dynamic_shapes=True
)
def aten_ops_arange_start_step(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# Case where inputs to arange are dynamic
if any(isinstance(tensor, TRTTensor) for tensor in args):
start_rank_0 = get_trt_tensor(ctx, args[0], name + "_start_rank_0", min_rank=0)
start_rank_1 = get_trt_tensor(ctx, args[0], name + "_start_rank_1", min_rank=1)
end = get_trt_tensor(ctx, args[1], name + "_end", min_rank=1)
step = args[2] if len(args) > 2 else 1
step = get_trt_tensor(ctx, step, name + "_step", min_rank=1)
# Calculate shape = (end-start) / step
shape = sub(
ctx,
target,
SourceIR.ATEN,
name + "_sub",
end,
start_rank_1,
)
shape = trunc_div(
ctx,
target,
SourceIR.ATEN,
name + "_shape",
shape,
step,
)
shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted")
fill_layer = ctx.net.add_fill(
shape.shape, trt.FillOperation.LINSPACE, shape.dtype
)
fill_layer.set_input(0, shape)
# Set start index
fill_layer.set_input(1, start_rank_0)
# Set delta/step
fill_layer.set_input(2, step)
return fill_layer.get_output(0)
return np.arange(*args)


def rand_validator(rand_node: Node) -> bool:
dtype = rand_node.kwargs.get("dtype", None)
layout = rand_node.kwargs.get("layout", None)
Expand Down
Loading
Loading