Skip to content

Commit 10f0d22

Browse files
Arm backend: Fix bug of inserting unnecessary casts for aten.where.self (#11816)
- In MatchWhereSelfDtypePass, target_dtype was initialized with fp32. This works when at least one of the inputs is fp32. But when both inputs are int32, the pass will incorrectly insert int32->fp32 casts. These casts are unnecessary and may introduce operand dtype mismatch issues. - Fix it by initializing target_dtype with input_dtype. Signed-off-by: Yufeng Shi <[email protected]>
1 parent bc605b8 commit 10f0d22

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

backends/arm/_passes/match_where_self_arg_dtype_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4949

5050
input_dtype = input_.meta["val"].dtype
5151
other_dtype = other_.meta["val"].dtype
52-
target_dtype = torch.float32
52+
target_dtype = input_dtype
5353
if input_dtype != other_dtype:
5454
target_dtype = get_largest_dtype(input_dtype, other_dtype)
5555

backends/arm/test/ops/test_where.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ def scalar_condition(input: torch.Tensor):
121121
scalar_condition,
122122
)
123123

124+
int32_scalar_cond = Where(
125+
1,
126+
torch.int32,
127+
scalar_condition,
128+
)
129+
124130
test_modules_common = {
125131
"two_dim_tensor_cond": lambda: two_dim_tensor_cond,
126132
"three_dim_tensor_cond": lambda: three_dim_tensor_cond,
@@ -134,6 +140,7 @@ def scalar_condition(input: torch.Tensor):
134140
**test_modules_common,
135141
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
136142
"float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool,
143+
"int32_scalar_cond": lambda: int32_scalar_cond,
137144
}
138145

139146
test_modules_BI = {

0 commit comments

Comments
 (0)