Skip to content

Commit dd14e58

Browse files
committed
[mlir][vector] First step of vector distribution transformation
This is the first of several steps to support distributing large vectors. This adds instructions extract_map and insert_map that allow us to do incremental lowering. Right now the transformation only apply to simple pointwise operation with a vector size matching the multiplicity of the IDs used to distribute the vector. This can be used to distribute large vectors to loops or SPMD. Differential Revision: https://reviews.llvm.org/D88341
1 parent e9b3884 commit dd14e58

File tree

8 files changed

+305
-0
lines changed

8 files changed

+305
-0
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,71 @@ def Vector_ExtractSlicesOp :
454454
}];
455455
}
456456

457+
def Vector_ExtractMapOp :
458+
Vector_Op<"extract_map", [NoSideEffect]>,
459+
Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
460+
Results<(outs AnyVector)> {
461+
let summary = "vector extract map operation";
462+
let description = [{
463+
Takes an 1-D vector and extract a sub-part of the vector starting at id with
464+
a size of `vector size / multiplicity`. This maps a given multiplicity of
465+
the vector to a Value such as a loop induction variable or an SPMD id.
466+
467+
Similarly to vector.tuple_get, this operation is used for progressive
468+
lowering and should be folded away before converting to LLVM.
469+
470+
471+
For instance, the following code:
472+
```mlir
473+
%a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
474+
%b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
475+
%c = addf %a, %b: vector<32xf32>
476+
vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
477+
```
478+
can be rewritten to:
479+
```mlir
480+
%a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
481+
%b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
482+
%ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32>
483+
%eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32>
484+
%ec = addf %ea, %eb : vector<1xf32>
485+
%c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32>
486+
vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
487+
```
488+
489+
Where %id can be an induction variable or an SPMD id going from 0 to 31.
490+
491+
And then be rewritten to:
492+
```mlir
493+
%a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32>
494+
%b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32>
495+
%c = addf %a, %b: vector<1xf32>
496+
vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32>
497+
```
498+
499+
Example:
500+
501+
```mlir
502+
%ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32>
503+
```
504+
}];
505+
let builders = [OpBuilder<
506+
"OpBuilder &builder, OperationState &result, " #
507+
"Value vector, Value id, int64_t multiplicity">];
508+
let extraClassDeclaration = [{
509+
VectorType getSourceVectorType() {
510+
return vector().getType().cast<VectorType>();
511+
}
512+
VectorType getResultType() {
513+
return getResult().getType().cast<VectorType>();
514+
}
515+
}];
516+
let assemblyFormat = [{
517+
$vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to`
518+
type(results)
519+
}];
520+
}
521+
457522
def Vector_FMAOp :
458523
Op<Vector_Dialect, "fma", [NoSideEffect,
459524
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
@@ -626,6 +691,46 @@ def Vector_InsertSlicesOp :
626691
}];
627692
}
628693

694+
def Vector_InsertMapOp :
695+
Vector_Op<"insert_map", [NoSideEffect]>,
696+
Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
697+
Results<(outs AnyVector)> {
698+
let summary = "vector insert map operation";
699+
let description = [{
700+
insert an 1-D vector and within a larger vector starting at id. The new
701+
vector created will have a size of `vector size * multiplicity`. This
702+
represents how a sub-part of the vector is written for a given Value such as
703+
a loop induction variable or an SPMD id.
704+
705+
Similarly to vector.tuple_get, this operation is used for progressive
706+
lowering and should be folded away before converting to LLVM.
707+
708+
This operations is meant to be used in combination with vector.extract_map.
709+
See example in extract.map description.
710+
711+
Example:
712+
713+
```mlir
714+
%v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32>
715+
```
716+
}];
717+
let builders = [OpBuilder<
718+
"OpBuilder &builder, OperationState &result, " #
719+
"Value vector, Value id, int64_t multiplicity">];
720+
let extraClassDeclaration = [{
721+
VectorType getSourceVectorType() {
722+
return vector().getType().cast<VectorType>();
723+
}
724+
VectorType getResultType() {
725+
return getResult().getType().cast<VectorType>();
726+
}
727+
}];
728+
let assemblyFormat = [{
729+
$vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to`
730+
type(results)
731+
}];
732+
}
733+
629734
def Vector_InsertStridedSliceOp :
630735
Vector_Op<"insert_strided_slice", [NoSideEffect,
631736
PredOpTrait<"operand #0 and result have same element type",

mlir/include/mlir/Dialect/Vector/VectorTransforms.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,47 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
172172
FilterConstraintType filter;
173173
};
174174

175+
struct DistributeOps {
176+
ExtractMapOp extract;
177+
InsertMapOp insert;
178+
};
179+
180+
/// Distribute a 1D vector pointwise operation over a range of given IDs taking
181+
/// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
182+
/// SPMD id). This transformation only inserts
183+
/// vector.extract_map/vector.insert_map. It is meant to be used with
184+
/// canonicalizations pattern to propagate and fold the vector
185+
/// insert_map/extract_map operations.
186+
/// Transforms:
187+
// %v = addf %a, %b : vector<32xf32>
188+
/// to:
189+
/// %v = addf %a, %b : vector<32xf32> %ev =
190+
/// vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> %nv =
191+
/// vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
192+
Optional<DistributeOps> distributPointwiseVectorOp(OpBuilder &builder,
193+
Operation *op, Value id,
194+
int64_t multiplicity);
195+
/// Canonicalize an extra element using the result of a pointwise operation.
196+
/// Transforms:
197+
/// %v = addf %a, %b : vector32xf32>
198+
/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
199+
/// to:
200+
/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
201+
/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
202+
/// %dv = addf %da, %db : vector<1xf32>
203+
struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
204+
using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
205+
PointwiseExtractPattern(
206+
MLIRContext *context, FilterConstraintType constraint =
207+
[](ExtractMapOp op) { return success(); })
208+
: OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
209+
LogicalResult matchAndRewrite(ExtractMapOp extract,
210+
PatternRewriter &rewriter) const override;
211+
212+
private:
213+
FilterConstraintType filter;
214+
};
215+
175216
} // namespace vector
176217

177218
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,29 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
900900
populateFromInt64AttrArray(strides(), results);
901901
}
902902

903+
//===----------------------------------------------------------------------===//
904+
// ExtractMapOp
905+
//===----------------------------------------------------------------------===//
906+
907+
void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
908+
Value vector, Value id, int64_t multiplicity) {
909+
VectorType type = vector.getType().cast<VectorType>();
910+
VectorType resultType = VectorType::get(type.getNumElements() / multiplicity,
911+
type.getElementType());
912+
ExtractMapOp::build(builder, result, resultType, vector, id, multiplicity);
913+
}
914+
915+
static LogicalResult verify(ExtractMapOp op) {
916+
if (op.getSourceVectorType().getShape().size() != 1 ||
917+
op.getResultType().getShape().size() != 1)
918+
return op.emitOpError("expects source and destination vectors of rank 1");
919+
if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() !=
920+
op.getSourceVectorType().getNumElements())
921+
return op.emitOpError("vector sizes mismatch. Source size must be equal "
922+
"to destination size * multiplicity");
923+
return success();
924+
}
925+
903926
//===----------------------------------------------------------------------===//
904927
// BroadcastOp
905928
//===----------------------------------------------------------------------===//
@@ -1122,6 +1145,30 @@ void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
11221145
populateFromInt64AttrArray(strides(), results);
11231146
}
11241147

1148+
//===----------------------------------------------------------------------===//
1149+
// InsertMapOp
1150+
//===----------------------------------------------------------------------===//
1151+
1152+
void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1153+
Value vector, Value id, int64_t multiplicity) {
1154+
VectorType type = vector.getType().cast<VectorType>();
1155+
VectorType resultType = VectorType::get(type.getNumElements() * multiplicity,
1156+
type.getElementType());
1157+
InsertMapOp::build(builder, result, resultType, vector, id, multiplicity);
1158+
}
1159+
1160+
static LogicalResult verify(InsertMapOp op) {
1161+
if (op.getSourceVectorType().getShape().size() != 1 ||
1162+
op.getResultType().getShape().size() != 1)
1163+
return op.emitOpError("expected source and destination vectors of rank 1");
1164+
if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() !=
1165+
op.getResultType().getNumElements())
1166+
return op.emitOpError(
1167+
"vector sizes mismatch. Destination size must be equal "
1168+
"to source size * multiplicity");
1169+
return success();
1170+
}
1171+
11251172
//===----------------------------------------------------------------------===//
11261173
// InsertStridedSliceOp
11271174
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,6 +2418,40 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
24182418
return failure();
24192419
}
24202420

2421+
LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2422+
ExtractMapOp extract, PatternRewriter &rewriter) const {
2423+
Operation *definedOp = extract.vector().getDefiningOp();
2424+
if (!definedOp || definedOp->getNumResults() != 1)
2425+
return failure();
2426+
// TODO: Create an interfaceOp for elementwise operations.
2427+
if (!isa<AddFOp>(definedOp))
2428+
return failure();
2429+
Location loc = extract.getLoc();
2430+
SmallVector<Value, 4> extractOperands;
2431+
for (OpOperand &operand : definedOp->getOpOperands())
2432+
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2433+
loc, operand.get(), extract.id(), extract.multiplicity()));
2434+
Operation *newOp = cloneOpWithOperandsAndTypes(
2435+
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2436+
rewriter.replaceOp(extract, newOp->getResult(0));
2437+
return success();
2438+
}
2439+
2440+
Optional<mlir::vector::DistributeOps>
2441+
mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
2442+
Value id, int64_t multiplicity) {
2443+
OpBuilder::InsertionGuard guard(builder);
2444+
builder.setInsertionPointAfter(op);
2445+
Location loc = op->getLoc();
2446+
Value result = op->getResult(0);
2447+
DistributeOps ops;
2448+
ops.extract =
2449+
builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
2450+
ops.insert =
2451+
builder.create<vector::InsertMapOp>(loc, ops.extract, id, multiplicity);
2452+
return ops;
2453+
}
2454+
24212455
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
24222456
// TODO: Add this as DRR pattern.
24232457
void mlir::vector::populateVectorToVectorTransformationPatterns(

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,3 +1328,31 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
13281328
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
13291329
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
13301330
}
1331+
1332+
// -----
1333+
1334+
func @extract_map_rank(%v: vector<2x32xf32>, %id : index) {
1335+
// expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}}
1336+
%0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32>
1337+
}
1338+
1339+
// -----
1340+
1341+
func @extract_map_size(%v: vector<63xf32>, %id : index) {
1342+
// expected-error@+1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * multiplicity}}
1343+
%0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32>
1344+
}
1345+
1346+
// -----
1347+
1348+
func @insert_map_rank(%v: vector<2x1xf32>, %id : index) {
1349+
// expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}}
1350+
%0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32>
1351+
}
1352+
1353+
// -----
1354+
1355+
func @insert_map_size(%v: vector<1xf32>, %id : index) {
1356+
// expected-error@+1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * multiplicity}}
1357+
%0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32>
1358+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,14 @@ func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru:
432432
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
433433
return
434434
}
435+
436+
// CHECK-LABEL: @extract_insert_map
437+
func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> {
438+
// CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32>
439+
%vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32>
440+
// CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32>
441+
%r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32>
442+
// CHECK: return %[[R]] : vector<32xf32>
443+
return %r : vector<32xf32>
444+
}
445+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
2+
3+
// CHECK-LABEL: func @distribute_vector_add
4+
// CHECK-SAME: (%[[ID:.*]]: index
5+
// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
6+
// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
7+
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
8+
// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32>
9+
// CHECK-NEXT: return %[[INS]] : vector<32xf32>
10+
func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
11+
%0 = addf %A, %B : vector<32xf32>
12+
return %0: vector<32xf32>
13+
}

mlir/test/lib/Transforms/TestVectorTransforms.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,28 @@ struct TestVectorUnrollingPatterns
125125
}
126126
};
127127

128+
struct TestVectorDistributePatterns
129+
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
130+
void getDependentDialects(DialectRegistry &registry) const override {
131+
registry.insert<VectorDialect>();
132+
}
133+
void runOnFunction() override {
134+
MLIRContext *ctx = &getContext();
135+
OwningRewritePatternList patterns;
136+
FuncOp func = getFunction();
137+
func.walk([&](AddFOp op) {
138+
OpBuilder builder(op);
139+
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
140+
builder, op.getOperation(), func.getArgument(0), 32);
141+
assert(ops.hasValue());
142+
SmallPtrSet<Operation *, 1> extractOp({ops->extract});
143+
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
144+
});
145+
patterns.insert<PointwiseExtractPattern>(ctx);
146+
applyPatternsAndFoldGreedily(getFunction(), patterns);
147+
}
148+
};
149+
128150
struct TestVectorTransferFullPartialSplitPatterns
129151
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
130152
FunctionPass> {
@@ -178,5 +200,9 @@ void registerTestVectorConversions() {
178200
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
179201
"Test conversion patterns to split "
180202
"transfer ops via scf.if + linalg ops");
203+
PassRegistration<TestVectorDistributePatterns> distributePass(
204+
"test-vector-distribute-patterns",
205+
"Test conversion patterns to distribute vector ops in the vector "
206+
"dialect");
181207
}
182208
} // namespace mlir

0 commit comments

Comments
 (0)