Skip to content

Commit bc5565f

Browse files
committed
[mlir][Affine] Introduce affine.vector_load and affine.vector_store
This patch adds `affine.vector_load` and `affine.vector_store` ops to the Affine dialect and lowers them to `vector.transfer_read` and `vector.transfer_write`, respectively, in the Vector dialect. Reviewed By: bondhugula, nicolasvasilache Differential Revision: https://reviews.llvm.org/D79658
1 parent accc6b5 commit bc5565f

File tree

9 files changed

+670
-107
lines changed

9 files changed

+670
-107
lines changed

mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
4444
void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
4545
MLIRContext *ctx);
4646

47+
/// Collect a set of patterns to convert vector-related Affine ops to the Vector
48+
/// dialect.
49+
void populateAffineToVectorConversionPatterns(
50+
OwningRewritePatternList &patterns, MLIRContext *ctx);
51+
4752
/// Emit code that computes the lower bound of the given affine loop using
4853
/// standard arithmetic operations.
4954
Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 182 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,44 @@ def AffineIfOp : Affine_Op<"if",
370370
let hasFolder = 1;
371371
}
372372

373-
def AffineLoadOp : Affine_Op<"load", []> {
373+
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
374+
Affine_Op<mnemonic, traits> {
375+
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
376+
[MemRead]>:$memref,
377+
Variadic<Index>:$indices);
378+
379+
code extraClassDeclarationBase = [{
380+
/// Returns the operand index of the memref.
381+
unsigned getMemRefOperandIndex() { return 0; }
382+
383+
/// Get memref operand.
384+
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
385+
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
386+
MemRefType getMemRefType() {
387+
return getMemRef().getType().cast<MemRefType>();
388+
}
389+
390+
/// Get affine map operands.
391+
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
392+
393+
/// Returns the affine map used to index the memref for this operation.
394+
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
395+
AffineMapAttr getAffineMapAttr() {
396+
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
397+
}
398+
399+
/// Returns the AffineMapAttr associated with 'memref'.
400+
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
401+
assert(memref == getMemRef());
402+
return {Identifier::get(getMapAttrName(), getContext()),
403+
getAffineMapAttr()};
404+
}
405+
406+
static StringRef getMapAttrName() { return "map"; }
407+
}];
408+
}
409+
410+
def AffineLoadOp : AffineLoadOpBase<"load", []> {
374411
let summary = "affine load operation";
375412
let description = [{
376413
The "affine.load" op reads an element from a memref, where the index
@@ -393,9 +430,6 @@ def AffineLoadOp : Affine_Op<"load", []> {
393430
```
394431
}];
395432

396-
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
397-
[MemRead]>:$memref,
398-
Variadic<Index>:$indices);
399433
let results = (outs AnyType:$result);
400434

401435
let builders = [
@@ -410,35 +444,7 @@ def AffineLoadOp : Affine_Op<"load", []> {
410444
"AffineMap map, ValueRange mapOperands">
411445
];
412446

413-
let extraClassDeclaration = [{
414-
/// Returns the operand index of the memref.
415-
unsigned getMemRefOperandIndex() { return 0; }
416-
417-
/// Get memref operand.
418-
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
419-
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
420-
MemRefType getMemRefType() {
421-
return getMemRef().getType().cast<MemRefType>();
422-
}
423-
424-
/// Get affine map operands.
425-
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
426-
427-
/// Returns the affine map used to index the memref for this operation.
428-
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
429-
AffineMapAttr getAffineMapAttr() {
430-
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
431-
}
432-
433-
/// Returns the AffineMapAttr associated with 'memref'.
434-
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
435-
assert(memref == getMemRef());
436-
return {Identifier::get(getMapAttrName(), getContext()),
437-
getAffineMapAttr()};
438-
}
439-
440-
static StringRef getMapAttrName() { return "map"; }
441-
}];
447+
let extraClassDeclaration = extraClassDeclarationBase;
442448

443449
let hasCanonicalizer = 1;
444450
let hasFolder = 1;
@@ -659,7 +665,45 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
659665
let hasFolder = 1;
660666
}
661667

662-
def AffineStoreOp : Affine_Op<"store", []> {
668+
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
669+
Affine_Op<mnemonic, traits> {
670+
671+
code extraClassDeclarationBase = [{
672+
/// Get value to be stored by store operation.
673+
Value getValueToStore() { return getOperand(0); }
674+
675+
/// Returns the operand index of the memref.
676+
unsigned getMemRefOperandIndex() { return 1; }
677+
678+
/// Get memref operand.
679+
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
680+
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
681+
682+
MemRefType getMemRefType() {
683+
return getMemRef().getType().cast<MemRefType>();
684+
}
685+
686+
/// Get affine map operands.
687+
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
688+
689+
/// Returns the affine map used to index the memref for this operation.
690+
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
691+
AffineMapAttr getAffineMapAttr() {
692+
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
693+
}
694+
695+
/// Returns the AffineMapAttr associated with 'memref'.
696+
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
697+
assert(memref == getMemRef());
698+
return {Identifier::get(getMapAttrName(), getContext()),
699+
getAffineMapAttr()};
700+
}
701+
702+
static StringRef getMapAttrName() { return "map"; }
703+
}];
704+
}
705+
706+
def AffineStoreOp : AffineStoreOpBase<"store", []> {
663707
let summary = "affine store operation";
664708
let description = [{
665709
The "affine.store" op writes an element to a memref, where the index
@@ -686,7 +730,6 @@ def AffineStoreOp : Affine_Op<"store", []> {
686730
[MemWrite]>:$memref,
687731
Variadic<Index>:$indices);
688732

689-
690733
let skipDefaultBuilders = 1;
691734
let builders = [
692735
OpBuilder<"OpBuilder &builder, OperationState &result, "
@@ -696,39 +739,7 @@ def AffineStoreOp : Affine_Op<"store", []> {
696739
"ValueRange mapOperands">
697740
];
698741

699-
let extraClassDeclaration = [{
700-
/// Get value to be stored by store operation.
701-
Value getValueToStore() { return getOperand(0); }
702-
703-
/// Returns the operand index of the memref.
704-
unsigned getMemRefOperandIndex() { return 1; }
705-
706-
/// Get memref operand.
707-
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
708-
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
709-
710-
MemRefType getMemRefType() {
711-
return getMemRef().getType().cast<MemRefType>();
712-
}
713-
714-
/// Get affine map operands.
715-
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
716-
717-
/// Returns the affine map used to index the memref for this operation.
718-
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
719-
AffineMapAttr getAffineMapAttr() {
720-
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
721-
}
722-
723-
/// Returns the AffineMapAttr associated with 'memref'.
724-
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
725-
assert(memref == getMemRef());
726-
return {Identifier::get(getMapAttrName(), getContext()),
727-
getAffineMapAttr()};
728-
}
729-
730-
static StringRef getMapAttrName() { return "map"; }
731-
}];
742+
let extraClassDeclaration = extraClassDeclarationBase;
732743

733744
let hasCanonicalizer = 1;
734745
let hasFolder = 1;
@@ -765,4 +776,107 @@ def AffineTerminatorOp :
765776
let verifier = ?;
766777
}
767778

779+
def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
780+
let summary = "affine vector load operation";
781+
let description = [{
782+
The "affine.vector_load" is the vector counterpart of
783+
[affine.load](#affineload-operation). It reads a slice from a
784+
[MemRef](../LangRef.md#memref-type), supplied as its first operand,
785+
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
786+
The index for each memref dimension is an affine expression of loop induction
787+
variables and symbols. These indices determine the start position of the read
788+
within the memref. The shape of the return vector type determines the shape of
789+
the slice read from the memref. This slice is contiguous along the respective
790+
dimensions of the shape. Strided vector loads will be supported in the future.
791+
An affine expression of loop IVs and symbols must be specified for each
792+
dimension of the memref. The keyword 'symbol' can be used to indicate SSA
793+
identifiers which are symbolic.
794+
795+
Example 1: 8-wide f32 vector load.
796+
797+
```mlir
798+
%1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
799+
```
800+
801+
Example 2: 4-wide f32 vector load. Uses 'symbol' keyword for symbols '%n' and '%m'.
802+
803+
```mlir
804+
%1 = affine.vector_load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
805+
```
806+
807+
Example 3: 2-dim f32 vector load.
808+
809+
```mlir
810+
%1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
811+
```
812+
813+
TODOs:
814+
* Add support for strided vector loads.
815+
* Consider adding a permutation map to permute the slice that is read from memory
816+
(see [vector.transfer_read](../Vector/#vectortransfer_read-vectortransferreadop)).
817+
}];
818+
819+
let results = (outs AnyVector:$result);
820+
821+
let extraClassDeclaration = extraClassDeclarationBase # [{
822+
VectorType getVectorType() {
823+
return result().getType().cast<VectorType>();
824+
}
825+
}];
826+
}
827+
828+
def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> {
829+
let summary = "affine vector store operation";
830+
let description = [{
831+
The "affine.vector_store" is the vector counterpart of
832+
[affine.store](#affinestore-affinestoreop). It writes a
833+
[vector](../LangRef.md#vector-type), supplied as its first operand,
834+
into a slice within a [MemRef](../LangRef.md#memref-type) of the same base
835+
elemental type, supplied as its second operand.
836+
The index for each memref dimension is an affine expression of loop
837+
induction variables and symbols. These indices determine the start position
838+
of the write within the memref. The shape of th input vector determines the
839+
shape of the slice written to the memref. This slice is contiguous along the
840+
respective dimensions of the shape. Strided vector stores will be supported
841+
in the future.
842+
An affine expression of loop IVs and symbols must be specified for each
843+
dimension of the memref. The keyword 'symbol' can be used to indicate SSA
844+
identifiers which are symbolic.
845+
846+
Example 1: 8-wide f32 vector store.
847+
848+
```mlir
849+
affine.vector_store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
850+
```
851+
852+
Example 2: 4-wide f32 vector store. Uses 'symbol' keyword for symbols '%n' and '%m'.
853+
854+
```mlir
855+
affine.vector_store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
856+
```
857+
858+
Example 3: 2-dim f32 vector store.
859+
860+
```mlir
861+
affine.vector_store %v0, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
862+
```
863+
864+
TODOs:
865+
* Add support for strided vector stores.
866+
* Consider adding a permutation map to permute the slice that is written to memory
867+
(see [vector.transfer_write](../Vector/#vectortransfer_write-vectortransferwriteop)).
868+
}];
869+
870+
let arguments = (ins AnyVector:$value,
871+
Arg<AnyMemRef, "the reference to store to",
872+
[MemWrite]>:$memref,
873+
Variadic<Index>:$indices);
874+
875+
let extraClassDeclaration = extraClassDeclarationBase # [{
876+
VectorType getVectorType() {
877+
return value().getType().cast<VectorType>();
878+
}
879+
}];
880+
}
881+
768882
#endif // AFFINE_OPS

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,13 @@ def Vector_TransferReadOp :
994994
```
995995
}];
996996

997+
let builders = [
998+
// Builder that sets permutation map and padding to 'getMinorIdentityMap'
999+
// and zero, respectively, by default.
1000+
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
1001+
"Value memref, ValueRange indices">
1002+
];
1003+
9971004
let extraClassDeclaration = [{
9981005
MemRefType getMemRefType() {
9991006
return memref().getType().cast<MemRefType>();
@@ -1058,6 +1065,13 @@ def Vector_TransferWriteOp :
10581065
```
10591066
}];
10601067

1068+
let builders = [
1069+
// Builder that sets permutation map and padding to 'getMinorIdentityMap'
1070+
// by default.
1071+
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
1072+
"Value memref, ValueRange indices">
1073+
];
1074+
10611075
let extraClassDeclaration = [{
10621076
VectorType getVectorType() {
10631077
return vector().getType().cast<VectorType>();

0 commit comments

Comments
 (0)