-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] [linalg] fix side effect of linalg op #114045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: donald chen (cxy-1993) ChangesLinalg op need to take into account memory side effects happening inside the region when determining their own side effects. This patch fixed issue #112881 Full diff: https://github.com/llvm/llvm-project/pull/114045.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bfc609bd708164..c2fee8ea55c960 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -30,6 +30,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ RecursiveMemoryEffects,
DestinationStyleOpInterface,
LinalgStructuredInterface,
ReifyRankedShapedTypeOpInterface], props)> {
diff --git a/mlir/test/Dialect/Linalg/recursive-effect.mlir b/mlir/test/Dialect/Linalg/recursive-effect.mlir
new file mode 100644
index 00000000000000..b5063d48b84cff
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/recursive-effect.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s --canonicalize | FileCheck %s
+
+func.func @map(%arg0: memref<1xf32>, %arg1 : tensor<1xf32>) {
+ %c1 = arith.constant 1 : index
+ %init = arith.constant dense<0.0> : tensor<1xf32>
+ %mapped = linalg.map ins(%arg1:tensor<1xf32>) outs(%init :tensor<1xf32>)
+ (%in : f32) {
+ memref.store %in, %arg0[%c1] : memref<1xf32>
+ linalg.yield %in : f32
+ }
+ func.return
+}
+
+// CHECK-LABEL: @map
+// CHECK: linalg.map
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % one minor comment
func.func @recursive_effect(%arg0: memref<1xf32>, %arg1 : tensor<1xf32>) { | ||
%c1 = arith.constant 1 : index | ||
%init = arith.constant dense<0.0> : tensor<1xf32> | ||
%mapped = linalg.map ins(%arg1:tensor<1xf32>) outs(%init :tensor<1xf32>) | ||
(%in : f32) { | ||
memref.store %in, %arg0[%c1] : memref<1xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use tensors all the way here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need an operation with memory side effects here, I'am afraid using tensors only cannt achieve this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps there are some usages that I'm unaware of. I'd be happy to hear if you could suggest any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't really make a difference, but a vector.print %in : f32
should also work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, done!~
Linalg op need to take into account memory side effects happening inside the region when determining their own side effects. This patch fixed issue llvm#112881
Linalg op need to take into account memory side effects happening inside the region when determining their own side effects. This patch fixed issue llvm#112881
…alg.generic` (#227) As discussed in #216, with upstream fix llvm/llvm-project#114045 of `linalg` op now implements `RecursiveMemoryEffects` trait, we can now convert `tts.scatter` to `linalg.generic` with body of `memref.store` on each scalar index and value element. For instance, `triton_shared/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir`: ``` mlir // RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s #map = affine_map<(d0) -> (d0)> module { tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<3> : tensor<4xi32> %cst_0 = arith.constant dense<64> : tensor<4xi32> %cst_1 = arith.constant dense<4> : tensor<4xi32> %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant 9.900000e+01 : f32 %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32> %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32> %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 { %4 = arith.divsi %arg3, %cst : tensor<4xi32> %5 = tt.splat %arg2 : i32 -> tensor<4xi32> %6 = arith.addi %4, %5 : tensor<4xi32> %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32> %cast = memref.cast %1 : memref<*xf32> to memref<?xf32> %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32> %9 = tensor.empty() : tensor<4xf32> %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) { ^bb0(%in: i32, %in_4: i1, %out: f32): %13 = scf.if %in_4 -> (f32) { %14 = arith.index_cast %in : i32 to index %extracted = tensor.extract %8[%14] : tensor<?xf32> scf.yield %extracted : f32 } else { scf.yield %cst_2 : f32 } linalg.yield %13 : f32 } -> tensor<4xf32> %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32> // tts.scatter lowers to: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %10, %7 : tensor<4xi32>, tensor<4xf32>, tensor<4xi1>) { ^bb0(%in: i32, %in_4: f32, %in_5: i1): scf.if %in_5 { %13 = arith.index_cast %in : i32 to index memref.store %in_4, %cast_3[%13] : memref<?xf32> } linalg.yield } %11 = arith.addi %6, %cst_1 : tensor<4xi32> %12 = arith.addi %arg4, %cst_1 : tensor<4xi32> scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32> } tt.return } } ``` We can also utilize `linalg-fuse-elementwise-ops` now: ```mlir // RUN: triton-shared-opt --linalg-fuse-elementwise-ops --canonicalize %s #map = affine_map<(d0) -> (d0)> module { tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<3> : tensor<4xi32> %cst_0 = arith.constant dense<64> : tensor<4xi32> %cst_1 = arith.constant dense<4> : tensor<4xi32> %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32> %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32> %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 { %4 = arith.divsi %arg3, %cst : tensor<4xi32> %5 = tt.splat %arg2 : i32 -> tensor<4xi32> %6 = arith.addi %4, %5 : tensor<4xi32> %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32> %cast = memref.cast %1 : memref<*xf32> to memref<?xf32> %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32> %cast_2 = memref.cast %0 : memref<*xf32> to memref<?xf32> linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) { ^bb0(%in: i32, %in_3: i1): scf.if %in_3 { %11 = arith.index_cast %in : i32 to index %extracted = tensor.extract %8[%11] : tensor<?xf32> %12 = arith.index_cast %in : i32 to index memref.store %extracted, %cast_2[%12] : memref<?xf32> } linalg.yield } %9 = arith.addi %6, %cst_1 : tensor<4xi32> %10 = arith.addi %arg4, %cst_1 : tensor<4xi32> scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> } tt.return } } ``` Co-authored-by: Xiaoran Weng <[email protected]>
Linalg op need to take into account memory side effects happening inside the region when determining their own side effects.
This patch fixed issue #112881