Skip to content

[mlir][vector] Add emulation patterns for vector masked load/store #74834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
void populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns);

/// Populate the pattern set with the following patterns:
///
/// [VectorMaskedLoadOpConverter]
/// Turns vector.maskedload to scf.if + memref.load
///
/// [VectorMaskedStoreOpConverter]
/// Turns vector.maskedstore to scf.if + memref.store
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
SubsetOpInterfaceImpl.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorTransferOpTransforms.cpp
Expand Down
161 changes: 161 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to emulate the
// 'vector.maskedload' and 'vector.maskedstore' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"

using namespace mlir;

namespace {

/// Convert vector.maskedload
///
/// Before:
///
/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
///
/// After:
///
/// %ivalue = %pass_thru
/// %m = vector.extract %mask[0]
/// %result0 = scf.if %m {
/// %v = memref.load %base[%idx_0, %idx_1]
/// %combined = vector.insert %v, %ivalue[0]
/// scf.yield %combined
/// } else {
/// scf.yield %ivalue
/// }
/// %m = vector.extract %mask[1]
/// %result1 = scf.if %m {
/// %v = memref.load %base[%idx_0, %idx_1 + 1]
/// %combined = vector.insert %v, %result0[1]
/// scf.yield %combined
/// } else {
/// scf.yield %result0
/// }
/// ...
///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
PatternRewriter &rewriter) const override {
VectorType maskVType = maskedLoadOp.getMaskVectorType();
if (maskVType.getShape().size() != 1)
return rewriter.notifyMatchFailure(
maskedLoadOp, "expected vector.maskedstore with 1-D mask");

Location loc = maskedLoadOp.getLoc();
int64_t maskLength = maskVType.getShape()[0];

Type indexType = rewriter.getIndexType();
Value mask = maskedLoadOp.getMask();
Value base = maskedLoadOp.getBase();
Value iValue = maskedLoadOp.getPassThru();
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
Value one = rewriter.create<arith::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);

auto ifOp = rewriter.create<scf::IfOp>(
loc, maskBit,
[&](OpBuilder &builder, Location loc) {
auto loadedValue =
builder.create<memref::LoadOp>(loc, base, indices);
auto combinedValue =
builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
builder.create<scf::YieldOp>(loc, combinedValue.getResult());
},
[&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, iValue);
});
iValue = ifOp.getResult(0);

indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
}

rewriter.replaceOp(maskedLoadOp, iValue);

return success();
}
};

/// Convert vector.maskedstore
///
/// Before:
///
/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
///
/// After:
///
/// %m = vector.extract %mask[0]
/// scf.if %m {
/// %extracted = vector.extract %value[0]
/// memref.store %extracted, %base[%idx_0, %idx_1]
/// }
/// %m = vector.extract %mask[1]
/// scf.if %m {
/// %extracted = vector.extract %value[1]
/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
/// }
/// ...
///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
PatternRewriter &rewriter) const override {
VectorType maskVType = maskedStoreOp.getMaskVectorType();
if (maskVType.getShape().size() != 1)
return rewriter.notifyMatchFailure(
maskedStoreOp, "expected vector.maskedstore with 1-D mask");

Location loc = maskedStoreOp.getLoc();
int64_t maskLength = maskVType.getShape()[0];

Type indexType = rewriter.getIndexType();
Value mask = maskedStoreOp.getMask();
Value base = maskedStoreOp.getBase();
Value value = maskedStoreOp.getValueToStore();
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
Value one = rewriter.create<arith::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);

auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);

rewriter.setInsertionPointAfter(ifOp);
indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
}

rewriter.eraseOp(maskedStoreOp);

return success();
}
};

} // namespace

void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
patterns.getContext(), benefit);
}
95 changes: 95 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s

// CHECK-LABEL: @vector_maskedload
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
// CHECK: %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
// CHECK: scf.yield %[[S10]] : vector<4xf32>
// CHECK: } else {
// CHECK: scf.yield %[[CST]] : vector<4xf32>
// CHECK: }
// CHECK: %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
// CHECK: %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
// CHECK: scf.yield %[[S10]] : vector<4xf32>
// CHECK: } else {
// CHECK: scf.yield %[[S2]] : vector<4xf32>
// CHECK: }
// CHECK: %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
// CHECK: %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
// CHECK: scf.yield %[[S10]] : vector<4xf32>
// CHECK: } else {
// CHECK: scf.yield %[[S4]] : vector<4xf32>
// CHECK: }
// CHECK: %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
// CHECK: %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
// CHECK: scf.yield %[[S10]] : vector<4xf32>
// CHECK: } else {
// CHECK: scf.yield %[[S6]] : vector<4xf32>
// CHECK: }
// CHECK: return %[[S8]] : vector<4xf32>
func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
%pass_thru = vector.splat %s : vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}

// CHECK-LABEL: @vector_maskedstore
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
// CHECK: scf.if %[[S1]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
// CHECK: }
// CHECK: %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
// CHECK: scf.if %[[S2]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
// CHECK: }
// CHECK: %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
// CHECK: scf.if %[[S3]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
// CHECK: }
// CHECK: %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
// CHECK: scf.if %[[S4]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
// CHECK: }
// CHECK: return
// CHECK:}
func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
return
}
27 changes: 27 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorEmulateMaskedLoadStore final
: public PassWrapper<TestVectorEmulateMaskedLoadStore,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)

StringRef getArgument() const override {
return "test-vector-emulate-masked-load-store";
}
StringRef getDescription() const override {
return "Test patterns that emulate the maskedload/maskedstore op by "
" memref.load/store and scf.if";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
scf::SCFDialect, vector::VectorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorMaskedLoadStoreEmulationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace

namespace mlir {
Expand Down Expand Up @@ -817,6 +842,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();

PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

PassRegistration<TestVectorEmulateMaskedLoadStore>();
}
} // namespace test
} // namespace mlir