Skip to content

Commit 5bd81d0

Browse files
committed
first commit in preparation
1 parent 554e27e commit 5bd81d0

File tree

3 files changed

+132
-112
lines changed

3 files changed

+132
-112
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,8 @@ def Vector_InsertStridedSliceOp :
11141114
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
11151115
});
11161116
}
1117+
// \return The indices in dest that the values are inserted to.
1118+
FailureOr<SmallVector<int64_t>> getLinearIndices();
11171119
}];
11181120

11191121
let hasFolder = 1;
@@ -1254,6 +1256,8 @@ def Vector_ExtractStridedSliceOp :
12541256
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
12551257
});
12561258
}
1259+
// \return The indices in source that the values are taken from.
1260+
FailureOr<SmallVector<int64_t>> getLinearIndices();
12571261
}];
12581262
let hasCanonicalizer = 1;
12591263
let hasFolder = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3178,6 +3178,101 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
31783178
stridesAttr);
31793179
}
31803180

3181+
/// Convert an array of attributes into a vector of integers, if possible.
3182+
static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
3183+
if (!attrs)
3184+
return failure();
3185+
SmallVector<int64_t> ints;
3186+
ints.reserve(attrs.size());
3187+
for (auto attr : attrs) {
3188+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
3189+
ints.push_back(intAttr.getInt());
3190+
} else {
3191+
return failure();
3192+
}
3193+
}
3194+
return ints;
3195+
}
3196+
3197+
/// Consider inserting a vector of shape `small` into a vector of shape `large`,
3198+
/// at position `offsets`: this function enumeratates all the indices in `large`
3199+
/// that are written to. The enumeration is with row-major ordering.
3200+
///
3201+
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
3202+
/// positions written to are (1,3) and (1,4), which have linearized indices 8
3203+
/// and 9. So [8,9] is returned.
3204+
///
3205+
/// The length of the returned vector is equal to the number of elements in
3206+
/// the shape `small` (i.e. the product of dimensions of `small`).
3207+
static SmallVector<int64_t>
3208+
getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
3209+
ArrayRef<int64_t> large,
3210+
ArrayRef<int64_t> offsets) {
3211+
3212+
// Example of alignment between, `large`, `small` and `offsets`:
3213+
// large = 4, 5, 6, 7, 8
3214+
// small = 1, 6, 7, 8
3215+
// offsets = 2, 3, 0
3216+
//
3217+
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
3218+
assert((large.size() >= small.size()) &&
3219+
"rank of 'large' cannot be lower than rank of 'small'");
3220+
assert((large.size() >= offsets.size()) &&
3221+
"rank of 'large' cannot be lower than the number of offsets");
3222+
unsigned delta = large.size() - small.size();
3223+
unsigned nOffsets = offsets.size();
3224+
auto getSmall = [&](int64_t i) -> int64_t {
3225+
return i >= delta ? small[i - delta] : 1;
3226+
};
3227+
auto getOffset = [&](int64_t i) -> int64_t {
3228+
return i < nOffsets ? offsets[i] : 0;
3229+
};
3230+
3231+
// Using 2 vectors of indices, at each iteration populate the updated set of
3232+
// indices based on the old set of indices, and the size of the small vector
3233+
// in the current iteration.
3234+
SmallVector<int64_t> indices{0};
3235+
int64_t stride = 1;
3236+
for (int i = large.size() - 1; i >= 0; --i) {
3237+
int64_t currentSize = indices.size();
3238+
int64_t smallSize = getSmall(i);
3239+
int64_t nextSize = currentSize * smallSize;
3240+
SmallVector<int64_t> nextIndices(nextSize);
3241+
int64_t *base = nextIndices.begin();
3242+
int64_t offset = getOffset(i) * stride;
3243+
for (int j = 0; j < smallSize; ++j) {
3244+
for (int k = 0; k < currentSize; ++k) {
3245+
base[k] = indices[k] + offset;
3246+
}
3247+
offset += stride;
3248+
base += currentSize;
3249+
}
3250+
stride *= large[i];
3251+
indices = std::move(nextIndices);
3252+
}
3253+
return indices;
3254+
}
3255+
3256+
FailureOr<SmallVector<int64_t>> InsertStridedSliceOp::getLinearIndices() {
3257+
3258+
// Stride > 1 to be considered if/when the insert_strided_slice supports it.
3259+
if (hasNonUnitStrides())
3260+
return failure();
3261+
3262+
// Only when the destination has a static size can the indices be enumerated.
3263+
if (getType().isScalable())
3264+
return failure();
3265+
3266+
// Only when the offsets are all static can the indices be enumerated.
3267+
FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
3268+
if (failed(offsets))
3269+
return failure();
3270+
3271+
return getStridedSliceInsertionIndices(getSourceVectorType().getShape(),
3272+
getDestVectorType().getShape(),
3273+
offsets.value());
3274+
}
3275+
31813276
// TODO: Should be moved to Tablegen ConfinedAttr attributes.
31823277
template <typename OpType>
31833278
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
@@ -3634,6 +3729,25 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
36343729
stridesAttr);
36353730
}
36363731

3732+
FailureOr<SmallVector<int64_t>> ExtractStridedSliceOp::getLinearIndices() {
3733+
3734+
// Stride > 1 to be considered if/when extract_strided_slice supports it.
3735+
if (hasNonUnitStrides())
3736+
return failure();
3737+
3738+
// Only when the source has a static size can the indices be enumerated.
3739+
if (getSourceVectorType().isScalable())
3740+
return failure();
3741+
3742+
// Only when the offsets are all static can the indices be enumerated.
3743+
FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
3744+
if (failed(offsets))
3745+
return failure();
3746+
3747+
return getStridedSliceInsertionIndices(
3748+
getType().getShape(), getSourceVectorType().getShape(), offsets.value());
3749+
}
3750+
36373751
LogicalResult ExtractStridedSliceOp::verify() {
36383752
auto type = getSourceVectorType();
36393753
auto offsets = getOffsetsAttr();

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 14 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
109109
}
110110
};
111111

112-
template <typename TOp>
113-
static bool stridesAllOne(TOp op) {
114-
static_assert(
115-
std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116-
std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117-
"expected vector.extract_strided_slice or vector.insert_strided_slice");
118-
ArrayAttr strides = op.getStrides();
119-
return llvm::all_of(strides, isOneInteger);
120-
}
121-
122-
/// Convert an array of attributes into a vector of integers, if possible.
123-
static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124-
if (!attrs)
125-
return failure();
126-
SmallVector<int64_t> ints;
127-
ints.reserve(attrs.size());
128-
for (auto attr : attrs) {
129-
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130-
ints.push_back(intAttr.getInt());
131-
} else {
132-
return failure();
133-
}
134-
}
135-
return ints;
136-
}
137-
138-
/// Consider inserting a vector of shape `small` into a vector of shape `large`,
139-
/// at position `offsets`: this function enumeratates all the indices in `large`
140-
/// that are written to. The enumeration is with row-major ordering.
141-
///
142-
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143-
/// positions written to are (1,3) and (1,4), which have linearized indices 8
144-
/// and 9. So [8,9] is returned.
145-
///
146-
/// The length of the returned vector is equal to the number of elements in
147-
/// the shape `small` (i.e. the product of dimensions of `small`).
148-
SmallVector<int64_t> static getStridedSliceInsertionIndices(
149-
ArrayRef<int64_t> small, ArrayRef<int64_t> large,
150-
ArrayRef<int64_t> offsets) {
151-
152-
// Example of alignment between, `large`, `small` and `offsets`:
153-
// large = 4, 5, 6, 7, 8
154-
// small = 1, 6, 7, 8
155-
// offsets = 2, 3, 0
156-
//
157-
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158-
assert((large.size() >= small.size()) &&
159-
"rank of 'large' cannot be lower than rank of 'small'");
160-
assert((large.size() >= offsets.size()) &&
161-
"rank of 'large' cannot be lower than the number of offsets");
162-
unsigned delta = large.size() - small.size();
163-
unsigned nOffsets = offsets.size();
164-
auto getSmall = [&](int64_t i) -> int64_t {
165-
return i >= delta ? small[i - delta] : 1;
166-
};
167-
auto getOffset = [&](int64_t i) -> int64_t {
168-
return i < nOffsets ? offsets[i] : 0;
169-
};
170-
171-
// Using 2 vectors of indices, at each iteration populate the updated set of
172-
// indices based on the old set of indices, and the size of the small vector
173-
// in the current iteration.
174-
SmallVector<int64_t> indices{0};
175-
int64_t stride = 1;
176-
for (int i = large.size() - 1; i >= 0; --i) {
177-
int64_t currentSize = indices.size();
178-
int64_t smallSize = getSmall(i);
179-
int64_t nextSize = currentSize * smallSize;
180-
SmallVector<int64_t> nextIndices(nextSize);
181-
int64_t *base = nextIndices.begin();
182-
int64_t offset = getOffset(i) * stride;
183-
for (int j = 0; j < smallSize; ++j) {
184-
for (int k = 0; k < currentSize; ++k) {
185-
base[k] = indices[k] + offset;
186-
}
187-
offset += stride;
188-
base += currentSize;
189-
}
190-
stride *= large[i];
191-
indices = std::move(nextIndices);
192-
}
193-
return indices;
194-
}
195-
196112
/// This pattern converts a vector.extract_strided_slice operation into a
197113
/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198114
///
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
231147

232148
// Expect a legalization failure if the strides are not all 1 (if ever the
233149
// verifier for extract_strided_slice allows non-1 strides).
234-
if (!stridesAllOne(extractStridedSliceOp)) {
150+
if (extractStridedSliceOp.hasNonUnitStrides()) {
235151
return rewriter.notifyMatchFailure(
236152
extractStridedSliceOp,
237153
"extract_strided_slice with strides != 1 not supported");
238154
}
239155

240-
FailureOr<SmallVector<int64_t>> offsets =
241-
intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242-
if (failed(offsets)) {
156+
FailureOr<SmallVector<int64_t>> indices =
157+
extractStridedSliceOp.getLinearIndices();
158+
if (failed(indices)) {
243159
return rewriter.notifyMatchFailure(extractStridedSliceOp,
244-
"failed to get integer offsets");
160+
"failed to get indices");
245161
}
246162

247-
ArrayRef<int64_t> inputShape =
248-
extractStridedSliceOp.getSourceVectorType().getShape();
249-
250-
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251-
252-
SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253-
outputShape, inputShape, offsets.value());
254-
255163
Value srcVector = adaptor.getVector();
256-
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257-
extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
164+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractStridedSliceOp,
165+
flatOutputType, srcVector,
166+
srcVector, indices.value());
258167
return success();
259168
}
260169
};
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
298207

299208
// Expect a legalization failure if the strides are not all 1 (if ever the
300209
// verifier for insert_strided_slice allows non-1 strides).
301-
if (!stridesAllOne(insertStridedSliceOp)) {
210+
if (insertStridedSliceOp.hasNonUnitStrides()) {
302211
return rewriter.notifyMatchFailure(
303212
insertStridedSliceOp,
304213
"insert_strided_slice with strides != 1 not supported");
305214
}
306215

307-
VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308-
ArrayRef<int64_t> inputShape = inputType.getShape();
309-
310216
VectorType outputType = insertStridedSliceOp.getType();
311-
ArrayRef<int64_t> outputShape = outputType.getShape();
312217
int64_t nOutputElements = outputType.getNumElements();
313218

314-
FailureOr<SmallVector<int64_t>> offsets =
315-
intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316-
if (failed(offsets)) {
219+
FailureOr<SmallVector<int64_t>> sliceIndices =
220+
insertStridedSliceOp.getLinearIndices();
221+
if (failed(sliceIndices))
317222
return rewriter.notifyMatchFailure(insertStridedSliceOp,
318-
"failed to get integer offsets");
319-
}
320-
SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321-
inputShape, outputShape, offsets.value());
223+
"failed to get indices");
322224

323225
SmallVector<int64_t> indices(nOutputElements);
324226
std::iota(indices.begin(), indices.end(), 0);
325-
for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
227+
for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices.value())) {
326228
indices[sliceIndex] = index + nOutputElements;
327229
}
328230

0 commit comments

Comments
 (0)