Skip to content

Commit 1fdf06d

Browse files
[mlir][bufferization] Reads from tensors with undefined data are not a conflict
Reading from tensor.empty or bufferization.alloc_tensor (without copy) cannot cause a conflict because these ops do not specify the contents of their result tensors. Differential Revision: https://reviews.llvm.org/D143183
1 parent 77910ac commit 1fdf06d

File tree

5 files changed

+57
-25
lines changed

5 files changed

+57
-25
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,14 +412,16 @@ class AnalysisState {
412412
/// in the operands) because their defining ops do not define the contents of
413413
/// the tensor.
414414
///
415+
/// Example:
416+
/// %a = tensor.empty() : tensor<10xf32>
417+
/// %b = arith.constant ... : tensor<10xf32>
418+
/// %r = arith.select %cond, %a, %b : tensor<10xf32>
419+
/// findDefinitions(%r) = {%b}. %a is excluded because it does not define the
420+
/// contents of the tensor.
421+
///
415422
/// Note: OpResults of unknown ops are handled conservatively and assumed to
416423
/// be definitions.
417-
///
418-
/// Note: When reaching an end of the reverse SSA use-def chain, that value
419-
/// is included regardless of whether it is a definition or not unless
420-
/// `alwaysIncludeLeaves` is unset.
421-
SetVector<Value> findDefinitions(Value value,
422-
bool alwaysIncludeLeaves = true) const;
424+
SetVector<Value> findDefinitions(Value value) const;
423425

424426
/// Return `true` if the given OpResult has been decided to bufferize inplace.
425427
virtual bool isInPlace(OpOperand &opOperand) const;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
494494
}
495495

496496
// Find the values that define the contents of the given value.
497-
llvm::SetVector<Value>
498-
AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const {
497+
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
499498
return findValueInReverseUseDefChain(
500499
value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
501-
/*followEquivalentOnly=*/false, alwaysIncludeLeaves);
500+
/*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
502501
}
503502

504503
AnalysisState::AnalysisState(const BufferizationOptions &options)

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
272272

273273
// If there is no preceding definition, the tensor contents are
274274
// undefined.
275-
if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
275+
if (findDefinitions(opResult).empty())
276276
for (OpOperand &use : opResult.getUses())
277277
undefinedTensorUses.insert(&use);
278278
}
@@ -513,8 +513,11 @@ static bool hasReadAfterWriteInterference(
513513

514514
for (OpOperand *uRead : usesRead) {
515515
Operation *readingOp = uRead->getOwner();
516+
LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
517+
LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
518+
<< " of " << *readingOp << "\n");
516519

517-
// Find most recent writes of uRead by following the SSA use-def chain.
520+
// Find the definition of uRead by following the SSA use-def chain.
518521
// E.g.:
519522
//
520523
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
@@ -525,14 +528,16 @@ static bool hasReadAfterWriteInterference(
525528
// definition is %0. Note that operations that create an alias but do not
526529
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
527530
SetVector<Value> definitions = state.findDefinitions(uRead->get());
531+
if (definitions.empty()) {
532+
// Fast path: No conflict if there are no definitions.
533+
LLVM_DEBUG(llvm::dbgs()
534+
<< " no conflict: read value has no definitions\n");
535+
continue;
536+
}
528537

529538
// Look for conflicting memory writes. Potential conflicts are writes to an
530539
// alias that have been decided to bufferize inplace.
531540
for (OpOperand *uConflictingWrite : usesWrite) {
532-
LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
533-
LLVM_DEBUG(llvm::dbgs()
534-
<< " uRead = operand " << uRead->getOperandNumber() << " of "
535-
<< *uRead->getOwner() << "\n");
536541
LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
537542
<< uConflictingWrite->getOperandNumber() << " of "
538543
<< *uConflictingWrite->getOwner() << "\n");
@@ -608,15 +613,15 @@ static bool hasReadAfterWriteInterference(
608613
LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
609614

610615
// No conflict if the conflicting write happens before the definition.
611-
if (Operation *writingOp = definition.getDefiningOp()) {
612-
if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
613-
// conflictingWritingOp happens before writingOp. No conflict.
616+
if (Operation *defOp = definition.getDefiningOp()) {
617+
if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
618+
// conflictingWritingOp happens before defOp. No conflict.
614619
LLVM_DEBUG(llvm::dbgs()
615620
<< " no conflict: write happens before definition\n");
616621
continue;
617622
}
618-
// No conflict if conflictingWritingOp is contained in writingOp.
619-
if (writingOp->isProperAncestor(conflictingWritingOp)) {
623+
// No conflict if conflictingWritingOp is contained in defOp.
624+
if (defOp->isProperAncestor(conflictingWritingOp)) {
620625
LLVM_DEBUG(
621626
llvm::dbgs()
622627
<< " no conflict: write is contained in definition\n");

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,27 @@ func.func @unknown_op_writing(%f: f32, %f2: f32, %pos: index) -> f32 {
3232
%3 = tensor.extract %1[%pos] : tensor<10xf32>
3333
return %3 : f32
3434
}
35+
36+
// -----
37+
38+
// CHECK-LABEL: func @read_of_undef_is_not_a_conflict(
39+
func.func @read_of_undef_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
40+
%0 = tensor.empty() : tensor<10xf32>
41+
// This can be in-place because the read below does reads undefined data.
42+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
43+
%1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
44+
%2 = tensor.extract %0[%idx] : tensor<10xf32>
45+
return %2 : f32
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: func @read_of_alloc_tensor_is_not_a_conflict(
51+
func.func @read_of_alloc_tensor_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
52+
%0 = bufferization.alloc_tensor() : tensor<10xf32>
53+
// This can be in-place because the read below does reads undefined data.
54+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
55+
%1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
56+
%2 = tensor.extract %0[%idx] : tensor<10xf32>
57+
return %2 : f32
58+
}

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,26 +713,28 @@ func.func @scf_foreach_privatized_but_not_copied(
713713
// -----
714714

715715
// CHECK-LABEL: func @scf_if_memory_space
716-
func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32)
716+
func.func @scf_if_memory_space(%c: i1, %f: f32, %cst: f32) -> (f32, f32)
717717
{
718718
%c0 = arith.constant 0 : index
719719
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
720-
%0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
720+
%alloc = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
721+
// CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<5xf32, 1>)
722+
%filled = linalg.fill ins(%cst : f32) outs(%alloc : tensor<5xf32>) -> tensor<5xf32>
721723
// CHECK: scf.if %{{.*}} -> (memref<5xf32, 1>) {
722724
%1 = scf.if %c -> tensor<5xf32> {
723725
// CHECK: %[[cloned:.*]] = bufferization.clone %[[alloc]]
724726
// CHECK: scf.yield %[[cloned]]
725-
scf.yield %0 : tensor<5xf32>
727+
scf.yield %filled : tensor<5xf32>
726728
} else {
727729
// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
728730
// CHECK: memref.store %{{.*}}, %[[alloc2]]
729731
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[alloc2]]
730732
// CHECK: memref.dealloc %[[alloc2]]
731733
// CHECK: scf.yield %[[cloned2]]
732-
%2 = tensor.insert %f into %0[%c0] : tensor<5xf32>
734+
%2 = tensor.insert %f into %filled[%c0] : tensor<5xf32>
733735
scf.yield %2 : tensor<5xf32>
734736
}
735-
%r0 = tensor.extract %0[%c0] : tensor<5xf32>
737+
%r0 = tensor.extract %filled[%c0] : tensor<5xf32>
736738
%r1 = tensor.extract %1[%c0] : tensor<5xf32>
737739
return %r0, %r1 : f32, f32
738740
}

0 commit comments

Comments
 (0)