Skip to content

Commit 32db6fb

Browse files
authored
[mlir][vector] Implement speculation for vector.transferx ops (#111533)
This patch implements speculation for vector.transfer_read/vector.transfer_write ops, allowing these ops to work with LICM.
1 parent 21da4e7 commit 32db6fb

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,7 @@ def Vector_TransferReadOp :
12401240
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
12411241
DeclareOpInterfaceMethods<MaskableOpInterface>,
12421242
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1243+
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
12431244
AttrSizedOperandSegments,
12441245
DestinationStyleOpInterface
12451246
]>,
@@ -1487,6 +1488,7 @@ def Vector_TransferWriteOp :
14871488
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
14881489
DeclareOpInterfaceMethods<MaskableOpInterface>,
14891490
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1491+
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
14901492
AttrSizedOperandSegments,
14911493
DestinationStyleOpInterface
14921494
]>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4245,6 +4245,12 @@ void TransferReadOp::getEffects(
42454245
SideEffects::DefaultResource::get());
42464246
}
42474247

4248+
Speculation::Speculatability TransferReadOp::getSpeculatability() {
4249+
if (hasPureTensorSemantics())
4250+
return Speculation::Speculatable;
4251+
return Speculation::NotSpeculatable;
4252+
}
4253+
42484254
namespace {
42494255
/// Store to load forwarding for transfer operations with permuation maps.
42504256
/// Even if the permutation maps are different we can still propagate the store
@@ -4627,6 +4633,12 @@ void TransferWriteOp::getEffects(
46274633
SideEffects::DefaultResource::get());
46284634
}
46294635

4636+
Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4637+
if (hasPureTensorSemantics())
4638+
return Speculation::Speculatable;
4639+
return Speculation::NotSpeculatable;
4640+
}
4641+
46304642
namespace {
46314643
/// Remove dead transfer write from the SSA chain so that it an be eliminated by
46324644
/// DCE

mlir/test/Transforms/loop-invariant-code-motion.mlir

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,110 @@ func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>,
12091209

12101210
func.return %final : tensor<?x128xi32>
12111211
}
1212+
1213+
// -----
1214+
1215+
// CHECK-LABEL: func @hoist_vector_transfer_ops
1216+
// CHECK: vector.transfer_read
1217+
// CHECK: scf.for
1218+
// CHECK-NOT: vector.transfer_read
1219+
// CHECK: arith.addf
1220+
// CHECK: scf.yield
1221+
func.func @hoist_vector_transfer_ops(
1222+
%a : tensor<128x128xf32>,
1223+
%lb : index,
1224+
%ub : index,
1225+
%step : index,
1226+
%ida : index,
1227+
%idb : index) -> vector<4x4xf32> {
1228+
%cst_0 = arith.constant 0.0 : f32
1229+
%cst = arith.constant dense<0.0> : vector<4x4xf32>
1230+
%final =
1231+
scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
1232+
%read = vector.transfer_read %a[%ida, %idb], %cst_0 : tensor<128x128xf32>, vector<4x4xf32>
1233+
%out = arith.addf %read, %acc : vector<4x4xf32>
1234+
scf.yield %out : vector<4x4xf32>
1235+
}
1236+
func.return %final : vector<4x4xf32>
1237+
}
1238+
1239+
// -----
1240+
1241+
// CHECK-LABEL: func @hoist_vector_transfer_ops
1242+
// CHECK: vector.transfer_write
1243+
// CHECK: vector.transfer_read
1244+
// CHECK: scf.for
1245+
// CHECK-NOT: vector.transfer_write
1246+
// CHECK-NOT: vector.transfer_read
1247+
// CHECK: arith.addf
1248+
// CHECK: scf.yield
1249+
func.func @hoist_vector_transfer_ops(
1250+
%lb : index,
1251+
%ub : index,
1252+
%step : index,
1253+
%ida : index,
1254+
%idb : index) -> vector<4x4xf32> {
1255+
%c0 = arith.constant 0 : index
1256+
%cst_0 = arith.constant 0.0 : f32
1257+
%cst = arith.constant dense<0.0> : vector<4x4xf32>
1258+
%empty = tensor.empty() : tensor<4x4xf32>
1259+
%final =
1260+
scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
1261+
%a = vector.transfer_write %cst, %empty[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
1262+
%read = vector.transfer_read %a[%c0, %c0], %cst_0 : tensor<4x4xf32>, vector<4x4xf32>
1263+
%out = arith.addf %read, %acc : vector<4x4xf32>
1264+
scf.yield %out : vector<4x4xf32>
1265+
}
1266+
func.return %final : vector<4x4xf32>
1267+
}
1268+
1269+
// -----
1270+
1271+
// CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_loop_dep
1272+
// CHECK-NOT: vector.transfer_read
1273+
// CHECK: scf.for
1274+
// CHECK: vector.transfer_read
1275+
// CHECK: arith.addf
1276+
// CHECK: scf.yield
1277+
func.func @do_not_hoist_vector_transfer_ops_loop_dep(
1278+
%a : tensor<128x128xf32>,
1279+
%lb : index,
1280+
%ub : index,
1281+
%step : index,
1282+
%ida : index) -> vector<4x4xf32> {
1283+
%cst_0 = arith.constant 0.0 : f32
1284+
%cst = arith.constant dense<0.0> : vector<4x4xf32>
1285+
%final =
1286+
scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
1287+
%read = vector.transfer_read %a[%ida, %i], %cst_0 : tensor<128x128xf32>, vector<4x4xf32>
1288+
%out = arith.addf %read, %acc : vector<4x4xf32>
1289+
scf.yield %out : vector<4x4xf32>
1290+
}
1291+
func.return %final : vector<4x4xf32>
1292+
}
1293+
1294+
// -----
1295+
1296+
// CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_memref
1297+
// CHECK-NOT: vector.transfer_read
1298+
// CHECK: scf.for
1299+
// CHECK: vector.transfer_read
1300+
// CHECK: arith.addf
1301+
// CHECK: scf.yield
1302+
func.func @do_not_hoist_vector_transfer_ops_memref(
1303+
%a : memref<128x128xf32>,
1304+
%lb : index,
1305+
%ub : index,
1306+
%step : index,
1307+
%ida : index,
1308+
%idb : index) -> vector<4x4xf32> {
1309+
%cst_0 = arith.constant 0.0 : f32
1310+
%cst = arith.constant dense<0.0> : vector<4x4xf32>
1311+
%final =
1312+
scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
1313+
%read = vector.transfer_read %a[%ida, %idb], %cst_0 : memref<128x128xf32>, vector<4x4xf32>
1314+
%out = arith.addf %read, %acc : vector<4x4xf32>
1315+
scf.yield %out : vector<4x4xf32>
1316+
}
1317+
func.return %final : vector<4x4xf32>
1318+
}

0 commit comments

Comments
 (0)