Skip to content

Commit dea33c8

Browse files
committed
[mlir][Transforms] teach CSE about recursive memory effects
Add support for reasoning about operations with recursive memory effects to CSE. The recursive effects are gathered by a helper function. I decided to allow returning duplicates from the helper function because there's no benefit to spending the computation time to remove them in the existing use case. Differential Revision: https://reviews.llvm.org/D156805
1 parent e6d5dcf commit dea33c8

File tree

6 files changed

+121
-13
lines changed

6 files changed

+121
-13
lines changed

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,17 @@ bool wouldOpBeTriviallyDead(Operation *op);
332332
/// conditions are satisfied.
333333
bool isMemoryEffectFree(Operation *op);
334334

335+
/// Returns the side effects of an operation. If the operation has
336+
/// RecursiveMemoryEffects, include all side effects of child operations.
337+
///
338+
/// std::nullopt indicates that an option did not have a memory effect interface
339+
/// and so no result could be obtained. An empty vector indicates that there
340+
/// were no memory effects found (but every operation implemented the memory
341+
/// effect interface or has RecursiveMemoryEffects). If the vector contains
342+
/// multiple effects, these effects may be duplicates.
343+
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
344+
getEffectsRecursively(Operation *rootOp);
345+
335346
/// Returns true if the given operation is speculatable, i.e. has no undefined
336347
/// behavior or other side effects.
337348
///

mlir/lib/Interfaces/SideEffectInterfaces.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,39 @@ bool mlir::isMemoryEffectFree(Operation *op) {
182182
return true;
183183
}
184184

185+
// the returned vector may contain duplicate effects
186+
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
187+
mlir::getEffectsRecursively(Operation *rootOp) {
188+
SmallVector<MemoryEffects::EffectInstance> effects;
189+
SmallVector<Operation *> effectingOps(1, rootOp);
190+
while (!effectingOps.empty()) {
191+
Operation *op = effectingOps.pop_back_val();
192+
193+
// If the operation has recursive effects, push all of the nested
194+
// operations on to the stack to consider.
195+
bool hasRecursiveEffects =
196+
op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
197+
if (hasRecursiveEffects) {
198+
for (Region &region : op->getRegions()) {
199+
for (Block &block : region) {
200+
for (Operation &nestedOp : block) {
201+
effectingOps.push_back(&nestedOp);
202+
}
203+
}
204+
}
205+
}
206+
207+
if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
208+
effectInterface.getEffects(effects);
209+
} else if (!hasRecursiveEffects) {
210+
// the operation does not have recursive memory effects or implement
211+
// the memory effect op interface. Its effects are unknown.
212+
return std::nullopt;
213+
}
214+
}
215+
return effects;
216+
}
217+
185218
bool mlir::isSpeculatable(Operation *op) {
186219
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
187220
if (!conditionallySpeculatable)

mlir/lib/Transforms/CSE.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,23 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
199199
}
200200
}
201201
while (nextOp && nextOp != toOp) {
202-
auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
203-
// TODO: Do we need to handle other effects generically?
204-
// If the operation does not implement the MemoryEffectOpInterface we
205-
// conservatively assumes it writes.
206-
if ((nextOpMemEffects &&
207-
nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
208-
!nextOpMemEffects) {
202+
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
203+
getEffectsRecursively(nextOp);
204+
if (!effects) {
205+
// TODO: Do we need to handle other effects generically?
206+
// If the operation does not implement the MemoryEffectOpInterface we
207+
// conservatively assume it writes.
209208
result.first->second =
210209
std::make_pair(nextOp, MemoryEffects::Write::get());
211210
return true;
212211
}
212+
213+
for (const MemoryEffects::EffectInstance &effect : *effects) {
214+
if (isa<MemoryEffects::Write>(effect.getEffect())) {
215+
result.first->second = {nextOp, MemoryEffects::Write::get()};
216+
return true;
217+
}
218+
}
213219
nextOp = nextOp->getNextNode();
214220
}
215221
result.first->second = std::make_pair(toOp, nullptr);

mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
332332
// CHECK: scf.yield %[[VAL_145]]
333333
// CHECK: }
334334
// CHECK: %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
335-
// CHECK: %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
336-
// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]]
335+
// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
337336
// CHECK: %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
338337
// CHECK: %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
339338
// CHECK: %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
@@ -529,4 +528,4 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
529528
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
530529
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
531530
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
532-
}
531+
}

mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,7 @@
142142
// CHECK: scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
143143
// CHECK: }
144144
// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
145-
// CHECK: %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
146-
// CHECK: %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
147-
// CHECK: memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
145+
// CHECK: memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
148146
// CHECK: scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
149147
// CHECK: }
150148
// CHECK: %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {

mlir/test/Transforms/cse.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,64 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
459459
// CHECK: }
460460
// CHECK-NOT: scf.if
461461
// CHECK: return %[[if]], %[[if]]
462+
463+
// CHECK-LABEL: @cse_recursive_effects_success
464+
func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
465+
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
466+
%0 = "test.op_with_memread"() : () -> (i32)
467+
468+
// do something with recursive effects, containing no side effects
469+
%true = arith.constant true
470+
// CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
471+
// CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
472+
%1 = scf.if %true -> (i32) {
473+
%c42 = arith.constant 42 : i32
474+
scf.yield %c42 : i32
475+
// CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
476+
// CHECK-NEXT: scf.yield %[[C42]]
477+
// CHECK-NEXT: } else {
478+
} else {
479+
%c24 = arith.constant 24 : i32
480+
scf.yield %c24 : i32
481+
// CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
482+
// CHECK-NEXT: scf.yield %[[C24]]
483+
// CHECK-NEXT: }
484+
}
485+
486+
// %2 can be removed
487+
// CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32
488+
%2 = "test.op_with_memread"() : () -> (i32)
489+
return %0, %2, %1 : i32, i32, i32
490+
}
491+
492+
// CHECK-LABEL: @cse_recursive_effects_failure
493+
func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
494+
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
495+
%0 = "test.op_with_memread"() : () -> (i32)
496+
497+
// do something with recursive effects, containing a write effect
498+
%true = arith.constant true
499+
// CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
500+
// CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
501+
%1 = scf.if %true -> (i32) {
502+
"test.op_with_memwrite"() : () -> ()
503+
// CHECK-NEXT: "test.op_with_memwrite"() : () -> ()
504+
%c42 = arith.constant 42 : i32
505+
scf.yield %c42 : i32
506+
// CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
507+
// CHECK-NEXT: scf.yield %[[C42]]
508+
// CHECK-NEXT: } else {
509+
} else {
510+
%c24 = arith.constant 24 : i32
511+
scf.yield %c24 : i32
512+
// CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
513+
// CHECK-NEXT: scf.yield %[[C24]]
514+
// CHECK-NEXT: }
515+
}
516+
517+
// %2 can not be be removed because of the write
518+
// CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32
519+
// CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32
520+
%2 = "test.op_with_memread"() : () -> (i32)
521+
return %0, %2, %1 : i32, i32, i32
522+
}

0 commit comments

Comments
 (0)