-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add emulation patterns for vector masked load/store #74834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hsiangkai Wang (Hsiangkai) ChangesUse spirv.mlir.loop and spirv.mlir.selection to lower vector.maskedload and vector.maskedstore. Patch is 21.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74834.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e48f29a4f1702..b32c004e28a1e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -647,6 +647,328 @@ struct VectorStoreOpConverter final
}
};
+mlir::spirv::LoopOp createSpirvLoop(ConversionPatternRewriter &rewriter,
+ Location loc) {
+ auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ loopOp.addEntryAndMergeBlock();
+
+ auto &loopBody = loopOp.getBody();
+ // Create header block.
+ loopBody.getBlocks().insert(std::next(loopBody.begin(), 1), new Block());
+ // Create continue block.
+ loopBody.getBlocks().insert(std::prev(loopBody.end(), 2), new Block());
+
+ return loopOp;
+}
+
+mlir::spirv::SelectionOp
+createSpirvSelection(ConversionPatternRewriter &rewriter, Location loc) {
+ auto selectionOp =
+ rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto &loopBody = selectionOp.getBody();
+ // Create header block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ // Create if-true block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ // Create merge block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ rewriter.create<spirv::MergeOp>(loc);
+
+ return selectionOp;
+}
+
+Value addOffsetToIndices(ConversionPatternRewriter &rewriter, Location loc,
+ SmallVectorImpl<Value> &indices, const Value offset,
+ const SPIRVTypeConverter &typeConverter,
+ const MemRefType memrefType, const Value base) {
+ indices.back() = rewriter.create<spirv::IAddOp>(loc, indices.back(), offset);
+ return spirv::getElementPtr(typeConverter, memrefType, base, indices, loc,
+ rewriter);
+}
+
+Value extractMaskBit(ConversionPatternRewriter &rewriter, Location loc,
+ Value mask, Value offset) {
+ return rewriter.create<spirv::VectorExtractDynamicOp>(
+ loc, rewriter.getI1Type(), mask, offset);
+}
+
+Value extractVectorElement(ConversionPatternRewriter &rewriter, Location loc,
+ Type type, Value vector, Value offset) {
+ return rewriter.create<spirv::VectorExtractDynamicOp>(loc, type, vector,
+ offset);
+}
+
+Value createConstantInteger(ConversionPatternRewriter &rewriter, Location loc,
+ int32_t value) {
+ auto i32Type = rewriter.getI32Type();
+ return rewriter.create<spirv::ConstantOp>(loc, i32Type,
+ IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedload to spirv dialect.
+///
+/// Before:
+///
+/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+/// %buffer = spirv.Variable
+/// spirv.mlir.loop {
+/// spirv.Branch ^bb1(0, %buffer)
+/// ^bb1(%i: i32, %partial: vector):
+/// %m = spirv.VectorExtractDynamic %mask[%i]
+/// %p = spirv.VectorExtractDynamic %pass_thru[%i]
+/// %value = spirv.Load
+/// %s = spirv.Select %m, %value, %p
+/// %v = spirv.VectorInsertDynamic %s, %partial[%i]
+/// spirv.Store %buffer, %v
+/// spirv.Branch ^bb2(%i, %v)
+/// ^bb2(%i: i32, %partial: vector):
+/// %update_i = spirv.IAdd %i, 1
+/// %cond = spirv.SLessThan %update_i, %veclen
+/// spirv.BranchConditional %cond, ^bb1(%update_i, %partial), ^bb3
+/// ^bb3:
+/// spirv.mlir.merge
+/// }
+/// %ret = spirv.Load %buffer
+/// return %ret
+///
+struct VectorMaskedLoadOpConverter final
+ : public OpConversionPattern<vector::MaskedLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = maskedLoadOp.getMemRefType();
+ if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+ return failure();
+
+ VectorType maskVType = maskedLoadOp.getMaskVectorType();
+ if (maskVType.getRank() != 1)
+ return failure();
+ if (maskVType.getShape().size() != 1)
+ return failure();
+
+ // Create a local variable to store the loaded value.
+ auto loc = maskedLoadOp.getLoc();
+ auto vectorType = maskedLoadOp.getVectorType();
+ auto pointerType =
+ spirv::PointerType::get(vectorType, spirv::StorageClass::Function);
+ auto alloc = rewriter.create<spirv::VariableOp>(
+ loc, pointerType, spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
+
+ // Create constants for the loop.
+ Value zero = createConstantInteger(rewriter, loc, 0);
+ Value one = createConstantInteger(rewriter, loc, 1);
+ Value maskLength =
+ createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+ auto emptyVector = rewriter.create<spirv::ConstantOp>(
+ loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+ // Construct a loop to go through the mask value
+ auto loopOp = createSpirvLoop(rewriter, loc);
+
+ auto *headerBlock = loopOp.getHeaderBlock();
+ auto *continueBlock = loopOp.getContinueBlock();
+
+ auto i32Type = rewriter.getI32Type();
+ BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+ BlockArgument partialVector = headerBlock->addArgument(vectorType, loc);
+ BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+ BlockArgument continueVector = continueBlock->addArgument(vectorType, loc);
+
+ // Insert code into loop entry block
+ rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+
+ // Header block needs two arguments: induction variable, updated vector
+ rewriter.create<spirv::BranchOp>(loc, headerBlock,
+ ArrayRef<Value>({zero, emptyVector}));
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointToEnd(headerBlock);
+ auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+ auto scalarType = memrefType.getElementType();
+ auto passThruValule = extractVectorElement(rewriter, loc, scalarType,
+ adaptor.getPassThru(), indVar);
+
+ // Load base[indVar]
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indices = llvm::to_vector<4>(adaptor.getIndices());
+ auto updatedAccessChain =
+ addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+ memrefType, adaptor.getBase());
+ auto loadScalar =
+ rewriter.create<spirv::LoadOp>(loc, scalarType, updatedAccessChain);
+
+ // Select the loaded value or pass-through according to the mask bit.
+ auto valueToInsert = rewriter.create<spirv::SelectOp>(
+ loc, scalarType, maskBit, loadScalar, passThruValule);
+
+ // Insert the selected value to output vector.
+ auto updatedVector = rewriter.create<spirv::VectorInsertDynamicOp>(
+ loc, vectorType, partialVector, valueToInsert, indVar);
+ rewriter.create<spirv::StoreOp>(loc, alloc, updatedVector);
+ rewriter.create<spirv::BranchOp>(loc, continueBlock,
+ ArrayRef<Value>({indVar, updatedVector}));
+
+ // Insert code into continue block
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Update induction variable.
+ auto updatedIndVar =
+ rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+ // Check if the induction variable < length(mask)
+ auto cmpOp =
+ rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+ auto *mergeBlock = loopOp.getMergeBlock();
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, headerBlock,
+ ArrayRef<Value>({updatedIndVar, continueVector}), mergeBlock,
+ std::nullopt);
+
+ // Insert code after loop
+ rewriter.setInsertionPointAfter(loopOp);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(maskedLoadOp, alloc);
+
+ return success();
+ }
+};
+
+/// Convert vector.maskedstore to spirv dialect.
+///
+/// Before:
+///
+/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+/// spirv.mlir.loop {
+/// spirv.Branch ^bb1(0)
+/// ^bb1(%i: i32):
+/// %m = spirv.VectorExtractDynamic %mask[%i]
+/// spirv.mlir.selection {
+/// spirv.BranchConditional %m, ^if_bb1, ^if_bb2
+/// ^if_bb1:
+/// %v = spirv.VectorExtractDynamic %value[%i]
+/// spirv.Store %base[%i], %v
+/// spirv.Branch ^if_bb2
+/// ^if_bb2:
+/// spirv.mlir.merge
+/// }
+/// spirv.Branch ^bb2(%i)
+/// ^bb2(%i: i32):
+/// %update_i = spirv.IAdd %i, 1
+/// %cond = spirv.SLessThan %update_i, %veclen
+/// spirv.BranchConditional %cond, ^bb1, ^bb3
+/// ^bb3:
+/// spirv.mlir.merge
+/// }
+/// return
+///
+struct VectorMaskedStoreOpConverter final
+ : public OpConversionPattern<vector::MaskedStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = maskedStoreOp.getMemRefType();
+ if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+ return failure();
+
+ VectorType maskVType = maskedStoreOp.getMaskVectorType();
+ if (maskVType.getRank() != 1)
+ return failure();
+ if (maskVType.getShape().size() != 1)
+ return failure();
+
+ // Create constants.
+ auto loc = maskedStoreOp.getLoc();
+ Value zero = createConstantInteger(rewriter, loc, 0);
+ Value one = createConstantInteger(rewriter, loc, 1);
+ Value maskLength =
+ createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+ // Construct a loop to go through the mask value
+ auto loopOp = createSpirvLoop(rewriter, loc);
+ auto *headerBlock = loopOp.getHeaderBlock();
+ auto *continueBlock = loopOp.getContinueBlock();
+
+ auto i32Type = rewriter.getI32Type();
+ BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+ BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+
+ // Insert code into loop entry block
+ rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+ rewriter.create<spirv::BranchOp>(loc, headerBlock, ArrayRef<Value>({zero}));
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointToEnd(headerBlock);
+ auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+ auto selectionOp = createSpirvSelection(rewriter, loc);
+ auto *selectionHeaderBlock = selectionOp.getHeaderBlock();
+ auto *selectionMergeBlock = selectionOp.getMergeBlock();
+ auto *selectionTrueBlock = &(*std::next(selectionOp.getBody().begin(), 1));
+
+ // Insert code into selection header block
+ rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, maskBit, selectionTrueBlock, std::nullopt, selectionMergeBlock,
+ std::nullopt);
+
+ // Insert code into selection true block
+ rewriter.setInsertionPointToEnd(selectionTrueBlock);
+ auto scalarType = memrefType.getElementType();
+ auto extractedStoreValue = extractVectorElement(
+ rewriter, loc, scalarType, adaptor.getValueToStore(), indVar);
+
+ // Store base[indVar]
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indices = llvm::to_vector<4>(adaptor.getIndices());
+ auto updatedAccessChain =
+ addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+ memrefType, adaptor.getBase());
+ rewriter.create<spirv::StoreOp>(loc, updatedAccessChain,
+ extractedStoreValue);
+ rewriter.create<spirv::BranchOp>(loc, selectionMergeBlock, std::nullopt);
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointAfter(selectionOp);
+ rewriter.create<spirv::BranchOp>(loc, continueBlock,
+ ArrayRef<Value>({indVar}));
+
+ // Insert code into loop continue block
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Update induction variable.
+ auto updatedIndVar =
+ rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+ // Check if the induction variable < length(mask)
+ auto cmpOp =
+ rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+ auto *mergeBlock = loopOp.getMergeBlock();
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, headerBlock, ArrayRef<Value>({updatedIndVar}), mergeBlock,
+ std::nullopt);
+
+ // Insert code after loop
+ rewriter.setInsertionPointAfter(loopOp);
+ rewriter.replaceOp(maskedStoreOp, loopOp);
+
+ return success();
+ }
+};
+
struct VectorReductionToIntDotProd final
: OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -821,7 +1143,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc..bc9e92981644b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -805,4 +805,115 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
return
}
+// CHECK-LABEL: @vector_maskedload
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+// CHECK: %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: %[[CST_F0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S4:.*]] = spirv.CompositeConstruct %[[CST_F0]], %[[CST_F0]], %[[CST_F0]], %[[CST_F0]] : (f32, f32, f32, f32) -> vector<4xf32>
+// CHECK: %[[S5:.*]] = spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+// CHECK: %[[C0_1:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[C4_1:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[CV0:.*]] = spirv.Constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^bb1(%[[C0_1]], %[[CV0]] : i32, vector<4xf32>)
+// CHECK: ^bb1(%[[S7:.*]]: i32, %[[S8:.*]]: vector<4xf32>): // 2 preds: ^bb0, ^bb2
+// CHECK: %[[S9:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S7]]] : vector<4xi1>, i32
+// CHECK: %[[S10:.*]] = spirv.VectorExtractDynamic %[[S4]][%[[S7]]] : vector<4xf32>, i32
+// CHECK: %[[S11:.*]] = spirv.IAdd %[[S2]], %[[S7]] : i32
+// CHECK: %[[C0_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C0_3:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C5:.*]] = spirv.Constant 5 : i32
+// CHECK: %[[S12:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+// CHECK: %[[S13:.*]] = spirv.IAdd %[[C0_3]], %[[S12]] : i32
+// CHECK: %[[C1_2:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[S14:.*]] = spirv.IMul %[[C1_2]], %[[S11]] : i32
+// CHECK: %[[S15:.*]] = spirv.IAdd %[[S13]], %[[S14]] : i32
+// CHECK: %[[S16:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S15]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S17:.*]] = spirv.Load "StorageBuffer" %[[S16]] : f32
+// CHECK: %[[S18:.*]] = spirv.Select %[[S9]], %[[S17]], %[[S10]] : i1, f32
+// CHECK: %[[S19:.*]] = spirv.VectorInsertDynamic %[[S18]], %[[S8]][%[[S7]]] : vector<4xf32>, i32
+// CHECK: spirv.Store "Function" %[[S5]], %[[S19]] : vector<4xf32>
+// CHECK: spirv.Branch ^bb2(%[[S7]], %[[S19]] : i32, vector<4xf32>)
+// CHECK: ^bb2(%[[S20:.*]]: i32, %[[S21:.*]]: vector<4xf32>): // pred: ^bb1
+// CHECK: %[[S22:.*]] = spirv.IAdd %[[S20]], %[[C1_1]] : i32
+// CHECK: %[[S23:.*]] = spirv.SLessThan %[[S22]], %[[C4_1]] : i32
+// CHECK: spirv.BranchConditional %[[S23]], ^bb1(%[[S22]], %[[S21]] : i32, vector<4xf32>), ^bb3
+// CHECK: ^bb3: // pred: ^bb2
+// CHECK: spirv.mlir.merge
+// CHECK: }
+// CHECK: %[[S6:.*]] = spirv.Load "Function" %[[S5]] : vector<4xf32>
+// CHECK: return %[[S6]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedstore
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+// CHECK: %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: %[[C0_1:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[C4_1:.*]] = spirv.Constant 4 : i32
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^bb1(%[[C0_1]] : i32)
+// CHECK: ^bb1(%[[S4:.*]]: i32): // 2 preds: ^bb0, ^bb2
+// CHECK: %[[S5:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S4]]] : vector<4xi1>, i32
+// CHECK: spirv.mlir.selection {
+// CHECK: spirv.BranchConditional %[[S5]], ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %[[S9:.*]] = spirv.VectorExtractDynamic %[[ARG1]][%[[S4]]] : vector<4xf32>, i32
+// CHECK: %[[S10:.*]] = spirv.IAdd %[[S2]], %[[S4]] : i32
+// CHECK: %[[C0_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C5:.*]] = spirv.Constant 5 : i32
+// CHECK: %[[S11:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+// CHECK: %[[S12:.*]] = spirv.IAdd %[[C1_2]], %[[S11]] : i32
+// CHECK: %[[C1_3:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[S13:.*]] = spirv.IMul %[[C1_3]], %[[S10]] : i32
+// CHECK: %[[S14:.*]] = spirv.I...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Hsiangkai,
This looks very useful but I wonder if we should perform the expansion at the level of the vector dialect instead, like we do with ops like vector.gather
: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp.
I know we had the same discussion under another PR recently: #69708 (comment), so I wonder if you found staying at the level of the vector dialect problematic for some reason.
I convert vector.transfer_read into vector.load or vector.maskedload depending on the in_bound and mask of vector.transfer_read operator. I already have #71674 to convert vector.load to spirv.load. My goal is to convert all the dialects into spirv dialect in some stage. What is the rational way to lower vector.maskedload to spirv.load? |
4a23555
to
82b3d7b
Compare
I'd say that the following would be preferred:
(or Unless there's something that we could do much better by going directly to SPIR-V. |
Thanks for your answering. It's very helpful.
Which way is better? |
82b3d7b
to
04c1e28
Compare
Hi @kuhar, I rewrite my patch to convert vector.maskedload to vector.load and convert vector.maskedstore to memref.store. Does it make sense for these patterns? Thanks for your help and review. |
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
I left some comments. This lowering looks like the right direction to me, but I think we need to introduce control flow to avoid out-of-bounds memory accesses (instead of selects). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This will be useful in general. A few comments.
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-masked-load-store-lowering.mlir
Outdated
Show resolved
Hide resolved
04c1e28
to
dd83d7e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only remaining concern remains this:
#74834 (comment)
Do we really prefer scf.for
loop over a sequence of scf.if
? If we rely on some loop unrolling pattern to do that for us, are dynamic extract elements getting simplified to extract over static indices?
When we emulate masked load/store, we do not know the length of mask, do we? How to generate a sequence of |
We know it based on the vector type. (Modulo scalable vectors probably, but I don't think the current lowering supports them either.) |
After thinking it again, yes, we can do it. I can unroll the loop according to the vector type. I will refine it to a sequence of |
dd83d7e
to
b7d6bce
Compare
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Outdated
Show resolved
Hide resolved
b7d6bce
to
0f2edd1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo tests. Also the title of the PR should probably be: '[mlir][vector] Add emulation patterns for ...'
0f2edd1
to
910eab3
Compare
In this patch, it will convert vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru to %ivalue = %pass_thru %m = vector.extract %mask[0] %result0 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1] %combined = vector.insert %v, %ivalue[0] scf.yield %combined } else { scf.yield %ivalue } %m = vector.extract %mask[1] %result1 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1 + 1] %combined = vector.insert %v, %result0[1] scf.yield %combined } else { scf.yield %result0 } ... It will convert vector.maskedstore %base[%idx_0, %idx_1], %mask, %value to %m = vector.extract %mask[0] scf.if %m { %extracted = vector.extract %value[0] memref.store %extracted, %base[%idx_0, %idx_1] } %m = vector.extract %mask[1] scf.if %m { %extracted = vector.extract %value[1] memref.store %extracted, %base[%idx_0, %idx_1 + 1] } ...
910eab3
to
550ce29
Compare
In this patch, it will convert
to
It will convert
to