Skip to content

Commit add1567

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Update XNN upsample tests to handle undecomposed op
Differential Revision: D68374352
1 parent fedb035 commit add1567

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def create_upsample_bilinear_2d(
2323
align_corners: bool,
2424
):
2525
output = internal_match.returning_nodes[0]
26+
if output.target == exir_ops.edge.aten.upsample_bilinear2d.vec:
27+
# Op was not decomposed, do nothing
28+
return
29+
2630
output_shape = output.meta["val"].shape
2731
output_h = output_shape[-2]
2832
output_w = output_shape[-1]

backends/xnnpack/test/ops/test_bilinear2d.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,16 @@ def test_fp32_bilinear2d_dynamic_bilinear2d_not_partitioned(self):
131131
3: torch.export.Dim("w", min=1, max=12),
132132
}
133133
}
134-
(
134+
artifact_str = str(
135135
Tester(self.StaticResizeBilinear2dModule(), example_inputs)
136136
.export(Export(dynamic_shapes))
137137
.to_edge_transform_and_lower()
138-
# NOTE The decomposition is partially delegated. This will need to be replaced
139-
# with the aten upsample op once decomp is removed.
140-
.check("executorch_exir_dialects_edge__ops_aten_index_Tensor")
138+
.get_artifact()
139+
.exported_program()
140+
)
141+
# NOTE The decomposition can be partially delegated. This will need to be replaced
142+
# with the aten upsample op once decomp is removed.
143+
self.assertTrue(
144+
"executorch_exir_dialects_edge__ops_aten_index_Tensor" in artifact_str or
145+
"executorch_exir_dialects_edge__ops_aten_upsample_bilinear2d_vec" in artifact_str
141146
)

exir/emit/test/test_emit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ class M(torch.nn.Module):
642642
def forward(self, x):
643643
return torch.nn.functional.interpolate(x, scale_factor=2)
644644

645-
x = (torch.randn(1, 1, 2, 2),)
645+
x = (torch.randn(1, 1, 2, 2, 2),)
646646
program = (
647647
to_edge(export(M(), x, strict=True)).to_executorch().executorch_program
648648
)

0 commit comments

Comments
 (0)