Skip to content

[mlir][vector] Add emulation patterns for vector masked load/store #74834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2023

Conversation

Hsiangkai
Copy link
Contributor

@Hsiangkai Hsiangkai commented Dec 8, 2023

In this patch, it will convert

vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru

to

%ivalue = %pass_thru
%m = vector.extract %mask[0]
%result0 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1]
  %combined = vector.insert %v, %ivalue[0]
  scf.yield %combined
} else {
  scf.yield %ivalue
}
%m = vector.extract %mask[1]
%result1 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1 + 1]
  %combined = vector.insert %v, %result0[1]
  scf.yield %combined
} else {
  scf.yield %result0
}
...

It will convert

vector.maskedstore %base[%idx_0, %idx_1], %mask, %value

to

%m = vector.extract %mask[0]
scf.if %m {
  %extracted = vector.extract %value[0]
  memref.store %extracted, %base[%idx_0, %idx_1]
}
%m = vector.extract %mask[1]
scf.if %m {
  %extracted = vector.extract %value[1]
  memref.store %extracted, %base[%idx_0, %idx_1 + 1]
}
...

@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

Use spirv.mlir.loop and spirv.mlir.selection to lower vector.maskedload and vector.maskedstore.


Patch is 21.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74834.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+324-1)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+111)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e48f29a4f1702..b32c004e28a1e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -647,6 +647,328 @@ struct VectorStoreOpConverter final
   }
 };
 
+mlir::spirv::LoopOp createSpirvLoop(ConversionPatternRewriter &rewriter,
+                                    Location loc) {
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+  loopOp.addEntryAndMergeBlock();
+
+  auto &loopBody = loopOp.getBody();
+  // Create header block.
+  loopBody.getBlocks().insert(std::next(loopBody.begin(), 1), new Block());
+  // Create continue block.
+  loopBody.getBlocks().insert(std::prev(loopBody.end(), 2), new Block());
+
+  return loopOp;
+}
+
+mlir::spirv::SelectionOp
+createSpirvSelection(ConversionPatternRewriter &rewriter, Location loc) {
+  auto selectionOp =
+      rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+  auto &loopBody = selectionOp.getBody();
+  // Create header block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create if-true block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create merge block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  rewriter.create<spirv::MergeOp>(loc);
+
+  return selectionOp;
+}
+
+Value addOffsetToIndices(ConversionPatternRewriter &rewriter, Location loc,
+                         SmallVectorImpl<Value> &indices, const Value offset,
+                         const SPIRVTypeConverter &typeConverter,
+                         const MemRefType memrefType, const Value base) {
+  indices.back() = rewriter.create<spirv::IAddOp>(loc, indices.back(), offset);
+  return spirv::getElementPtr(typeConverter, memrefType, base, indices, loc,
+                              rewriter);
+}
+
+Value extractMaskBit(ConversionPatternRewriter &rewriter, Location loc,
+                     Value mask, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(
+      loc, rewriter.getI1Type(), mask, offset);
+}
+
+Value extractVectorElement(ConversionPatternRewriter &rewriter, Location loc,
+                           Type type, Value vector, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(loc, type, vector,
+                                                        offset);
+}
+
+Value createConstantInteger(ConversionPatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<spirv::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedload to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %buffer = spirv.Variable
+///   spirv.mlir.loop {
+///     spirv.Branch ^bb1(0, %buffer)
+///   ^bb1(%i: i32, %partial: vector):
+///     %m = spirv.VectorExtractDynamic %mask[%i]
+///     %p = spirv.VectorExtractDynamic %pass_thru[%i]
+///     %value = spirv.Load
+///     %s = spirv.Select %m, %value, %p
+///     %v = spirv.VectorInsertDynamic %s, %partial[%i]
+///     spirv.Store %buffer, %v
+///     spirv.Branch ^bb2(%i, %v)
+///   ^bb2(%i: i32, %partial: vector):
+///     %update_i = spirv.IAdd %i, 1
+///     %cond = spirv.SLessThan %update_i, %veclen
+///     spirv.BranchConditional %cond, ^bb1(%update_i, %partial), ^bb3
+///   ^bb3:
+///     spirv.mlir.merge
+///   }
+///   %ret = spirv.Load %buffer
+///   return %ret
+///
+struct VectorMaskedLoadOpConverter final
+    : public OpConversionPattern<vector::MaskedLoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedLoadOp.getMemRefType();
+    if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+      return failure();
+
+    VectorType maskVType = maskedLoadOp.getMaskVectorType();
+    if (maskVType.getRank() != 1)
+      return failure();
+    if (maskVType.getShape().size() != 1)
+      return failure();
+
+    // Create a local variable to store the loaded value.
+    auto loc = maskedLoadOp.getLoc();
+    auto vectorType = maskedLoadOp.getVectorType();
+    auto pointerType =
+        spirv::PointerType::get(vectorType, spirv::StorageClass::Function);
+    auto alloc = rewriter.create<spirv::VariableOp>(
+        loc, pointerType, spirv::StorageClass::Function,
+        /*initializer=*/nullptr);
+
+    // Create constants for the loop.
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    auto emptyVector = rewriter.create<spirv::ConstantOp>(
+        loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+    // Construct a loop to go through the mask value
+    auto loopOp = createSpirvLoop(rewriter, loc);
+
+    auto *headerBlock = loopOp.getHeaderBlock();
+    auto *continueBlock = loopOp.getContinueBlock();
+
+    auto i32Type = rewriter.getI32Type();
+    BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+    BlockArgument partialVector = headerBlock->addArgument(vectorType, loc);
+    BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+    BlockArgument continueVector = continueBlock->addArgument(vectorType, loc);
+
+    // Insert code into loop entry block
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+
+    // Header block needs two arguments: induction variable, updated vector
+    rewriter.create<spirv::BranchOp>(loc, headerBlock,
+                                     ArrayRef<Value>({zero, emptyVector}));
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointToEnd(headerBlock);
+    auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+    auto scalarType = memrefType.getElementType();
+    auto passThruValule = extractVectorElement(rewriter, loc, scalarType,
+                                               adaptor.getPassThru(), indVar);
+
+    // Load base[indVar]
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indices = llvm::to_vector<4>(adaptor.getIndices());
+    auto updatedAccessChain =
+        addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+                           memrefType, adaptor.getBase());
+    auto loadScalar =
+        rewriter.create<spirv::LoadOp>(loc, scalarType, updatedAccessChain);
+
+    // Select the loaded value or pass-through according to the mask bit.
+    auto valueToInsert = rewriter.create<spirv::SelectOp>(
+        loc, scalarType, maskBit, loadScalar, passThruValule);
+
+    // Insert the selected value to output vector.
+    auto updatedVector = rewriter.create<spirv::VectorInsertDynamicOp>(
+        loc, vectorType, partialVector, valueToInsert, indVar);
+    rewriter.create<spirv::StoreOp>(loc, alloc, updatedVector);
+    rewriter.create<spirv::BranchOp>(loc, continueBlock,
+                                     ArrayRef<Value>({indVar, updatedVector}));
+
+    // Insert code into continue block
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Update induction variable.
+    auto updatedIndVar =
+        rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+    // Check if the induction variable < length(mask)
+    auto cmpOp =
+        rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+    auto *mergeBlock = loopOp.getMergeBlock();
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, headerBlock,
+        ArrayRef<Value>({updatedIndVar, continueVector}), mergeBlock,
+        std::nullopt);
+
+    // Insert code after loop
+    rewriter.setInsertionPointAfter(loopOp);
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(maskedLoadOp, alloc);
+
+    return success();
+  }
+};
+
+/// Convert vector.maskedstore to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   spirv.mlir.loop {
+///     spirv.Branch ^bb1(0)
+///   ^bb1(%i: i32):
+///     %m = spirv.VectorExtractDynamic %mask[%i]
+///     spirv.mlir.selection {
+///       spirv.BranchConditional %m, ^if_bb1, ^if_bb2
+///       ^if_bb1:
+///         %v = spirv.VectorExtractDynamic %value[%i]
+///         spirv.Store %base[%i], %v
+///         spirv.Branch ^if_bb2
+///       ^if_bb2:
+///         spirv.mlir.merge
+///     }
+///     spirv.Branch ^bb2(%i)
+///   ^bb2(%i: i32):
+///     %update_i = spirv.IAdd %i, 1
+///     %cond = spirv.SLessThan %update_i, %veclen
+///     spirv.BranchConditional %cond, ^bb1, ^bb3
+///   ^bb3:
+///     spirv.mlir.merge
+///   }
+///   return
+///
+struct VectorMaskedStoreOpConverter final
+    : public OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedStoreOp.getMemRefType();
+    if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+      return failure();
+
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getRank() != 1)
+      return failure();
+    if (maskVType.getShape().size() != 1)
+      return failure();
+
+    // Create constants.
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    // Construct a loop to go through the mask value
+    auto loopOp = createSpirvLoop(rewriter, loc);
+    auto *headerBlock = loopOp.getHeaderBlock();
+    auto *continueBlock = loopOp.getContinueBlock();
+
+    auto i32Type = rewriter.getI32Type();
+    BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+    BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+
+    // Insert code into loop entry block
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+    rewriter.create<spirv::BranchOp>(loc, headerBlock, ArrayRef<Value>({zero}));
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointToEnd(headerBlock);
+    auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+    auto selectionOp = createSpirvSelection(rewriter, loc);
+    auto *selectionHeaderBlock = selectionOp.getHeaderBlock();
+    auto *selectionMergeBlock = selectionOp.getMergeBlock();
+    auto *selectionTrueBlock = &(*std::next(selectionOp.getBody().begin(), 1));
+
+    // Insert code into selection header block
+    rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, maskBit, selectionTrueBlock, std::nullopt, selectionMergeBlock,
+        std::nullopt);
+
+    // Insert code into selection true block
+    rewriter.setInsertionPointToEnd(selectionTrueBlock);
+    auto scalarType = memrefType.getElementType();
+    auto extractedStoreValue = extractVectorElement(
+        rewriter, loc, scalarType, adaptor.getValueToStore(), indVar);
+
+    // Store base[indVar]
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indices = llvm::to_vector<4>(adaptor.getIndices());
+    auto updatedAccessChain =
+        addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+                           memrefType, adaptor.getBase());
+    rewriter.create<spirv::StoreOp>(loc, updatedAccessChain,
+                                    extractedStoreValue);
+    rewriter.create<spirv::BranchOp>(loc, selectionMergeBlock, std::nullopt);
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointAfter(selectionOp);
+    rewriter.create<spirv::BranchOp>(loc, continueBlock,
+                                     ArrayRef<Value>({indVar}));
+
+    // Insert code into loop continue block
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Update induction variable.
+    auto updatedIndVar =
+        rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+    // Check if the induction variable < length(mask)
+    auto cmpOp =
+        rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+    auto *mergeBlock = loopOp.getMergeBlock();
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, headerBlock, ArrayRef<Value>({updatedIndVar}), mergeBlock,
+        std::nullopt);
+
+    // Insert code after loop
+    rewriter.setInsertionPointAfter(loopOp);
+    rewriter.replaceOp(maskedStoreOp, loopOp);
+
+    return success();
+  }
+};
+
 struct VectorReductionToIntDotProd final
     : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -821,7 +1143,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc..bc9e92981644b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -805,4 +805,115 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
   return
 }
 
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[CST_F0:.*]] = arith.constant 0.000000e+00 : f32
+//       CHECK:    %[[S4:.*]] = spirv.CompositeConstruct %[[CST_F0]], %[[CST_F0]], %[[CST_F0]], %[[CST_F0]] : (f32, f32, f32, f32) -> vector<4xf32>
+//       CHECK:    %[[S5:.*]] = spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[C4_1:.*]] = spirv.Constant 4 : i32
+//       CHECK:    %[[CV0:.*]] = spirv.Constant dense<0.000000e+00> : vector<4xf32>
+//       CHECK:    spirv.mlir.loop {
+//       CHECK:      spirv.Branch ^bb1(%[[C0_1]], %[[CV0]] : i32, vector<4xf32>)
+//       CHECK:    ^bb1(%[[S7:.*]]: i32, %[[S8:.*]]: vector<4xf32>):  // 2 preds: ^bb0, ^bb2
+//       CHECK:      %[[S9:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S7]]] : vector<4xi1>, i32
+//       CHECK:      %[[S10:.*]] = spirv.VectorExtractDynamic %[[S4]][%[[S7]]] : vector<4xf32>, i32
+//       CHECK:      %[[S11:.*]] = spirv.IAdd %[[S2]], %[[S7]] : i32
+//       CHECK:      %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:      %[[C0_3:.*]] = spirv.Constant 0 : i32
+//       CHECK:      %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:      %[[S12:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:      %[[S13:.*]] = spirv.IAdd %[[C0_3]], %[[S12]] : i32
+//       CHECK:      %[[C1_2:.*]] = spirv.Constant 1 : i32
+//       CHECK:      %[[S14:.*]] = spirv.IMul %[[C1_2]], %[[S11]] : i32
+//       CHECK:      %[[S15:.*]] = spirv.IAdd %[[S13]], %[[S14]] : i32
+//       CHECK:      %[[S16:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S15]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:      %[[S17:.*]] = spirv.Load "StorageBuffer" %[[S16]] : f32
+//       CHECK:      %[[S18:.*]] = spirv.Select %[[S9]], %[[S17]], %[[S10]] : i1, f32
+//       CHECK:      %[[S19:.*]] = spirv.VectorInsertDynamic %[[S18]], %[[S8]][%[[S7]]] : vector<4xf32>, i32
+//       CHECK:      spirv.Store "Function" %[[S5]], %[[S19]] : vector<4xf32>
+//       CHECK:      spirv.Branch ^bb2(%[[S7]], %[[S19]] : i32, vector<4xf32>)
+//       CHECK:    ^bb2(%[[S20:.*]]: i32, %[[S21:.*]]: vector<4xf32>):  // pred: ^bb1
+//       CHECK:      %[[S22:.*]] = spirv.IAdd %[[S20]], %[[C1_1]] : i32
+//       CHECK:      %[[S23:.*]] = spirv.SLessThan %[[S22]], %[[C4_1]] : i32
+//       CHECK:      spirv.BranchConditional %[[S23]], ^bb1(%[[S22]], %[[S21]] : i32, vector<4xf32>), ^bb3
+//       CHECK:    ^bb3:  // pred: ^bb2
+//       CHECK:      spirv.mlir.merge
+//       CHECK:    }
+//       CHECK:    %[[S6:.*]] = spirv.Load "Function" %[[S5]] : vector<4xf32>
+//       CHECK:    return %[[S6]] : vector<4xf32>
+//       CHECK:  }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[C4_1:.*]] = spirv.Constant 4 : i32
+//       CHECK:    spirv.mlir.loop {
+//       CHECK:      spirv.Branch ^bb1(%[[C0_1]] : i32)
+//       CHECK:    ^bb1(%[[S4:.*]]: i32):  // 2 preds: ^bb0, ^bb2
+//       CHECK:      %[[S5:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S4]]] : vector<4xi1>, i32
+//       CHECK:      spirv.mlir.selection {
+//       CHECK:        spirv.BranchConditional %[[S5]], ^bb1, ^bb2
+//       CHECK:      ^bb1:  // pred: ^bb0
+//       CHECK:        %[[S9:.*]] = spirv.VectorExtractDynamic %[[ARG1]][%[[S4]]] : vector<4xf32>, i32
+//       CHECK:        %[[S10:.*]] = spirv.IAdd %[[S2]], %[[S4]] : i32
+//       CHECK:        %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C1_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:        %[[S11:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:        %[[S12:.*]] = spirv.IAdd %[[C1_2]], %[[S11]] : i32
+//       CHECK:        %[[C1_3:.*]] = spirv.Constant 1 : i32
+//       CHECK:        %[[S13:.*]] = spirv.IMul %[[C1_3]], %[[S10]] : i32
+//       CHECK:        %[[S14:.*]] = spirv.I...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Hsiangkai,

This looks very useful but I wonder if we should perform the expansion at the level of the vector dialect instead, like we do with ops like vector.gather: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp.

I know we had the same discussion under another PR recently: #69708 (comment), so I wonder if you found staying at the level of the vector dialect problematic for some reason.

@Hsiangkai
Copy link
Contributor Author

Hi @Hsiangkai,

This looks very useful but I wonder if we should perform the expansion at the level of the vector dialect instead, like we do with ops like vector.gather: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp.

I know we had the same discussion under another PR recently: #69708 (comment), so I wonder if you found staying at the level of the vector dialect problematic for some reason.

I convert vector.transfer_read into vector.load or vector.maskedload depending on the in_bound and mask of vector.transfer_read operator. I already have #71674 to convert vector.load to spirv.load. My goal is to convert all the dialects into spirv dialect in some stage. What is the rational way to lower vector.maskedload to spirv.load?

@kuhar
Copy link
Member

kuhar commented Dec 11, 2023

My goal is to convert all the dialects into spirv dialect in some stage. What is the rational way to lower vector.maskedload to spirv.load?

I'd say that the following would be preferred:

  1. vector.gather --> vector.load + scf.if (already there)
  2. vector.transfer_read/transfer_write --> vector.load/store, vector.maskedload/maskedstore
  3. vector.maskedload/maskedstore --> vector.load/store + scf.if
  4. vector./memref. / load/store --> ConvertVectorToSPIRV/ConvertMemRefToSPIRV (already there)
  5. scf.if --> ConvertSCFToSPIRV (already there)

(or memref.load/store, I don't think there's much difference?)

Unless there's something that we could do much better by going directly to SPIR-V.

@Hsiangkai
Copy link
Contributor Author

My goal is to convert all the dialects into spirv dialect in some stage. What is the rational way to lower vector.maskedload to spirv.load?

I'd say that the following would be preferred:

  1. vector.gather --> vector.load + scf.if (already there)
  2. vector.transfer_read/transfer_write --> vector.load/store, vector.maskedload/maskedstore
  3. vector.maskedload/maskedstore --> vector.load/store + scf.if
  4. vector./memref. / load/store --> ConvertVectorToSPIRV/ConvertMemRefToSPIRV (already there)
  5. scf.if --> ConvertSCFToSPIRV (already there)

(or memref.load/store, I don't think there's much difference?)

Unless there's something that we could do much better by going directly to SPIR-V.

Thanks for your answering. It's very helpful.
About item 3, I can come out 2 different ways to do conversion.

  1. vector.maskedload ---> vector.load + arith.select
  2. vector.maskedload ---> spirv.load + spirv.select (I already use this pattern in this patch.)

Which way is better?

@Hsiangkai
Copy link
Contributor Author

Hi @kuhar,

I rewrite my patch to convert vector.maskedload to vector.load and convert vector.maskedstore to memref.store. Does it make sense for these patterns? Thanks for your help and review.

@Hsiangkai Hsiangkai changed the title [mlir][vector][spirv] Lower vector.maskedload and vector.maskedstore to SPIR-V [mlir][vector] Add patterns for vector masked load/store Dec 12, 2023
@kuhar
Copy link
Member

kuhar commented Dec 12, 2023

Hi @kuhar,

I rewrite my patch to convert vector.maskedload to vector.load and convert vector.maskedstore to memref.store. Does it make sense for these patterns? Thanks for your help and review.

I left some comments. This lowering looks like the right direction to me, but I think we need to introduce control flow to avoid out-of-bounds memory accesses (instead of selects).

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This will be useful in general. A few comments.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only remaining concern remains this:
#74834 (comment)

Do we really prefer scf.for loop over a sequence of scf.if? If we rely on some loop unrolling pattern to do that for us, are dynamic extract elements getting simplified to extract over static indices?

@Hsiangkai
Copy link
Contributor Author

My only remaining concern remains this: #74834 (comment)

Do we really prefer scf.for loop over a sequence of scf.if? If we rely on some loop unrolling pattern to do that for us, are dynamic extract elements getting simplified to extract over static indices?

When we emulate masked load/store, we do not know the length of mask, do we? How to generate a sequence of scf.if according to the mask length?

@kuhar
Copy link
Member

kuhar commented Dec 13, 2023

My only remaining concern remains this: #74834 (comment)
Do we really prefer scf.for loop over a sequence of scf.if? If we rely on some loop unrolling pattern to do that for us, are dynamic extract elements getting simplified to extract over static indices?

When we emulate masked load/store, we do not know the length of mask, do we? How to generate a sequence of scf.if according to the mask length?

We know it based on the vector type. (Modulo scalable vectors probably, but I don't think the current lowering supports them either.)

@Hsiangkai
Copy link
Contributor Author

My only remaining concern remains this: #74834 (comment)
Do we really prefer scf.for loop over a sequence of scf.if? If we rely on some loop unrolling pattern to do that for us, are dynamic extract elements getting simplified to extract over static indices?

When we emulate masked load/store, we do not know the length of mask, do we? How to generate a sequence of scf.if according to the mask length?

We know it based on the vector type. (Modulo scalable vectors probably, but I don't think the current lowering supports them either.)

After thinking it again, yes, we can do it. I can unroll the loop according to the vector type. I will refine it to a sequence of scf.if in this patch. Thank you!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo tests. Also the title of the PR should probably be: '[mlir][vector] Add emulation patterns for ...'

@Hsiangkai Hsiangkai changed the title [mlir][vector] Add patterns for vector masked load/store [mlir][vector] Add emulation patterns for vector masked load/store Dec 15, 2023
@Hsiangkai Hsiangkai requested a review from kuhar December 15, 2023 09:25
In this patch, it will convert

vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru

to

%ivalue = %pass_thru
%m = vector.extract %mask[0]
%result0 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1]
  %combined = vector.insert %v, %ivalue[0]
  scf.yield %combined
} else {
  scf.yield %ivalue
}
%m = vector.extract %mask[1]
%result1 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1 + 1]
  %combined = vector.insert %v, %result0[1]
  scf.yield %combined
} else {
  scf.yield %result0
}
...

It will convert

vector.maskedstore %base[%idx_0, %idx_1], %mask, %value

to

%m = vector.extract %mask[0]
scf.if %m {
  %extracted = vector.extract %value[0]
  memref.store %extracted, %base[%idx_0, %idx_1]
}
%m = vector.extract %mask[1]
scf.if %m {
  %extracted = vector.extract %value[1]
  memref.store %extracted, %base[%idx_0, %idx_1 + 1]
}
...
@Hsiangkai Hsiangkai merged commit f643eec into llvm:main Dec 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants