Skip to content

[mlir][linalg] fix linalg.batch_reduce_matmul auto cast #102585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2024

Conversation

zhczhong
Copy link
Member

@zhczhong zhczhong commented Aug 9, 2024

Fix the auto-cast of linalg.batch_reduce_matmul from cast_to_T(A * cast_to_T(B)) + C to cast_to_T(A) * cast_to_T(B) + C

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: zhicong zhong (zhczhong)

Changes

Fix the auto-cast of linalg.batch_reduce_matmul from cast_to_T(A * cast_to_T(B)) + C to cast_to_T(A) * cast_to_T(B) + C


Full diff: https://github.com/llvm/llvm-project/pull/102585.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+13-14)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+1-2)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+27)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..249b0f56477cc8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,5 +1,3 @@
-### AUTOGENERATED from core_named_ops.py
-### To regenerate, run: bin/update_core_linalg_named_ops.sh
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: copy
@@ -1908,25 +1906,25 @@ structured_op: !LinalgStructuredOpConfig
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: type
-            fn_name: cast_signed
-            type_var: U
+            kind: binary
+            fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: binary
-                fn_name: mul
+                kind: type
+                fn_name: cast_signed
+                type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
                 - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: B
+                  scalar_arg: B
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
@@ -6509,3 +6507,4 @@ structured_op: !LinalgStructuredOpConfig
                           scalar_const: '2.3283063999999999E-10 : f64'
             - !ScalarExpression
               scalar_arg: min
+
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..afb68b471d347a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -593,8 +593,7 @@ def batch_reduce_matmul(
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.m, D.n] += TypeFn.cast_signed(
-        U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
-    )
+        U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) 
 
 
 @linalg_structured_op
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..1e8f1435ca0fa5 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
 // CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
 // CHECK:         linalg.yield %[[ADD]] : f32
 
+// -----
+
+func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
+  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+                             outs(%out: memref<8x8xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @generalize_batch_reduce_gemm_bf16
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
+// CHECK:         %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
+// CHECK:         %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
+
 // -----
 
 // CHECK-LABEL: generalize_linalg_map

Copy link

github-actions bot commented Aug 9, 2024

✅ With the latest revision this PR passed the Python code formatter.

@rengolin
Copy link
Member

rengolin commented Aug 9, 2024

@shahidact

@rengolin
Copy link
Member

rengolin commented Aug 9, 2024

Quick question: did you change the Python file and then regenerated the Yaml file, or did you change both manually?

OpDSL doesn't make it easy to know the difference and I made that mistake myself already once. 😅

@shahidact
Copy link
Contributor

@zhczhong Good catch, Pls also include the related changes from "core_named_ops.py" in this PR.

Copy link
Contributor

@xurui1995 xurui1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhczhong
Copy link
Member Author

Quick question: did you change the Python file and then regenerated the Yaml file, or did you change both manually?

OpDSL doesn't make it easy to know the difference and I made that mistake myself already once. 😅

Thanks for the reminder! I changed the python file and use it to generate the Yaml file

@zhczhong Good catch, Pls also include the related changes from "core_named_ops.py" in this PR.

Thanks! The change has been included here

"""Performs a batch-reduce matrix multiplication of two 3D inputs.
The partial multiplication results are reduced into a 2D output.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.k, D.n]
)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks right to me. Thanks for fixing it

@zhczhong zhczhong merged commit 558d7ad into llvm:main Aug 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants