Skip to content

[mlir] [linalg] fix side effect of linalg op #114045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2024
Merged

Conversation

cxy-1993
Copy link
Contributor

Linalg op need to take into account memory side effects happening inside the region when determining their own side effects.

This patch fixed issue #112881

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

Changes

Linalg op need to take into account memory side effects happening inside the region when determining their own side effects.

This patch fixed issue #112881


Full diff: https://github.com/llvm/llvm-project/pull/114045.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1)
  • (added) mlir/test/Dialect/Linalg/recursive-effect.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bfc609bd708164..c2fee8ea55c960 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -30,6 +30,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
        SingleBlockImplicitTerminator<"YieldOp">,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+       RecursiveMemoryEffects,
        DestinationStyleOpInterface,
        LinalgStructuredInterface,
        ReifyRankedShapedTypeOpInterface], props)> {
diff --git a/mlir/test/Dialect/Linalg/recursive-effect.mlir b/mlir/test/Dialect/Linalg/recursive-effect.mlir
new file mode 100644
index 00000000000000..b5063d48b84cff
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/recursive-effect.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s --canonicalize | FileCheck %s
+
+func.func @map(%arg0: memref<1xf32>, %arg1 : tensor<1xf32>) {
+  %c1 = arith.constant 1 : index
+  %init = arith.constant dense<0.0> : tensor<1xf32>
+  %mapped = linalg.map ins(%arg1:tensor<1xf32>) outs(%init :tensor<1xf32>) 
+            (%in : f32) {
+              memref.store %in, %arg0[%c1] : memref<1xf32>
+              linalg.yield %in : f32
+            }
+  func.return
+}
+
+// CHECK-LABEL: @map
+//       CHECK: linalg.map

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % one minor comment

Comment on lines 1240 to 1245
func.func @recursive_effect(%arg0: memref<1xf32>, %arg1 : tensor<1xf32>) {
%c1 = arith.constant 1 : index
%init = arith.constant dense<0.0> : tensor<1xf32>
%mapped = linalg.map ins(%arg1:tensor<1xf32>) outs(%init :tensor<1xf32>)
(%in : f32) {
memref.store %in, %arg0[%c1] : memref<1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use tensors all the way here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an operation with memory side effects here, I'am afraid using tensors only cannt achieve this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps there are some usages that I'm unaware of. I'd be happy to hear if you could suggest any.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't really make a difference, but a vector.print %in : f32 should also work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, done!~

Linalg op need to take into account memory side effects happening
inside the region when determining their own side effects.

This patch fixed issue llvm#112881
@cxy-1993 cxy-1993 merged commit df0d249 into llvm:main Oct 30, 2024
8 checks passed
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Linalg op need to take into account memory side effects happening inside
the region when determining their own side effects.

This patch fixed issue
llvm#112881
nhat-nguyen pushed a commit to microsoft/triton-shared that referenced this pull request Feb 14, 2025
…alg.generic` (#227)

As discussed in #216, with upstream fix
llvm/llvm-project#114045 of `linalg` op now
implements `RecursiveMemoryEffects` trait, we can now convert
`tts.scatter` to `linalg.generic` with body of `memref.store` on each
scalar index and value element.

For instance,
`triton_shared/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir`:

``` mlir
// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst_2 = arith.constant 9.900000e+01 : f32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %13 = scf.if %in_4 -> (f32) {
          %14 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%14] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst_2 : f32
        }
        linalg.yield %13 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      // tts.scatter lowers to:
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %10, %7 : tensor<4xi32>, tensor<4xf32>, tensor<4xi1>) {
      ^bb0(%in: i32, %in_4: f32, %in_5: i1):
        scf.if %in_5 {
          %13 = arith.index_cast %in : i32 to index
          memref.store %in_4, %cast_3[%13] : memref<?xf32>
        }
        linalg.yield
      }
      %11 = arith.addi %6, %cst_1 : tensor<4xi32>
      %12 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}
```

We can also utilize `linalg-fuse-elementwise-ops` now:

```mlir
// RUN: triton-shared-opt --linalg-fuse-elementwise-ops --canonicalize %s

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
      %cast_2 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) {
      ^bb0(%in: i32, %in_3: i1):
        scf.if %in_3 {
          %11 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%11] : tensor<?xf32>
          %12 = arith.index_cast %in : i32 to index
          memref.store %extracted, %cast_2[%12] : memref<?xf32>
        }
        linalg.yield
      }
      %9 = arith.addi %6, %cst_1 : tensor<4xi32>
      %10 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}
```

Co-authored-by: Xiaoran Weng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants