Skip to content

Commit b537df9

Browse files
authored
[mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V (#101547)
Resolves #87050.
1 parent 293df8a commit b537df9

File tree

3 files changed

+65
-7
lines changed

3 files changed

+65
-7
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
807807
// TypeCastingOp
808808
//===----------------------------------------------------------------------===//
809809

810+
static std::optional<spirv::FPRoundingMode>
811+
convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812+
switch (roundingMode) {
813+
case arith::RoundingMode::downward:
814+
return spirv::FPRoundingMode::RTN;
815+
case arith::RoundingMode::to_nearest_even:
816+
return spirv::FPRoundingMode::RTE;
817+
case arith::RoundingMode::toward_zero:
818+
return spirv::FPRoundingMode::RTZ;
819+
case arith::RoundingMode::upward:
820+
return spirv::FPRoundingMode::RTP;
821+
case arith::RoundingMode::to_nearest_away:
822+
// SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823+
// (as of SPIR-V 1.6)
824+
return std::nullopt;
825+
}
826+
llvm_unreachable("Unhandled rounding mode");
827+
}
828+
810829
/// Converts type-casting standard operations to SPIR-V operations.
811830
template <typename Op, typename SPIRVOp>
812831
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,22 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
829848
// Then we can just erase this operation by forwarding its operand.
830849
rewriter.replaceOp(op, adaptor.getOperands().front());
831850
} else {
832-
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
833-
adaptor.getOperands());
851+
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
852+
op, dstType, adaptor.getOperands());
834853
if (auto roundingModeOp =
835854
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
836855
if (arith::RoundingModeAttr roundingMode =
837856
roundingModeOp.getRoundingModeAttr()) {
838-
// TODO: Perform rounding mode attribute conversion and attach to new
839-
// operation when defined in the dialect.
840-
return failure();
857+
if (auto rm =
858+
convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
859+
newOp->setAttr(
860+
getDecorationString(spirv::Decoration::FPRoundingMode),
861+
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
862+
} else {
863+
return rewriter.notifyMatchFailure(
864+
op->getLoc(),
865+
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
866+
}
841867
}
842868
}
843869
}

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s
22

3+
///===----------------------------------------------------------------------===//
4+
// Cast ops
5+
//===----------------------------------------------------------------------===//
6+
7+
module attributes {
8+
spirv.target_env = #spirv.target_env<
9+
#spirv.vce<v1.0, [Float16, Kernel], []>, #spirv.resource_limits<>>
10+
} {
11+
12+
func.func @experimental_constrained_fptrunc(%arg0 : f32) {
13+
// expected-error@+1 {{failed to legalize operation 'arith.truncf'}}
14+
%3 = arith.truncf %arg0 to_nearest_away : f32 to f16
15+
return
16+
}
17+
18+
} // end module
19+
320
///===----------------------------------------------------------------------===//
421
// Binary ops
522
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
221221
// -----
222222

223223
//===----------------------------------------------------------------------===//
224-
// std bit ops
224+
// Bit ops
225225
//===----------------------------------------------------------------------===//
226226

227227
module attributes {
@@ -653,7 +653,7 @@ func.func @corner_cases() {
653653
// -----
654654

655655
//===----------------------------------------------------------------------===//
656-
// std cast ops
656+
// Cast ops
657657
//===----------------------------------------------------------------------===//
658658

659659
module attributes {
@@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
754754
return %0 : f16
755755
}
756756

757+
758+
// CHECK-LABEL: @experimental_constrained_fptrunc
759+
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
760+
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
761+
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
762+
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
763+
%1 = arith.truncf %arg0 downward : f64 to f32
764+
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
765+
%2 = arith.truncf %arg0 upward : f64 to f32
766+
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
767+
%3 = arith.truncf %arg0 toward_zero : f64 to f32
768+
return
769+
}
770+
771+
757772
// CHECK-LABEL: @sitofp1
758773
func.func @sitofp1(%arg0 : i32) -> f32 {
759774
// CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32

0 commit comments

Comments
 (0)