@@ -1246,8 +1246,9 @@ def Vector_TransferReadOp :
1246
1246
```
1247
1247
1248
1248
This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3,
1249
- %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice
1250
- is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
1249
+ %expr4]`. The size of the slice can be inferred from the resulting vector
1250
+ shape and walking back through the permutation map: 3 along d2 and 5 along
1251
+ d0, so the slice is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
1251
1252
1252
1253
That slice needs to be read into a `vector<3x4x5xf32>`. Since the
1253
1254
permutation map is not full rank, there must be a broadcast along vector
@@ -1257,44 +1258,52 @@ def Vector_TransferReadOp :
1257
1258
1258
1259
```mlir
1259
1260
// %expr1, %expr2, %expr3, %expr4 defined before this point
1260
- %tmp = alloc() : vector<3x4x5xf32>
1261
- %view_in_tmp = "element_type_cast"(%tmp ) : memref<1xvector <3x4x5xf32>>
1261
+ // alloc a temporary buffer for performing the "gather" of the slice.
1262
+ %tmp = memref.alloc( ) : memref<vector <3x4x5xf32>>
1262
1263
for %i = 0 to 3 {
1263
1264
affine.for %j = 0 to 4 {
1264
1265
affine.for %k = 0 to 5 {
1265
- %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
1266
- memref<?x?x?x?xf32>
1267
- store %tmp[%i, %j, %k] : vector<3x4x5xf32>
1266
+ // Note that this load does not involve %j.
1267
+ %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
1268
+ // Update the temporary gathered slice with the individual element
1269
+ %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1270
+ %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
1271
+ memref.store %updated, %temp : memref<vector<3x4x5xf32>>
1268
1272
}}}
1269
- %c0 = arith.constant 0 : index
1270
- %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
1273
+ // At this point we gathered the elements from the original
1274
+ // memref into the desired vector layout, stored in the `%tmp` allocation.
1275
+ %vec = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1271
1276
```
1272
1277
1273
1278
On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that
1274
- the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are
1275
- actually transferred between `%A` and `%tmp`.
1279
+ the temporary storage footprint could conceptually be only `3 * 5` values but
1280
+ `3 * 4 * 5` values are actually transferred between `%A` and `%tmp`.
1276
1281
1277
- Alternatively, if a notional vector broadcast operation were available, the
1278
- lowered code would resemble:
1282
+ Alternatively, if a notional vector broadcast operation were available, we
1283
+ could avoid the loop on `%j` and the lowered code would resemble:
1279
1284
1280
1285
```mlir
1281
1286
// %expr1, %expr2, %expr3, %expr4 defined before this point
1282
- %tmp = alloc() : vector<3x4x5xf32>
1283
- %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
1287
+ %tmp = memref.alloc() : memref<vector<3x4x5xf32>>
1284
1288
for %i = 0 to 3 {
1285
1289
affine.for %k = 0 to 5 {
1286
- %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
1287
- memref<?x?x?x?xf32>
1288
- store %tmp[%i, 0, %k] : vector<3x4x5xf32>
1290
+ %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
1291
+ %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1292
+ // Here we only store to the first element in dimension one
1293
+ %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
1294
+ memref.store %updated, %temp : memref<vector<3x4x5xf32>>
1289
1295
}}
1290
- %c0 = arith.constant 0 : index
1291
- %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
1296
+ // At this point we gathered the elements from the original
1297
+ // memref into the desired vector layout, stored in the `%tmp` allocation.
1298
+ // However we haven't replicated them alongside the first dimension, we need
1299
+ // to broadcast now.
1300
+ %partialVec = load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1292
1301
%vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
1293
1302
```
1294
1303
1295
1304
where `broadcast` broadcasts from element 0 to all others along the
1296
- specified dimension. This time, the temporary storage footprint is `3 * 5`
1297
- values which is the same amount of data as the `3 * 5` values transferred .
1305
+ specified dimension. This time, the number of loaded element is `3 * 5`
1306
+ values.
1298
1307
An additional `1` broadcast is required. On a GPU this broadcast could be
1299
1308
implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
1300
1309
@@ -1310,7 +1319,7 @@ def Vector_TransferReadOp :
1310
1319
// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
1311
1320
// and pad with %f0 to handle the boundary case:
1312
1321
%f0 = arith.constant 0.0f : f32
1313
- for %i0 = 0 to %0 {
1322
+ affine. for %i0 = 0 to %0 {
1314
1323
affine.for %i1 = 0 to %1 step 256 {
1315
1324
affine.for %i2 = 0 to %2 step 32 {
1316
1325
%v = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1320,7 +1329,7 @@ def Vector_TransferReadOp :
1320
1329
1321
1330
// or equivalently (rewrite with vector.transpose)
1322
1331
%f0 = arith.constant 0.0f : f32
1323
- for %i0 = 0 to %0 {
1332
+ affine. for %i0 = 0 to %0 {
1324
1333
affine.for %i1 = 0 to %1 step 256 {
1325
1334
affine.for %i2 = 0 to %2 step 32 {
1326
1335
%v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1333,7 +1342,7 @@ def Vector_TransferReadOp :
1333
1342
// Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
1334
1343
// vector<128xf32>. The underlying implementation will require a 1-D vector
1335
1344
// broadcast:
1336
- for %i0 = 0 to %0 {
1345
+ affine. for %i0 = 0 to %0 {
1337
1346
affine.for %i1 = 0 to %1 {
1338
1347
%3 = vector.transfer_read %A[%i0, %i1]
1339
1348
{permutation_map: (d0, d1) -> (0)} :
0 commit comments