Skip to content

Commit 2cc8308

Browse files
committed
feat: support _pdist_forward dynamo converter
1 parent 8ae9eff commit 2cc8308

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,3 +2596,26 @@ def aten_ops_remainder(
25962596
args[0],
25972597
args[1],
25982598
)
2599+
2600+
2601+
@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
2602+
@enforce_tensor_types(
2603+
{
2604+
0: (TRTTensor,),
2605+
}
2606+
)
2607+
def aten_ops_pdist(
2608+
ctx: ConversionContext,
2609+
target: Target,
2610+
args: Tuple[Argument, ...],
2611+
kwargs: Dict[str, Argument],
2612+
name: str,
2613+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2614+
return impl.normalization.pdist(
2615+
ctx,
2616+
target,
2617+
SourceIR.ATEN,
2618+
name,
2619+
args[0],
2620+
args_bounds_check(args, 1, 2),
2621+
)

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt.dynamo.conversion import impl
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
cast_trt_tensor,
1112
get_positive_dim,
1213
get_trt_tensor,
1314
to_numpy,
@@ -440,3 +441,70 @@ def get_softmax_dim(ndim: int) -> int:
440441
layer.axes = 1 << dim
441442
set_layer_name(layer, target, name, source_ir)
442443
return layer.get_output(0)
444+
445+
446+
def pdist(
447+
ctx: ConversionContext,
448+
target: Target,
449+
source_ir: Optional[SourceIR],
450+
name: str,
451+
input: TRTTensor,
452+
p: float = 2,
453+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
454+
shape = input.shape
455+
extend_input = impl.shuffle.reshape(
456+
ctx,
457+
target,
458+
source_ir,
459+
f"{name}_reshape",
460+
input,
461+
shape=shape[0:1] + (1,) + shape[1:],
462+
)
463+
x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input)
464+
465+
if p == 0:
466+
# norm = torch.sum(x!=0, dim=2)
467+
nonzero_val = impl.elementwise.ne(ctx, target, source_ir, f"{name}_ne", x, 0)
468+
norm = impl.reduce.sum(
469+
ctx, target, source_ir, f"{name}_sum", nonzero_val, dim=2, keepdim=False
470+
)
471+
norm = cast_trt_tensor(
472+
ctx, norm, torch.float32, f"{name}_cast", target, source_ir
473+
)
474+
elif p == 1:
475+
# norm = torch.sum(torch.abs(x), dim=2)
476+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
477+
norm = impl.reduce.sum(
478+
ctx, target, source_ir, f"{name}_sum", abs_val, dim=2, keepdim=False
479+
)
480+
elif 0 < p < 1 or 1 < p < float("inf"):
481+
# norm = torch.pow(torch.sum(torch.pow(torch.abs(x), p), dim=2), 1/p)
482+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
483+
pow_val = impl.elementwise.pow(
484+
ctx, target, source_ir, f"{name}_pow1", abs_val, p
485+
)
486+
sum_val = impl.reduce.sum(
487+
ctx, target, source_ir, f"{name}_sum", pow_val, dim=2, keepdim=False
488+
)
489+
norm = impl.elementwise.pow(
490+
ctx, target, source_ir, f"{name}_pow2", sum_val, 1 / p
491+
)
492+
elif p == float("inf"):
493+
# norm = torch.max(torch.abs(x))
494+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
495+
norm = impl.reduce.max(
496+
ctx,
497+
target,
498+
source_ir,
499+
f"{name}_max",
500+
abs_val,
501+
dim=2,
502+
keepdim=False,
503+
return_indices=False,
504+
)
505+
else:
506+
raise RuntimeError(
507+
f"p should between [0, inf], currently p={p} is not supported!"
508+
)
509+
indices = np.triu_indices(shape[0], k=1)
510+
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestPdistConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((2, 3), 0),
13+
((2, 3), 0.4),
14+
((2, 3), 1),
15+
((2, 3), 1.5),
16+
((3, 4), 2),
17+
((3, 4), 2.99),
18+
((4, 5), 3),
19+
((4, 5), 3.3),
20+
((5, 6), float("inf")),
21+
]
22+
)
23+
def test_pdist_float(self, shape, p):
24+
class Pdist(nn.Module):
25+
def forward(self, input):
26+
return torch.ops.aten._pdist_forward.default(input, p)
27+
28+
inputs = [torch.randn(shape)]
29+
self.run_test(
30+
Pdist(),
31+
inputs,
32+
)
33+
34+
35+
if __name__ == "__main__":
36+
run_tests()

0 commit comments

Comments
 (0)