Skip to content

[mlir][spirv] Improve folding of MemRef to SPIRV Lowering #85433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2024

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Mar 15, 2024

Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work #70704

@inbelic
Copy link
Contributor Author

inbelic commented Mar 15, 2024

This commit is dependent on PR #85430 and will have to be rebased on its changes before it will pass the test changes. This is due to the use of the SelectOp folding that #85430 implements.

Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#70704
@inbelic inbelic force-pushed the inbelic/memref-to-spirv-folding branch from 87c7b2e to adc3bd2 Compare March 19, 2024 20:37
@inbelic
Copy link
Contributor Author

inbelic commented Mar 19, 2024

Rebased onto required commit now that it is merged

@inbelic inbelic marked this pull request as ready for review March 19, 2024 23:41
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Finn Plummer (inbelic)

Changes

Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#70704


Patch is 42.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85433.diff

8 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+28-24)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+5-4)
  • (modified) mlir/test/Conversion/GPUToSPIRV/load-store.mlir (+2-6)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir (+38-120)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+8-26)
  • (modified) mlir/test/Conversion/SCFToSPIRV/for.mlir (+2-10)
  • (modified) mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir (+4-6)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-14)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
   assert(targetBits % sourceBits == 0);
   Type type = srcIdx.getType();
   IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
   IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
-  auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
-  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
-  return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+  auto srcBitsValue =
+      builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
 }
 
 /// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
   Value lastDim = op->getOperand(op.getNumOperands() - 1);
   Type type = lastDim.getType();
   IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
   auto indices = llvm::to_vector<4>(op.getIndices());
   // There are two elements if this is a 1-D tensor.
   assert(indices.size() == 2);
-  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
   Type t = typeConverter.convertType(op.getComponentPtr().getType());
   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
 }
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
     return srcBool;
   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
-  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+                                               zero);
 }
 
 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
           loc, builder.getIntegerType(targetBits), value);
     }
 
-    value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+    value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
   }
-  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
-                                                   offset);
+  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+                                                         value, offset);
 }
 
 /// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
     return srcInt;
 
   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
-  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+  return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
 }
 
 //===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // ____XXXX________ -> ____________XXXX
   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
-  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
       loc, spvLoadOp.getType(), spvLoadOp, offset);
 
   // Apply the mask to extract corresponding bits.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+  result =
+      rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
 
   // Apply sign extension on the loading value unconditionally. The signedness
   // semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   IntegerAttr shiftValueAttr =
       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
   Value shiftValue =
-      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
-  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
-                                                      shiftValue);
-  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
-                                                          shiftValue);
+      rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+                                                            result, shiftValue);
+  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+      loc, dstType, result, shiftValue);
 
   rewriter.replaceOp(loadOp, result);
 
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 
   // Create a mask to clear the destination. E.g., if it is the second i8 in
   // i32, 0xFFFF00FF is created.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  Value clearBitsMask =
-      rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
-  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+      loc, dstType, mask, offset);
+  clearBitsMask =
+      rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
 
   Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 
     int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
     Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
-    return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+    return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
   }();
 
   rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
   // broken down into progressive small steps so we can have intermediate steps
   // using other dialects. At the moment SPIR-V is the final sink.
 
-  Value linearizedIndex = builder.create<spirv::ConstantOp>(
+  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
       loc, integerType, IntegerAttr::get(integerType, offset));
   for (const auto &index : llvm::enumerate(indices)) {
-    Value strideVal = builder.create<spirv::ConstantOp>(
+    Value strideVal = builder.createOrFold<spirv::ConstantOp>(
         loc, integerType,
         IntegerAttr::get(integerType, strides[index.index()]));
-    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    Value update =
+        builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
     linearizedIndex =
-        builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+        builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
   }
   return linearizedIndex;
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
       // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
       %13 = arith.addi %arg4, %3 : index
       // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
-      // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
       // CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
-      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
-      // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
-      // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
-      // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
-      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
       // CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
       %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
 // CHECK-LABEL: @load_i1
 func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
 // INDEX64-LABEL: @load_i8
 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
 
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
   //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
-  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
   //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
   //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
 func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
   //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  //     CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
 func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
   //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   //     CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
   //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
   memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
-  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
 
   //   INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
   //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
-  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
   memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -177,19 +...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, it got so much more concise now. Thanks!

@inbelic inbelic merged commit 38f8a3c into llvm:main Mar 21, 2024
@inbelic inbelic deleted the inbelic/memref-to-spirv-folding branch March 21, 2024 15:49
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#70704
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants