Skip to content

Commit c89c31a

Browse files
[mlir][bufferization] Fix bufferization of repetitive regions
The previous strategy was too complex and faulty. Op dominance cannot be used to rule out RaW conflicts due to op ordering if the reading op and the conflicting writing op are in a sub repetitive region of the closest enclosing repetitive region of the definition of the read value. Differential Revision: https://reviews.llvm.org/D143087
1 parent 993bce9 commit c89c31a

File tree

4 files changed

+199
-87
lines changed

4 files changed

+199
-87
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ Region *getEnclosingRepetitiveRegion(Value value,
575575
Region *getEnclosingRepetitiveRegion(Block *block,
576576
const BufferizationOptions &options);
577577

578+
/// Assuming that the given region is repetitive, find the next enclosing
579+
/// repetitive region.
580+
Region *getNextEnclosingRepetitiveRegion(Region *region,
581+
const BufferizationOptions &options);
582+
578583
namespace detail {
579584
/// This is the default implementation of
580585
/// BufferizableOpInterface::getAliasingOpOperands. Should not be called from

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
4141
using namespace mlir;
4242
using namespace bufferization;
4343

44+
static bool isRepetitiveRegion(Region *region,
45+
const BufferizationOptions &options) {
46+
Operation *op = region->getParentOp();
47+
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
48+
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
49+
return true;
50+
return false;
51+
}
52+
4453
Region *bufferization::getEnclosingRepetitiveRegion(
4554
Operation *op, const BufferizationOptions &options) {
4655
if (!op->getBlock())
@@ -52,11 +61,9 @@ Region *bufferization::getEnclosingRepetitiveRegion(
5261
Value value, const BufferizationOptions &options) {
5362
Region *region = value.getParentRegion();
5463
while (region) {
55-
Operation *op = region->getParentOp();
56-
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
57-
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
58-
return region;
59-
region = op->getParentRegion();
64+
if (isRepetitiveRegion(region, options))
65+
return region;
66+
region = region->getParentRegion();
6067
}
6168
return nullptr;
6269
}
@@ -67,13 +74,22 @@ Region *bufferization::getEnclosingRepetitiveRegion(
6774
Operation *op = nullptr;
6875
do {
6976
op = region->getParentOp();
70-
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
71-
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
72-
return region;
77+
if (isRepetitiveRegion(region, options))
78+
return region;
7379
} while ((region = op->getParentRegion()));
7480
return nullptr;
7581
}
7682

83+
Region *bufferization::getNextEnclosingRepetitiveRegion(
84+
Region *region, const BufferizationOptions &options) {
85+
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
86+
while ((region = region->getParentRegion())) {
87+
if (isRepetitiveRegion(region, options))
88+
break;
89+
}
90+
return region;
91+
}
92+
7793
Operation *bufferization::getOwnerOfValue(Value value) {
7894
if (auto opResult = value.dyn_cast<OpResult>())
7995
return opResult.getDefiningOp();

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 69 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -346,25 +346,27 @@ static bool happensBefore(Operation *a, Operation *b,
346346
return false;
347347
}
348348

349-
/// Return `true` if op dominance can be used to rule out read-after-write
350-
/// conflicts wrt. the given reads and writes.
349+
/// Return `true` if op dominance can be used to rule out a read-after-write
350+
/// conflicts based on the ordering of ops.
351351
///
352-
/// Op dominance can often be used to rule out potential conflicts such as
353-
/// "read" happens before "write". E.g., the following IR is not a RaW conflict
354-
/// because the the read happens *before* the write.
352+
/// Generalized op dominance can often be used to rule out potential conflicts
353+
/// due to "read happens before write". E.g., the following IR is not a RaW
354+
/// conflict because the read happens *before* the write.
355355
///
356-
/// %0 = ... : tensor<?xf32>
357-
/// "reading_op"(%0) : tensor<?xf32>
358-
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
356+
/// Example 1:
357+
/// %0 = ... : tensor<?xf32> // DEF
358+
/// "reading_op"(%0) : tensor<?xf32> // READ
359+
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
359360
///
360361
/// This is no longer true inside loops (or repetitive regions). In such cases,
361362
/// there may not be a meaningful `happensBefore` relationship because ops
362363
/// could be executed multiple times. E.g.:
363364
///
364-
/// %0 = ... : tensor<?xf32>
365+
/// Example 2:
366+
/// %0 = ... : tensor<?xf32> // DEF
365367
/// scf.for ... {
366-
/// "reading_op"(%0) : tensor<?xf32>
367-
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
368+
/// "reading_op"(%0) : tensor<?xf32> // READ
369+
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
368370
/// ...
369371
/// }
370372
///
@@ -374,92 +376,78 @@ static bool happensBefore(Operation *a, Operation *b,
374376
/// execution of writing_op. This is problematic because the tensor %0 they
375377
/// operate on (i.e., the "definition") is defined outside of the loop.
376378
///
377-
/// Counter example:
379+
/// On a high-level, there is a potential RaW in a program if there exists a
380+
/// possible program execution such that there is a sequence of DEF, followed
381+
/// by WRITE, followed by READ. Each additional DEF resets the sequence.
378382
///
383+
/// E.g.:
384+
/// No conflict: DEF, WRITE, DEF, READ
385+
/// Potential conflict: DEF, READ, WRITE, READ, WRITE
386+
///
387+
/// Example 1 has no conflict: DEF, READ, WRITE
388+
/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
389+
//
390+
/// Example 3:
379391
/// scf.for ... {
380392
/// %0 = ... : tensor<?xf32>
381393
/// "reading_op"(%0) : tensor<?xf32>
382394
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
383395
/// ...
384396
/// }
397+
/// This has no conflict: (DEF, READ, WRITE)*
385398
///
386-
/// In this example, the definition %0 is in the same repetitive region as
387-
/// "writing_op", so op dominance can be used to compute the `happensBefore`
388-
/// relationship.
389-
///
390-
/// Whether op dominance can be used or not is decided as follows: Find the
391-
/// closest enclosing repetitive region of all buffer writes wrt. the given
392-
/// tensor reads and writes. (The given sets of reads and writes contain the
393-
/// entire alias set.) In case of a read, we look at the op that defines the
394-
/// read value. In case of a write, we look at the op that is writing. If all of
395-
/// those ops are in the same closest enclosing repetitive region (nullptr in
396-
/// case of "no repetitive region" found at all), then op dominance can be used.
397-
/// Otherwise, it cannot be used.
398-
///
399-
/// Example: The common enclosing repetitive region is the scf.for loop.
400-
/// Op dominance can be used.
399+
/// Example 4:
400+
/// %0 = ... : tensor<?xf32>
401401
/// scf.for ... {
402-
/// %0 = tensor.generate
403-
/// "read"(%0)
402+
/// scf.for ... { "reading_op"(%0) }
403+
/// %1 = "writing_op"(%0)
404404
/// }
405+
/// This has a potential conflict: DEF, ((READ)*, WRITE)*
405406
///
406-
/// Example: The common enclosing repetitive region is nullptr: There is no
407-
/// repetitive region around the tensor.generate. Op dominance can be
408-
/// used.
409-
/// %0 = tensor.generate
410-
/// scf.for ... { "read"(%0) }
407+
/// Example 5:
408+
/// %0 = ... : tensor<?xf32>
409+
/// scf.for ... { %1 = "writing_op"(%0) }
410+
/// scf.for ... { "reading_op"(%0) }
411+
/// This has a potential conflict: DEF, WRITE*, READ*
411412
///
412-
/// Example: The common enclosing repetitive regions of tensor.generate and
413-
/// "write" differ. Op dominance cannot be used.
414-
/// %0 = tensor.generate
415-
/// scf.for ... {
416-
/// "read"(%0)
417-
/// "write"(%0)
418-
/// }
413+
/// The following rules are used to rule out RaW conflicts via ordering of ops:
419414
///
420-
/// Example: The common enclosing repetitive regions of tensor.generate and
421-
/// "write" differ, but there is no read of %0, so op dominance can be
422-
/// used.
423-
/// %0 = tensor.generate
424-
/// scf.for ... {
425-
/// "write"(%0)
426-
/// }
415+
/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
416+
/// a repetitive region that enclosing both READ and WRITE, we cannot rule
417+
/// out RaW conflict due to the ordering of ops.
418+
/// 2. Otherwise: There are no loops that interfere with our analysis; for
419+
/// analysis purposes, we can assume that there are no loops/repetitive
420+
/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
421+
/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
427422
///
428-
/// Note: iter_args of loops are not aliases of their respective block
429-
/// arguments, so op domanice can be used when analyzing ops that operate
430-
/// on them.
431-
bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
432-
const DenseSet<OpOperand *> &usesWrite,
423+
bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
424+
const SetVector<Value> &definitions,
433425
const AnalysisState &state) {
434426
const BufferizationOptions &options = state.getOptions();
435-
std::optional<Region *> commonEnclosingRegion;
427+
for (Value def : definitions) {
428+
Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options);
429+
Region *rDef = getEnclosingRepetitiveRegion(def, options);
436430

437-
// In case of a write, take the region in which the write takes place.
438-
for (OpOperand *uWrite : usesWrite) {
439-
Region *r = getEnclosingRepetitiveRegion(uWrite->getOwner(), options);
440-
if (!commonEnclosingRegion.has_value()) {
441-
commonEnclosingRegion = r;
431+
// READ and DEF are in the same repetitive region. `happensBefore` can be
432+
// used to rule out RaW conflicts due to op ordering.
433+
if (rRead == rDef)
442434
continue;
443-
}
444-
if (*commonEnclosingRegion != r)
445-
return false;
446-
}
447435

448-
// In case of a read, take the region which the read value is defined.
449-
for (OpOperand *uRead : usesRead) {
450-
// Optimization: Skip reads of values that have no defined contents.
451-
if (!state.bufferizesToMemoryWrite(uRead->get()))
452-
continue;
453-
Region *r = getEnclosingRepetitiveRegion(uRead->get(), options);
454-
if (!commonEnclosingRegion.has_value()) {
455-
commonEnclosingRegion = r;
456-
continue;
436+
// Find the enclosing repetitive region of READ that is closest to DEF but
437+
// not the repetitive region of DEF itself.
438+
while (true) {
439+
Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
440+
if (nextRegion == rDef)
441+
break;
442+
assert(nextRegion && "expected to find another repetitive region");
443+
rRead = nextRegion;
457444
}
458-
if (*commonEnclosingRegion != r)
445+
446+
// We cannot use op dominance if WRITE is inside the same repetitive region.
447+
if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
459448
return false;
460449
}
461-
462-
return commonEnclosingRegion.has_value();
450+
return true;
463451
}
464452

465453
/// Annotate IR with details about the detected RaW conflict.
@@ -507,10 +495,6 @@ static bool hasReadAfterWriteInterference(
507495
AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
508496
const BufferizationOptions &options = state.getOptions();
509497

510-
// Check if op dominance can be used to rule out read-after-write conflicts.
511-
bool useDominance = canUseOpDominance(usesRead, usesWrite, state);
512-
LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
513-
514498
for (OpOperand *uRead : usesRead) {
515499
Operation *readingOp = uRead->getOwner();
516500
LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
@@ -542,6 +526,12 @@ static bool hasReadAfterWriteInterference(
542526
<< uConflictingWrite->getOperandNumber() << " of "
543527
<< *uConflictingWrite->getOwner() << "\n");
544528

529+
// Check if op dominance can be used to rule out read-after-write
530+
// conflicts.
531+
bool useDominance =
532+
canUseOpDominance(uRead, uConflictingWrite, definitions, state);
533+
LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
534+
545535
// Throughout this loop, check for multiple requirements that have to be
546536
// met for uConflictingWrite to be an actual conflict.
547537
Operation *conflictingWritingOp = uConflictingWrite->getOwner();

mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,104 @@ func.func @no_raw_conflict_after_repetitive_use(%arg0: tensor<4xf32>,
697697

698698
return %2, %7 : tensor<4xf32>, tensor<4xf32>
699699
}
700+
701+
// -----
702+
703+
// CHECK-LABEL: func @read_of_bbarg_in_repetitive_region(
704+
func.func @read_of_bbarg_in_repetitive_region(
705+
%t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
706+
// CHECK: scf.for
707+
scf.for %iv = %a to %b step %c {
708+
// Must bufferize out-of-place because definition of read is in a different
709+
// repetitive region.
710+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["false"]}
711+
%2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
712+
%3 = tensor.extract %2[%a] : tensor<4xf32>
713+
vector.print %3 : f32
714+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
715+
%4 = tensor.insert %cst into %2[%a] : tensor<4xf32>
716+
%5 = tensor.extract %4[%a] : tensor<4xf32>
717+
vector.print %5 : f32
718+
}
719+
return
720+
}
721+
722+
// -----
723+
724+
// CHECK-LABEL: func @read_definition_in_same_repetitive_region_as_write(
725+
func.func @read_definition_in_same_repetitive_region_as_write(
726+
%t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
727+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
728+
%1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
729+
// CHECK: scf.for
730+
scf.for %iv = %a to %b step %c {
731+
// Can bufferize in-place.
732+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
733+
%2 = tensor.extract_slice %1[0][4][1] : tensor<10xf32> to tensor<4xf32>
734+
%3 = tensor.extract %2[%a] : tensor<4xf32>
735+
vector.print %3 : f32
736+
}
737+
return
738+
}
739+
740+
// -----
741+
742+
// CHECK-LABEL: func @read_definition_in_same_repetitive_region_as_conflicting_write(
743+
func.func @read_definition_in_same_repetitive_region_as_conflicting_write(
744+
%t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
745+
// Cannot bufferize in-place according to normal op dominance rules.
746+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "false", "none"]}
747+
%1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
748+
// CHECK: scf.for
749+
scf.for %iv = %a to %b step %c {
750+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
751+
%2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
752+
%3 = tensor.extract %2[%a] : tensor<4xf32>
753+
vector.print %3 : f32
754+
}
755+
return
756+
}
757+
758+
// -----
759+
760+
// CHECK: func @write_value_in_repetitive_region(
761+
func.func @write_value_in_repetitive_region(
762+
%t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
763+
%0 = tensor.extract %t[%a] : tensor<10xf32>
764+
vector.print %0 : f32
765+
766+
scf.for %iv = %a to %b step %c {
767+
// No further read of %0, so this can bufferize in-place.
768+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
769+
%2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
770+
// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
771+
%filled = linalg.fill ins(%cst : f32) outs(%2 : tensor<4xf32>) -> tensor<4xf32>
772+
%3 = tensor.extract %filled[%a] : tensor<4xf32>
773+
vector.print %3 : f32
774+
}
775+
return
776+
}
777+
778+
// -----
779+
780+
// CHECK-LABEL: func @nesting_op_repetitive_regions(
781+
func.func @nesting_op_repetitive_regions(
782+
%t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
783+
// Cannot bufferize in-place according to normal op dominance rules.
784+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "false", "none"]}
785+
%1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
786+
// CHECK: scf.for
787+
scf.for %iv1 = %a to %b step %c {
788+
// CHECK: scf.for
789+
scf.for %iv2 = %a to %b step %c {
790+
// CHECK: scf.for
791+
scf.for %iv3 = %a to %b step %c {
792+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
793+
%2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
794+
%3 = tensor.extract %2[%a] : tensor<4xf32>
795+
vector.print %3 : f32
796+
}
797+
}
798+
}
799+
return
800+
}

0 commit comments

Comments
 (0)