-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Vectorize phi for loop carried @llvm.vector.reduce.fadd #78244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a14e251
910dc9f
05488f6
f63eac7
85b7410
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,9 @@ | |
#include "llvm/ADT/Statistic.h" | ||
#include "llvm/Analysis/ValueTracking.h" | ||
#include "llvm/CodeGen/TargetPassConfig.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/InstVisitor.h" | ||
#include "llvm/IR/PatternMatch.h" | ||
#include "llvm/IR/Intrinsics.h" | ||
#include "llvm/InitializePasses.h" | ||
#include "llvm/Pass.h" | ||
|
||
|
@@ -51,6 +52,7 @@ class RISCVCodeGenPrepare : public FunctionPass, | |
|
||
bool visitInstruction(Instruction &I) { return false; } | ||
bool visitAnd(BinaryOperator &BO); | ||
bool visitIntrinsicInst(IntrinsicInst &I); | ||
}; | ||
|
||
} // end anonymous namespace | ||
|
@@ -103,6 +105,62 @@ bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) { | |
return true; | ||
} | ||
|
||
// LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector | ||
// reduction instructions write the result in the first element of a vector | ||
// register. So when a reduction in a loop uses a scalar phi, we end up with | ||
// unnecessary scalar moves: | ||
// | ||
// loop: | ||
// vfmv.s.f v10, fa0 | ||
// vfredosum.vs v8, v8, v10 | ||
// vfmv.f.s fa0, v8 | ||
// | ||
// This mainly affects ordered fadd reductions, since other types of reduction | ||
// typically use element-wise vectorisation in the loop body. This tries to | ||
// vectorize any scalar phis that feed into a fadd reduction: | ||
// | ||
// loop: | ||
// %phi = phi <float> [ ..., %entry ], [ %acc, %loop ] | ||
// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %phi, <vscale x 2 x float> %vec) | ||
// | ||
// -> | ||
// | ||
// loop: | ||
// %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop ] | ||
// %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0 | ||
// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 2 x float> %vec) | ||
// %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0 | ||
// | ||
// Which eliminates the scalar -> vector -> scalar crossing during instruction | ||
// selection. | ||
bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) { | ||
if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd) | ||
return false; | ||
|
||
auto *PHI = dyn_cast<PHINode>(I.getOperand(0)); | ||
if (!PHI || !PHI->hasOneUse() || | ||
!llvm::is_contained(PHI->incoming_values(), &I)) | ||
return false; | ||
|
||
Type *VecTy = I.getOperand(1)->getType(); | ||
IRBuilder<> Builder(PHI); | ||
auto *VecPHI = Builder.CreatePHI(VecTy, PHI->getNumIncomingValues()); | ||
|
||
for (auto *BB : PHI->blocks()) { | ||
Builder.SetInsertPoint(BB->getTerminator()); | ||
Value *InsertElt = Builder.CreateInsertElement( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't look like we verified that we're in a loop. If the phi is just the merge of control flow from an if/else or something does this unnecessarily hoist an insertelement into the if/else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The initial revision of this PR checked that the reduction was one of the incoming values in phi: 05488f6 but I relaxed it to address the comment in #78244 (comment) I haven't included any tests for non-loop control flow though. @dtcxzyw should we maybe start with this patch restricted to loops and then relax it in a follow up patch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please revert to the loop specific version. Please also restrict this - in the initial patch - to the case where the only use of the phi is the intrinsic. Let's start with something vary narrowly focused on the motivating case and generalize in separate patches if desired. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, we can generalize it in the future if we find it exists in some real-world applications. |
||
VecTy, PHI->getIncomingValueForBlock(BB), (uint64_t)0); | ||
VecPHI->addIncoming(InsertElt, BB); | ||
} | ||
|
||
Builder.SetInsertPoint(&I); | ||
I.setOperand(0, Builder.CreateExtractElement(VecPHI, (uint64_t)0)); | ||
|
||
PHI->eraseFromParent(); | ||
|
||
return true; | ||
dtcxzyw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
bool RISCVCodeGenPrepare::runOnFunction(Function &F) { | ||
if (skipFunction(F)) | ||
return false; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 | ||
; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s | ||
|
||
declare i64 @llvm.vscale.i64() | ||
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>) | ||
|
||
define float @reduce_fadd(ptr %f) { | ||
; CHECK-LABEL: reduce_fadd: | ||
; CHECK: # %bb.0: # %entry | ||
; CHECK-NEXT: csrr a2, vlenb | ||
; CHECK-NEXT: srli a1, a2, 1 | ||
; CHECK-NEXT: vsetvli a3, zero, e32, m1, ta, ma | ||
; CHECK-NEXT: vmv.s.x v8, zero | ||
; CHECK-NEXT: slli a2, a2, 1 | ||
; CHECK-NEXT: li a3, 1024 | ||
; CHECK-NEXT: .LBB0_1: # %vector.body | ||
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 | ||
; CHECK-NEXT: vl2re32.v v10, (a0) | ||
; CHECK-NEXT: vsetvli a4, zero, e32, m2, ta, ma | ||
; CHECK-NEXT: vfredosum.vs v8, v10, v8 | ||
; CHECK-NEXT: sub a3, a3, a1 | ||
; CHECK-NEXT: add a0, a0, a2 | ||
; CHECK-NEXT: bnez a3, .LBB0_1 | ||
; CHECK-NEXT: # %bb.2: # %exit | ||
; CHECK-NEXT: vfmv.f.s fa0, v8 | ||
; CHECK-NEXT: ret | ||
entry: | ||
%vscale = tail call i64 @llvm.vscale.i64() | ||
%vecsize = shl nuw nsw i64 %vscale, 2 | ||
br label %vector.body | ||
|
||
vector.body: | ||
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] | ||
%vec.phi = phi float [ 0.000000e+00, %entry ], [ %acc, %vector.body ] | ||
%gep = getelementptr inbounds float, ptr %f, i64 %index | ||
%wide.load = load <vscale x 4 x float>, ptr %gep, align 4 | ||
%acc = tail call float @llvm.vector.reduce.fadd.nxv4f32(float %vec.phi, <vscale x 4 x float> %wide.load) | ||
%index.next = add nuw i64 %index, %vecsize | ||
%done = icmp eq i64 %index.next, 1024 | ||
br i1 %done, label %exit, label %vector.body | ||
|
||
exit: | ||
ret float %acc | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 | ||
; RUN: opt %s -S -riscv-codegenprepare -mtriple=riscv64 -mattr=+v | FileCheck %s | ||
|
||
declare i64 @llvm.vscale.i64() | ||
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>) | ||
|
||
define float @reduce_fadd(ptr %f) { | ||
; CHECK-LABEL: define float @reduce_fadd( | ||
; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR2:[0-9]+]] { | ||
; CHECK-NEXT: entry: | ||
; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64() | ||
; CHECK-NEXT: [[VECSIZE:%.*]] = shl nuw nsw i64 [[VSCALE]], 2 | ||
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] | ||
; CHECK: vector.body: | ||
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] | ||
; CHECK-NEXT: [[TMP0:%.*]] = phi <vscale x 4 x float> [ insertelement (<vscale x 4 x float> poison, float 0.000000e+00, i64 0), [[ENTRY]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ] | ||
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds float, ptr [[F]], i64 [[INDEX]] | ||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[GEP]], align 4 | ||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <vscale x 4 x float> [[TMP0]], i64 0 | ||
; CHECK-NEXT: [[ACC:%.*]] = tail call float @llvm.vector.reduce.fadd.nxv4f32(float [[TMP1]], <vscale x 4 x float> [[WIDE_LOAD]]) | ||
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VECSIZE]] | ||
; CHECK-NEXT: [[DONE:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 | ||
; CHECK-NEXT: [[TMP2]] = insertelement <vscale x 4 x float> poison, float [[ACC]], i64 0 | ||
; CHECK-NEXT: br i1 [[DONE]], label [[EXIT:%.*]], label [[VECTOR_BODY]] | ||
; CHECK: exit: | ||
; CHECK-NEXT: ret float [[ACC]] | ||
; | ||
|
||
entry: | ||
%vscale = tail call i64 @llvm.vscale.i64() | ||
%vecsize = shl nuw nsw i64 %vscale, 2 | ||
br label %vector.body | ||
|
||
vector.body: | ||
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] | ||
%vec.phi = phi float [ 0.000000e+00, %entry ], [ %acc, %vector.body ] | ||
%gep = getelementptr inbounds float, ptr %f, i64 %index | ||
%wide.load = load <vscale x 4 x float>, ptr %gep, align 4 | ||
%acc = tail call float @llvm.vector.reduce.fadd.nxv4f32(float %vec.phi, <vscale x 4 x float> %wide.load) | ||
%index.next = add nuw i64 %index, %vecsize | ||
%done = icmp eq i64 %index.next, 1024 | ||
br i1 %done, label %exit, label %vector.body | ||
|
||
exit: | ||
ret float %acc | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for the record... This isn't checking that the phi is in a loop. It's checking it is in a cycle, which is not quite the same thing. In particular, we allow both non-loop cycles, and loops in unreachable code which aren't "loops". Not a problem here, just pointing it out for future reference.