@@ -1209,3 +1209,110 @@ func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>,
1209
1209
1210
1210
func.return %final : tensor <?x128 xi32 >
1211
1211
}
1212
+
1213
+ // -----
1214
+
1215
+ // CHECK-LABEL: func @hoist_vector_transfer_ops
1216
+ // CHECK: vector.transfer_read
1217
+ // CHECK: scf.for
1218
+ // CHECK-NOT: vector.transfer_read
1219
+ // CHECK: arith.addf
1220
+ // CHECK: scf.yield
1221
+ func.func @hoist_vector_transfer_ops (
1222
+ %a : tensor <128 x128 xf32 >,
1223
+ %lb : index ,
1224
+ %ub : index ,
1225
+ %step : index ,
1226
+ %ida : index ,
1227
+ %idb : index ) -> vector <4 x4 xf32 > {
1228
+ %cst_0 = arith.constant 0.0 : f32
1229
+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1230
+ %final =
1231
+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1232
+ %read = vector.transfer_read %a [%ida , %idb ], %cst_0 : tensor <128 x128 xf32 >, vector <4 x4 xf32 >
1233
+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1234
+ scf.yield %out : vector <4 x4 xf32 >
1235
+ }
1236
+ func.return %final : vector <4 x4 xf32 >
1237
+ }
1238
+
1239
+ // -----
1240
+
1241
+ // CHECK-LABEL: func @hoist_vector_transfer_ops
1242
+ // CHECK: vector.transfer_write
1243
+ // CHECK: vector.transfer_read
1244
+ // CHECK: scf.for
1245
+ // CHECK-NOT: vector.transfer_write
1246
+ // CHECK-NOT: vector.transfer_read
1247
+ // CHECK: arith.addf
1248
+ // CHECK: scf.yield
1249
+ func.func @hoist_vector_transfer_ops (
1250
+ %lb : index ,
1251
+ %ub : index ,
1252
+ %step : index ,
1253
+ %ida : index ,
1254
+ %idb : index ) -> vector <4 x4 xf32 > {
1255
+ %c0 = arith.constant 0 : index
1256
+ %cst_0 = arith.constant 0.0 : f32
1257
+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1258
+ %empty = tensor.empty () : tensor <4 x4 xf32 >
1259
+ %final =
1260
+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1261
+ %a = vector.transfer_write %cst , %empty [%c0 , %c0 ] : vector <4 x4 xf32 >, tensor <4 x4 xf32 >
1262
+ %read = vector.transfer_read %a [%c0 , %c0 ], %cst_0 : tensor <4 x4 xf32 >, vector <4 x4 xf32 >
1263
+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1264
+ scf.yield %out : vector <4 x4 xf32 >
1265
+ }
1266
+ func.return %final : vector <4 x4 xf32 >
1267
+ }
1268
+
1269
+ // -----
1270
+
1271
+ // CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_loop_dep
1272
+ // CHECK-NOT: vector.transfer_read
1273
+ // CHECK: scf.for
1274
+ // CHECK: vector.transfer_read
1275
+ // CHECK: arith.addf
1276
+ // CHECK: scf.yield
1277
+ func.func @do_not_hoist_vector_transfer_ops_loop_dep (
1278
+ %a : tensor <128 x128 xf32 >,
1279
+ %lb : index ,
1280
+ %ub : index ,
1281
+ %step : index ,
1282
+ %ida : index ) -> vector <4 x4 xf32 > {
1283
+ %cst_0 = arith.constant 0.0 : f32
1284
+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1285
+ %final =
1286
+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1287
+ %read = vector.transfer_read %a [%ida , %i ], %cst_0 : tensor <128 x128 xf32 >, vector <4 x4 xf32 >
1288
+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1289
+ scf.yield %out : vector <4 x4 xf32 >
1290
+ }
1291
+ func.return %final : vector <4 x4 xf32 >
1292
+ }
1293
+
1294
+ // -----
1295
+
1296
+ // CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_memref
1297
+ // CHECK-NOT: vector.transfer_read
1298
+ // CHECK: scf.for
1299
+ // CHECK: vector.transfer_read
1300
+ // CHECK: arith.addf
1301
+ // CHECK: scf.yield
1302
+ func.func @do_not_hoist_vector_transfer_ops_memref (
1303
+ %a : memref <128 x128 xf32 >,
1304
+ %lb : index ,
1305
+ %ub : index ,
1306
+ %step : index ,
1307
+ %ida : index ,
1308
+ %idb : index ) -> vector <4 x4 xf32 > {
1309
+ %cst_0 = arith.constant 0.0 : f32
1310
+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1311
+ %final =
1312
+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1313
+ %read = vector.transfer_read %a [%ida , %idb ], %cst_0 : memref <128 x128 xf32 >, vector <4 x4 xf32 >
1314
+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1315
+ scf.yield %out : vector <4 x4 xf32 >
1316
+ }
1317
+ func.return %final : vector <4 x4 xf32 >
1318
+ }
0 commit comments