Skip to content

Commit 8995586

Browse files
committed
[MLIR][doc] Improve/fix the doc on mlir.vector.transfer_read (NFC)
This doc was written 4 years ago, some refresh in the example was overdue I suspect. Differential Revision: https://reviews.llvm.org/D151037
1 parent 550c60e commit 8995586

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,8 +1246,9 @@ def Vector_TransferReadOp :
12461246
```
12471247

12481248
This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3,
1249-
%expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice
1250-
is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
1249+
%expr4]`. The size of the slice can be inferred from the resulting vector
1250+
shape and walking back through the permutation map: 3 along d2 and 5 along
1251+
d0, so the slice is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
12511252

12521253
That slice needs to be read into a `vector<3x4x5xf32>`. Since the
12531254
permutation map is not full rank, there must be a broadcast along vector
@@ -1257,44 +1258,52 @@ def Vector_TransferReadOp :
12571258

12581259
```mlir
12591260
// %expr1, %expr2, %expr3, %expr4 defined before this point
1260-
%tmp = alloc() : vector<3x4x5xf32>
1261-
%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
1261+
// alloc a temporary buffer for performing the "gather" of the slice.
1262+
%tmp = memref.alloc() : memref<vector<3x4x5xf32>>
12621263
for %i = 0 to 3 {
12631264
affine.for %j = 0 to 4 {
12641265
affine.for %k = 0 to 5 {
1265-
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
1266-
memref<?x?x?x?xf32>
1267-
store %tmp[%i, %j, %k] : vector<3x4x5xf32>
1266+
// Note that this load does not involve %j.
1267+
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
1268+
// Update the temporary gathered slice with the individual element
1269+
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1270+
%updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
1271+
memref.store %updated, %temp : memref<vector<3x4x5xf32>>
12681272
}}}
1269-
%c0 = arith.constant 0 : index
1270-
%vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
1273+
// At this point we gathered the elements from the original
1274+
// memref into the desired vector layout, stored in the `%tmp` allocation.
1275+
%vec = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
12711276
```
12721277

12731278
On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that
1274-
the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are
1275-
actually transferred between `%A` and `%tmp`.
1279+
the temporary storage footprint could conceptually be only `3 * 5` values but
1280+
`3 * 4 * 5` values are actually transferred between `%A` and `%tmp`.
12761281

1277-
Alternatively, if a notional vector broadcast operation were available, the
1278-
lowered code would resemble:
1282+
Alternatively, if a notional vector broadcast operation were available, we
1283+
could avoid the loop on `%j` and the lowered code would resemble:
12791284

12801285
```mlir
12811286
// %expr1, %expr2, %expr3, %expr4 defined before this point
1282-
%tmp = alloc() : vector<3x4x5xf32>
1283-
%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
1287+
%tmp = memref.alloc() : memref<vector<3x4x5xf32>>
12841288
for %i = 0 to 3 {
12851289
affine.for %k = 0 to 5 {
1286-
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
1287-
memref<?x?x?x?xf32>
1288-
store %tmp[%i, 0, %k] : vector<3x4x5xf32>
1290+
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
1291+
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1292+
// Here we only store to the first element in dimension one
1293+
%updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
1294+
memref.store %updated, %temp : memref<vector<3x4x5xf32>>
12891295
}}
1290-
%c0 = arith.constant 0 : index
1291-
%tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
1296+
// At this point we gathered the elements from the original
1297+
// memref into the desired vector layout, stored in the `%tmp` allocation.
1298+
// However we haven't replicated them alongside the first dimension, we need
1299+
// to broadcast now.
1300+
%partialVec = load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
12921301
%vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
12931302
```
12941303

12951304
where `broadcast` broadcasts from element 0 to all others along the
1296-
specified dimension. This time, the temporary storage footprint is `3 * 5`
1297-
values which is the same amount of data as the `3 * 5` values transferred.
1305+
specified dimension. This time, the number of loaded element is `3 * 5`
1306+
values.
12981307
An additional `1` broadcast is required. On a GPU this broadcast could be
12991308
implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
13001309

@@ -1310,7 +1319,7 @@ def Vector_TransferReadOp :
13101319
// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
13111320
// and pad with %f0 to handle the boundary case:
13121321
%f0 = arith.constant 0.0f : f32
1313-
for %i0 = 0 to %0 {
1322+
affine.for %i0 = 0 to %0 {
13141323
affine.for %i1 = 0 to %1 step 256 {
13151324
affine.for %i2 = 0 to %2 step 32 {
13161325
%v = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1320,7 +1329,7 @@ def Vector_TransferReadOp :
13201329

13211330
// or equivalently (rewrite with vector.transpose)
13221331
%f0 = arith.constant 0.0f : f32
1323-
for %i0 = 0 to %0 {
1332+
affine.for %i0 = 0 to %0 {
13241333
affine.for %i1 = 0 to %1 step 256 {
13251334
affine.for %i2 = 0 to %2 step 32 {
13261335
%v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1333,7 +1342,7 @@ def Vector_TransferReadOp :
13331342
// Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
13341343
// vector<128xf32>. The underlying implementation will require a 1-D vector
13351344
// broadcast:
1336-
for %i0 = 0 to %0 {
1345+
affine.for %i0 = 0 to %0 {
13371346
affine.for %i1 = 0 to %1 {
13381347
%3 = vector.transfer_read %A[%i0, %i1]
13391348
{permutation_map: (d0, d1) -> (0)} :

0 commit comments

Comments
 (0)