Skip to content

Commit 98f6289

Browse files
committed
[mlir][Vector] Add support for Value indices to vector.extract/insert
`vector.extract/insert` ops only support constant indices. This PR is extending them so that arbitrary values can be used instead. This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops Differential Revision: https://reviews.llvm.org/D155034
1 parent 6ebc179 commit 98f6289

File tree

19 files changed

+535
-197
lines changed

19 files changed

+535
-197
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ inline bool isReductionIterator(Attribute attr) {
131131
return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
132132
}
133133

134+
/// Returns the integer numbers in `values`. `values` are expected to be
135+
/// constant operations.
136+
SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
137+
138+
/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
139+
/// be constant operations.
140+
SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
141+
142+
/// Convert `foldResults` into Values. Integer attributes are converted to
143+
/// constant op.
144+
SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
145+
ArrayRef<OpFoldResult> foldResults);
146+
147+
/// Returns the constant index ops in `values`. `values` are expected to be
148+
/// constant operations.
149+
SmallVector<arith::ConstantIndexOp>
150+
getAsConstantIndexOps(ArrayRef<Value> values);
151+
134152
//===----------------------------------------------------------------------===//
135153
// Vector Masking Utilities
136154
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,7 @@ def Vector_ExtractOp :
523523
Vector_Op<"extract", [Pure,
524524
PredOpTrait<"operand and result have same element type",
525525
TCresVTEtIsSameAsOpBase<0, 0>>,
526-
InferTypeOpAdaptorWithIsCompatible]>,
527-
Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
528-
Results<(outs AnyType)> {
526+
InferTypeOpAdaptorWithIsCompatible]> {
529527
let summary = "extract operation";
530528
let description = [{
531529
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
@@ -535,21 +533,55 @@ def Vector_ExtractOp :
535533

536534
```mlir
537535
%1 = vector.extract %0[3]: vector<4x8x16xf32>
538-
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
536+
%2 = vector.extract %0[2, 1, 3]: vector<4x8x16xf32>
539537
%3 = vector.extract %1[]: vector<f32>
538+
%4 = vector.extract %0[%a, %b, %c]: vector<4x8x16xf32>
539+
%5 = vector.extract %0[2, %b]: vector<4x8x16xf32>
540540
```
541541
}];
542+
543+
let arguments = (ins
544+
AnyVectorOfAnyRank:$vector,
545+
Variadic<Index>:$dynamic_position,
546+
DenseI64ArrayAttr:$static_position
547+
);
548+
let results = (outs AnyType:$result);
549+
542550
let builders = [
543-
// Convenience builder which assumes the values in `position` are defined by
544-
// ConstantIndexOp.
545-
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
551+
OpBuilder<(ins "Value":$source, "int64_t":$position)>,
552+
OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
553+
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
554+
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
546555
];
556+
547557
let extraClassDeclaration = [{
548558
VectorType getSourceVectorType() {
549559
return ::llvm::cast<VectorType>(getVector().getType());
550560
}
561+
562+
/// Return a vector with all the static and dynamic position indices.
563+
SmallVector<OpFoldResult> getMixedPosition() {
564+
OpBuilder builder(getContext());
565+
return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
566+
}
567+
568+
unsigned getNumIndices() {
569+
return getStaticPosition().size();
570+
}
571+
572+
bool hasDynamicPosition() {
573+
auto dynPos = getDynamicPosition();
574+
return std::any_of(dynPos.begin(), dynPos.end(),
575+
[](Value operand) { return operand != nullptr; });
576+
}
551577
}];
552-
let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
578+
579+
let assemblyFormat = [{
580+
$vector ``
581+
custom<DynamicIndexList>($dynamic_position, $static_position)
582+
attr-dict `:` type($vector)
583+
}];
584+
553585
let hasCanonicalizer = 1;
554586
let hasFolder = 1;
555587
let hasVerifier = 1;
@@ -638,9 +670,7 @@ def Vector_InsertOp :
638670
Vector_Op<"insert", [Pure,
639671
PredOpTrait<"source operand and result have same element type",
640672
TCresVTEtIsSameAsOpBase<0, 0>>,
641-
AllTypesMatch<["dest", "res"]>]>,
642-
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
643-
Results<(outs AnyVectorOfAnyRank:$res)> {
673+
AllTypesMatch<["dest", "result"]>]> {
644674
let summary = "insert operation";
645675
let description = [{
646676
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
@@ -651,24 +681,53 @@ def Vector_InsertOp :
651681

652682
```mlir
653683
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
654-
%5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
684+
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
655685
%8 = vector.insert %6, %7[] : f32 into vector<f32>
656-
%11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
686+
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
687+
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
657688
```
658689
}];
659-
let assemblyFormat = [{
660-
$source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
661-
}];
690+
691+
let arguments = (ins
692+
AnyType:$source,
693+
AnyVectorOfAnyRank:$dest,
694+
Variadic<Index>:$dynamic_position,
695+
DenseI64ArrayAttr:$static_position
696+
);
697+
let results = (outs AnyVectorOfAnyRank:$result);
662698

663699
let builders = [
664-
// Convenience builder which assumes all values are constant indices.
665-
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
700+
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
701+
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
702+
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
703+
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
666704
];
705+
667706
let extraClassDeclaration = [{
668707
Type getSourceType() { return getSource().getType(); }
669708
VectorType getDestVectorType() {
670709
return ::llvm::cast<VectorType>(getDest().getType());
671710
}
711+
712+
/// Return a vector with all the static and dynamic position indices.
713+
SmallVector<OpFoldResult> getMixedPosition() {
714+
OpBuilder builder(getContext());
715+
return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
716+
}
717+
718+
unsigned getNumIndices() {
719+
return getStaticPosition().size();
720+
}
721+
722+
bool hasDynamicPosition() {
723+
return llvm::any_of(getDynamicPosition(),
724+
[](Value operand) { return operand != nullptr; });
725+
}
726+
}];
727+
728+
let assemblyFormat = [{
729+
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
730+
attr-dict `:` type($source) `into` type($dest)
672731
}];
673732

674733
let hasCanonicalizer = 1;

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
126126
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
127127
}
128128

129+
/// Convert `foldResult` into a Value. Integer attribute is converted to
130+
/// an LLVM constant op.
131+
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
132+
OpFoldResult foldResult) {
133+
if (auto attr = foldResult.dyn_cast<Attribute>()) {
134+
auto intAttr = cast<IntegerAttr>(attr);
135+
return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
136+
}
137+
138+
return foldResult.get<Value>();
139+
}
140+
129141
namespace {
130142

131143
/// Trivial Vector to LLVM conversions
@@ -1079,41 +1091,53 @@ class VectorExtractOpConversion
10791091
auto loc = extractOp->getLoc();
10801092
auto resultType = extractOp.getResult().getType();
10811093
auto llvmResultType = typeConverter->convertType(resultType);
1082-
ArrayRef<int64_t> positionArray = extractOp.getPosition();
1083-
10841094
// Bail if result type cannot be lowered.
10851095
if (!llvmResultType)
10861096
return failure();
10871097

1098+
SmallVector<OpFoldResult> positionVec;
1099+
for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1100+
if (pos.is<Value>())
1101+
// Make sure we use the value that has been already converted to LLVM.
1102+
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1103+
else
1104+
positionVec.push_back(pos);
1105+
}
1106+
10881107
// Extract entire vector. Should be handled by folder, but just to be safe.
1089-
if (positionArray.empty()) {
1108+
ArrayRef<OpFoldResult> position(positionVec);
1109+
if (position.empty()) {
10901110
rewriter.replaceOp(extractOp, adaptor.getVector());
10911111
return success();
10921112
}
10931113

10941114
// One-shot extraction of vector from array (only requires extractvalue).
10951115
if (isa<VectorType>(resultType)) {
1116+
if (extractOp.hasDynamicPosition())
1117+
return failure();
1118+
10961119
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1097-
loc, adaptor.getVector(), positionArray);
1120+
loc, adaptor.getVector(), getAsIntegers(position));
10981121
rewriter.replaceOp(extractOp, extracted);
10991122
return success();
11001123
}
11011124

11021125
// Potential extraction of 1-D vector from array.
11031126
Value extracted = adaptor.getVector();
1104-
if (positionArray.size() > 1) {
1105-
extracted = rewriter.create<LLVM::ExtractValueOp>(
1106-
loc, extracted, positionArray.drop_back());
1107-
}
1127+
if (position.size() > 1) {
1128+
if (extractOp.hasDynamicPosition())
1129+
return failure();
11081130

1109-
// Remaining extraction of element from 1-D LLVM vector
1110-
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1111-
auto constant =
1112-
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
1113-
extracted =
1114-
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
1115-
rewriter.replaceOp(extractOp, extracted);
1131+
SmallVector<int64_t> nMinusOnePosition =
1132+
getAsIntegers(position.drop_back());
1133+
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1134+
nMinusOnePosition);
1135+
}
11161136

1137+
Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1138+
// Remaining extraction of element from 1-D LLVM vector.
1139+
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1140+
lastPosition);
11171141
return success();
11181142
}
11191143
};
@@ -1194,48 +1218,63 @@ class VectorInsertOpConversion
11941218
auto sourceType = insertOp.getSourceType();
11951219
auto destVectorType = insertOp.getDestVectorType();
11961220
auto llvmResultType = typeConverter->convertType(destVectorType);
1197-
ArrayRef<int64_t> positionArray = insertOp.getPosition();
1198-
11991221
// Bail if result type cannot be lowered.
12001222
if (!llvmResultType)
12011223
return failure();
12021224

1225+
SmallVector<OpFoldResult> positionVec;
1226+
for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
1227+
if (pos.is<Value>())
1228+
// Make sure we use the value that has been already converted to LLVM.
1229+
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1230+
else
1231+
positionVec.push_back(pos);
1232+
}
1233+
12031234
// Overwrite entire vector with value. Should be handled by folder, but
12041235
// just to be safe.
1205-
if (positionArray.empty()) {
1236+
ArrayRef<OpFoldResult> position(positionVec);
1237+
if (position.empty()) {
12061238
rewriter.replaceOp(insertOp, adaptor.getSource());
12071239
return success();
12081240
}
12091241

12101242
// One-shot insertion of a vector into an array (only requires insertvalue).
12111243
if (isa<VectorType>(sourceType)) {
1244+
if (insertOp.hasDynamicPosition())
1245+
return failure();
1246+
12121247
Value inserted = rewriter.create<LLVM::InsertValueOp>(
1213-
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
1248+
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
12141249
rewriter.replaceOp(insertOp, inserted);
12151250
return success();
12161251
}
12171252

12181253
// Potential extraction of 1-D vector from array.
12191254
Value extracted = adaptor.getDest();
12201255
auto oneDVectorType = destVectorType;
1221-
if (positionArray.size() > 1) {
1256+
if (position.size() > 1) {
1257+
if (insertOp.hasDynamicPosition())
1258+
return failure();
1259+
12221260
oneDVectorType = reducedVectorTypeBack(destVectorType);
12231261
extracted = rewriter.create<LLVM::ExtractValueOp>(
1224-
loc, extracted, positionArray.drop_back());
1262+
loc, extracted, getAsIntegers(position.drop_back()));
12251263
}
12261264

12271265
// Insertion of an element into a 1-D LLVM vector.
1228-
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1229-
auto constant =
1230-
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
12311266
Value inserted = rewriter.create<LLVM::InsertElementOp>(
12321267
loc, typeConverter->convertType(oneDVectorType), extracted,
1233-
adaptor.getSource(), constant);
1268+
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
12341269

12351270
// Potential insertion of resulting 1-D vector into array.
1236-
if (positionArray.size() > 1) {
1271+
if (position.size() > 1) {
1272+
if (insertOp.hasDynamicPosition())
1273+
return failure();
1274+
12371275
inserted = rewriter.create<LLVM::InsertValueOp>(
1238-
loc, adaptor.getDest(), inserted, positionArray.drop_back());
1276+
loc, adaptor.getDest(), inserted,
1277+
getAsIntegers(position.drop_back()));
12391278
}
12401279

12411280
rewriter.replaceOp(insertOp, inserted);

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,10 +1063,11 @@ struct UnrollTransferReadConversion
10631063
/// If the result of the TransferReadOp has exactly one user, which is a
10641064
/// vector::InsertOp, return that operation's indices.
10651065
void getInsertionIndices(TransferReadOp xferOp,
1066-
SmallVector<int64_t, 8> &indices) const {
1067-
if (auto insertOp = getInsertOp(xferOp))
1068-
indices.assign(insertOp.getPosition().begin(),
1069-
insertOp.getPosition().end());
1066+
SmallVectorImpl<OpFoldResult> &indices) const {
1067+
if (auto insertOp = getInsertOp(xferOp)) {
1068+
auto pos = insertOp.getMixedPosition();
1069+
indices.append(pos.begin(), pos.end());
1070+
}
10701071
}
10711072

10721073
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1110,9 +1111,9 @@ struct UnrollTransferReadConversion
11101111
getXferIndices(b, xferOp, iv, xferIndices);
11111112

11121113
// Indices for the new vector.insert op.
1113-
SmallVector<int64_t, 8> insertionIndices;
1114+
SmallVector<OpFoldResult, 8> insertionIndices;
11141115
getInsertionIndices(xferOp, insertionIndices);
1115-
insertionIndices.push_back(i);
1116+
insertionIndices.push_back(rewriter.getIndexAttr(i));
11161117

11171118
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
11181119
auto newXferOp = b.create<vector::TransferReadOp>(
@@ -1195,10 +1196,11 @@ struct UnrollTransferWriteConversion
11951196
/// If the input of the given TransferWriteOp is an ExtractOp, return its
11961197
/// indices.
11971198
void getExtractionIndices(TransferWriteOp xferOp,
1198-
SmallVector<int64_t, 8> &indices) const {
1199-
if (auto extractOp = getExtractOp(xferOp))
1200-
indices.assign(extractOp.getPosition().begin(),
1201-
extractOp.getPosition().end());
1199+
SmallVectorImpl<OpFoldResult> &indices) const {
1200+
if (auto extractOp = getExtractOp(xferOp)) {
1201+
auto pos = extractOp.getMixedPosition();
1202+
indices.append(pos.begin(), pos.end());
1203+
}
12021204
}
12031205

12041206
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1235,9 +1237,9 @@ struct UnrollTransferWriteConversion
12351237
getXferIndices(b, xferOp, iv, xferIndices);
12361238

12371239
// Indices for the new vector.extract op.
1238-
SmallVector<int64_t, 8> extractionIndices;
1240+
SmallVector<OpFoldResult, 8> extractionIndices;
12391241
getExtractionIndices(xferOp, extractionIndices);
1240-
extractionIndices.push_back(i);
1242+
extractionIndices.push_back(b.getI64IntegerAttr(i));
12411243

12421244
auto extracted =
12431245
b.create<vector::ExtractOp>(loc, vec, extractionIndices);

0 commit comments

Comments
 (0)