Skip to content

Commit cb3542e

Browse files
committed
[MLIR][TOSA] Added lowerings for Reduce operations to Linalg
Lowerings for min, max, prod, and sum reduction operations on int and float values. This includes reduction tests for both cases. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D97893
1 parent 9525af7 commit cb3542e

File tree

2 files changed

+263
-6
lines changed

2 files changed

+263
-6
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 165 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
269269
SmallVector<Type> opResultTypes;
270270
SmallVector<Value> initTensors;
271271
for (auto result : results) {
272-
auto resultType = result.getType().template cast<ShapedType>();
273-
if (!resultType.hasStaticShape())
272+
auto resultTy = result.getType().template cast<ShapedType>();
273+
if (!resultTy.hasStaticShape())
274274
return rewriter.notifyMatchFailure(
275275
operation,
276276
"tosa to linalg conversion expects statically shaped tensors");
277277

278278
initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
279-
loc, ArrayRef<Value>({}), resultType.getShape(),
280-
resultType.getElementType()));
279+
loc, ArrayRef<Value>({}), resultTy.getShape(),
280+
resultTy.getElementType()));
281281
opResultTypes.push_back(result.getType());
282282
}
283283

@@ -330,6 +330,152 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
330330
return success();
331331
}
332332

333+
// Returns the constant initial value for a given reduction operation. The
334+
// attribute type varies depending on the element type required.
335+
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
336+
PatternRewriter &rewriter) {
337+
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
338+
return rewriter.getFloatAttr(elementTy, 0.0);
339+
340+
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
341+
return rewriter.getIntegerAttr(elementTy, 0);
342+
343+
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
344+
return rewriter.getFloatAttr(elementTy, 1.0);
345+
346+
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
347+
return rewriter.getIntegerAttr(elementTy, 1);
348+
349+
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
350+
return rewriter.getFloatAttr(
351+
elementTy, APFloat::getLargest(
352+
elementTy.cast<FloatType>().getFloatSemantics(), false));
353+
354+
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
355+
return rewriter.getIntegerAttr(
356+
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
357+
358+
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
359+
return rewriter.getFloatAttr(
360+
elementTy, APFloat::getLargest(
361+
elementTy.cast<FloatType>().getFloatSemantics(), true));
362+
363+
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
364+
return rewriter.getIntegerAttr(
365+
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
366+
367+
return {};
368+
}
369+
370+
// Creates the body calculation for a reduction. The operations vary depending
371+
// on the input type.
372+
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
373+
ValueRange args,
374+
Type elementTy,
375+
PatternRewriter &rewriter) {
376+
Location loc = op->getLoc();
377+
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
378+
return rewriter.create<AddFOp>(loc, args);
379+
}
380+
381+
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
382+
return rewriter.create<AddIOp>(loc, args);
383+
}
384+
385+
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
386+
return rewriter.create<MulFOp>(loc, args);
387+
}
388+
389+
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
390+
return rewriter.create<MulIOp>(loc, args);
391+
}
392+
393+
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
394+
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
395+
args[0], args[1]);
396+
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
397+
}
398+
399+
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
400+
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
401+
args[0], args[1]);
402+
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
403+
}
404+
405+
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
406+
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
407+
args[0], args[1]);
408+
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
409+
}
410+
411+
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
412+
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
413+
args[0], args[1]);
414+
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
415+
}
416+
417+
return {};
418+
}
419+
420+
// Performs the match and rewrite for reduction operations. This includes
421+
// declaring a correctly sized initial value, and the linalg.generic operation
422+
// that reduces across the specified axis.
423+
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
424+
PatternRewriter &rewriter) {
425+
auto loc = op->getLoc();
426+
auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
427+
auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
428+
auto elementTy = resultTy.getElementType();
429+
Value input = op->getOperand(0);
430+
431+
// First fill the output buffer with the init value.
432+
auto initTensor = rewriter
433+
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
434+
resultTy.getShape(),
435+
resultTy.getElementType())
436+
.result();
437+
438+
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
439+
if (!fillValueAttr)
440+
return rewriter.notifyMatchFailure(
441+
op, "No initial value found for reduction operation");
442+
443+
auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
444+
auto filledTensor =
445+
rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
446+
447+
SmallVector<AffineExpr, 2> srcExprs;
448+
SmallVector<AffineExpr, 2> dstExprs;
449+
SmallVector<StringRef, 4> iteratorTypes;
450+
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
451+
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
452+
453+
iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
454+
: getParallelIteratorTypeName());
455+
if (axis != i)
456+
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
457+
}
458+
459+
bool didEncounterError = false;
460+
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
461+
auto linalgOp = rewriter.create<linalg::GenericOp>(
462+
loc, resultTy, input, filledTensor, maps, iteratorTypes,
463+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
464+
auto result = createLinalgBodyCalculationForReduceOp(
465+
op, blockArgs, elementTy, rewriter);
466+
if (result)
467+
didEncounterError = true;
468+
469+
nestedBuilder.create<linalg::YieldOp>(loc, result);
470+
});
471+
472+
if (!didEncounterError)
473+
return failure();
474+
475+
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
476+
return success();
477+
}
478+
333479
namespace {
334480

335481
template <typename SrcOp>
@@ -500,6 +646,17 @@ class IdentityNConverter : public OpRewritePattern<SrcOp> {
500646
}
501647
};
502648

649+
template <typename SrcOp>
650+
class ReduceConverter : public OpRewritePattern<SrcOp> {
651+
public:
652+
using OpRewritePattern<SrcOp>::OpRewritePattern;
653+
654+
LogicalResult matchAndRewrite(SrcOp reduceOp,
655+
PatternRewriter &rewriter) const final {
656+
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
657+
}
658+
};
659+
503660
} // namespace
504661

505662
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -521,6 +678,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
521678
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
522679
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
523680
IdentityNConverter<tosa::IdentityOp>,
524-
IdentityNConverter<tosa::IdentityNOp>,
525-
ReshapeOpConverter, TransposeConverter>(context);
681+
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
682+
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
683+
ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
684+
TransposeConverter>(context);
526685
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,101 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
335335
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
336336
return
337337
}
338+
339+
// -----
340+
341+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
342+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
343+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
344+
345+
// CHECK-LABEL: @reduce_float
346+
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
347+
func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
348+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
349+
// CHECK: [[CST0:%.+]] = constant 0.0
350+
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
351+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>)
352+
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
353+
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
354+
// CHECK: linalg.yield [[RES]] : f32
355+
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
356+
357+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
358+
// CHECK: [[CST0:%.+]] = constant 0.0
359+
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
360+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>)
361+
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
362+
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
363+
// CHECK: linalg.yield [[RES]] : f32
364+
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5xf32>
365+
366+
// CHECK: constant 1.0
367+
// CHECK: linalg.fill
368+
// CHECK: linalg.generic
369+
// CHECK: mulf
370+
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
371+
372+
// CHECK: constant 3.40282347E+38 : f32
373+
// CHECK: linalg.fill
374+
// CHECK: linalg.generic
375+
// CHECK: cmpf olt
376+
// CHECK: select
377+
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
378+
379+
// CHECK: constant -3.40282347E+38 : f32
380+
// CHECK: linalg.fill
381+
// CHECK: linalg.generic
382+
// CHECK: cmpf ogt
383+
// CHECK: select
384+
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
385+
return
386+
}
387+
388+
// -----
389+
390+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
391+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
392+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
393+
394+
// CHECK-LABEL: @reduce_int
395+
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
396+
func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
397+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
398+
// CHECK: [[CST0:%.+]] = constant 0
399+
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
400+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>)
401+
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
402+
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
403+
// CHECK: linalg.yield [[RES]] : i32
404+
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
405+
406+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
407+
// CHECK: [[CST0:%.+]] = constant 0
408+
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
409+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>)
410+
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
411+
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
412+
// CHECK: linalg.yield [[RES]] : i32
413+
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5xi32>
414+
415+
// CHECK: constant 1
416+
// CHECK: linalg.fill
417+
// CHECK: linalg.generic
418+
// CHECK: muli
419+
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
420+
421+
// CHECK: constant 2147483647 : i32
422+
// CHECK: linalg.fill
423+
// CHECK: linalg.generic
424+
// CHECK: cmpi slt
425+
// CHECK: select
426+
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
427+
428+
// CHECK: constant -2147483648 : i32
429+
// CHECK: linalg.fill
430+
// CHECK: linalg.generic
431+
// CHECK: cmpi sgt
432+
// CHECK: select
433+
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
434+
return
435+
}

0 commit comments

Comments
 (0)