Skip to content

[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering #125668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 201 additions & 10 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,47 @@
#include "llvm/ADT/Sequence.h"

#include <numeric>
#include <type_traits>

using namespace mlir;
using namespace mlir::tosa;

// Helper function to materialize the semantically correct compare and select
// operations given a binary operation with a specific NaN propagation mode.
//
// In the case of "PROPAGATE" semantics no compare and selection is required and
// this function does nothing.
//
// In the case of "IGNORE" semantics this function materializes a comparison of
// the current operands to the op which will return true for any NaN
// argument and then selects between the non-NaN operation argument and the
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
// code:
//
// binary<op>(lhs, rhs):
// result = op(lhs, rhs)
// if lhs == NaN return rhs
// if rhs == NaN return lhs
// return result
template <typename OpTy>
static Value
materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
Value lhs, Value rhs, Value result) {
auto nanMode = op.getNanMode();
if (nanMode == "PROPAGATE")
return result;

// Unordered comparison of NaN against itself will always return true.
Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
Value rhsOrResult =
rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
rhsOrResult);
}

template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
Expand Down Expand Up @@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(

// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
rewriter, args[0], args[1], max);
}

if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
Expand All @@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(

// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
rewriter, args[0], args[1], min);
}

if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
Expand Down Expand Up @@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
return clampFloatHelper(loc, args[0], min, max, rewriter);
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);

auto clampOp = llvm::cast<tosa::ClampOp>(op);
const auto nanMode = clampOp.getNanMode();
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
if (nanMode == "PROPAGATE")
return result;

// In the case of "IGNORE" semantics materialize a comparison
// of the current operand to the reduction which will return true for a NaN
// argument and then selects between the initial reduction value and the
// calculated result based on whether the argument is NaN or not. In pseudo
// code:
//
// reduce<op>(x, init):
// result = op(init, x)
// return init if x == NaN else result

// Unordered comparison of NaN against itself will always return true.
Value isNaN = rewriter.create<arith::CmpFOp>(
op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
// is NaN.
return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
}

if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
Expand Down Expand Up @@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
template <typename OpTy>
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
Expand All @@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
}
}

SmallVector<Value> inputs, outputs;
inputs.push_back(input);

// First fill the output buffer with the init value.
auto emptyTensor =
rewriter
Expand All @@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
.result();
outputs.push_back(filledTensor);

bool isNanIgnoreMode = false;
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
if (op.getNanMode() == "IGNORE") {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
// Additionally we have to keep track of whether we've seen any non-NaN
// values and then do a final select based on this predicate.
auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
auto emptyBoolTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
dynDims)
.getResult();
auto allResultsNaNTensor =
rewriter
.create<linalg::FillOp>(loc, ValueRange{trueValue},
ValueRange{emptyBoolTensor})
.result();
// Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
// need to have the same number of inputs and outputs.
//
// The second input isn't actually used anywhere since the value used to
// update the NaN flag is calculated inside the body of the reduction and
// then used to update an out value.
// In order to satisfy type constraints we just pass another copy of the
// input here.
inputs.push_back(input);
outputs.push_back(allResultsNaNTensor);
}
}

bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::ReduceOp>(
loc, input, filledTensor, axis,
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, blockArgs, elementTy, rewriter);
op, binaryArgs, elementTy, rewriter);
if (result)
didEncounterError = true;

nestedBuilder.create<linalg::YieldOp>(loc, result);
SmallVector<Value> resultsToYield;
if (isNanIgnoreMode) {
auto inputValue = blockArgs[0];
auto initialValue = blockArgs[2];
auto oldAllResultsNanFlagValue = blockArgs[3];

// Unordered comparison of NaN against itself will always return true.
Value isNaN = nestedBuilder.create<arith::CmpFOp>(
op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
// If we've encountered a NaN, take the non-NaN value.
auto selectOp = nestedBuilder.create<arith::SelectOp>(
op->getLoc(), isNaN, initialValue, result);
// Update the flag which keeps track of whether we have seen a non-NaN
// value.
auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
op->getLoc(), oldAllResultsNanFlagValue, isNaN);
resultsToYield.push_back(selectOp);
resultsToYield.push_back(newAllResultsNanFlagValue);
} else {
resultsToYield.push_back(result);
}
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
});

if (!didEncounterError)
return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");

if (isNanIgnoreMode) {
// Materialize a check to see whether we encountered any non-NaN values, if
// we didn't we need to select a tensor of NaNs since the result will just
// be the initial identity value propagated through all the compares and
// selects inside the reduction.

// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
auto emptyNanTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();
auto nanFilledTensor =
rewriter
.create<linalg::FillOp>(loc, ValueRange{nanValue},
ValueRange{emptyNanTensor})
.result();

// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();

// Do a selection between the tensors akin to:
// result = NaN if "all results NaN" else result.
SmallVector<Value> ins, outs;
ins.push_back(linalgOp->getOpResult(1));
ins.push_back(nanFilledTensor);
ins.push_back(linalgOp->getResult(0));
outs.push_back(finalEmptyTensor);
auto linalgSelect =
rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
linalgOp = linalgSelect;
}

SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
reassociationMap.resize(expandInputRank);

for (uint64_t i = 0; i < expandInputRank; i++) {
Expand All @@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
// not have access to such information. This matters when handling dynamically
// sized tensors.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultTy, linalgOp.getResults()[0], reassociationMap);
op, resultTy, linalgOp->getResults()[0], reassociationMap);
return success();
}

Expand Down Expand Up @@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
nestedLoc, predicate, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);

// Check if we need to materialize compare and select for the given
// NaN propagation mode.

// "PROPAGATE" matches the default NaN propagation mode of the arith
// dialect so no compare and select is required.
//
// In the case "IGNORE" we check if the current argument is NaN and
// select the old index and value otherwise take the updated index and
// value.
if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
// Unordered comparison of NaN against itself will always return
// true.
Value isNaN = rewriter.create<arith::CmpFOp>(
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
newValue);
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
oldValue, resultMax);
resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, isNaN, oldIndex, resultIndex);
}
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
Expand Down
41 changes: 37 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,44 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
filledEmptyTensor, strideAttr, dilationAttr);
} else {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
filledEmptyTensor, strideAttr, dilationAttr);
return llvm::success();
}

auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
op->getLoc(), ArrayRef<Type>{resultTy},
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
dilationAttr);

rewriter.replaceOp(op, resultOp);
// "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
// compare and select materialization is required.
//
// In the case of "IGNORE" we need to insert a compare and select. Since
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
auto genericOp = rewriter.create<linalg::GenericOp>(
op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
resultOp.getIteratorTypesArray(),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
IRMapping map;
auto oldBlock = resultOp.getRegion().begin();
auto oldArgs = oldBlock->getArguments();
auto &oldMaxOp = *resultOp.getBlock()->begin();
map.map(oldArgs, blockArgs);
auto *newOp = opBuilder.clone(oldMaxOp, map);
Value isNaN = opBuilder.create<arith::CmpFOp>(
op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
blockArgs.front());
auto selectOp = opBuilder.create<arith::SelectOp>(
op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
});
rewriter.replaceOp(resultOp, genericOp);
}

return success();
}
};
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -906,3 +906,27 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return
}

// -----

// CHECK-LABEL: @max_pool2d_nan_propagate
func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
// CHECK: linalg.pooling_nhwc_max
// CHECK-NOT: linalg.generic
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0 : tensor<1x4x32x62xf32>
}

// -----

// CHECK-LABEL: @max_pool2d_nan_ignore
func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
// CHECK-NOT: linalg.pooling_nhwc_max
// CHECK: linalg.generic
// CHECK: arith.maximumf
// CHECK: arith.cmpf uno
// CHECK: arith.select
// CHECK: linalg.yield
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0: tensor<1x4x32x62xf32>
}
Loading