Skip to content

Commit 9f12215

Browse files
Recommit "[mlir][vector] Allow unroll of contraction in arbitrary order"
Fixed issue with vector.contract default unroll permutation. Adds support for vector unroll transformations to unroll in different orders. For example, the vector.contract can be unrolled into a smaller set of contractions. There is a choice of how to unroll the decomposition based on the traversal order of (dim0, dim1, dim2). The choice of traversal order can now be specified by a callback which given by the caller of the transform. For now, only the vector.contract, vector.transfer_read/transfer_write operations support the callback. Differential Revision: https://reviews.llvm.org/D127004
1 parent 064db24 commit 9f12215

File tree

5 files changed

+397
-76
lines changed

5 files changed

+397
-76
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ struct UnrollVectorOptions {
128128
};
129129
return *this;
130130
}
131+
132+
/// Function that returns the traversal order (in terms of "for loop order",
133+
/// i.e. slowest varying dimension to fastest varying dimension) that shoudl
134+
/// be used when unrolling the given operation into units of the native vector
135+
/// size.
136+
using UnrollTraversalOrderFnType =
137+
std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
138+
UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
139+
UnrollVectorOptions &
140+
setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
141+
traversalOrderCallback = std::move(traversalOrderFn);
142+
return *this;
143+
}
131144
};
132145

133146
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp

Lines changed: 110 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1616
#include "mlir/IR/ImplicitLocOpBuilder.h"
1717
#include "mlir/Interfaces/VectorInterfaces.h"
18+
#include "mlir/Support/MathExtras.h"
1819
#include "llvm/ADT/MapVector.h"
19-
#include "llvm/Support/Debug.h"
20+
#include "llvm/ADT/STLExtras.h"
21+
#include <numeric>
2022

2123
#define DEBUG_TYPE "vector-unrolling"
2224

@@ -36,20 +38,81 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
3638
return elementOffsets;
3739
}
3840

41+
/// A functor that accomplishes the same thing as `getVectorOffset` but allows
42+
/// for reordering the traversal of the dimensions. The order of traversal is
43+
/// given in "for loop order" (outer to inner).
44+
namespace {
45+
class DecomposeShapeIterator {
46+
private:
47+
SmallVector<int64_t, 4> vectorShape;
48+
SmallVector<int64_t> loopOrder;
49+
SmallVector<int64_t> sliceStrides;
50+
int64_t maxIndexVal{1};
51+
52+
public:
53+
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
54+
ArrayRef<int64_t> targetShape,
55+
ArrayRef<int64_t> loopOrder)
56+
: vectorShape(targetShape.begin(), targetShape.end()),
57+
loopOrder(loopOrder.begin(), loopOrder.end()),
58+
sliceStrides(originalShape.size()) {
59+
assert(originalShape.size() == targetShape.size());
60+
assert(loopOrder.size() == targetShape.size());
61+
62+
// Compute the count for each dimension.
63+
SmallVector<int64_t> sliceDimCounts(originalShape.size());
64+
for (unsigned r = 0; r < originalShape.size(); ++r) {
65+
sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
66+
maxIndexVal *= sliceDimCounts[r];
67+
}
68+
69+
// Reversing "loop order" gives dimensions from fastest varying to slowest
70+
// varying (smallest stride to largest stride).
71+
int64_t accum = 1;
72+
for (auto idx : llvm::reverse(loopOrder)) {
73+
sliceStrides[idx] = accum;
74+
accum *= sliceDimCounts[idx];
75+
}
76+
}
77+
78+
// Turn the linear index into a d-tuple based on units of vectors of size
79+
// `vectorShape`. The linear index is assumed to represent traversal of the
80+
// dimensions based on `order`.
81+
SmallVector<int64_t> delinearize(int64_t index) const {
82+
// Traverse in for loop order (largest stride to smallest stride).
83+
SmallVector<int64_t> vectorOffsets(sliceStrides.size());
84+
for (auto idx : loopOrder) {
85+
vectorOffsets[idx] = index / sliceStrides[idx];
86+
index %= sliceStrides[idx];
87+
}
88+
return vectorOffsets;
89+
}
90+
91+
int64_t maxIndex() const { return maxIndexVal; }
92+
93+
/// Return the offset within d-tuple based on the ordering given by
94+
/// `loopOrder`.
95+
SmallVector<int64_t> getVectorOffset(int64_t index) const {
96+
SmallVector<int64_t> vectorOffsets = delinearize(index);
97+
SmallVector<int64_t> elementOffsets =
98+
computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
99+
return elementOffsets;
100+
}
101+
};
102+
} // namespace
103+
39104
/// Compute the indices of the slice `index` for a tranfer op.
40-
static SmallVector<Value>
41-
sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
42-
ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
43-
AffineMap permutationMap, Location loc,
44-
OpBuilder &builder) {
105+
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
106+
ArrayRef<Value> indices,
107+
AffineMap permutationMap,
108+
Location loc,
109+
OpBuilder &builder) {
45110
MLIRContext *ctx = builder.getContext();
46111
auto isBroadcast = [](AffineExpr expr) {
47112
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
48113
return constExpr.getValue() == 0;
49114
return false;
50115
};
51-
SmallVector<int64_t, 4> elementOffsets =
52-
getVectorOffset(originalShape, targetShape, index);
53116
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
54117
SmallVector<Value> slicedIndices(indices.begin(), indices.end());
55118
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
@@ -99,6 +162,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
99162
return targetShape;
100163
}
101164

165+
static SmallVector<int64_t>
166+
getUnrollOrder(unsigned numLoops, Operation *op,
167+
const vector::UnrollVectorOptions &options) {
168+
SmallVector<int64_t> loopOrder =
169+
llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
170+
if (options.traversalOrderCallback != nullptr) {
171+
Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172+
if (order.hasValue()) {
173+
loopOrder = std::move(*order);
174+
}
175+
}
176+
return loopOrder;
177+
}
178+
102179
namespace {
103180

104181
struct UnrollTransferReadPattern
@@ -121,27 +198,30 @@ struct UnrollTransferReadPattern
121198
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
122199
Location loc = readOp.getLoc();
123200
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
124-
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
125-
// Compute shape ratio of 'shape' and 'sizes'.
126-
int64_t sliceCount = computeMaxLinearIndex(ratio);
201+
127202
// Prepare the result vector;
128203
Value result = rewriter.create<arith::ConstantOp>(
129204
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
130205
auto targetType =
131206
VectorType::get(*targetShape, sourceVectorType.getElementType());
132207
SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
133208
readOp.getIndices().end());
134-
for (int64_t i = 0; i < sliceCount; i++) {
209+
210+
SmallVector<int64_t> loopOrder =
211+
getUnrollOrder(originalSize.size(), readOp, options);
212+
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213+
loopOrder);
214+
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
215+
SmallVector<int64_t, 4> elementOffsets =
216+
indexToOffsets.getVectorOffset(i);
135217
SmallVector<Value, 4> indices =
136-
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
218+
sliceTransferIndices(elementOffsets, originalIndices,
137219
readOp.getPermutationMap(), loc, rewriter);
138220
auto slicedRead = rewriter.create<vector::TransferReadOp>(
139221
loc, targetType, readOp.getSource(), indices,
140222
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
141223
readOp.getInBoundsAttr());
142224

143-
SmallVector<int64_t, 4> elementOffsets =
144-
getVectorOffset(originalSize, *targetShape, i);
145225
result = rewriter.create<vector::InsertStridedSliceOp>(
146226
loc, slicedRead, result, elementOffsets, strides);
147227
}
@@ -174,20 +254,21 @@ struct UnrollTransferWritePattern
174254
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
175255
Location loc = writeOp.getLoc();
176256
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
177-
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
178-
// Compute shape ratio of 'shape' and 'sizes'.
179-
int64_t sliceCount = computeMaxLinearIndex(ratio);
180257
SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
181258
writeOp.getIndices().end());
259+
260+
SmallVector<int64_t> loopOrder =
261+
getUnrollOrder(originalSize.size(), writeOp, options);
262+
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
263+
loopOrder);
182264
Value resultTensor;
183-
for (int64_t i = 0; i < sliceCount; i++) {
265+
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
184266
SmallVector<int64_t, 4> elementOffsets =
185-
getVectorOffset(originalSize, *targetShape, i);
267+
indexToOffsets.getVectorOffset(i);
186268
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
187269
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
188-
189270
SmallVector<Value, 4> indices =
190-
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
271+
sliceTransferIndices(elementOffsets, originalIndices,
191272
writeOp.getPermutationMap(), loc, rewriter);
192273
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
193274
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
@@ -236,20 +317,22 @@ struct UnrollContractionPattern
236317
return failure();
237318
auto dstVecType = contractOp.getResultType().cast<VectorType>();
238319
SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
239-
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
240320

241-
// Compute shape ratio of 'shape' and 'sizes'.
242-
int64_t sliceCount = computeMaxLinearIndex(ratio);
243321
Location loc = contractOp.getLoc();
244322
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
245323
AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
246324
llvm::MapVector<
247325
SmallVector<int64_t>, Value,
248326
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
249327
accCache;
328+
329+
SmallVector<int64_t> loopOrder = getUnrollOrder(
330+
contractOp.getIteratorTypes().size(), contractOp, options);
331+
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332+
loopOrder);
333+
const int64_t sliceCount = indexToOffsets.maxIndex();
250334
for (int64_t i = 0; i < sliceCount; i++) {
251-
SmallVector<int64_t, 4> offsets =
252-
getVectorOffset(originalSize, *targetShape, i);
335+
SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
253336
SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
254337

255338
// Helper to coompute the new shape of each operand and extract the slice.

mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER
23

34
// CHECK-LABEL: func @transfer_read_unroll
45
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -13,6 +14,19 @@
1314
// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
1415
// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32>
1516

17+
// ORDER-LABEL: func @transfer_read_unroll
18+
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
19+
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
20+
// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
21+
// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
22+
// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
23+
// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
24+
// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
25+
// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
26+
// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
27+
// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
28+
// ORDER-NEXT: return %[[VEC3]] : vector<4x4xf32>
29+
1630
func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
1731
%c0 = arith.constant 0 : index
1832
%cf0 = arith.constant 0.0 : f32
@@ -33,6 +47,19 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
3347
// CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
3448
// CHECK-NEXT: return
3549

50+
// ORDER-LABEL: func @transfer_write_unroll
51+
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
52+
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
53+
// ORDER: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
54+
// ORDER-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
55+
// ORDER-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
56+
// ORDER-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
57+
// ORDER-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
58+
// ORDER-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
59+
// ORDER-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
60+
// ORDER-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
61+
// ORDER-NEXT: return
62+
3663
func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
3764
%c0 = arith.constant 0 : index
3865
vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -222,6 +249,25 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
222249
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
223250
// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
224251
// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32>
252+
253+
// ORDER-LABEL: func @transfer_read_unroll_different_rank
254+
// ORDER-DAG: %[[C4:.*]] = arith.constant 4 : index
255+
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
256+
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
257+
// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
258+
// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
259+
// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
260+
// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
261+
// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
262+
// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
263+
// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
264+
// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
265+
// ORDER-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
266+
// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
267+
// ORDER-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
268+
// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
269+
// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32>
270+
225271
#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
226272
func.func @transfer_read_unroll_different_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
227273
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)