Skip to content

Commit 755c050

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Fix load/store operations generated while lower loops when
output has zero rank. While lowering to loops, no indices should be used in the load/store operation if the buffer is zero-rank. Differential Revision: https://reviews.llvm.org/D75391
1 parent f708c82 commit 755c050

File tree

2 files changed

+127
-29
lines changed

2 files changed

+127
-29
lines changed

mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,25 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
242242
// 1.a. Emit std_load from input views.
243243
for (unsigned i = 0; i < nInputs; ++i) {
244244
Value input = genericOp.getInput(i);
245-
if (!input.getType().cast<ShapedType>().getRank()) {
246-
indexedValues[i] = std_load(input);
247-
} else {
245+
if (input.getType().cast<ShapedType>().getRank()) {
248246
ValueHandleArray indexing(makeCanonicalAffineApplies(
249247
b, loc, genericOp.getInputIndexingMap(i), allIvs));
250248
indexedValues[i] = std_load(input, indexing);
249+
} else {
250+
indexedValues[i] = std_load(input);
251251
}
252252
}
253253

254254
// 1.b. Emit std_load from output views.
255255
for (unsigned i = 0; i < nOutputs; ++i) {
256-
ValueHandleArray indexing(makeCanonicalAffineApplies(
257-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
258-
indexedValues[nInputs + i] =
259-
std_load(genericOp.getOutputBuffer(i), indexing);
256+
Value output = genericOp.getOutputBuffer(i);
257+
if (output.getType().cast<ShapedType>().getRank()) {
258+
ValueHandleArray indexing(makeCanonicalAffineApplies(
259+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
260+
indexedValues[nInputs + i] = std_load(output, indexing);
261+
} else {
262+
indexedValues[nInputs + i] = std_load(output);
263+
}
260264
}
261265

262266
auto funcOp = genericOp.getFunction();
@@ -267,9 +271,14 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
267271

268272
// 3. Emit std_store.
269273
for (unsigned i = 0; i < nOutputs; ++i) {
270-
ValueHandleArray indexing(makeCanonicalAffineApplies(
271-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
272-
std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
274+
Value output = genericOp.getOutputBuffer(i);
275+
if (output.getType().cast<ShapedType>().getRank()) {
276+
ValueHandleArray indexing(makeCanonicalAffineApplies(
277+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
278+
std_store(callOp->getResult(i), output, indexing);
279+
} else {
280+
std_store(callOp->getResult(i), output);
281+
}
273282
}
274283
return;
275284
}
@@ -288,10 +297,15 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
288297
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
289298
assert(yieldOp->getNumOperands() == nOutputs);
290299
for (unsigned i = 0; i < nOutputs; ++i) {
291-
ValueHandleArray indexing(makeCanonicalAffineApplies(
292-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
293-
std_store(map.lookup(yieldOp->getOperand(i)),
294-
genericOp.getOutputBuffer(i), indexing);
300+
Value output = genericOp.getOutputBuffer(i);
301+
if (output.getType().cast<ShapedType>().getRank()) {
302+
ValueHandleArray indexing(makeCanonicalAffineApplies(
303+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
304+
std_store(map.lookup(yieldOp->getOperand(i)),
305+
genericOp.getOutputBuffer(i), indexing);
306+
} else {
307+
std_store(map.lookup(yieldOp->getOperand(i)), output);
308+
}
295309
}
296310
}
297311
};
@@ -348,21 +362,25 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
348362
// 1.a. Emit std_load from input views.
349363
for (unsigned i = 0; i < nInputs; ++i) {
350364
Value input = indexedGenericOp.getInput(i);
351-
if (!input.getType().cast<ShapedType>().getRank()) {
352-
indexedValues[nLoops + i] = std_load(input);
353-
} else {
365+
if (input.getType().cast<ShapedType>().getRank()) {
354366
ValueHandleArray indexing(makeCanonicalAffineApplies(
355367
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
356368
indexedValues[nLoops + i] = std_load(input, indexing);
369+
} else {
370+
indexedValues[nLoops + i] = std_load(input);
357371
}
358372
}
359373

360374
// 1.b. Emit std_load from output views.
361375
for (unsigned i = 0; i < nOutputs; ++i) {
362-
ValueHandleArray indexing(makeCanonicalAffineApplies(
363-
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
364-
indexedValues[nLoops + nInputs + i] =
365-
std_load(indexedGenericOp.getOutputBuffer(i), indexing);
376+
Value output = indexedGenericOp.getOutputBuffer(i);
377+
if (output.getType().cast<ShapedType>().getRank()) {
378+
ValueHandleArray indexing(makeCanonicalAffineApplies(
379+
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
380+
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
381+
} else {
382+
indexedValues[nLoops + nInputs + i] = std_load(output);
383+
}
366384
}
367385

368386
if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -372,10 +390,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
372390

373391
// 3. Emit std_store.
374392
for (unsigned i = 0; i < nOutputs; ++i) {
375-
ValueHandleArray indexing(makeCanonicalAffineApplies(
376-
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
377-
std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
378-
indexing);
393+
Value output = indexedGenericOp.getOutputBuffer(i);
394+
if (output.getType().cast<ShapedType>().getRank()) {
395+
ValueHandleArray indexing(makeCanonicalAffineApplies(
396+
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
397+
std_store(callOp->getResult(i), output, indexing);
398+
} else {
399+
std_store(callOp->getResult(i), output);
400+
}
379401
}
380402
return;
381403
}
@@ -394,10 +416,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
394416
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
395417
assert(yieldOp->getNumOperands() == nOutputs);
396418
for (unsigned i = 0; i < nOutputs; ++i) {
397-
ValueHandleArray indexing(makeCanonicalAffineApplies(
398-
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
399-
std_store(map.lookup(yieldOp->getOperand(i)),
400-
indexedGenericOp.getOutputBuffer(i), indexing);
419+
Value output = indexedGenericOp.getOutputBuffer(i);
420+
if (output.getType().cast<ShapedType>().getRank()) {
421+
ValueHandleArray indexing(makeCanonicalAffineApplies(
422+
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
423+
std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
424+
} else {
425+
std_store(map.lookup(yieldOp->getOperand(i)), output);
426+
}
401427
}
402428
}
403429
};

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,75 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
411411
// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
412412
// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
413413
// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
414+
415+
#reduce_1D_access = [
416+
affine_map<(i) -> (i)>,
417+
affine_map<(i) -> (0)>
418+
]
419+
420+
#trait_reduce_1D = {
421+
args_in = 1,
422+
args_out = 1,
423+
indexing_maps = #reduce_1D_access,
424+
iterator_types = ["reduction"],
425+
library_call = "some_reduce_external_fn"
426+
}
427+
428+
func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
429+
{
430+
linalg.generic #trait_reduce_1D %arg0, %arg1 {
431+
^bb(%a: f32, %b: f32) :
432+
%0 = addf %a, %b : f32
433+
linalg.yield %0 : f32
434+
} : memref<?xf32>, memref<f32>
435+
return
436+
}
437+
// CHECK-LABEL: @generic_op_1D_reduce
438+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
439+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
440+
// CHECK: loop.for %[[i:.*]] = {{.*}}
441+
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
442+
// CHECK: %[[b:.*]] = load %[[ARG1]][]
443+
// CHECK: %[[c:.*]] = addf %[[a]], %[[b]] : f32
444+
// CHECK: store %[[c]], %[[ARG1]][]
445+
446+
447+
#reduce_init_1D_access = [
448+
affine_map<(i) -> (i)>,
449+
affine_map<(i) -> (0)>,
450+
affine_map<(i) -> (0)>
451+
]
452+
453+
#trait_reduce_init_1D = {
454+
args_in = 2,
455+
args_out = 1,
456+
indexing_maps = #reduce_init_1D_access,
457+
iterator_types = ["reduction"],
458+
library_call = "some_reduce_external_fn"
459+
}
460+
461+
func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
462+
%arg1: memref<f32>,
463+
%arg2: memref<f32>)
464+
{
465+
linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
466+
^bb(%i : index, %a: f32, %b: f32, %c: f32) :
467+
%0 = constant 0 : index
468+
%1 = cmpi "eq", %0, %i : index
469+
%2 = select %1, %b, %c : f32
470+
%3 = addf %a, %2 : f32
471+
linalg.yield %3 : f32
472+
} : memref<?xf32>, memref<f32>, memref<f32>
473+
return
474+
}
475+
// CHECK-LABEL: @indexed_generic_op_1D_reduce
476+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
477+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
478+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
479+
// CHECK: loop.for %[[i:.*]] = {{.*}}
480+
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
481+
// CHECK: %[[b:.*]] = load %[[ARG1]][]
482+
// CHECK: %[[c:.*]] = load %[[ARG2]][]
483+
// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
484+
// CHECK: %[[e:.*]] = addf %[[a]], %[[d]]
485+
// CHECK: store %[[e]], %[[ARG2]][]

0 commit comments

Comments
 (0)