Skip to content

[mlir][affine]introducing new symbol rules that the result of a Pure operation that whose operands are valid symbolic identifiers #118478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

linuxlonelyeagle
Copy link
Member

As title.
Consider the following code:

#map = affine_map<()[s0] -> (s0 mod 32)>

module {
  gpu.module @gpu {
    gpu.func @gemm(%arg0: memref<?x?xf32>) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
      %c0 = arith.constant 0 : index 
       %0 = affine.apply #map()[%thread_id_x]
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }

The code above is fine.The following code causes problems.The reason is that affine.for does not have AffineScope.But affine.for should not be able to introduce AffineScope.

#map = affine_map<()[s0] -> (s0 mod 32)>

module {
  gpu.module @gpu {
    gpu.func @gemm(%arg0: memref<?x?xf32>) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }
  }
}

Why do we need to do this?

module {
  gpu.module @gpu {
    gpu.func @gemm(%arg0: memref<?x?xf32>) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
         //thread load op.
      }
      gpu.return
    }
  }
}

Here is the result of a separate thread loadOp (after lower, although there is no memref.load in it yet).The importance of having the threadid as a legal symbol is demonstrated here, although I could have put the op in the Regon of the funcOp, but that would have added an unreasonable amount of complexity.

%thread_id_x = gpu.thread_id  x
%0 = affine.apply #map()[%thread_id_x]
%c128 = arith.constant 128 : index
affine.for %arg4 = %0 to %c128 step 8 {
   %c32 = arith.constant 32 : index
 }

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-affine

Author: lonely eagle (linuxlonelyeagle)

Changes

As title.
Consider the following code:

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index 
       %0 = affine.apply #map()[%thread_id_x]
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }

The code above is fine.The following code causes problems.The reason is that affine.for does not have AffineScope.But affine.for should not be able to introduce AffineScope.

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }
  }
}

Why do we need to do this?

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
         //thread load op.
      }
      gpu.return
    }
  }
}

Here is the result of a separate thread loadOp (after lower, although there is no memref.load in it yet).The importance of having the threadid as a legal symbol is demonstrated here, although I could have put the op in the Regon of the funcOp, but that would have added an unreasonable amount of complexity.

%thread_id_x = gpu.thread_id  x
%0 = affine.apply #map()[%thread_id_x]
%c128 = arith.constant 128 : index
affine.for %arg4 = %0 to %c128 step 8 {
   %c32 = arith.constant 32 : index
 }

Full diff: https://github.com/llvm/llvm-project/pull/118478.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+36)
  • (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+28-28)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c8..cf355515deb63d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -410,6 +411,7 @@ bool mlir::affine::isValidSymbol(Value value) {
 /// A value can be used as a symbol for `region` iff it meets one of the
 /// following conditions:
 /// *) It is a constant.
+/// *) It is a threadId Op.
 /// *) It is the result of an affine apply operation with symbol arguments.
 /// *) It is a result of the dim op on a memref whose corresponding size is
 ///    a valid symbol.
@@ -443,6 +445,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
   if (matchPattern(defOp, m_Constant(&operandCst)))
     return true;
 
+  // ThreadId operation is ok.
+  if (isa<gpu::ThreadIdOp>(defOp))
+    return true;
+
   // Affine apply operation is ok if all of its operands are ok.
   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
     return applyOp.isValidSymbol(region);
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 7f7a01be891e05..9dad5cdb28cbc4 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -22,4 +22,5 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRSideEffectInterfaces
   MLIRUBDialect
   MLIRValueBoundsOpInterface
+  MLIRGPUDialect
   )
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index c6bfb688db1c1d..5bd556619f3d5a 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -324,3 +324,39 @@ module attributes {gpu.container_module} {
 // CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
 // CHECK:             }
 // CHECK:             gpu.return
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 mod 32)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 mod 32)>
+
+module {
+  gpu.module @gpu {
+    gpu.func @affine_thread_id(%arg0: memref<?x?xf32>) kernel {
+      %c3 = arith.constant 3 : index
+      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
+      %c0 = arith.constant 0 : index
+      affine.for %arg3 = %c0 to %dim step 32 {
+        %thread_id_x = gpu.thread_id  x
+        %0 = affine.apply #map()[%thread_id_x]
+        %c128 = arith.constant 128 : index
+        affine.for %arg4 = %0 to %c128 step 8 {
+          %c32 = arith.constant 32 : index
+        }
+      }
+      gpu.return
+    }
+  }
+}
+
+// CHECK-LABEL:     @affine_thread_id
+// CHECK-SAME:        (%[[VAL_0:.*]]: memref<?x?xf32>) kernel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK:             %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x?xf32>
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
+// CHECK:               %[[VAL_5:.*]] = gpu.thread_id  x
+// CHECK:               %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK:               %[[VAL_7:.*]] = arith.constant 128 : index
+// CHECK:               affine.for %[[VAL_8:.*]] = %[[VAL_6]] to %[[VAL_7]] step 8 {
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 72572c6a38de12..6018eb40bac2a8 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -43,7 +43,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 128)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 128)> 
 
 // CHECK-LABEL: func.func @warpgroup_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -61,7 +61,7 @@ func.func @warpgroup_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C384]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 16)>
 
 // CHECK-LABEL: func.func @warp_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -114,7 +114,7 @@ func.func @warp_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C32]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C3]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -354,9 +354,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 32) floordiv 128) mod 2)>
-// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<(d0, d1, d2) -> (d2 + ((d0 + d1 * 32) floordiv 128) floordiv 2)>
+// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 32) floordiv 128) mod 2)>
+// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<()[s0, s1, s2] -> (s2 + ((s0 + s1 * 32) floordiv 128) floordiv 2)>
 
 // CHECK-LABEL: func.func @warpgroup_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -376,9 +376,9 @@ func.func @warpgroup_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %st
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]](%[[TIDX]], %[[TIDY]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]]()[%[[TIDX]], %[[TIDY]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C768]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -410,9 +410,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) mod 2)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) floordiv 2)>
+// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) mod 2)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) floordiv 2)>
 
 // CHECK-LABEL: func.func @warp_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -432,9 +432,9 @@ func.func @warp_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C192]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -466,12 +466,12 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 18) floordiv 32) mod 3)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 18) floordiv 32) mod 6) floordiv 3)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 18) floordiv 32) mod 3)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 18) floordiv 32) mod 6) floordiv 3)>
 
-// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 18)>
-// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) mod 10)>
-// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) floordiv 10)>
+// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 18)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) mod 10)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) floordiv 10)>
 
 // CHECK-LABEL: func.func @map_multi_level_linear(
 func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
@@ -504,9 +504,9 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
       memref.store %6, %y[%i, %j] : !type
     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
 
-    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[LIN]], %[[C192]] : index
     //     CHECK: scf.if %[[CMPLIN]]
     scf.forall (%i, %j, %k) in (%c3, %c2, %c1) {
@@ -515,8 +515,8 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
         memref.store %8, %y[%i, %j] : !type
      }  {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_2>] }
 
-    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
     //     CHECK: scf.if %[[COND]]
     //     CHECK:   memref.load %{{.*}}[%[[LIDX]]] : memref<32xf32>
@@ -648,7 +648,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0) -> (d0 * 32)>
 
 // CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK-DAG: #[[$MAPW:.*]] = affine_map<(d0, d1, d2) -> (d2 * 32 + ((d0 + d1 * 4) floordiv 32) * 32)>
+// CHECK-DAG: #[[$MAPW:.*]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
 
 // CHECK-LABEL: func.func @simple_fill(
 func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
@@ -667,7 +667,7 @@ func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
 //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
 //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
 //       CHECK:     %[[TIDZ:.*]] = gpu.thread_id  z
-//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 //   CHECK-NOT:     scf.if
 //       CHECK:       memref.subview %{{.*}}[%[[THX]]]
       %1 = affine.apply #map1(%arg2)

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2024

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

As title.
Consider the following code:

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index 
       %0 = affine.apply #map()[%thread_id_x]
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }

The code above is fine.The following code causes problems.The reason is that affine.for does not have AffineScope.But affine.for should not be able to introduce AffineScope.

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }
  }
}

Why do we need to do this?

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
         //thread load op.
      }
      gpu.return
    }
  }
}

Here is the result of a separate thread loadOp (after lower, although there is no memref.load in it yet).The importance of having the threadid as a legal symbol is demonstrated here, although I could have put the op in the Regon of the funcOp, but that would have added an unreasonable amount of complexity.

%thread_id_x = gpu.thread_id  x
%0 = affine.apply #map()[%thread_id_x]
%c128 = arith.constant 128 : index
affine.for %arg4 = %0 to %c128 step 8 {
   %c32 = arith.constant 32 : index
 }

Full diff: https://github.com/llvm/llvm-project/pull/118478.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+36)
  • (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+28-28)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c8..cf355515deb63d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -410,6 +411,7 @@ bool mlir::affine::isValidSymbol(Value value) {
 /// A value can be used as a symbol for `region` iff it meets one of the
 /// following conditions:
 /// *) It is a constant.
+/// *) It is a threadId Op.
 /// *) It is the result of an affine apply operation with symbol arguments.
 /// *) It is a result of the dim op on a memref whose corresponding size is
 ///    a valid symbol.
@@ -443,6 +445,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
   if (matchPattern(defOp, m_Constant(&operandCst)))
     return true;
 
+  // ThreadId operation is ok.
+  if (isa<gpu::ThreadIdOp>(defOp))
+    return true;
+
   // Affine apply operation is ok if all of its operands are ok.
   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
     return applyOp.isValidSymbol(region);
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 7f7a01be891e05..9dad5cdb28cbc4 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -22,4 +22,5 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRSideEffectInterfaces
   MLIRUBDialect
   MLIRValueBoundsOpInterface
+  MLIRGPUDialect
   )
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index c6bfb688db1c1d..5bd556619f3d5a 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -324,3 +324,39 @@ module attributes {gpu.container_module} {
 // CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
 // CHECK:             }
 // CHECK:             gpu.return
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 mod 32)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 mod 32)>
+
+module {
+  gpu.module @gpu {
+    gpu.func @affine_thread_id(%arg0: memref<?x?xf32>) kernel {
+      %c3 = arith.constant 3 : index
+      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
+      %c0 = arith.constant 0 : index
+      affine.for %arg3 = %c0 to %dim step 32 {
+        %thread_id_x = gpu.thread_id  x
+        %0 = affine.apply #map()[%thread_id_x]
+        %c128 = arith.constant 128 : index
+        affine.for %arg4 = %0 to %c128 step 8 {
+          %c32 = arith.constant 32 : index
+        }
+      }
+      gpu.return
+    }
+  }
+}
+
+// CHECK-LABEL:     @affine_thread_id
+// CHECK-SAME:        (%[[VAL_0:.*]]: memref<?x?xf32>) kernel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK:             %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x?xf32>
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
+// CHECK:               %[[VAL_5:.*]] = gpu.thread_id  x
+// CHECK:               %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK:               %[[VAL_7:.*]] = arith.constant 128 : index
+// CHECK:               affine.for %[[VAL_8:.*]] = %[[VAL_6]] to %[[VAL_7]] step 8 {
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 72572c6a38de12..6018eb40bac2a8 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -43,7 +43,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 128)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 128)> 
 
 // CHECK-LABEL: func.func @warpgroup_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -61,7 +61,7 @@ func.func @warpgroup_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C384]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 16)>
 
 // CHECK-LABEL: func.func @warp_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -114,7 +114,7 @@ func.func @warp_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C32]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C3]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -354,9 +354,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 32) floordiv 128) mod 2)>
-// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<(d0, d1, d2) -> (d2 + ((d0 + d1 * 32) floordiv 128) floordiv 2)>
+// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 32) floordiv 128) mod 2)>
+// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<()[s0, s1, s2] -> (s2 + ((s0 + s1 * 32) floordiv 128) floordiv 2)>
 
 // CHECK-LABEL: func.func @warpgroup_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -376,9 +376,9 @@ func.func @warpgroup_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %st
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]](%[[TIDX]], %[[TIDY]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]]()[%[[TIDX]], %[[TIDY]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C768]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -410,9 +410,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) mod 2)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) floordiv 2)>
+// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) mod 2)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) floordiv 2)>
 
 // CHECK-LABEL: func.func @warp_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -432,9 +432,9 @@ func.func @warp_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C192]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -466,12 +466,12 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 18) floordiv 32) mod 3)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 18) floordiv 32) mod 6) floordiv 3)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 18) floordiv 32) mod 3)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 18) floordiv 32) mod 6) floordiv 3)>
 
-// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 18)>
-// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) mod 10)>
-// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) floordiv 10)>
+// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 18)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) mod 10)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) floordiv 10)>
 
 // CHECK-LABEL: func.func @map_multi_level_linear(
 func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
@@ -504,9 +504,9 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
       memref.store %6, %y[%i, %j] : !type
     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
 
-    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[LIN]], %[[C192]] : index
     //     CHECK: scf.if %[[CMPLIN]]
     scf.forall (%i, %j, %k) in (%c3, %c2, %c1) {
@@ -515,8 +515,8 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
         memref.store %8, %y[%i, %j] : !type
      }  {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_2>] }
 
-    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
     //     CHECK: scf.if %[[COND]]
     //     CHECK:   memref.load %{{.*}}[%[[LIDX]]] : memref<32xf32>
@@ -648,7 +648,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0) -> (d0 * 32)>
 
 // CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK-DAG: #[[$MAPW:.*]] = affine_map<(d0, d1, d2) -> (d2 * 32 + ((d0 + d1 * 4) floordiv 32) * 32)>
+// CHECK-DAG: #[[$MAPW:.*]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
 
 // CHECK-LABEL: func.func @simple_fill(
 func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
@@ -667,7 +667,7 @@ func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
 //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
 //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
 //       CHECK:     %[[TIDZ:.*]] = gpu.thread_id  z
-//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 //   CHECK-NOT:     scf.if
 //       CHECK:       memref.subview %{{.*}}[%[[THX]]]
       %1 = affine.apply #map1(%arg2)

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2024

@llvm/pr-subscribers-mlir-gpu

Author: lonely eagle (linuxlonelyeagle)

Changes

As title.
Consider the following code:

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index 
       %0 = affine.apply #map()[%thread_id_x]
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }

The code above is fine.The following code causes problems.The reason is that affine.for does not have AffineScope.But affine.for should not be able to introduce AffineScope.

#map = affine_map&lt;()[s0] -&gt; (s0 mod 32)&gt;

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
        %thread_id_x = gpu.thread_id  x
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }
  }
}

Why do we need to do this?

module {
  gpu.module @<!-- -->gpu {
    gpu.func @<!-- -->gemm(%arg0: memref&lt;?x?xf32&gt;) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref&lt;?x?xf32&gt;
      %c0 = arith.constant 0 : index
      affine.for %arg3 = %c0 to %dim step 32 {
         //thread load op.
      }
      gpu.return
    }
  }
}

Here is the result of a separate thread loadOp (after lower, although there is no memref.load in it yet).The importance of having the threadid as a legal symbol is demonstrated here, although I could have put the op in the Regon of the funcOp, but that would have added an unreasonable amount of complexity.

%thread_id_x = gpu.thread_id  x
%0 = affine.apply #map()[%thread_id_x]
%c128 = arith.constant 128 : index
affine.for %arg4 = %0 to %c128 step 8 {
   %c32 = arith.constant 32 : index
 }

Full diff: https://github.com/llvm/llvm-project/pull/118478.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+36)
  • (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+28-28)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c8..cf355515deb63d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -410,6 +411,7 @@ bool mlir::affine::isValidSymbol(Value value) {
 /// A value can be used as a symbol for `region` iff it meets one of the
 /// following conditions:
 /// *) It is a constant.
+/// *) It is a threadId Op.
 /// *) It is the result of an affine apply operation with symbol arguments.
 /// *) It is a result of the dim op on a memref whose corresponding size is
 ///    a valid symbol.
@@ -443,6 +445,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
   if (matchPattern(defOp, m_Constant(&operandCst)))
     return true;
 
+  // ThreadId operation is ok.
+  if (isa<gpu::ThreadIdOp>(defOp))
+    return true;
+
   // Affine apply operation is ok if all of its operands are ok.
   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
     return applyOp.isValidSymbol(region);
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 7f7a01be891e05..9dad5cdb28cbc4 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -22,4 +22,5 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRSideEffectInterfaces
   MLIRUBDialect
   MLIRValueBoundsOpInterface
+  MLIRGPUDialect
   )
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index c6bfb688db1c1d..5bd556619f3d5a 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -324,3 +324,39 @@ module attributes {gpu.container_module} {
 // CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
 // CHECK:             }
 // CHECK:             gpu.return
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 mod 32)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 mod 32)>
+
+module {
+  gpu.module @gpu {
+    gpu.func @affine_thread_id(%arg0: memref<?x?xf32>) kernel {
+      %c3 = arith.constant 3 : index
+      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
+      %c0 = arith.constant 0 : index
+      affine.for %arg3 = %c0 to %dim step 32 {
+        %thread_id_x = gpu.thread_id  x
+        %0 = affine.apply #map()[%thread_id_x]
+        %c128 = arith.constant 128 : index
+        affine.for %arg4 = %0 to %c128 step 8 {
+          %c32 = arith.constant 32 : index
+        }
+      }
+      gpu.return
+    }
+  }
+}
+
+// CHECK-LABEL:     @affine_thread_id
+// CHECK-SAME:        (%[[VAL_0:.*]]: memref<?x?xf32>) kernel {
+// CHECK:             %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK:             %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x?xf32>
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
+// CHECK:               %[[VAL_5:.*]] = gpu.thread_id  x
+// CHECK:               %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK:               %[[VAL_7:.*]] = arith.constant 128 : index
+// CHECK:               affine.for %[[VAL_8:.*]] = %[[VAL_6]] to %[[VAL_7]] step 8 {
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 72572c6a38de12..6018eb40bac2a8 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -43,7 +43,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 128)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 128)> 
 
 // CHECK-LABEL: func.func @warpgroup_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -61,7 +61,7 @@ func.func @warpgroup_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C384]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 16)>
 
 // CHECK-LABEL: func.func @warp_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -114,7 +114,7 @@ func.func @warp_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 //      CHECK:   gpu.launch
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
 //  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C32]] : index
 //  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C3]] : index
 //      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -354,9 +354,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 32) floordiv 128) mod 2)>
-// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<(d0, d1, d2) -> (d2 + ((d0 + d1 * 32) floordiv 128) floordiv 2)>
+// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 32) floordiv 128) mod 2)>
+// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<()[s0, s1, s2] -> (s2 + ((s0 + s1 * 32) floordiv 128) floordiv 2)>
 
 // CHECK-LABEL: func.func @warpgroup_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -376,9 +376,9 @@ func.func @warpgroup_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %st
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]](%[[TIDX]], %[[TIDY]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]]()[%[[TIDX]], %[[TIDY]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C768]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -410,9 +410,9 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) mod 2)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) floordiv 2)>
+// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) mod 2)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) floordiv 2)>
 
 // CHECK-LABEL: func.func @warp_linear(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -432,9 +432,9 @@ func.func @warp_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
 // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
 // CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
-// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
-// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C192]] : index
 //     CHECK: scf.if %[[CMPLIN]]
 //      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -466,12 +466,12 @@ module attributes {transform.with_named_sequence} {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 18) floordiv 32) mod 3)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 18) floordiv 32) mod 6) floordiv 3)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 18) floordiv 32) mod 3)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 18) floordiv 32) mod 6) floordiv 3)>
 
-// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 18)>
-// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) mod 10)>
-// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) floordiv 10)>
+// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 18)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) mod 10)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) floordiv 10)>
 
 // CHECK-LABEL: func.func @map_multi_level_linear(
 func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
@@ -504,9 +504,9 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
       memref.store %6, %y[%i, %j] : !type
     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
 
-    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[LIN]], %[[C192]] : index
     //     CHECK: scf.if %[[CMPLIN]]
     scf.forall (%i, %j, %k) in (%c3, %c2, %c1) {
@@ -515,8 +515,8 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
         memref.store %8, %y[%i, %j] : !type
      }  {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_2>] }
 
-    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]]()[%[[TIDX]], %[[TIDY]]]
+    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]]()[%[[TIDX]], %[[TIDY]]]
     // CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
     //     CHECK: scf.if %[[COND]]
     //     CHECK:   memref.load %{{.*}}[%[[LIDX]]] : memref<32xf32>
@@ -648,7 +648,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0) -> (d0 * 32)>
 
 // CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK-DAG: #[[$MAPW:.*]] = affine_map<(d0, d1, d2) -> (d2 * 32 + ((d0 + d1 * 4) floordiv 32) * 32)>
+// CHECK-DAG: #[[$MAPW:.*]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
 
 // CHECK-LABEL: func.func @simple_fill(
 func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
@@ -667,7 +667,7 @@ func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
 //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
 //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
 //       CHECK:     %[[TIDZ:.*]] = gpu.thread_id  z
-//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
 //   CHECK-NOT:     scf.if
 //       CHECK:       memref.subview %{{.*}}[%[[THX]]]
       %1 = affine.apply #map1(%arg2)

@linuxlonelyeagle linuxlonelyeagle requested review from MaheshRavishankar, ftynse, bondhugula and antiagainst and removed request for ftynse December 3, 2024 12:05
@adam-smnk
Copy link
Contributor

Just a fly-by comment as I'm no affine expert.

Wouldn't it be sufficient to just define the thread ID outside the loop?
For example:

#map = affine_map<()[s0] -> (s0 mod 32)>

module {
  gpu.module @gpu {
    gpu.func @gemm(%arg0: memref<?x?xf32>) kernel {
      %c3 = arith.constant 3 : index
      %dim = memref.dim %arg0, %c3 : memref<?x?xf32>
      %c0 = arith.constant 0 : index
      %thread_id_x = gpu.thread_id  x // Move declaration here
      affine.for %arg3 = %c0 to %dim step 32 {
        %0 = affine.apply #map()[%thread_id_x]
        %c128 = arith.constant 128 : index
        affine.for %arg4 = %0 to %c128 step 8 {
          %c32 = arith.constant 32 : index
        }
      }
      gpu.return
    }
  }
}

Also, the proposed change seems super specific - it might be fine but if we go this route, would it reasonable to allow any other "special" symbols as well?

@linuxlonelyeagle
Copy link
Member Author

You're right, as I commented above. But I would like to say that assuming a loadOp or storeOp. assuming that the number of threads is not big enough, then a thread needs to access the data multiple times, then it needs to be implemented with a for loop. the position of the threadid should be aligned with the position of the loadOp and storeOp. Assuming the loadOp is in an affine.for, then the threadid I think should also be in an affine.for, which will cause the check to fail. It is indeed possible to mention the threadid in the funcOp, but this would add extra operations, implemented by looking up its parent op through the op matched in the pattern.
I'm not an expert in the affine, sorry.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Dec 16, 2024
Copy link

github-actions bot commented Dec 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from 7450b8c to 52562d0 Compare December 16, 2024 13:15
@krzysz00
Copy link
Contributor

Yeah, seconding pulling the thread ID out of the loop and affine.apply-ing it in as an immediate solution.

However, - and I think this discussion should be postponed to next year because people are out for holidays - I do agree that a gpu.thread_id has dimension-like properties so ... maybe there's something here?

I'm not sure I like the idea of "affine symbol" being the trait name, though I don't have a better one ... if we want to pursue this.

@linuxlonelyeagle
Copy link
Member Author

The discussion being postponed to next year is okay with me.

@linuxlonelyeagle linuxlonelyeagle requested a review from grypp January 6, 2025 14:00
@ftynse
Copy link
Member

ftynse commented Jan 6, 2025

Relevant, so far empty, discussion thread: https://discourse.llvm.org/t/make-threadid-op-is-valid-symbol-introduce-affinesymbol-trait/83702/2. I'll give @bondhugula some more time to react, given the season.

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you squash your commits and add a proper commit summary and update the title as well? It should be along Introduce AffineSymbol trait and use it for ... gpu.threadid op.

@bondhugula
Copy link
Contributor

Extensibility like this appears welcome. Here are a few quick things to think about while on this (some of these may be side-issues/open questions):

  1. Are ops with AffineSymbol trait always also side-effect-free/pure?
  2. Can such ops always be canonicalized to be placed at (or hoisted to) the top-level of an AffineScope?
  3. Will (and should) such ops always have zero operands?

E.g. arith.constant satisfies all of the above properties

@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from 52562d0 to 8efb3ff Compare January 8, 2025 09:03
@linuxlonelyeagle linuxlonelyeagle changed the title [mlir][affine]make threadid op is valid symbol. [mlir][affine]Introduce AffineSymbol trait and use it for gpu.threadid op Jan 8, 2025
@linuxlonelyeagle
Copy link
Member Author

linuxlonelyeagle commented Jan 8, 2025

Can you squash your commits and add a proper commit summary and update the title as well? It should be along Introduce AffineSymbol trait and use it for ... gpu.threadid op.

Sorry, I don't seem to be able to squash commits very well because it syncs the llvm main branch multiple times. rebase comes up with a lot of other people's commits. but I squashed the recent commits after syncing the llvm main branch code.

The merge button below can squash it.

@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from 8e125de to ddcec59 Compare January 12, 2025 09:18
@linuxlonelyeagle
Copy link
Member Author

I designed AffineSymbol at the beginning by referring to constant's properties for design.
1.The ops should be pure.

But gpu.thread_id not pure! It's not marked so.

2.Yes.
3.Yes.

Why can't an AffineSymbol trait op as proposed here be allowed to have operands which are also in turn valid affine symbols? Can we generalize/extend this?

If "zero operand" is a hard implied property, the trait verifier should be checking for it, but it isn't in the PR.

You are right, pure and no side effects should be correct.

@krzysz00
Copy link
Contributor

One drive-by comment: affine maps already have "symbols" - they use the s0, s1, ... sN, so we really need to find a new name for this property.

AffineDimensionLike, perhaps?

@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from ddcec59 to 7e5a95b Compare January 14, 2025 03:36
@bondhugula
Copy link
Contributor

bondhugula commented Jan 15, 2025

I designed AffineSymbol at the beginning by referring to constant's properties for design.
1.The ops should be pure.

But gpu.thread_id not pure! It's not marked so.

2.Yes.
3.Yes.

Why can't an AffineSymbol trait op as proposed here be allowed to have operands which are also in turn valid affine symbols? Can we generalize/extend this?
If "zero operand" is a hard implied property, the trait verifier should be checking for it, but it isn't in the PR.

You are right, pure and no side effects should be correct.

Now, thinking more about it, you don't really need the AffineSymbol trait for your use case as well as others to make the current definition more powerful. How about allowing these:
"any index-typed result of a pure operation that has operands that are in turn symbols" (zero operand pure operations generating index will be a trivial case of this)
as a valid symbol? This is in line with the symbol concept as a Value that doesn't change in the region of interest. The fact that the operation is pure also means it can be freely hoisted/canonicalized to the top level of the AffineScope or higher making it a valid symbol as per the existing rules as well.

This will cover gpu.thread_id any many other valid use cases without need a new trait/interface. A new trait/interface may still be needed for other cases of interest, but we should do the former first instead of having to ask other external ops to add a new trait when not needed. However, if gpu.thread_id is argued to be not "pure" for some reason, we'll have to think about the trait as proposed in this PR.

@bondhugula
Copy link
Contributor

bondhugula commented Jan 15, 2025

One drive-by comment: affine maps already have "symbols" - they use the s0, s1, ... sN, so we really need to find a new name for this property.

AffineDimensionLike, perhaps?

These are symbolic identifiers in affine maps, and thus also referred to as symbols. They are separate from SSA values, which are valid affine symbols. Both are related but not obviously the same since the former are math names and the latter are SSA values. The names are fine I feel - the contexts are different.

@linuxlonelyeagle
Copy link
Member Author

linuxlonelyeagle commented Jan 15, 2025

I designed AffineSymbol at the beginning by referring to constant's properties for design.
1.The ops should be pure.

But gpu.thread_id not pure! It's not marked so.

2.Yes.
3.Yes.

Why can't an AffineSymbol trait op as proposed here be allowed to have operands which are also in turn valid affine symbols? Can we generalize/extend this?
If "zero operand" is a hard implied property, the trait verifier should be checking for it, but it isn't in the PR.

You are right, pure and no side effects should be correct.

Now, thinking more about it, you don't really need the AffineSymbol trait for your use case as well as others to make the current definition more powerful. How about allowing these: "any index-typed result of a pure operation that has operands that are in turn symbols" (zero operand pure operations generating index will be a trivial case of this) as a valid symbol? This is in line with the symbol concept as a Value that doesn't change in the region of interest. The fact that the operation is pure also means it can be freely hoisted/canonicalized to the top level of the AffineScope or higher making it a valid symbol as per the existing rules as well.

This will cover gpu.thread_id any many other valid use cases without need a new trait/interface. A new trait/interface may still be needed for other cases of interest, but we should do the former first instead of having to ask other external ops to add a new trait when not needed. However, if gpu.thread_id is argued to be not "pure" for some reason, we'll have to think about the trait as proposed in this PR.

"any index-typed result of a pure operation that has operands that are in turn symbols" (zero operand pure operations generating index will be a trivial case of this),I think it makes sense, but I have a question.Why does the operand of such an Op also need to be a symbol?
I simply modified the logic of judging whether it is a symbol to the following, but I have not yet judged the operand. I'm afraid the impact of doing so would be too great.

In fact, I think the current implementation is ok. Although this trait may not be really needed, the impact it causes is within a controllable range.

  if (isPure(defOp))
    return true;


********************
********************
Failed Tests (17):
  MLIR :: Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
  MLIR :: Dialect/Affine/invalid.mlir
  MLIR :: Dialect/Affine/load-store-invalid.mlir
  MLIR :: Dialect/GPU/transform-gpu.mlir
  MLIR :: Dialect/Linalg/convert-conv2d-to-img2col.mlir
  MLIR :: Dialect/Linalg/tile-and-fuse-tensors.mlir
  MLIR :: Dialect/Linalg/tile-conv.mlir
  MLIR :: Dialect/Linalg/tile-indexed.mlir
  MLIR :: Dialect/Linalg/tile-to-forall.mlir
  MLIR :: Dialect/Linalg/transform-op-pad.mlir
  MLIR :: Dialect/Linalg/transform-op-split.mlir
  MLIR :: Dialect/Linalg/transform-tile-reduction.mlir
  MLIR :: Dialect/Tensor/tiling.mlir
  MLIR :: Interfaces/TilingInterface/tile-using-interface.mlir
  MLIR :: Interfaces/TilingInterface/tile-using-scfforall.mlir
  MLIR :: Transforms/parallel-loop-collapsing.mlir
  MLIR :: Transforms/single-parallel-loop-collapsing.mlir


Testing Time: 2.97s

Total Discovered Tests: 3033
  Unsupported      :  515 (16.98%)
  Passed           : 2500 (82.43%)
  Expectedly Failed:    1 (0.03%)
  Failed           :   17 (0.56%)
FAILED: tools/mlir/test/CMakeFiles/check-mlir /root/llvm/llvm-project/build/tools/mlir/test/CMakeFiles/check-mlir 

@bondhugula
Copy link
Contributor

bondhugula commented Jan 15, 2025

Why does the operand of such an Op also need to be a symbol?

This is key - otherwise, you are generating SSA values that are functions of dimensional values like loop IVs and those are exactly the opposite of symbols. For eg., an affine.apply is a pure operation - when it takes dimensional values, it'll generate a dimensional value, not symbolic. So the pure op should have zero or more symbolic operands.

@linuxlonelyeagle
Copy link
Member Author

I designed AffineSymbol at the beginning by referring to constant's properties for design.
1.The ops should be pure.

But gpu.thread_id not pure! It's not marked so.

2.Yes.
3.Yes.

Why can't an AffineSymbol trait op as proposed here be allowed to have operands which are also in turn valid affine symbols? Can we generalize/extend this?

If "zero operand" is a hard implied property, the trait verifier should be checking for it, but it isn't in the PR.

the thread_id is indeed marked as pure. I roughly understand what you mean above. Let me ask you again about the concepts of Affine dimensions and symbols. In other words, pure Op + symbol operands, the resulting ssa value is also a symbol.pure Op + dimensional operands, the obtained ssa value is dimensional ssa value.So what does pure op + symbol and dimensional ssa value get?

@linuxlonelyeagle
Copy link
Member Author

I have removed the affinesymbol trait. The current implementation adds logic to check if it is Pure and if its operand is a valid symbol.Cc @bondhugula This PR broke some tests, but I think it was the right change and made sense.

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overall looking good now. Mostly minor comments.

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title and commit title/summary needs to be updated as well.

@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from b81ed07 to 1b4232c Compare January 17, 2025 03:05
@linuxlonelyeagle linuxlonelyeagle force-pushed the support-threadid-as-symbol branch from 1b4232c to e8746ac Compare January 17, 2025 03:08
@linuxlonelyeagle linuxlonelyeagle changed the title [mlir][affine]introduce AffineSymbol trait and use it for using gpu.threadid op in the inner loops. [mlir][affine]introducing new symbol rules that the result of a Pure operation that whose operands are valid symbolic identifiers Jan 17, 2025
@linuxlonelyeagle
Copy link
Member Author

I modified the tile of this PR, and the commit information was modified during the final merge.

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks.

@bondhugula
Copy link
Contributor

Please go ahead and merge if you have access. If not, I can do it. The commits will need to be squashed with proper summary.

@linuxlonelyeagle
Copy link
Member Author

It's been an honor to work with you. I'll do the rest of the work.I can merge it.Thank you. @bondhugula

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants