@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
807
807
// TypeCastingOp
808
808
// ===----------------------------------------------------------------------===//
809
809
810
+ static std::optional<spirv::FPRoundingMode>
811
+ convertArithRoundingModeToSPIRV (arith::RoundingMode roundingMode) {
812
+ switch (roundingMode) {
813
+ case arith::RoundingMode::downward:
814
+ return spirv::FPRoundingMode::RTN;
815
+ case arith::RoundingMode::to_nearest_even:
816
+ return spirv::FPRoundingMode::RTE;
817
+ case arith::RoundingMode::toward_zero:
818
+ return spirv::FPRoundingMode::RTZ;
819
+ case arith::RoundingMode::upward:
820
+ return spirv::FPRoundingMode::RTP;
821
+ case arith::RoundingMode::to_nearest_away:
822
+ // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823
+ // (as of SPIR-V 1.6)
824
+ return std::nullopt;
825
+ }
826
+ llvm_unreachable (" Unhandled rounding mode" );
827
+ }
828
+
810
829
// / Converts type-casting standard operations to SPIR-V operations.
811
830
template <typename Op, typename SPIRVOp>
812
831
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,22 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
829
848
// Then we can just erase this operation by forwarding its operand.
830
849
rewriter.replaceOp (op, adaptor.getOperands ().front ());
831
850
} else {
832
- rewriter.template replaceOpWithNewOp <SPIRVOp>(op, dstType,
833
- adaptor.getOperands ());
851
+ auto newOp = rewriter.template replaceOpWithNewOp <SPIRVOp>(
852
+ op, dstType, adaptor.getOperands ());
834
853
if (auto roundingModeOp =
835
854
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
836
855
if (arith::RoundingModeAttr roundingMode =
837
856
roundingModeOp.getRoundingModeAttr ()) {
838
- // TODO: Perform rounding mode attribute conversion and attach to new
839
- // operation when defined in the dialect.
840
- return failure ();
857
+ if (auto rm =
858
+ convertArithRoundingModeToSPIRV (roundingMode.getValue ())) {
859
+ newOp->setAttr (
860
+ getDecorationString (spirv::Decoration::FPRoundingMode),
861
+ spirv::FPRoundingModeAttr::get (rewriter.getContext (), *rm));
862
+ } else {
863
+ return rewriter.notifyMatchFailure (
864
+ op->getLoc (),
865
+ llvm::formatv (" unsupported rounding mode '{0}'" , roundingMode));
866
+ }
841
867
}
842
868
}
843
869
}
0 commit comments