Skip to content

Commit 05627cd

Browse files
committed
chore: updates
1 parent 60b3e51 commit 05627cd

File tree

2 files changed

+13
-50
lines changed

2 files changed

+13
-50
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def pdist(
452452
p: float = 2,
453453
) -> Union[TRTTensor, Sequence[TRTTensor]]:
454454
shape = input.shape
455+
# Extend input from shape [N, D] to [N, 1, D]
455456
extend_input = impl.shuffle.reshape(
456457
ctx,
457458
target,
@@ -460,7 +461,18 @@ def pdist(
460461
input,
461462
shape=shape[0:1] + (1,) + shape[1:],
462463
)
463-
x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input)
464+
# Expand the input from [N, 1, D] to [N, N, D]
465+
x = impl.slice.expand(
466+
ctx,
467+
target,
468+
source_ir,
469+
f"{name}_sub",
470+
extend_input,
471+
(shape[0], shape[0]) + shape[1:],
472+
)
473+
# Subtract the expanded input from original input. Result shape = [N, N, D]
474+
# This matrix has the distance of each sample to every other sample and hence the shape is [N, N, D]
475+
x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", x, input)
464476

465477
if p == 0:
466478
# norm = torch.sum(x!=0, dim=2)

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,6 @@ def forward(self, x):
2424
inputs,
2525
)
2626

27-
def test_layernorm_with_dynamic_shape(self):
28-
class LayerNorm(torch.nn.Module):
29-
def forward(self, x):
30-
return torch.ops.aten.layer_norm.default(
31-
x,
32-
torch.tensor([3, 224, 224]),
33-
torch.ones((3, 224, 224)),
34-
torch.zeros((3, 224, 224)),
35-
1e-05,
36-
True,
37-
)
38-
39-
input_specs = [
40-
Input(
41-
shape=(-1, 3, 224, 224),
42-
dtype=torch.float32,
43-
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
44-
),
45-
]
46-
47-
self.run_test_with_dynamic_shape(
48-
LayerNorm(),
49-
input_specs,
50-
)
51-
5227

5328
class TestNativeLayerNormConverter(DispatchTestCase):
5429
def test_layer_norm(self):
@@ -68,30 +43,6 @@ def forward(self, x):
6843
inputs,
6944
)
7045

71-
def test_layernorm_with_dynamic_shape(self):
72-
class LayerNorm(torch.nn.Module):
73-
def forward(self, x):
74-
return torch.ops.aten.native_layer_norm.default(
75-
x,
76-
torch.tensor([3, 224, 224]),
77-
torch.ones((3, 224, 224)),
78-
torch.zeros((3, 224, 224)),
79-
1e-05,
80-
)[0]
81-
82-
input_specs = [
83-
Input(
84-
shape=(-1, 3, 224, 224),
85-
dtype=torch.float32,
86-
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
87-
),
88-
]
89-
90-
self.run_test_with_dynamic_shape(
91-
LayerNorm(),
92-
input_specs,
93-
)
94-
9546

9647
if __name__ == "__main__":
9748
run_tests()

0 commit comments

Comments
 (0)