Skip to content

Commit 3e8d735

Browse files
authored
feat: InstanceNorm decomposition (#3288)
1 parent afb1516 commit 3e8d735

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,26 @@ def log_softmax_decomposition(
400400
)
401401

402402

403+
@register_torch_trt_decomposition(aten.instance_norm, registry=TORCH_TRT_DECOMPOSITIONS)
404+
def instance_norm_decomposition(
405+
input: torch.Tensor,
406+
weight: Optional[torch.Tensor],
407+
bias: Optional[torch.Tensor],
408+
running_mean: Optional[torch.Tensor],
409+
running_var: Optional[torch.Tensor],
410+
use_input_stats: bool,
411+
momentum: float,
412+
eps: float,
413+
cudnn_enabled: bool,
414+
) -> torch.Tensor:
415+
if use_input_stats:
416+
return torch.nn.functional.group_norm(input, input.shape[1], weight, bias, eps)
417+
else:
418+
return torch.nn.functional.batch_norm(
419+
input, running_mean, running_var, weight, bias, False, momentum, eps
420+
)
421+
422+
403423
def get_decompositions(
404424
enable_experimental_decompositions: bool = False,
405425
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,75 @@ def forward(self, x):
15871587
f"Log_softmax TRT outputs don't match with the original model.",
15881588
)
15891589

1590+
@parameterized.expand(
1591+
[
1592+
((1, 3, 5), True),
1593+
((1, 3, 5), False),
1594+
((2, 4, 6, 8), True),
1595+
((2, 4, 6, 8), False),
1596+
((3, 6, 9, 12, 15), True),
1597+
((3, 6, 9, 12, 15), False),
1598+
]
1599+
)
1600+
def test_lowering_instance_norm(self, shape, use_input_stats):
1601+
class TestModule(torch.nn.Module):
1602+
def forward(self, input, weight, bias, running_mean=None, running_var=None):
1603+
return torch.ops.aten.instance_norm.default(
1604+
input,
1605+
weight,
1606+
bias,
1607+
running_mean,
1608+
running_var,
1609+
use_input_stats,
1610+
0.1,
1611+
1e-05,
1612+
True,
1613+
)
1614+
1615+
# Operations expected to be removed in the traced graph after decompositions
1616+
unexpected_ops = {torch.ops.aten.instance_norm.default}
1617+
1618+
inputs = [
1619+
torch.randn(shape, device="cuda"),
1620+
torch.randn(shape[1], device="cuda"),
1621+
torch.randn(shape[1], device="cuda"),
1622+
]
1623+
if not use_input_stats:
1624+
inputs += [
1625+
torch.randn(shape[1], device="cuda"),
1626+
torch.rand(shape[1], device="cuda"),
1627+
]
1628+
1629+
fx_graph = torch.fx.symbolic_trace(TestModule())
1630+
unexpected_ops_seen, _ = lower_graph_testing(
1631+
fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1
1632+
)
1633+
1634+
self.assertEqual(
1635+
len(unexpected_ops_seen),
1636+
0,
1637+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
1638+
)
1639+
1640+
torch._dynamo.reset()
1641+
1642+
# Validate that the results between Torch and Torch-TRT are similar
1643+
optimized_model = torch_tensorrt.compile(
1644+
fx_graph, "dynamo", inputs, min_block_size=1
1645+
)
1646+
optimized_model_results = optimized_model(*inputs).detach().cpu()
1647+
torch_model_results = fx_graph(*inputs).detach().cpu()
1648+
1649+
max_diff = float(
1650+
torch.max(torch.abs(optimized_model_results - torch_model_results))
1651+
)
1652+
self.assertAlmostEqual(
1653+
max_diff,
1654+
0,
1655+
DECIMALS_OF_AGREEMENT,
1656+
"Instance_norm TRT outputs don't match with the original model.",
1657+
)
1658+
15901659

15911660
if __name__ == "__main__":
15921661
run_tests()

0 commit comments

Comments
 (0)