Skip to content

Commit 6fcccf1

Browse files
authored
feat: support aten._cdist_forward converter (#2726)
1 parent e0fb192 commit 6fcccf1

File tree

3 files changed

+318
-0
lines changed

3 files changed

+318
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,26 @@ def aten_ops_linear(
21862186
)
21872187

21882188

2189+
@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
2190+
def aten_ops_cdist_forward(
2191+
ctx: ConversionContext,
2192+
target: Target,
2193+
args: Tuple[Argument, ...],
2194+
kwargs: Dict[str, Argument],
2195+
name: str,
2196+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2197+
return impl.normalization.cdist_forward(
2198+
ctx,
2199+
target,
2200+
SourceIR.ATEN,
2201+
name,
2202+
x1=args[0],
2203+
x2=args[1],
2204+
p=args[2],
2205+
compute_mode=args_bounds_check(args, 3, None),
2206+
)
2207+
2208+
21892209
def avg_pool_param_validator(pool_node: Node) -> bool:
21902210
ceil_mode = args_bounds_check(pool_node.args, 4, False)
21912211
divisor_override = args_bounds_check(pool_node.args, 6)

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
23

34
import numpy as np
@@ -21,6 +22,8 @@
2122
from torch_tensorrt.fx.types import TRTTensor
2223
from torch_tensorrt.fx.utils import get_dynamic_dims
2324

25+
_LOGGER: logging.Logger = logging.getLogger(__name__)
26+
2427

2528
def batch_norm(
2629
ctx: ConversionContext,
@@ -446,3 +449,201 @@ def pdist(
446449
)
447450
indices = np.triu_indices(shape[0], k=1)
448451
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
452+
453+
454+
def cdist_forward(
455+
ctx: ConversionContext,
456+
target: Target,
457+
source_ir: Optional[SourceIR],
458+
name: str,
459+
x1: TRTTensor,
460+
x2: TRTTensor,
461+
p: float,
462+
compute_mode: Optional[int],
463+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
464+
"""
465+
Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension
466+
of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting
467+
the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances
468+
are computed for each matching set in these dimensions.
469+
470+
The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are
471+
merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions
472+
(except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting).
473+
474+
Args:
475+
x1 (Tensor): input tensor of shape B x P x M.
476+
x2 (Tensor): input tensor of shape B x R x M.
477+
p (float): p value for the p-norm distance to calculate between each vector pair
478+
compute_mode (int): Controls the computation method based on the size of the input sets:
479+
- None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate
480+
Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25).
481+
- 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate
482+
euclidean distance (p = 2)
483+
- 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate
484+
euclidean distance (p = 2)
485+
486+
Example:
487+
- If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20].
488+
This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2.
489+
- For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features),
490+
since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2.
491+
492+
Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation,
493+
especially useful when working with large datasets. This parameter allows you to control how the distances are computed,
494+
with different modes available to leverage matrix multiplication for speed improvements.
495+
"""
496+
if compute_mode is None:
497+
compute_mode = 0
498+
499+
x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]]
500+
x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:])
501+
502+
# Reshape x1 and x2 for broadcasting
503+
x1_expanded = impl.shuffle.reshape(
504+
ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape
505+
)
506+
x2_expanded = impl.shuffle.reshape(
507+
ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape
508+
)
509+
510+
diff = impl.elementwise.sub(
511+
ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded
512+
)
513+
514+
if p == 0:
515+
diff_non_zero = impl.elementwise.ne(
516+
ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0
517+
)
518+
diff_non_zero = cast_trt_tensor(
519+
ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir
520+
)
521+
dist = impl.reduce.sum(
522+
ctx,
523+
target,
524+
source_ir,
525+
f"{name}_sum",
526+
diff_non_zero,
527+
dim=-1,
528+
keepdim=False,
529+
)
530+
elif p == 1:
531+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
532+
dist = impl.reduce.sum(
533+
ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False
534+
)
535+
elif p == 2:
536+
if (
537+
compute_mode == 0 and (x1.shape[-2] > 25 or x2.shape[-2] > 25)
538+
) or compute_mode == 1:
539+
# Compute squared elements
540+
x1_squared = impl.elementwise.pow(
541+
ctx, target, source_ir, f"{name}_x1_squared", x1, 2
542+
)
543+
x2_squared = impl.elementwise.pow(
544+
ctx, target, source_ir, f"{name}_x2_squared", x2, 2
545+
)
546+
547+
# Sum squares along the last dimension
548+
x1_sum_squared = impl.reduce.sum(
549+
ctx,
550+
target,
551+
source_ir,
552+
f"{name}_x1_sum",
553+
x1_squared,
554+
dim=-1,
555+
keepdim=True,
556+
)
557+
x2_sum_squared = impl.reduce.sum(
558+
ctx,
559+
target,
560+
source_ir,
561+
f"{name}_x2_sum",
562+
x2_squared,
563+
dim=-1,
564+
keepdim=True,
565+
)
566+
567+
# Reshape sums for broadcasting
568+
rank = len(x2.shape)
569+
permute_shape = list(range(rank - 2)) + [rank - 1, rank - 2]
570+
x1_sum_expanded = x1_sum_squared
571+
x2_sum_expanded = impl.permutation.permute(
572+
ctx, target, source_ir, f"{name}_permute", x2_sum_squared, permute_shape
573+
)
574+
575+
# Compute dot product of x1 and transposed x2
576+
x2_tr = impl.permutation.permute(
577+
ctx, target, source_ir, f"{name}_permute_mm", x2, permute_shape
578+
)
579+
dot_product = impl.matmul.matrix_multiply(
580+
ctx,
581+
target,
582+
source_ir,
583+
f"{name}_dot_product",
584+
x1,
585+
x2_tr,
586+
input_matrix_op=trt.MatrixOperation.NONE,
587+
other_matrix_op=trt.MatrixOperation.NONE,
588+
)
589+
590+
# Combine results to get squared distances
591+
dist_squared = impl.elementwise.add(
592+
ctx,
593+
target,
594+
source_ir,
595+
f"{name}_dist_squared_initial",
596+
x1_sum_expanded,
597+
x2_sum_expanded,
598+
)
599+
dist_squared = impl.elementwise.sub(
600+
ctx,
601+
target,
602+
source_ir,
603+
f"{name}_dist_squared",
604+
dist_squared,
605+
impl.elementwise.mul(
606+
ctx, target, source_ir, f"{name}_dot_product_scaled", dot_product, 2
607+
),
608+
)
609+
610+
# Compute the Euclidean distances
611+
dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_dist", dist_squared)
612+
else:
613+
diff_squared = impl.elementwise.pow(
614+
ctx, target, source_ir, f"{name}_diff_squared", diff, 2
615+
)
616+
dist_squared = impl.reduce.sum(
617+
ctx,
618+
target,
619+
source_ir,
620+
f"{name}_dist_sq_sum",
621+
diff_squared,
622+
dim=-1,
623+
keepdim=False,
624+
)
625+
dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared)
626+
elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"):
627+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
628+
pow_val = impl.elementwise.pow(
629+
ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p
630+
)
631+
sum_val = impl.reduce.sum(
632+
ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False
633+
)
634+
dist = impl.elementwise.pow(
635+
ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p
636+
)
637+
elif p == float("inf"):
638+
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
639+
dist = impl.reduce.max(
640+
ctx,
641+
target,
642+
source_ir,
643+
f"{name}_max",
644+
abs_val,
645+
dim=-1,
646+
keepdim=False,
647+
return_indices=False,
648+
)
649+
return dist
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCdistConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("p_0", (4, 3, 4), 0, 0),
13+
("p>0_p<1_1", (10, 3, 5, 2, 6), 0.5, 1),
14+
("p>0_p<1_2", (10, 2, 15, 2, 7, 2), 0.5, 1),
15+
("p_1", (15, 10, 5), 1, None),
16+
("p>1_p<2", (19, 11, 5), 1.5, None),
17+
("small_p_2_mode_1", (6, 6, 5), 2.0, 1),
18+
("large_p_2_mode_0", (35, 35, 5), 2.0, 0),
19+
("p>2", (15, 10, 5), 2.99, None),
20+
("p_inf", (5, 15, 5), float("inf"), 0),
21+
]
22+
)
23+
def test_cdist_float_same_shape(self, name, shape, p, compute_mode):
24+
class Cdist(nn.Module):
25+
def forward(self, x1, x2):
26+
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)
27+
28+
inputs = [torch.randn(shape), torch.randn(shape)]
29+
self.run_test(
30+
Cdist(),
31+
inputs,
32+
)
33+
34+
@parameterized.expand(
35+
[
36+
("p_0", (1, 5), (2, 3, 5), 0, 0),
37+
("p_1", (4, 5), (2, 3, 5), 1, None),
38+
("diff_shape_p_0", (2, 5, 4, 5), (2, 5, 8, 5), 0, 2),
39+
("diff_shape_p_1", (2, 4, 5), (2, 3, 5), 1, 1),
40+
("p>0_p<1", (2, 2, 4, 5), (2, 3, 5), 0.5, None),
41+
("p>1_p<2", (5, 2, 12, 5), (2, 3, 5), 1.5, 1),
42+
("p_2", (2, 2, 14, 5), (2, 3, 5), 2, 0),
43+
("p>2", (2, 2, 4, 5), (2, 10, 5), 2.99, 2),
44+
("p_inf", (2, 2, 3, 5), (2, 8, 5), float("inf"), None),
45+
]
46+
)
47+
def test_cdist_float_broadcast_and_diff_shape(
48+
self, name, shape_1, shape_2, p, compute_mode
49+
):
50+
class Cdist(nn.Module):
51+
def forward(self, x1, x2):
52+
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)
53+
54+
inputs = [torch.randn(shape_1), torch.randn(shape_2)]
55+
self.run_test(
56+
Cdist(),
57+
inputs,
58+
)
59+
60+
@parameterized.expand(
61+
[
62+
("compute_mode_0", (15, 10, 5), (15, 35, 5), 2.0, 0),
63+
("compute_mode_1", (35, 35, 5), (35, 45, 5), 2.0, 0),
64+
("compute_mode_2", (15, 10, 5), (15, 35, 5), 2.0, 1),
65+
("compute_mode_3", (35, 35, 5), (35, 45, 5), 2.0, 2),
66+
("p_2_mm_shape_1", (2, 2, 14, 5), (3, 5), 2, 1),
67+
("p_2_mm_shape_2", (2, 2, 14, 5), (2, 3, 5), 2, 1),
68+
("p_2_mm_shape_3", (2, 2, 14, 5), (2, 2, 3, 5), 2, 1),
69+
]
70+
)
71+
def test_cdist_p_2_compute_mode(self, name, shape_1, shape_2, p, compute_mode):
72+
class Cdist(nn.Module):
73+
def forward(self, x1, x2):
74+
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)
75+
76+
inputs = [torch.randn(shape_1), torch.randn(shape_2)]
77+
self.run_test(Cdist(), inputs)
78+
79+
@parameterized.expand(
80+
[
81+
("p_2_matmul", (50, 40, 30, 30), (50, 40, 35, 30), 2, 1),
82+
("p_2_elementwise_pow", (50, 40, 30, 50), (50, 40, 35, 50), 2, 2),
83+
]
84+
)
85+
def test_cdist_efficiency_p_2_compute_mode(
86+
self, name, shape_1, shape_2, p, compute_mode
87+
):
88+
class Cdist(nn.Module):
89+
def forward(self, x1, x2):
90+
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)
91+
92+
inputs = [torch.randn(shape_1), torch.randn(shape_2)]
93+
self.run_test(Cdist(), inputs)
94+
95+
96+
if __name__ == "__main__":
97+
run_tests()

0 commit comments

Comments
 (0)