Skip to content

Commit f643eec

Browse files
authored
[mlir][vector] Add emulation patterns for vector masked load/store (#74834)
In this patch, it will convert ``` vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru ``` to ``` %ivalue = %pass_thru %m = vector.extract %mask[0] %result0 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1] %combined = vector.insert %v, %ivalue[0] scf.yield %combined } else { scf.yield %ivalue } %m = vector.extract %mask[1] %result1 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1 + 1] %combined = vector.insert %v, %result0[1] scf.yield %combined } else { scf.yield %result0 } ... ``` It will convert ``` vector.maskedstore %base[%idx_0, %idx_1], %mask, %value ``` to ``` %m = vector.extract %mask[0] scf.if %m { %extracted = vector.extract %value[0] memref.store %extracted, %base[%idx_0, %idx_1] } %m = vector.extract %mask[1] scf.if %m { %extracted = vector.extract %value[1] memref.store %extracted, %base[%idx_0, %idx_1 + 1] } ... ```
1 parent ef067f5 commit f643eec

File tree

5 files changed

+294
-0
lines changed

5 files changed

+294
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@ void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
254254
void populateVectorMaskLoweringPatternsForSideEffectingOps(
255255
RewritePatternSet &patterns);
256256

257+
/// Populate the pattern set with the following patterns:
258+
///
259+
/// [VectorMaskedLoadOpConverter]
260+
/// Turns vector.maskedload to scf.if + memref.load
261+
///
262+
/// [VectorMaskedStoreOpConverter]
263+
/// Turns vector.maskedstore to scf.if + memref.store
264+
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
265+
PatternBenefit benefit = 1);
266+
257267
} // namespace vector
258268
} // namespace mlir
259269
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1313
SubsetOpInterfaceImpl.cpp
1414
VectorDistribute.cpp
1515
VectorDropLeadUnitDim.cpp
16+
VectorEmulateMaskedLoadStore.cpp
1617
VectorEmulateNarrowType.cpp
1718
VectorInsertExtractStridedSliceRewritePatterns.cpp
1819
VectorTransferOpTransforms.cpp
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to emulate the
10+
// 'vector.maskedload' and 'vector.maskedstore' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
18+
using namespace mlir;
19+
20+
namespace {
21+
22+
/// Convert vector.maskedload
23+
///
24+
/// Before:
25+
///
26+
/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27+
///
28+
/// After:
29+
///
30+
/// %ivalue = %pass_thru
31+
/// %m = vector.extract %mask[0]
32+
/// %result0 = scf.if %m {
33+
/// %v = memref.load %base[%idx_0, %idx_1]
34+
/// %combined = vector.insert %v, %ivalue[0]
35+
/// scf.yield %combined
36+
/// } else {
37+
/// scf.yield %ivalue
38+
/// }
39+
/// %m = vector.extract %mask[1]
40+
/// %result1 = scf.if %m {
41+
/// %v = memref.load %base[%idx_0, %idx_1 + 1]
42+
/// %combined = vector.insert %v, %result0[1]
43+
/// scf.yield %combined
44+
/// } else {
45+
/// scf.yield %result0
46+
/// }
47+
/// ...
48+
///
49+
struct VectorMaskedLoadOpConverter final
50+
: OpRewritePattern<vector::MaskedLoadOp> {
51+
using OpRewritePattern::OpRewritePattern;
52+
53+
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54+
PatternRewriter &rewriter) const override {
55+
VectorType maskVType = maskedLoadOp.getMaskVectorType();
56+
if (maskVType.getShape().size() != 1)
57+
return rewriter.notifyMatchFailure(
58+
maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59+
60+
Location loc = maskedLoadOp.getLoc();
61+
int64_t maskLength = maskVType.getShape()[0];
62+
63+
Type indexType = rewriter.getIndexType();
64+
Value mask = maskedLoadOp.getMask();
65+
Value base = maskedLoadOp.getBase();
66+
Value iValue = maskedLoadOp.getPassThru();
67+
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68+
Value one = rewriter.create<arith::ConstantOp>(
69+
loc, indexType, IntegerAttr::get(indexType, 1));
70+
for (int64_t i = 0; i < maskLength; ++i) {
71+
auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
72+
73+
auto ifOp = rewriter.create<scf::IfOp>(
74+
loc, maskBit,
75+
[&](OpBuilder &builder, Location loc) {
76+
auto loadedValue =
77+
builder.create<memref::LoadOp>(loc, base, indices);
78+
auto combinedValue =
79+
builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80+
builder.create<scf::YieldOp>(loc, combinedValue.getResult());
81+
},
82+
[&](OpBuilder &builder, Location loc) {
83+
builder.create<scf::YieldOp>(loc, iValue);
84+
});
85+
iValue = ifOp.getResult(0);
86+
87+
indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
88+
}
89+
90+
rewriter.replaceOp(maskedLoadOp, iValue);
91+
92+
return success();
93+
}
94+
};
95+
96+
/// Convert vector.maskedstore
97+
///
98+
/// Before:
99+
///
100+
/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
101+
///
102+
/// After:
103+
///
104+
/// %m = vector.extract %mask[0]
105+
/// scf.if %m {
106+
/// %extracted = vector.extract %value[0]
107+
/// memref.store %extracted, %base[%idx_0, %idx_1]
108+
/// }
109+
/// %m = vector.extract %mask[1]
110+
/// scf.if %m {
111+
/// %extracted = vector.extract %value[1]
112+
/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
113+
/// }
114+
/// ...
115+
///
116+
struct VectorMaskedStoreOpConverter final
117+
: OpRewritePattern<vector::MaskedStoreOp> {
118+
using OpRewritePattern::OpRewritePattern;
119+
120+
LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
121+
PatternRewriter &rewriter) const override {
122+
VectorType maskVType = maskedStoreOp.getMaskVectorType();
123+
if (maskVType.getShape().size() != 1)
124+
return rewriter.notifyMatchFailure(
125+
maskedStoreOp, "expected vector.maskedstore with 1-D mask");
126+
127+
Location loc = maskedStoreOp.getLoc();
128+
int64_t maskLength = maskVType.getShape()[0];
129+
130+
Type indexType = rewriter.getIndexType();
131+
Value mask = maskedStoreOp.getMask();
132+
Value base = maskedStoreOp.getBase();
133+
Value value = maskedStoreOp.getValueToStore();
134+
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
135+
Value one = rewriter.create<arith::ConstantOp>(
136+
loc, indexType, IntegerAttr::get(indexType, 1));
137+
for (int64_t i = 0; i < maskLength; ++i) {
138+
auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
139+
140+
auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
141+
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
142+
auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
143+
rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
144+
145+
rewriter.setInsertionPointAfter(ifOp);
146+
indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
147+
}
148+
149+
rewriter.eraseOp(maskedStoreOp);
150+
151+
return success();
152+
}
153+
};
154+
155+
} // namespace
156+
157+
void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
158+
RewritePatternSet &patterns, PatternBenefit benefit) {
159+
patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
160+
patterns.getContext(), benefit);
161+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
2+
3+
// CHECK-LABEL: @vector_maskedload
4+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
5+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
6+
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
7+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
8+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
9+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
10+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
11+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
12+
// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
13+
// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
14+
// CHECK: %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
15+
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
16+
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
17+
// CHECK: scf.yield %[[S10]] : vector<4xf32>
18+
// CHECK: } else {
19+
// CHECK: scf.yield %[[CST]] : vector<4xf32>
20+
// CHECK: }
21+
// CHECK: %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
22+
// CHECK: %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
23+
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
24+
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
25+
// CHECK: scf.yield %[[S10]] : vector<4xf32>
26+
// CHECK: } else {
27+
// CHECK: scf.yield %[[S2]] : vector<4xf32>
28+
// CHECK: }
29+
// CHECK: %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
30+
// CHECK: %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
31+
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
32+
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
33+
// CHECK: scf.yield %[[S10]] : vector<4xf32>
34+
// CHECK: } else {
35+
// CHECK: scf.yield %[[S4]] : vector<4xf32>
36+
// CHECK: }
37+
// CHECK: %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
38+
// CHECK: %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
39+
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
40+
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
41+
// CHECK: scf.yield %[[S10]] : vector<4xf32>
42+
// CHECK: } else {
43+
// CHECK: scf.yield %[[S6]] : vector<4xf32>
44+
// CHECK: }
45+
// CHECK: return %[[S8]] : vector<4xf32>
46+
func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
47+
%idx_0 = arith.constant 0 : index
48+
%idx_1 = arith.constant 1 : index
49+
%idx_4 = arith.constant 4 : index
50+
%mask = vector.create_mask %idx_1 : vector<4xi1>
51+
%s = arith.constant 0.0 : f32
52+
%pass_thru = vector.splat %s : vector<4xf32>
53+
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
54+
return %0: vector<4xf32>
55+
}
56+
57+
// CHECK-LABEL: @vector_maskedstore
58+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
59+
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
60+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
61+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
62+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
63+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
64+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
65+
// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
66+
// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
67+
// CHECK: scf.if %[[S1]] {
68+
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
69+
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
70+
// CHECK: }
71+
// CHECK: %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
72+
// CHECK: scf.if %[[S2]] {
73+
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
74+
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
75+
// CHECK: }
76+
// CHECK: %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
77+
// CHECK: scf.if %[[S3]] {
78+
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
79+
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
80+
// CHECK: }
81+
// CHECK: %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
82+
// CHECK: scf.if %[[S4]] {
83+
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
84+
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
85+
// CHECK: }
86+
// CHECK: return
87+
// CHECK:}
88+
func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
89+
%idx_0 = arith.constant 0 : index
90+
%idx_1 = arith.constant 1 : index
91+
%idx_4 = arith.constant 4 : index
92+
%mask = vector.create_mask %idx_1 : vector<4xi1>
93+
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
94+
return
95+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
777777
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
778778
}
779779
};
780+
781+
struct TestVectorEmulateMaskedLoadStore final
782+
: public PassWrapper<TestVectorEmulateMaskedLoadStore,
783+
OperationPass<func::FuncOp>> {
784+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
785+
786+
StringRef getArgument() const override {
787+
return "test-vector-emulate-masked-load-store";
788+
}
789+
StringRef getDescription() const override {
790+
return "Test patterns that emulate the maskedload/maskedstore op by "
791+
" memref.load/store and scf.if";
792+
}
793+
void getDependentDialects(DialectRegistry &registry) const override {
794+
registry
795+
.insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
796+
scf::SCFDialect, vector::VectorDialect>();
797+
}
798+
799+
void runOnOperation() override {
800+
RewritePatternSet patterns(&getContext());
801+
populateVectorMaskedLoadStoreEmulationPatterns(patterns);
802+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
803+
}
804+
};
780805
} // namespace
781806

782807
namespace mlir {
@@ -817,6 +842,8 @@ void registerTestVectorLowerings() {
817842
PassRegistration<TestVectorGatherLowering>();
818843

819844
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
845+
846+
PassRegistration<TestVectorEmulateMaskedLoadStore>();
820847
}
821848
} // namespace test
822849
} // namespace mlir

0 commit comments

Comments
 (0)