Skip to content

Commit d8ee28b

Browse files
[mlir][Linalg] Extend buffer allocation to support Linalg init tensors
This revision adds init_tensors support to buffer allocation for Linalg on tensors. Currently makes the assumption that the init_tensors fold onto the first output tensors. This assumption is not currently enforced or cast in stone and requires experimenting with tiling linalg on tensors for ops **without reductions**. Still this allows progress towards the end-to-end goal.
1 parent 8fa45e1 commit d8ee28b

File tree

4 files changed

+210
-29
lines changed

4 files changed

+210
-29
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op,
374374

375375
template <typename GenericOpType>
376376
static LogicalResult verifyGenericOp(GenericOpType op) {
377-
auto nInputViews = op.getNumInputs();
378377
auto nLoops = op.getNumLoops();
379378

380379
if (op.inputs().size() + op.output_buffers().size() +
@@ -410,8 +409,7 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
410409
auto idx = en.index();
411410
auto m = en.value().template cast<AffineMapAttr>().getValue();
412411
indexingMaps.push_back(m); // Save reference to map for further checks.
413-
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
414-
: op.getOutputShapedType(idx - nInputViews);
412+
auto view = op.getShapedType(idx);
415413

416414
if (m.getNumSymbols() != expectedNumSymbols)
417415
return op.emitOpError("expected the number of symbols in indexing_map #")

mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,50 @@ class GenericOpConverter
3939
linalg::GenericOpAdaptor adaptor(operands,
4040
op.getOperation()->getAttrDictionary());
4141

42-
// TODO: support ops with reduction.
43-
if (!op.init_tensors().empty())
44-
return failure();
45-
4642
// All inputs need to be turned into buffers first. Until then, bail out.
4743
if (llvm::any_of(adaptor.inputs(),
4844
[](Value in) { return !in.getType().isa<MemRefType>(); }))
4945
return failure();
5046

47+
// All init_tensors need to be turned into buffers first. Until then, bail
48+
// out.
49+
if (llvm::any_of(adaptor.init_tensors(),
50+
[](Value in) { return !in.getType().isa<MemRefType>(); }))
51+
return failure();
52+
5153
Location loc = op.getLoc();
52-
SmallVector<Value, 2> outputBuffers, newOutputBuffers;
53-
outputBuffers.assign(adaptor.output_buffers().begin(),
54-
adaptor.output_buffers().end());
54+
SmallVector<Value, 2> newOutputBuffers;
5555
newOutputBuffers.reserve(op.getNumOutputs());
5656
newOutputBuffers.append(adaptor.output_buffers().begin(),
5757
adaptor.output_buffers().end());
5858

5959
// Update all types to memref types.
60-
for (Type t : op.getResultTypes()) {
61-
auto type = t.cast<ShapedType>();
60+
// Assume the init tensors fold onto the first results.
61+
// TODO: update this assumption because the reality is more complex under
62+
// linalg on tensor based transformations.
63+
for (auto en : llvm::enumerate(op.getResultTypes())) {
64+
auto type = en.value().cast<ShapedType>();
6265
if (!type.hasStaticShape())
6366
return rewriter.notifyMatchFailure(
6467
op, "dynamic shapes not currently supported");
6568
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
66-
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
67-
newOutputBuffers.push_back(alloc);
69+
bool foldedInitTensor = en.index() < op.getNumInitTensors();
70+
if (foldedInitTensor) {
71+
// Dealing with an init tensor requires distinguishing between 1-use
72+
// and many-use cases which would create aliasing and WAR hazards.
73+
Value initTensor = op.getInitTensor(en.index());
74+
Value initBuffer = adaptor.init_tensors()[en.index()];
75+
if (initTensor.hasOneUse()) {
76+
newOutputBuffers.push_back(initBuffer);
77+
continue;
78+
}
79+
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
80+
rewriter.create<linalg::CopyOp>(loc, initBuffer, alloc);
81+
newOutputBuffers.push_back(alloc);
82+
} else {
83+
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
84+
newOutputBuffers.push_back(alloc);
85+
}
6886
}
6987

7088
// Generate a new linalg operation that works on buffers.
@@ -82,8 +100,12 @@ class GenericOpConverter
82100
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
83101
oldBlock.getArgumentTypes());
84102

85-
// Add the result arguments to the new block.
86-
for (Value v : newOutputBuffers)
103+
// Add the result arguments that do not come from init_tensors to the new
104+
// block.
105+
// TODO: update this assumption because the reality is more complex under
106+
// linalg on tensor based transformations.
107+
for (Value v :
108+
ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size()))
87109
newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
88110

89111
// Clone the body of the old block to the new block.

mlir/test/Transforms/buffer-placement-preparation.mlir

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,141 @@ func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg
382382
// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
383383
// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
384384
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
385+
386+
// -----
387+
388+
#accesses = [
389+
affine_map<(i, j, k) -> (j, i, k)>,
390+
affine_map<(i, j, k) -> (i, j)>
391+
]
392+
393+
#trait = {
394+
indexing_maps = #accesses,
395+
iterator_types = ["parallel", "parallel", "reduction"]
396+
}
397+
398+
func @generic_with_init_tensor(
399+
%arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) {
400+
401+
%0 = linalg.generic #trait
402+
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
403+
init(%arg1 : tensor<3x2xf32>) {
404+
^bb(%v0: vector<3x4xi4>, %v1: f32) :
405+
%f0 = constant 0.0 : f32
406+
linalg.yield %f0 : f32
407+
} -> tensor<3x2xf32>
408+
409+
return %0 : tensor<3x2xf32>
410+
}
411+
// CHECK-LABEL: func @generic_with_init_tensor
412+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) {
413+
// CHECK-NEXT: linalg.generic
414+
// CHECK: linalg.copy(%[[ARG1]], %[[RESULT0]])
415+
// CHECK-NEXT: return
416+
// CHECK-NOT: %
417+
418+
// -----
419+
420+
#accesses = [
421+
affine_map<(i, j, k) -> (j, i, k)>,
422+
affine_map<(i, j, k) -> (i, j)>
423+
]
424+
425+
#trait = {
426+
indexing_maps = #accesses,
427+
iterator_types = ["parallel", "parallel", "reduction"]
428+
}
429+
430+
func @init_tensor_with_2_uses(
431+
%arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) {
432+
433+
%0 = linalg.generic #trait
434+
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
435+
init(%arg1 : tensor<3x2xf32>) {
436+
^bb(%v0: vector<3x4xi4>, %v1: f32) :
437+
%f0 = constant 0.0 : f32
438+
linalg.yield %f0 : f32
439+
} -> tensor<3x2xf32>
440+
441+
%1 = linalg.generic #trait
442+
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
443+
init(%arg1 : tensor<3x2xf32>) {
444+
^bb(%v0: vector<3x4xi4>, %v1: f32) :
445+
%f0 = constant 0.0 : f32
446+
linalg.yield %f0 : f32
447+
} -> tensor<3x2xf32>
448+
449+
return %0, %1 : tensor<3x2xf32>, tensor<3x2xf32>
450+
}
451+
// CHECK-LABEL: func @init_tensor_with_2_uses
452+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>, %[[RESULT1:.*]]: memref<3x2xf32>) {
453+
// CHECK-NEXT: %[[ALLOC0:.*]] = alloc
454+
// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]])
455+
// CHECK-NEXT: linalg.generic
456+
// CHECK-SAME: outs(%[[ALLOC0]]
457+
// CHECK-NEXT: ^bb
458+
// CHECK-NEXT: constant
459+
// CHECK-NEXT: yield
460+
// CHECK-NEXT: }
461+
// CHECK-NEXT: %[[ALLOC1:.*]] = alloc
462+
// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC1]])
463+
// CHECK-NEXT: linalg.generic
464+
// CHECK-SAME: outs(%[[ALLOC1]]
465+
// CHECK-NEXT: ^bb
466+
// CHECK-NEXT: constant
467+
// CHECK-NEXT: yield
468+
// CHECK-NEXT: }
469+
// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[RESULT0]])
470+
// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[RESULT1]])
471+
// CHECK-NEXT: return
472+
// CHECK-NOT: %
473+
474+
// -----
475+
476+
#accesses = [
477+
affine_map<(i, j, k) -> (j, i, k)>,
478+
affine_map<(i, j, k) -> (i, j)>
479+
]
480+
481+
#trait = {
482+
indexing_maps = #accesses,
483+
iterator_types = ["parallel", "parallel", "reduction"]
484+
}
485+
486+
func @init_tensor_with_1_use_def_chain(
487+
%arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) {
488+
489+
%0 = linalg.generic #trait
490+
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
491+
init(%arg1 : tensor<3x2xf32>) {
492+
^bb(%v0: vector<3x4xi4>, %v1: f32) :
493+
%f0 = constant 0.0 : f32
494+
linalg.yield %f0 : f32
495+
} -> tensor<3x2xf32>
496+
497+
%1 = linalg.generic #trait
498+
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
499+
init(%0 : tensor<3x2xf32>) {
500+
^bb(%v0: vector<3x4xi4>, %v1: f32) :
501+
%f0 = constant 0.0 : f32
502+
linalg.yield %f0 : f32
503+
} -> tensor<3x2xf32>
504+
505+
return %1 : tensor<3x2xf32>
506+
}
507+
// CHECK-LABEL: func @init_tensor_with_1_use_def_chain
508+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) {
509+
// CHECK-NEXT: linalg.generic
510+
// CHECK-NEXT: ^bb
511+
// CHECK-NEXT: constant
512+
// CHECK-NEXT: yield
513+
// CHECK-NEXT: }
514+
// CHECK-NEXT: linalg.generic
515+
// CHECK-NEXT: ^bb
516+
// CHECK-NEXT: constant
517+
// CHECK-NEXT: yield
518+
// CHECK-NEXT: }
519+
// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT0]])
520+
// CHECK-NEXT: return
521+
// CHECK-NOT: %
522+

mlir/test/lib/Transforms/TestBufferPlacement.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,53 @@ struct TestBufferPlacementPreparationPass
5656
linalg::GenericOpAdaptor adaptor(operands,
5757
op.getOperation()->getAttrDictionary());
5858

59-
// TODO: support ops with reduction.
60-
if (!op.init_tensors().empty())
61-
return failure();
62-
6359
// All inputs need to be turned into buffers first. Until then, bail out.
6460
if (llvm::any_of(adaptor.inputs(), [](Value in) {
6561
return !in.getType().isa<MemRefType>();
6662
}))
6763
return failure();
6864

65+
// All init_tensors need to be turned into buffers first. Until then, bail
66+
// out.
67+
if (llvm::any_of(adaptor.init_tensors(), [](Value in) {
68+
return !in.getType().isa<MemRefType>();
69+
}))
70+
return failure();
71+
6972
Location loc = op.getLoc();
70-
SmallVector<Value, 2> outputBuffers, newOutputBuffers;
71-
outputBuffers.assign(adaptor.output_buffers().begin(),
72-
adaptor.output_buffers().end());
73+
SmallVector<Value, 2> newOutputBuffers;
7374
newOutputBuffers.reserve(op.getNumOutputs());
7475
newOutputBuffers.append(adaptor.output_buffers().begin(),
7576
adaptor.output_buffers().end());
7677

7778
// Update all types to memref types.
78-
for (Type t : op.getResultTypes()) {
79-
auto type = t.cast<ShapedType>();
79+
// Assume the init tensors fold onto the first results.
80+
// TODO: update this assumption because the reality is more complex under
81+
// linalg on tensor based transformations.
82+
for (auto en : llvm::enumerate(op.getResultTypes())) {
83+
auto type = en.value().cast<ShapedType>();
8084
if (!type.hasStaticShape())
8185
return rewriter.notifyMatchFailure(
8286
op, "dynamic shapes not currently supported");
8387
auto memrefType =
8488
MemRefType::get(type.getShape(), type.getElementType());
85-
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
86-
newOutputBuffers.push_back(alloc);
89+
bool foldedInitTensor = en.index() < op.getNumInitTensors();
90+
if (foldedInitTensor) {
91+
// Dealing with an init tensor requires distinguishing between 1-use
92+
// and many-use cases which would create aliasing and WAR hazards.
93+
Value initTensor = op.getInitTensor(en.index());
94+
Value initBuffer = adaptor.init_tensors()[en.index()];
95+
if (initTensor.hasOneUse()) {
96+
newOutputBuffers.push_back(initBuffer);
97+
continue;
98+
}
99+
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
100+
rewriter.create<linalg::CopyOp>(loc, initBuffer, alloc);
101+
newOutputBuffers.push_back(alloc);
102+
} else {
103+
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
104+
newOutputBuffers.push_back(alloc);
105+
}
87106
}
88107

89108
// Generate a new linalg operation that works on buffers.
@@ -101,8 +120,12 @@ struct TestBufferPlacementPreparationPass
101120
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
102121
oldBlock.getArgumentTypes());
103122

104-
// Add the result arguments to the new block.
105-
for (Value v : newOutputBuffers)
123+
// Add the result arguments that do not come from init_tensors to the new
124+
// block.
125+
// TODO: update this assumption because the reality is more complex under
126+
// linalg on tensor based transformations.
127+
for (Value v : ValueRange(newOutputBuffers)
128+
.drop_front(adaptor.init_tensors().size()))
106129
newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
107130

108131
// Clone the body of the old block to the new block.

0 commit comments

Comments
 (0)