-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] fix linalg.batch_reduce_matmul auto cast #102585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: zhicong zhong (zhczhong) ChangesFix the auto-cast of Full diff: https://github.com/llvm/llvm-project/pull/102585.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..249b0f56477cc8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,5 +1,3 @@
-### AUTOGENERATED from core_named_ops.py
-### To regenerate, run: bin/update_core_linalg_named_ops.sh
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: copy
@@ -1908,25 +1906,25 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
+ kind: binary
+ fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: binary
- fn_name: mul
+ kind: type
+ fn_name: cast_signed
+ type_var: U
operands:
- !ScalarExpression
scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
- !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
+ scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
@@ -6509,3 +6507,4 @@ structured_op: !LinalgStructuredOpConfig
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
scalar_arg: min
+
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..afb68b471d347a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -593,8 +593,7 @@ def batch_reduce_matmul(
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed(
- U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
- )
+ U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
@linalg_structured_op
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..1e8f1435ca0fa5 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
+// -----
+
+func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
+ linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+ outs(%out: memref<8x8xf32>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @generalize_batch_reduce_gemm_bf16
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
+// CHECK: %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
+// CHECK: %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+
// -----
// CHECK-LABEL: generalize_linalg_map
|
✅ With the latest revision this PR passed the Python code formatter. |
Quick question: did you change the Python file and then regenerated the Yaml file, or did you change both manually? OpDSL doesn't make it easy to know the difference and I made that mistake myself already once. 😅 |
@zhczhong Good catch, Pls also include the related changes from "core_named_ops.py" in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for the reminder! I changed the python file and use it to generate the Yaml file
Thanks! The change has been included here llvm-project/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py Lines 587 to 597 in b7f615e
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks right to me. Thanks for fixing it
Fix the auto-cast of
linalg.batch_reduce_matmul
fromcast_to_T(A * cast_to_T(B)) + C
tocast_to_T(A) * cast_to_T(B) + C