Skip to content

Commit 8d46bfa

Browse files
committed
[mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR
Summary: Makes the vector.tranpose runnable on CPU. Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76644
1 parent 42df3e2 commit 8d46bfa

File tree

4 files changed

+139
-48
lines changed

4 files changed

+139
-48
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
5353
/// Collect a set of transformation patterns that are related to contracting
5454
/// or expanding vector operations:
5555
/// ContractionOpLowering,
56-
/// ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern
56+
/// ShapeCastOp2DDownCastRewritePattern,
57+
/// ShapeCastOp2DUpCastRewritePattern
58+
/// TransposeOpLowering
5759
/// OuterproductOpLowering
5860
/// These transformation express higher level vector ops in terms of more
5961
/// elementary extraction, insertion, reduction, product, and broadcast ops.

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def Vector_ContractionOp :
8888
iterator in the iterator type list, to each dimension of an N-D vector.
8989

9090
Examples:
91-
91+
```mlir
9292
// Simple dot product (K = 0).
9393
#contraction_accesses = [
9494
affine_map<(i) -> (i)>,
@@ -139,6 +139,7 @@ def Vector_ContractionOp :
139139

140140
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
141141
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
142+
```
142143
}];
143144
let builders = [OpBuilder<
144145
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
@@ -203,7 +204,7 @@ def Vector_ReductionOp :
203204
http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics
204205

205206
Examples:
206-
```
207+
```mlir
207208
%1 = vector.reduction "add", %0 : vector<16xf32> into f32
208209

209210
%3 = vector.reduction "xor", %2 : vector<4xi32> into i32
@@ -247,7 +248,7 @@ def Vector_BroadcastOp :
247248
shaped vector with the same element type is always legal.
248249

249250
Examples:
250-
```
251+
```mlir
251252
%0 = constant 0.0 : f32
252253
%1 = vector.broadcast %0 : f32 to vector<16xf32>
253254
%2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
@@ -290,7 +291,7 @@ def Vector_ShuffleOp :
290291
above, all mask values are in the range [0,s_1+t_1)
291292

292293
Examples:
293-
```
294+
```mlir
294295
%0 = vector.shuffle %a, %b[0, 3]
295296
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
296297
%1 = vector.shuffle %c, %b[0, 1, 2]
@@ -332,7 +333,7 @@ def Vector_ExtractElementOp :
332333
https://llvm.org/docs/LangRef.html#extractelement-instruction
333334

334335
Example:
335-
```
336+
```mlir
336337
%c = constant 15 : i32
337338
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
338339
```
@@ -360,7 +361,7 @@ def Vector_ExtractOp :
360361
the proper position. Degenerates to an element type in the 0-D case.
361362

362363
Examples:
363-
```
364+
```mlir
364365
%1 = vector.extract %0[3]: vector<4x8x16xf32>
365366
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
366367
```
@@ -396,7 +397,7 @@ def Vector_ExtractSlicesOp :
396397
Currently, only unit strides are supported.
397398

398399
Examples:
399-
```
400+
```mlir
400401
%0 = vector.transfer_read ...: vector<4x2xf32>
401402

402403
%1 = vector.extract_slices %0, [2, 2], [1, 1]
@@ -448,8 +449,7 @@ def Vector_FMAOp :
448449
to the `llvm.fma.*` intrinsic.
449450

450451
Example:
451-
452-
```
452+
```mlir
453453
%3 = vector.fma %0, %1, %2: vector<8x16xf32>
454454
```
455455
}];
@@ -483,7 +483,7 @@ def Vector_InsertElementOp :
483483
https://llvm.org/docs/LangRef.html#insertelement-instruction
484484

485485
Example:
486-
```
486+
```mlir
487487
%c = constant 15 : i32
488488
%f = constant 0.0f : f32
489489
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
@@ -516,7 +516,7 @@ def Vector_InsertOp :
516516
position. Degenerates to a scalar source type when n = 0.
517517

518518
Examples:
519-
```
519+
```mlir
520520
%2 = vector.insert %0, %1[3]:
521521
vector<8x16xf32> into vector<4x8x16xf32>
522522
%5 = vector.insert %3, %4[3, 3, 3]:
@@ -559,7 +559,7 @@ def Vector_InsertSlicesOp :
559559
Currently, only unit strides are supported.
560560

561561
Examples:
562-
```
562+
```mlir
563563
%0 = vector.extract_slices %0, [2, 2], [1, 1]
564564
: vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
565565

@@ -617,7 +617,7 @@ def Vector_InsertStridedSliceOp :
617617
the proper location as specified by the offsets.
618618

619619
Examples:
620-
```
620+
```mlir
621621
%2 = vector.insert_strided_slice %0, %1
622622
{offsets = [0, 0, 2], strides = [1, 1]}:
623623
vector<2x4xf32> into vector<16x4x8xf32>
@@ -659,8 +659,7 @@ def Vector_OuterProductOp :
659659
lower to actual `fma` instructions on x86.
660660

661661
Examples:
662-
663-
```
662+
```mlir
664663
%2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32>
665664
return %2: vector<4x8xf32>
666665

@@ -709,8 +708,8 @@ def Vector_ReshapeOp :
709708
In the examples below, valid data elements are represented by an alphabetic
710709
character, and undefined data elements are represented by '-'.
711710

712-
Example
713-
711+
Example:
712+
```mlir
714713
vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
715714

716715
input: [a, b, c, d, e, f]
@@ -719,8 +718,9 @@ def Vector_ReshapeOp :
719718

720719
vector layout: [a, b, c, d, e, f, -, -]
721720

722-
Example
723-
721+
```
722+
Example:
723+
```mlir
724724
vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
725725

726726
input: [a, b, c, d, e, f, g, h, i, j]
@@ -729,9 +729,9 @@ def Vector_ReshapeOp :
729729

730730
vector layout: [[a, b, c, d, e, f, g, h],
731731
[i, j, -, -, -, -, -, -]]
732-
733-
Example
734-
732+
```
733+
Example:
734+
```mlir
735735
vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
736736
[2, 3]
737737

@@ -750,9 +750,9 @@ def Vector_ReshapeOp :
750750
[-, -, -]]
751751
[[n, o, -],
752752
[-, -, -]]]]
753-
754-
Example
755-
753+
```
754+
Example:
755+
```mlir
756756
%1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
757757
: vector<3x2x4xf32> to vector<2x3x4xf32>
758758

@@ -776,6 +776,7 @@ def Vector_ReshapeOp :
776776
[[j, k, l, m],
777777
[n, o, p, q],
778778
[r, -, -, -]]]
779+
```
779780
}];
780781

781782
let extraClassDeclaration = [{
@@ -828,7 +829,7 @@ def Vector_StridedSliceOp :
828829
`offsets` and ending at `offsets + sizes`.
829830

830831
Examples:
831-
```
832+
```mlir
832833
%1 = vector.strided_slice %0
833834
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
834835
vector<4x8x16xf32> to vector<2x4x16xf32>
@@ -947,13 +948,12 @@ def Vector_TransferReadOp :
947948
implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
948949

949950
Syntax
950-
```
951+
```mlir
951952
operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list
952953
`{` attribute-entry `} :` memref-type `,` vector-type
953954
```
954955

955956
Examples:
956-
957957
```mlir
958958
// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
959959
// and pad with %f0 to handle the boundary case:
@@ -1028,7 +1028,7 @@ def Vector_TransferWriteOp :
10281028

10291029
Syntax:
10301030

1031-
```
1031+
```mlir
10321032
operation ::= `vector.transfer_write` ssa-use-list `{` attribute-entry `} :
10331033
` vector-type ', ' memref-type '
10341034
```
@@ -1139,7 +1139,7 @@ def Vector_TypeCastOp :
11391139

11401140
Syntax:
11411141

1142-
```
1142+
```mlir
11431143
operation ::= `vector.type_cast` ssa-use : memref-type to memref-type
11441144
```
11451145

@@ -1183,8 +1183,10 @@ def Vector_ConstantMaskOp :
11831183
define a hyper-rectangular region within which elements values are set to 1
11841184
(otherwise element values are set to 0).
11851185

1186-
Example: create a constant vector mask of size 4x3xi1 with elements in range
1187-
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
1186+
Example:
1187+
```
1188+
create a constant vector mask of size 4x3xi1 with elements in range
1189+
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
11881190

11891191
%1 = vector.constant_mask [3, 2] : vector<4x3xi1>
11901192

@@ -1196,6 +1198,7 @@ def Vector_ConstantMaskOp :
11961198
rows 1 | 1 1 0
11971199
2 | 1 1 0
11981200
3 | 0 0 0
1201+
```
11991202
}];
12001203

12011204
let extraClassDeclaration = [{
@@ -1217,8 +1220,10 @@ def Vector_CreateMaskOp :
12171220
hyper-rectangular region within which elements values are set to 1
12181221
(otherwise element values are set to 0).
12191222

1220-
Example: create a vector mask of size 4x3xi1 where elements in range
1221-
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
1223+
Example:
1224+
```
1225+
create a vector mask of size 4x3xi1 where elements in range
1226+
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
12221227

12231228
%1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
12241229

@@ -1230,6 +1235,7 @@ def Vector_CreateMaskOp :
12301235
rows 1 | 1 1 0
12311236
2 | 1 1 0
12321237
3 | 0 0 0
1238+
```
12331239
}];
12341240

12351241
let hasCanonicalizer = 1;
@@ -1248,9 +1254,8 @@ def Vector_TupleOp :
12481254
transformation and should be removed before lowering to lower-level
12491255
dialects.
12501256

1251-
12521257
Examples:
1253-
```
1258+
```mlir
12541259
%0 = vector.transfer_read ... : vector<2x2xf32>
12551260
%1 = vector.transfer_read ... : vector<2x1xf32>
12561261
%2 = vector.transfer_read ... : vector<2x2xf32>
@@ -1280,20 +1285,21 @@ def Vector_TransposeOp :
12801285
Takes a n-D vector and returns the transposed n-D vector defined by
12811286
the permutation of ranks in the n-sized integer array attribute.
12821287
In the operation
1283-
1284-
%1 = vector.tranpose %0, [i_1, .., i_n]
1285-
: vector<d_1 x .. x d_n x f32>
1286-
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
1287-
1288+
```mlir
1289+
%1 = vector.tranpose %0, [i_1, .., i_n]
1290+
: vector<d_1 x .. x d_n x f32>
1291+
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
1292+
```
12881293
the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
12891294

12901295
Example:
1291-
1296+
```mlir
12921297
%1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
12931298

12941299
[ [a, b, c], [ [a, d],
12951300
[d, e, f] ] -> [b, e],
12961301
[c, f] ]
1302+
```
12971303
}];
12981304
let extraClassDeclaration = [{
12991305
VectorType getVectorType() {
@@ -1321,7 +1327,7 @@ def Vector_TupleGetOp :
13211327
dialects.
13221328

13231329
Examples:
1324-
```
1330+
```mlir
13251331
%4 = vector.tuple %0, %1, %2, %3
13261332
: vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>>
13271333

@@ -1351,7 +1357,7 @@ def Vector_PrintOp :
13511357
format (for testing and debugging). No return value.
13521358

13531359
Examples:
1354-
```
1360+
```mlir
13551361
%0 = constant 0.0 : f32
13561362
%1 = vector.broadcast %0 : f32 to vector<4xf32>
13571363
vector.print %1 : vector<4xf32>
@@ -1414,7 +1420,7 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
14141420

14151421
Example:
14161422

1417-
```
1423+
```mlir
14181424
%C = vector.matrix_multiply %A, %B
14191425
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
14201426
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>

0 commit comments

Comments
 (0)