Skip to content

Commit 3c3810e

Browse files
author
Nicolas Vasilache
committed
[mlir][vector] Avoid hoisting alloca'ed temporary buffers across AutomaticAllocationScope
This revision avoids incorrect hoisting of alloca'd buffers across an AutomaticAllocationScope boundary. In the more general case, we will probably need a ParallelScope-like interface. Differential Revision: https://reviews.llvm.org/D118768
1 parent 83b7454 commit 3c3810e

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,22 @@ struct BufferAllocs {
267267
Value maskBuffer;
268268
};
269269

270+
// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
271+
static Operation *getAutomaticAllocationScope(Operation *op) {
272+
Operation *scope =
273+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
274+
assert(scope && "Expected op to be inside automatic allocation scope");
275+
return scope;
276+
}
277+
270278
/// Allocate temporary buffers for data (vector) and mask (if present).
271-
/// TODO: Parallelism and threadlocal considerations.
272279
template <typename OpTy>
273280
static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
274281
Location loc = xferOp.getLoc();
275282
OpBuilder::InsertionGuard guard(b);
276-
Operation *scope =
277-
xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
278-
assert(scope && "Expected op to be inside automatic allocation scope");
283+
Operation *scope = getAutomaticAllocationScope(xferOp);
284+
assert(scope->getNumRegions() == 1 &&
285+
"AutomaticAllocationScope with >1 regions");
279286
b.setInsertionPointToStart(&scope->getRegion(0).front());
280287

281288
BufferAllocs result;

mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,14 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b,
438438
});
439439
}
440440

441+
// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
442+
static Operation *getAutomaticAllocationScope(Operation *op) {
443+
Operation *scope =
444+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
445+
assert(scope && "Expected op to be inside automatic allocation scope");
446+
return scope;
447+
}
448+
441449
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
442450
/// masking) fastpath and a slowpath.
443451
///
@@ -538,12 +546,14 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
538546
// Top of the function `alloc` for transient storage.
539547
Value alloc;
540548
{
541-
FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
542549
RewriterBase::InsertionGuard guard(b);
543-
b.setInsertionPointToStart(&funcOp.getRegion().front());
550+
Operation *scope = getAutomaticAllocationScope(xferOp);
551+
assert(scope->getNumRegions() == 1 &&
552+
"AutomaticAllocationScope with >1 regions");
553+
b.setInsertionPointToStart(&scope->getRegion(0).front());
544554
auto shape = xferOp.getVectorType().getShape();
545555
Type elementType = xferOp.getVectorType().getElementType();
546-
alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
556+
alloc = b.create<memref::AllocaOp>(scope->getLoc(),
547557
MemRefType::get(shape, elementType),
548558
ValueRange{}, b.getI64IntegerAttr(32));
549559
}

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,22 @@ func @transfer_write_strided(%A : vector<4xf32>, %B : memref<8x4xf32, affine_map
481481
// CHECK-LABEL: transfer_write_strided(
482482
// CHECK: scf.for
483483
// CHECK: store
484+
485+
// -----
486+
487+
func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()
488+
489+
// CHECK-LABEL: transfer_read_within_async_execute
490+
func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.token {
491+
%c0 = arith.constant 0 : index
492+
%f0 = arith.constant 0.0 : f32
493+
// CHECK-NOT: alloca
494+
// CHECK: async.execute
495+
// CHECK: alloca
496+
%token = async.execute {
497+
%0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32>
498+
call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> ()
499+
async.yield
500+
}
501+
return %token : !async.token
502+
}

mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,22 @@ func @split_vector_transfer_write_strided_2d(
393393
// LINALG: }
394394
// LINALG: return
395395
// LINALG: }
396+
397+
// -----
398+
399+
func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()
400+
401+
// CHECK-LABEL: transfer_read_within_async_execute
402+
func @transfer_read_within_async_execute(%A : memref<?x?xf32>) -> !async.token {
403+
%c0 = arith.constant 0 : index
404+
%f0 = arith.constant 0.0 : f32
405+
// CHECK-NOT: alloca
406+
// CHECK: async.execute
407+
// CHECK: alloca
408+
%token = async.execute {
409+
%0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x?xf32>, vector<2x2xf32>
410+
call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> ()
411+
async.yield
412+
}
413+
return %token : !async.token
414+
}

0 commit comments

Comments
 (0)