Skip to content

Commit 558d7ad

Browse files
authored
[mlir][linalg] fix linalg.batch_reduce_matmul auto cast (#102585)
Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`
1 parent c6062d3 commit 558d7ad

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,25 +1908,25 @@ structured_op: !LinalgStructuredOpConfig
19081908
scalar_arg: C
19091909
- !ScalarExpression
19101910
scalar_fn:
1911-
kind: type
1912-
fn_name: cast_signed
1913-
type_var: U
1911+
kind: binary
1912+
fn_name: mul
19141913
operands:
19151914
- !ScalarExpression
19161915
scalar_fn:
1917-
kind: binary
1918-
fn_name: mul
1916+
kind: type
1917+
fn_name: cast_signed
1918+
type_var: U
19191919
operands:
19201920
- !ScalarExpression
19211921
scalar_arg: A
1922+
- !ScalarExpression
1923+
scalar_fn:
1924+
kind: type
1925+
fn_name: cast_signed
1926+
type_var: U
1927+
operands:
19221928
- !ScalarExpression
1923-
scalar_fn:
1924-
kind: type
1925-
fn_name: cast_signed
1926-
type_var: U
1927-
operands:
1928-
- !ScalarExpression
1929-
scalar_arg: B
1929+
scalar_arg: B
19301930
--- !LinalgOpConfig
19311931
metadata: !LinalgOpMetadata
19321932
name: matvec

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ def batch_reduce_matmul(
592592
"""
593593
domain(D.b, D.m, D.n, D.k)
594594
implements(ContractionOpInterface)
595-
C[D.m, D.n] += TypeFn.cast_signed(
596-
U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
595+
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
596+
U, B[D.b, D.k, D.n]
597597
)
598598

599599

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
329329
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
330330
// CHECK: linalg.yield %[[ADD]] : f32
331331

332+
// -----
333+
334+
func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
335+
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
336+
outs(%out: memref<8x8xf32>)
337+
return
338+
}
339+
340+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
341+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
342+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
343+
344+
// CHECK: @generalize_batch_reduce_gemm_bf16
345+
346+
// CHECK: linalg.generic
347+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
348+
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
349+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
350+
// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
351+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
352+
// CHECK: %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
353+
// CHECK: %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
354+
// CHECK: %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
355+
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
356+
// CHECK: linalg.yield %[[ADD]] : f32
357+
358+
332359
// -----
333360

334361
// CHECK-LABEL: generalize_linalg_map

0 commit comments

Comments
 (0)