Skip to content

Commit 4b67dc9

Browse files
Arm backend: Do not delegate casting to FP dtypes with BI profile (#10906)
- Casting to floating-point dtypes should be rejected for delegation. Class ToCopySupported should guarantee this. However, the shallow copy used in func _merge_supported_types will modify the dict SUPPORTED_INT_TYPES unintentionally, merging the dict SUPPORTED_FLOAT_TYPES into SUPPORTED_INT_TYPES. Therefore, casting to floating-point dtypes can also pass the check under BI profile. - Fix it by using deepcopy. - Add unittest in test_to_copy.py to check the castings to FP dtypes are not delegated. Signed-off-by: Yufeng Shi <[email protected]>
1 parent d069d65 commit 4b67dc9

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
import copy
78
import logging
89

910
import torch
@@ -42,7 +43,9 @@ def _merge_supported_types(
4243
dtypes1: SupportedTypeDict,
4344
dtypes2: SupportedTypeDict,
4445
) -> SupportedTypeDict:
45-
merged_dtypes = dtypes1
46+
merged_dtypes = copy.deepcopy(
47+
dtypes1
48+
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
4649
for k, v in dtypes2.items():
4750
merged_dtypes[k] = merged_dtypes.get(k, []) + v
4851
return merged_dtypes

backends/arm/test/ops/test_to_copy.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import torch
1313

1414
from executorch.backends.arm.test import common
15-
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineMI
15+
from executorch.backends.arm.test.tester.test_pipeline import (
16+
OpNotSupportedPipeline,
17+
TosaPipelineMI,
18+
)
1619

1720
input_t1 = Tuple[torch.Tensor] # Input x
1821

@@ -31,11 +34,14 @@ def forward(self, x: torch.Tensor):
3134
3235
Only test unquantized graphs as explicit casting of dtypes messes with the
3336
quantization.
37+
However, the model being exported may have some explicit casting to floating
38+
point dtypes. The casting or their decomposition should be rejected during
39+
partition. This test will be coveraged by class TestToCopy_BI.
3440
3541
Note: This is also covered by test_scalars.py.
3642
"""
3743

38-
_TO_COPY_TEST_DATA = {
44+
_TO_COPY_TEST_DATA_MI = {
3945
"rand_fp16": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32),
4046
"rand_fp32": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16),
4147
"rand_int8": lambda: (
@@ -53,7 +59,7 @@ def forward(self, x: torch.Tensor):
5359
}
5460

5561

56-
@common.parametrize("test_data", _TO_COPY_TEST_DATA)
62+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_MI)
5763
def test_copy_tosa_MI(test_data: Tuple):
5864
test_tensor, new_dtype = test_data()
5965

@@ -64,3 +70,49 @@ def test_copy_tosa_MI(test_data: Tuple):
6470
exir_op=[],
6571
)
6672
pipeline.run()
73+
74+
75+
"""
76+
Casting operations that output floating-point dtypes should be rejected under BI profile,
77+
rather than introducing an invalid dtype into the tosa graph.
78+
For example, x.to(dtype=torch.float32) will be eventually lowered to
79+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default. We should reject this operation
80+
in ToCopySupported::is_node_tosa_supported() before it goes into the delegated graph.
81+
"""
82+
_TO_COPY_TEST_DATA_BI = {
83+
"rand_int8_fp32": lambda: (
84+
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8),
85+
torch.float32,
86+
),
87+
"rand_int16_fp32": lambda: (
88+
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16),
89+
torch.float32,
90+
),
91+
"rand_int32_fp32": lambda: (
92+
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32),
93+
torch.float32,
94+
),
95+
"rand_int32_fp16": lambda: (
96+
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32),
97+
torch.float16,
98+
),
99+
"rand_int32_bf16": lambda: (
100+
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32),
101+
torch.bfloat16,
102+
),
103+
}
104+
105+
106+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_BI)
107+
def test_copy_tosa_BI(test_data: Tuple):
108+
test_tensor, new_dtype = test_data()
109+
110+
pipeline = OpNotSupportedPipeline[input_t1](
111+
Cast(new_dtype),
112+
(test_tensor,),
113+
{
114+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
115+
},
116+
quantize=True,
117+
)
118+
pipeline.run()

0 commit comments

Comments
 (0)