Skip to content

[mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V #101547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
// TypeCastingOp
//===----------------------------------------------------------------------===//

static std::optional<spirv::FPRoundingMode>
convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
switch (roundingMode) {
case arith::RoundingMode::downward:
return spirv::FPRoundingMode::RTN;
case arith::RoundingMode::to_nearest_even:
return spirv::FPRoundingMode::RTE;
case arith::RoundingMode::toward_zero:
return spirv::FPRoundingMode::RTZ;
case arith::RoundingMode::upward:
return spirv::FPRoundingMode::RTP;
case arith::RoundingMode::to_nearest_away:
// SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
// (as of SPIR-V 1.6)
return std::nullopt;
}
llvm_unreachable("Unhandled rounding mode");
}

/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
Expand All @@ -829,15 +848,22 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
op, dstType, adaptor.getOperands());
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
// TODO: Perform rounding mode attribute conversion and attach to new
// operation when defined in the dialect.
return failure();
if (auto rm =
convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
newOp->setAttr(
getDecorationString(spirv::Decoration::FPRoundingMode),
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
} else {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s

///===----------------------------------------------------------------------===//
// Cast ops
//===----------------------------------------------------------------------===//

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Float16, Kernel], []>, #spirv.resource_limits<>>
} {

func.func @experimental_constrained_fptrunc(%arg0 : f32) {
// expected-error@+1 {{failed to legalize operation 'arith.truncf'}}
%3 = arith.truncf %arg0 to_nearest_away : f32 to f16
return
}

} // end module

///===----------------------------------------------------------------------===//
// Binary ops
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 17 additions & 2 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
// -----

//===----------------------------------------------------------------------===//
// std bit ops
// Bit ops
//===----------------------------------------------------------------------===//

module attributes {
Expand Down Expand Up @@ -653,7 +653,7 @@ func.func @corner_cases() {
// -----

//===----------------------------------------------------------------------===//
// std cast ops
// Cast ops
//===----------------------------------------------------------------------===//

module attributes {
Expand Down Expand Up @@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
return %0 : f16
}


// CHECK-LABEL: @experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
%1 = arith.truncf %arg0 downward : f64 to f32
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
%2 = arith.truncf %arg0 upward : f64 to f32
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
%3 = arith.truncf %arg0 toward_zero : f64 to f32
return
}


// CHECK-LABEL: @sitofp1
func.func @sitofp1(%arg0 : i32) -> f32 {
// CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
Expand Down
Loading