Skip to content

Commit ed70312

Browse files
committed
[mlir][vector] Update syntax and representation of insert/extract_strided_slice
This commit updates the representation of both `extract_strided_slice` and `insert_strided_slice` to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t. Because previously the offsets, strides, and sizes were in the attribute dictionary (with no special syntax), simply replacing the attribute types with `DenseI64ArrayAttr` would be a syntax break. So since a break is unavoidable this commit also tackles a long-standing TODO: ```mlir // TODO: Evolve to a range form syntax similar to: %1 = vector.extract_strided_slice %0[0:2:1][2:4:1] : vector<4x8x16xf32> to vector<2x4x16xf32> ``` This is done by introducing a new `StridedSliceAttr` attribute that can be used for both operations, with syntax based on the above example. See the attribute documentation `VectorAttributes.td` for a full overview.
1 parent dac9042 commit ed70312

File tree

11 files changed

+333
-311
lines changed

11 files changed

+333
-311
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
include "mlir/Dialect/Vector/IR/Vector.td"
1717
include "mlir/IR/EnumAttr.td"
1818

19+
class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
20+
: AttrDef<Vector_Dialect, attrName, traits> {
21+
let mnemonic = attrMnemonic;
22+
}
23+
1924
// The "kind" of combining function for contractions and reductions.
2025
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
2126
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,63 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
8287
let assemblyFormat = "`<` $value `>`";
8388
}
8489

90+
def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
91+
{
92+
let summary = "strided vector slice";
93+
94+
let description = [{
95+
An attribute that represents a strided slice of a vector.
96+
97+
*Syntax:*
98+
99+
```
100+
offset = integer-literal
101+
stride = integer-literal
102+
size = integer-literal
103+
offset-list = offset (`,` offset)*
104+
105+
// Without sizes (used for insert_strided_slice)
106+
strided-slice-without-sizes = offset-list? (`[` offset `:` stride `]`)+
107+
108+
// With sizes (used for extract_strided_slice)
109+
strided-slice-with-sizes = (`[` offset `:` size `:` stride `]`)+
110+
```
111+
112+
*Examples:*
113+
114+
Without sizes:
115+
116+
`[0:1][4:2]`
117+
118+
- The first dimension starts at offset 0 and is strided by 1
119+
- The second dimension starts at offset 4 and is strided by 2
120+
121+
`[0, 1, 2][3:1][4:8]`
122+
123+
- The first three dimensions are indexed without striding (offsets 0, 1, 2)
124+
- The fourth dimension starts at offset 3 and is strided by 1
125+
- The fifth dimension starts at offset 4 and is strided by 8
126+
127+
With sizes (used for extract_strided_slice)
128+
129+
`[0:2:4][2:4:3]`
130+
131+
- The first dimension starts at offset 0, has size 2, and is strided by 4
132+
- The second dimension starts at offset 2, has size 4, and is strided by 3
133+
}];
134+
135+
let parameters = (ins
136+
ArrayRefParameter<"int64_t">:$offsets,
137+
OptionalArrayRefParameter<"int64_t">:$sizes,
138+
ArrayRefParameter<"int64_t">:$strides
139+
);
140+
141+
let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
142+
return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
143+
}]>
144+
];
145+
146+
let hasCustomAssemblyFormat = 1;
147+
}
148+
85149
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
10401040
PredOpTrait<"operand #0 and result have same element type",
10411041
TCresVTEtIsSameAsOpBase<0, 0>>,
10421042
AllTypesMatch<["dest", "res"]>]>,
1043-
Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
1044-
I64ArrayAttr:$strides)>,
1043+
Arguments<(ins AnyVector:$source, AnyVector:$dest,
1044+
Vector_StridedSliceAttr:$strided_slice)>,
10451045
Results<(outs AnyVector:$res)> {
10461046
let summary = "strided_slice operation";
10471047
let description = [{
@@ -1059,14 +1059,13 @@ def Vector_InsertStridedSliceOp :
10591059
Example:
10601060

10611061
```mlir
1062-
%2 = vector.insert_strided_slice %0, %1
1063-
{offsets = [0, 0, 2], strides = [1, 1]}:
1064-
vector<2x4xf32> into vector<16x4x8xf32>
1062+
%2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
1063+
: vector<2x4xf32> into vector<16x4x8xf32>
10651064
```
10661065
}];
10671066

10681067
let assemblyFormat = [{
1069-
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
1068+
$source `,` $dest `` $strided_slice attr-dict `:` type($source) `into` type($dest)
10701069
}];
10711070

10721071
let builders = [
@@ -1081,10 +1080,13 @@ def Vector_InsertStridedSliceOp :
10811080
return ::llvm::cast<VectorType>(getDest().getType());
10821081
}
10831082
bool hasNonUnitStrides() {
1084-
return llvm::any_of(getStrides(), [](Attribute attr) {
1085-
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
1083+
return llvm::any_of(getStrides(), [](int64_t stride) {
1084+
return stride != 1;
10861085
});
10871086
}
1087+
1088+
ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
1089+
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
10881090
}];
10891091

10901092
let hasFolder = 1;
@@ -1182,8 +1184,7 @@ def Vector_ExtractStridedSliceOp :
11821184
Vector_Op<"extract_strided_slice", [Pure,
11831185
PredOpTrait<"operand and result have same element type",
11841186
TCresVTEtIsSameAsOpBase<0, 0>>]>,
1185-
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
1186-
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
1187+
Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
11871188
Results<(outs AnyVector)> {
11881189
let summary = "extract_strided_slice operation";
11891190
let description = [{
@@ -1200,13 +1201,8 @@ def Vector_ExtractStridedSliceOp :
12001201
Example:
12011202

12021203
```mlir
1203-
%1 = vector.extract_strided_slice %0
1204-
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
1205-
vector<4x8x16xf32> to vector<2x4x16xf32>
1206-
1207-
// TODO: Evolve to a range form syntax similar to:
12081204
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
1209-
vector<4x8x16xf32> to vector<2x4x16xf32>
1205+
: vector<4x8x16xf32> to vector<2x4x16xf32>
12101206
```
12111207
}];
12121208
let builders = [
@@ -1217,17 +1213,20 @@ def Vector_ExtractStridedSliceOp :
12171213
VectorType getSourceVectorType() {
12181214
return ::llvm::cast<VectorType>(getVector().getType());
12191215
}
1220-
void getOffsets(SmallVectorImpl<int64_t> &results);
12211216
bool hasNonUnitStrides() {
1222-
return llvm::any_of(getStrides(), [](Attribute attr) {
1223-
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
1217+
return llvm::any_of(getStrides(), [](int64_t stride) {
1218+
return stride != 1;
12241219
});
12251220
}
1221+
1222+
ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
1223+
ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
1224+
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
12261225
}];
12271226
let hasCanonicalizer = 1;
12281227
let hasFolder = 1;
12291228
let hasVerifier = 1;
1230-
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
1229+
let assemblyFormat = "$vector `` $strided_slice attr-dict `:` type($vector) `to` type(results)";
12311230
}
12321231

12331232
// TODO: Tighten semantics so that masks and inbounds can't be used

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
940940
return success();
941941
}
942942

943-
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944-
SmallVectorImpl<int64_t> &results) {
945-
for (auto attr : arrayAttr)
946-
results.push_back(cast<IntegerAttr>(attr).getInt());
947-
}
948-
949943
static LogicalResult
950944
convertExtractStridedSlice(RewriterBase &rewriter,
951945
vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
996990
auto sourceVector = it->second;
997991

998992
// offset and sizes at warp-level of onwership.
999-
SmallVector<int64_t> offsets;
1000-
populateFromInt64AttrArray(op.getOffsets(), offsets);
993+
ArrayRef<int64_t> offsets = op.getOffsets();
1001994

1002-
SmallVector<int64_t> sizes;
1003-
populateFromInt64AttrArray(op.getSizes(), sizes);
1004995
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005996

1006997
// Compute offset in vector registers. Note that the mma.sync vector registers

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
4646
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
4747
return cast<IntegerAttr>(attr[0]).getInt();
4848
}
49-
static uint64_t getFirstIntValue(ArrayAttr attr) {
50-
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51-
}
5249
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
5350
auto attr = foldResults[0].dyn_cast<Attribute>();
5451
if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
187184
if (!dstType)
188185
return failure();
189186

190-
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
191-
uint64_t size = getFirstIntValue(extractOp.getSizes());
192-
uint64_t stride = getFirstIntValue(extractOp.getStrides());
187+
int64_t offset = extractOp.getOffsets().front();
188+
int64_t size = extractOp.getSizes().front();
189+
int64_t stride = extractOp.getStrides().front();
193190
if (stride != 1)
194191
return failure();
195192

@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
323320
Value srcVector = adaptor.getOperands().front();
324321
Value dstVector = adaptor.getOperands().back();
325322

326-
uint64_t stride = getFirstIntValue(insertOp.getStrides());
323+
uint64_t stride = insertOp.getStrides().front();
327324
if (stride != 1)
328325
return failure();
329-
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
326+
uint64_t offset = insertOp.getOffsets().front();
330327

331328
if (isa<spirv::ScalarType>(srcVector.getType())) {
332329
assert(!isa<spirv::ScalarType>(dstVector.getType()));

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final
549549
if (failed(ext))
550550
return failure();
551551

552-
VectorType origTy = op.getType();
553-
VectorType extractTy =
554-
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
555552
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
556-
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
553+
op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
557554
op.getStrides());
558555
ext->recreateAndReplace(rewriter, op, newExtract);
559556
return success();

0 commit comments

Comments
 (0)