Skip to content

Commit c7be2de

Browse files
authored
[mlir][VectorOps] Add fold ExtractOp(CreateMask) -> CreateMask (#69456)
This allows folding extracts from `vector.create_mask` ops that have a known value. Currently, there's no fold for this, but you get the same effect from the unrolling in LowerVectorMask (part of -convert-vector-to-llvm), then folds after that. However, for a future patch, this simplification needs to be done before lowering to LLVM, hence the need for this fold. E.g.: ``` %0 = vector.create_mask %c1, %dimA, %dimB : vector<1x[4]x[4]xi1> %1 = vector.extract %mask[0] : vector<[4]x[4]xi1> ``` -> ``` %0 = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> ```
1 parent febf5c9 commit c7be2de

File tree

2 files changed

+186
-1
lines changed

2 files changed

+186
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ static MaskFormat getMaskFormat(Value mask) {
100100
return MaskFormat::AllTrue;
101101
if (allFalse)
102102
return MaskFormat::AllFalse;
103+
} else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
104+
// Finds all-false create_masks. An all-true create_mask requires all
105+
// dims to be constants, so that'll be folded to a constant_mask, then
106+
// detected in the constant_mask case.
107+
auto maskOperands = m.getOperands();
108+
for (Value operand : maskOperands) {
109+
if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
110+
int64_t dimSize =
111+
llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
112+
if (dimSize <= 0)
113+
return MaskFormat::AllFalse;
114+
}
115+
}
116+
return MaskFormat::Unknown;
103117
}
104118
return MaskFormat::Unknown;
105119
}
@@ -1942,6 +1956,71 @@ class ExtractOpNonSplatConstantFolder final
19421956
}
19431957
};
19441958

1959+
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
1960+
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
1961+
public:
1962+
using OpRewritePattern::OpRewritePattern;
1963+
1964+
LogicalResult matchAndRewrite(ExtractOp extractOp,
1965+
PatternRewriter &rewriter) const override {
1966+
auto createMaskOp =
1967+
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
1968+
if (!createMaskOp)
1969+
return failure();
1970+
1971+
VectorType extractedMaskType =
1972+
llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
1973+
1974+
if (!extractedMaskType)
1975+
return failure();
1976+
1977+
auto maskOperands = createMaskOp.getOperands();
1978+
ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
1979+
VectorType maskType = createMaskOp.getVectorType();
1980+
1981+
bool containsUnknownDims = false;
1982+
bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
1983+
1984+
for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
1985+
dimIdx++) {
1986+
int64_t pos = extractOpPos[dimIdx];
1987+
Value operand = maskOperands[dimIdx];
1988+
auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
1989+
if (!constantOp) {
1990+
// Bounds of this dim unknown.
1991+
containsUnknownDims = true;
1992+
continue;
1993+
}
1994+
1995+
int64_t createMaskBound =
1996+
llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1997+
1998+
if (pos != ShapedType::kDynamic) {
1999+
// If any position is outside the range from the `create_mask`, then the
2000+
// extracted mask will be all-false.
2001+
allFalse |= pos >= createMaskBound;
2002+
} else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2003+
// This dim is not all-true and since this is a dynamic index we don't
2004+
// know if the extraction is within the true or false region.
2005+
// Note: Zero dims have already handled via getMaskFormat().
2006+
containsUnknownDims = true;
2007+
}
2008+
}
2009+
2010+
if (allFalse) {
2011+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2012+
extractOp, DenseElementsAttr::get(extractedMaskType, false));
2013+
} else if (!containsUnknownDims) {
2014+
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2015+
extractOp, extractedMaskType,
2016+
maskOperands.drop_front(extractOpPos.size()));
2017+
} else {
2018+
return failure();
2019+
}
2020+
return success();
2021+
}
2022+
};
2023+
19452024
// Folds extract(shape_cast(..)) into shape_cast when the total element count
19462025
// does not change.
19472026
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -1968,7 +2047,7 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
19682047
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
19692048
MLIRContext *context) {
19702049
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
1971-
ExtractOpFromBroadcast>(context);
2050+
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
19722051
results.add(foldExtractFromShapeCastToShapeCast);
19732052
}
19742053

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(
6767

6868
// -----
6969

70+
// CHECK-LABEL: extract_from_create_mask
71+
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
72+
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
73+
%c2 = arith.constant 2 : index
74+
%mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
75+
// CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
76+
// CHECK-NOT: vector.extract
77+
%extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
78+
return %extract : vector<[4]x[4]xi1>
79+
}
80+
81+
// -----
82+
83+
// CHECK-LABEL: extract_from_create_mask_all_false
84+
func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
85+
%c2 = arith.constant 2 : index
86+
%mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
87+
// CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
88+
// CHECK-NOT: vector.extract
89+
%extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
90+
return %extract : vector<[4]x[4]xi1>
91+
}
92+
93+
// -----
94+
95+
// CHECK-LABEL: extract_from_create_mask_leading_scalable
96+
// CHECK-SAME: %[[DIM0:.*]]: index
97+
func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
98+
%c3 = arith.constant 3 : index
99+
%mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
100+
// CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
101+
// CHECK-NOT: vector.extract
102+
%extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
103+
return %extract : vector<8xi1>
104+
}
105+
106+
// -----
107+
108+
// CHECK-LABEL: extract_from_create_mask_dynamic_position
109+
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
110+
func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
111+
%c4 = arith.constant 4 : index
112+
%c3 = arith.constant 3 : index
113+
%mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
114+
// CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
115+
// CHECK-NOT: vector.extract
116+
%extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
117+
return %extract : vector<6xi1>
118+
}
119+
120+
// -----
121+
122+
// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
123+
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
124+
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
125+
%c0 = arith.constant 0 : index
126+
%c1 = arith.constant 1 : index
127+
%mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
128+
// CHECK: arith.constant dense<false> : vector<6xi1>
129+
// CHECK-NOT: vector.extract
130+
%extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1>
131+
return %extract : vector<6xi1>
132+
}
133+
134+
// -----
135+
136+
// CHECK-LABEL: extract_from_create_mask_dynamic_position_unknown
137+
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
138+
func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %index: index) -> vector<6xi1> {
139+
%c2 = arith.constant 2 : index
140+
%mask = vector.create_mask %c2, %dim0 : vector<4x6xi1>
141+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
142+
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1>
143+
// CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1>
144+
%extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1>
145+
return %extract : vector<6xi1>
146+
}
147+
148+
// -----
149+
150+
// CHECK-LABEL: extract_from_create_mask_mixed_position_unknown
151+
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
152+
func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0: index) -> vector<4xi1> {
153+
%c2 = arith.constant 2 : index
154+
%mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1>
155+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
156+
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1>
157+
// CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1>
158+
%extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1>
159+
return %extract : vector<4xi1>
160+
}
161+
162+
// -----
163+
164+
// CHECK-LABEL: extract_from_non_constant_create_mask
165+
// CHECK-SAME: %[[DIM0:.*]]: index
166+
func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1> {
167+
%mask = vector.create_mask %dim0, %dim0 : vector<[2]x[2]xi1>
168+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM0]] : vector<[2]x[2]xi1>
169+
// CHECK-NEXT: vector.extract %[[MASK]][0] : vector<[2]xi1> from vector<[2]x[2]xi1>
170+
%extract = vector.extract %mask[0] : vector<[2]xi1> from vector<[2]x[2]xi1>
171+
return %extract : vector<[2]xi1>
172+
}
173+
174+
// -----
175+
70176
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
71177
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
72178
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>

0 commit comments

Comments
 (0)