Skip to content

Commit e95fede

Browse files
committed
fix linalg.batch_reduce_matmul auto cast
1 parent d38bae3 commit e95fede

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
### AUTOGENERATED from core_named_ops.py
2-
### To regenerate, run: bin/update_core_linalg_named_ops.sh
31
--- !LinalgOpConfig
42
metadata: !LinalgOpMetadata
53
name: copy
@@ -1908,25 +1906,25 @@ structured_op: !LinalgStructuredOpConfig
19081906
scalar_arg: C
19091907
- !ScalarExpression
19101908
scalar_fn:
1911-
kind: type
1912-
fn_name: cast_signed
1913-
type_var: U
1909+
kind: binary
1910+
fn_name: mul
19141911
operands:
19151912
- !ScalarExpression
19161913
scalar_fn:
1917-
kind: binary
1918-
fn_name: mul
1914+
kind: type
1915+
fn_name: cast_signed
1916+
type_var: U
19191917
operands:
19201918
- !ScalarExpression
19211919
scalar_arg: A
1920+
- !ScalarExpression
1921+
scalar_fn:
1922+
kind: type
1923+
fn_name: cast_signed
1924+
type_var: U
1925+
operands:
19221926
- !ScalarExpression
1923-
scalar_fn:
1924-
kind: type
1925-
fn_name: cast_signed
1926-
type_var: U
1927-
operands:
1928-
- !ScalarExpression
1929-
scalar_arg: B
1927+
scalar_arg: B
19301928
--- !LinalgOpConfig
19311929
metadata: !LinalgOpMetadata
19321930
name: matvec
@@ -6509,3 +6507,4 @@ structured_op: !LinalgStructuredOpConfig
65096507
scalar_const: '2.3283063999999999E-10 : f64'
65106508
- !ScalarExpression
65116509
scalar_arg: min
6510+

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,7 @@ def batch_reduce_matmul(
593593
domain(D.b, D.m, D.n, D.k)
594594
implements(ContractionOpInterface)
595595
C[D.m, D.n] += TypeFn.cast_signed(
596-
U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
597-
)
596+
U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
598597

599598

600599
@linalg_structured_op

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
329329
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
330330
// CHECK: linalg.yield %[[ADD]] : f32
331331

332+
// -----
333+
334+
func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
335+
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
336+
outs(%out: memref<8x8xf32>)
337+
return
338+
}
339+
340+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
341+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
342+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
343+
344+
// CHECK: @generalize_batch_reduce_gemm_bf16
345+
346+
// CHECK: linalg.generic
347+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
348+
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
349+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
350+
// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
351+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
352+
// CHECK: %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
353+
// CHECK: %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
354+
// CHECK: %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
355+
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
356+
// CHECK: linalg.yield %[[ADD]] : f32
357+
358+
332359
// -----
333360

334361
// CHECK-LABEL: generalize_linalg_map

0 commit comments

Comments
 (0)