Skip to content

Commit 98542a3

Browse files
authored
[mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders (#127995)
This PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalization pattern. This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default. There is also a test change which makes the indices of a vector.extract test dynamic. This is so that it doesn't fold away after this pr.
1 parent e3ece07 commit 98542a3

File tree

5 files changed

+72
-111
lines changed

5 files changed

+72
-111
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 56 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,20 +2031,71 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20312031
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
20322032
ArrayRef<int64_t> staticPos,
20332033
int64_t poisonVal) {
2034-
if (!llvm::is_contained(staticPos, poisonVal))
2034+
if (!is_contained(staticPos, poisonVal))
20352035
return {};
20362036

20372037
return ub::PoisonAttr::get(context);
20382038
}
20392039

20402040
/// Fold a vector extract from is a poison source.
20412041
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
2042-
if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2042+
if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
20432043
return srcAttr;
20442044

20452045
return {};
20462046
}
20472047

2048+
/// Fold a vector extract extracting from a DenseElementsAttr.
2049+
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
2050+
Attribute srcAttr) {
2051+
auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2052+
if (!denseAttr) {
2053+
return {};
2054+
}
2055+
2056+
if (denseAttr.isSplat()) {
2057+
Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2058+
if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2059+
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2060+
return newAttr;
2061+
}
2062+
2063+
auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2064+
if (vecTy.isScalable())
2065+
return {};
2066+
2067+
if (extractOp.hasDynamicPosition()) {
2068+
return {};
2069+
}
2070+
2071+
// Materializing subsets of a large constant array can generally lead to
2072+
// explosion in IR size because of different combination of subsets that
2073+
// can exist. However, vector.extract is a restricted form of subset
2074+
// extract where you can only extract non-overlapping (or the same) subset for
2075+
// a given rank of the subset. Because of this property, the IR size can only
2076+
// increase at most by `rank * size(array)` from a single constant array being
2077+
// extracted by multiple extracts.
2078+
2079+
// Calculate the linearized position of the continuous chunk of elements to
2080+
// extract.
2081+
SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2082+
copy(extractOp.getStaticPosition(), completePositions.begin());
2083+
int64_t startPos =
2084+
linearize(completePositions, computeStrides(vecTy.getShape()));
2085+
auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2086+
2087+
TypedAttr newAttr;
2088+
if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2089+
SmallVector<Attribute> elementValues(
2090+
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2091+
newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2092+
} else {
2093+
newAttr = *denseValuesBegin;
2094+
}
2095+
2096+
return newAttr;
2097+
}
2098+
20482099
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20492100
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20502101
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2056,6 +2107,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20562107
return res;
20572108
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
20582109
return res;
2110+
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2111+
return res;
20592112
if (succeeded(foldExtractOpFromExtractChain(*this)))
20602113
return getResult();
20612114
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2119,80 +2172,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21192172
}
21202173
};
21212174

2122-
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2123-
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2124-
public:
2125-
using OpRewritePattern::OpRewritePattern;
2126-
2127-
LogicalResult matchAndRewrite(ExtractOp extractOp,
2128-
PatternRewriter &rewriter) const override {
2129-
// Return if 'ExtractOp' operand is not defined by a splat vector
2130-
// ConstantOp.
2131-
Value sourceVector = extractOp.getVector();
2132-
Attribute vectorCst;
2133-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2134-
return failure();
2135-
auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2136-
if (!splat)
2137-
return failure();
2138-
TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2139-
if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2140-
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2141-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2142-
return success();
2143-
}
2144-
};
2145-
2146-
// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2147-
class ExtractOpNonSplatConstantFolder final
2148-
: public OpRewritePattern<ExtractOp> {
2149-
public:
2150-
using OpRewritePattern::OpRewritePattern;
2151-
2152-
LogicalResult matchAndRewrite(ExtractOp extractOp,
2153-
PatternRewriter &rewriter) const override {
2154-
// TODO: Canonicalization for dynamic position not implemented yet.
2155-
if (extractOp.hasDynamicPosition())
2156-
return failure();
2157-
2158-
// Return if 'ExtractOp' operand is not defined by a compatible vector
2159-
// ConstantOp.
2160-
Value sourceVector = extractOp.getVector();
2161-
Attribute vectorCst;
2162-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2163-
return failure();
2164-
2165-
auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2166-
if (vecTy.isScalable())
2167-
return failure();
2168-
2169-
// The splat case is handled by `ExtractOpSplatConstantFolder`.
2170-
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2171-
if (!dense || dense.isSplat())
2172-
return failure();
2173-
2174-
// Calculate the linearized position of the continuous chunk of elements to
2175-
// extract.
2176-
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2177-
copy(extractOp.getStaticPosition(), completePositions.begin());
2178-
int64_t elemBeginPosition =
2179-
linearize(completePositions, computeStrides(vecTy.getShape()));
2180-
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2181-
2182-
TypedAttr newAttr;
2183-
if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2184-
SmallVector<Attribute> elementValues(
2185-
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2186-
newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2187-
} else {
2188-
newAttr = *denseValuesBegin;
2189-
}
2190-
2191-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2192-
return success();
2193-
}
2194-
};
2195-
21962175
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
21972176
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
21982177
public:
@@ -2330,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23302309

23312310
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23322311
MLIRContext *context) {
2333-
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2334-
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2312+
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
23352313
results.add(foldExtractFromShapeCastToShapeCast);
23362314
results.add(foldExtractFromFromElements);
23372315
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
3232

3333
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
3434
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
35-
// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
36-
// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
3735

38-
// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
39-
// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
40-
// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
41-
42-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
36+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
4337
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
4438

4539
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
175169
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
176170
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
177171
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
178-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
179-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
180-
181-
// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
182-
// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
183-
// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
184-
// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
172+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
173+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
174+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
185175

186-
// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
187-
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
176+
// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
177+
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
188178
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
189179
// CHECK: }
190180

@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
675665
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
676666
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
677667
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
678-
// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
679-
// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
680-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
668+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
681669
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
682670
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
683671

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
310310
// -----
311311

312312
// ALL-LABEL: test_vector_extract_scalar
313-
func.func @test_vector_extract_scalar() {
313+
func.func @test_vector_extract_scalar(%idx : index) {
314314
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
315315
// ALL-NOT: vector.shuffle
316316
// ALL: vector.extract
317317
// ALL-NOT: vector.shuffle
318-
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
318+
%0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
319319
return
320320
}
321321

mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
101101

102102
// CHECK-LABEL: func @transfer_write_arith_constant(
103103
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
104-
// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
105-
// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
106-
// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
104+
// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
105+
// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
107106
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
108107
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
109108
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
242242
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
243243
// CHECK-SAME: %[[VAL_4:.*]]: index,
244244
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
245+
// CHECK: %[[TRUE:.*]] = arith.constant true
245246
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
246-
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
247247

248248
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
249249
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
250250

251-
// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
252251
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
253-
// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
252+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
254253
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
255254
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
256255

257-
// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
258256
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
259-
// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
257+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
260258
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
261259
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
262260

263-
// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
264261
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
265-
// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
262+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
266263
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
267264
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
268265

269-
// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
270266
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
271-
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
267+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
272268
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
273269
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
274270

0 commit comments

Comments
 (0)