-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. #143540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. #143540
Conversation
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupport across all element types.
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesThe omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupported across all element types. A future extension is to use BFDOT for add reductions when available, especially for the nxv8bf16 case. Full diff: https://github.com/llvm/llvm-project/pull/143540.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 4a1cd642233ef..1fc5fc66c56e5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -188,6 +188,7 @@ class VectorLegalizer {
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+ void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
public:
VectorLegalizer(SelectionDAG& dag) :
@@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
- case ISD::VECREDUCE_FMUL:
- case ISD::VECTOR_FIND_LAST_ACTIVE:
- Action = TLI.getOperationAction(Node->getOpcode(),
- Node->getOperand(0).getValueType());
- break;
case ISD::VECREDUCE_FMAX:
- case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
+ case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
+ case ISD::VECREDUCE_FMUL:
+ case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
- // Defer non-vector results to LegalizeDAG.
- if (Action == TargetLowering::Promote)
- Action = TargetLowering::Legal;
break;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
Results.push_back(Round.getValue(1));
}
+void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
+ SmallVectorImpl<SDValue> &Results) {
+ MVT OpVT = Node->getOperand(0).getSimpleValueType();
+ assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
+ MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
+
+ SDLoc DL(Node);
+ SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
+ SDValue Rdx =
+ DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
+ Node->getFlags());
+ SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ Results.push_back(Res);
+}
+
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// For a few operations there is a specific concept for promotion based on
// the operand's type.
@@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::STRICT_FMA:
PromoteSTRICT(Node, Results);
return;
+ case ISD::VECREDUCE_FADD:
+ case ISD::VECREDUCE_FMAX:
+ case ISD::VECREDUCE_FMAXIMUM:
+ case ISD::VECREDUCE_FMIN:
+ case ISD::VECREDUCE_FMINIMUM:
+ PromoteVECREDUCE(Node, Results);
+ return;
case ISD::FP_ROUND:
case ISD::FP_EXTEND:
// These operations are used to do promotion so they can't be promoted
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a0ffb4b6d5a4c..0d23666383cda 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
SDValue Op = Node->getOperand(0);
EVT VT = Op.getValueType();
- if (VT.isScalableVector())
- report_fatal_error(
- "Expanding reductions for scalable vectors is undefined.");
-
// Try to use a shuffle reduction for power of two vectors.
if (VT.isPow2VectorType()) {
- while (VT.getVectorNumElements() > 1) {
+ while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
break;
@@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
VT = HalfVT;
+
+ // Stop if splitting is enough to make the reduction legal.
+ if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
+ return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
+ Node->getFlags());
}
}
+ if (VT.isScalableVector())
+ report_fatal_error(
+ "Expanding reductions for scalable vectors is undefined.");
+
EVT EltVT = VT.getVectorElementType();
unsigned NumElts = VT.getVectorNumElements();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index caac00c5b2faa..9322f615827d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
- ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
+ ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
+ ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
+ ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
new file mode 100644
index 0000000000000..eb462c780437f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -0,0 +1,235 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; FADDV
+
+define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fadd z0.s, z0.s, z1.s
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMAXNMV
+
+define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMINNMV
+
+define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMAXV
+
+define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMINV
+
+define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>)
|
if (VT.isScalableVector()) | ||
report_fatal_error( | ||
"Expanding reductions for scalable vectors is undefined."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (VT.isScalableVector()) | |
report_fatal_error( | |
"Expanding reductions for scalable vectors is undefined."); | |
if (VT.isScalableVector()) { | |
reportFatalInternalError( | |
"expanding reductions for scalable vectors is undefined"); | |
} |
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT); | ||
|
||
SDLoc DL(Node); | ||
SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the strictfp case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does strictfp apply here? There are no STRICT_VECREDUCE_ nodes. There are the ordered VECREDUCE_SEQ_ nodes, but they go down a different path so are not covered by this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This pushes on the broader issue with the current strictfp strategy where we need to duplicate all possible FP intrinsics and we're missing them here.
So for these non-strict intrinsics it's not an issue, the issue is we don't have strict versions of these
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node, | |||
Results.push_back(Round.getValue(1)); | |||
} | |||
|
|||
void VectorLegalizer::PromoteVECREDUCE(SDNode *Node, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't clearly indicate this is the for FP operations
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp, | ||
Node->getFlags()); | ||
SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx, | ||
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use 1 here? We are converting back so it should be exactly representable, but the other cases seem to not do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think we can for MIN/MAX but not for ADD because those cases might legitimately round to infinity?
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupported across all element types.
A future extension is to use BFDOT for add reductions when available, especially for the nxv8bf16 case.