Skip to content

[NVPTX] Remove NVPTX::IMAD opcode, and rely on intruction selection only #121724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::StoreV4)
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
MAKE_CASE(NVPTXISD::IMAD)
MAKE_CASE(NVPTXISD::BFE)
MAKE_CASE(NVPTXISD::BFI)
MAKE_CASE(NVPTXISD::PRMT)
Expand Down Expand Up @@ -4451,14 +4450,8 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
if (!N0.getNode()->hasOneUse())
return SDValue();

// fold (add (mul a, b), c) -> (mad a, b, c)
//
if (N0.getOpcode() == ISD::MUL)
return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
N0.getOperand(1), N1);

// fold (add (select cond, 0, (mul a, b)), c)
// -> (select cond, c, (mad a, b, c))
// -> (select cond, c, (add (mul a, b), c))
//
if (N0.getOpcode() == ISD::SELECT) {
unsigned ZeroOpNum;
Expand All @@ -4473,8 +4466,10 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
return SDValue();

SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
M->getOperand(0), M->getOperand(1), N1);
SDLoc DL(N);
SDValue Mul =
DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
((ZeroOpNum == 1) ? N1 : MAD),
((ZeroOpNum == 1) ? MAD : N1));
Expand Down Expand Up @@ -4911,8 +4906,10 @@ static SDValue matchMADConstOnePattern(SDValue Add) {
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
TargetLowering::DAGCombinerInfo &DCI) {

if (SDValue Y = matchMADConstOnePattern(Add))
return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
if (SDValue Y = matchMADConstOnePattern(Add)) {
SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
}

return SDValue();
}
Expand Down Expand Up @@ -4959,7 +4956,7 @@ PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,

SDLoc DL(N);

// (mul x, (add y, 1)) -> (mad x, y, x)
// (mul x, (add y, 1)) -> (add (mul x, y), x)
if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
return Res;
if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ enum NodeType : unsigned {
FSHR_CLAMP,
MUL_WIDE_SIGNED,
MUL_WIDE_UNSIGNED,
IMAD,
SETP_F16X2,
SETP_BF16X2,
BFE,
Expand Down
101 changes: 34 additions & 67 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def hasLDG : Predicate<"Subtarget->hasLDG()">;
def hasLDU : Predicate<"Subtarget->hasLDU()">;
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;
def hasOptEnabled : Predicate<"TM.getOptLevel() != CodeGenOptLevel::None">;

def doF32FTZ : Predicate<"useF32FTZ()">;
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
Expand Down Expand Up @@ -1092,73 +1093,39 @@ def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
//
// Integer multiply-add
//
def SDTIMAD :
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>,
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
def imad : SDNode<"NVPTXISD::IMAD", SDTIMAD>;

def MAD16rrr :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set i16:$dst, (imad i16:$a, i16:$b, i16:$c))]>;
def MAD16rri :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set i16:$dst, (imad i16:$a, i16:$b, imm:$c))]>;
def MAD16rir :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set i16:$dst, (imad i16:$a, imm:$b, i16:$c))]>;
def MAD16rii :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, i16imm:$b, i16imm:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set i16:$dst, (imad i16:$a, imm:$b, imm:$c))]>;

def MAD32rrr :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set i32:$dst, (imad i32:$a, i32:$b, i32:$c))]>;
def MAD32rri :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set i32:$dst, (imad i32:$a, i32:$b, imm:$c))]>;
def MAD32rir :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set i32:$dst, (imad i32:$a, imm:$b, i32:$c))]>;
def MAD32rii :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, i32imm:$b, i32imm:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set i32:$dst, (imad i32:$a, imm:$b, imm:$c))]>;

def MAD64rrr :
NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set i64:$dst, (imad i64:$a, i64:$b, i64:$c))]>;
def MAD64rri :
NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set i64:$dst, (imad i64:$a, i64:$b, imm:$c))]>;
def MAD64rir :
NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set i64:$dst, (imad i64:$a, imm:$b, i64:$c))]>;
def MAD64rii :
NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, i64imm:$b, i64imm:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set i64:$dst, (imad i64:$a, imm:$b, imm:$c))]>;
def mul_oneuse : PatFrag<(ops node:$a, node:$b), (mul node:$a, node:$b), [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RISC-V has a generalized form of one-use pattern:

class binop_oneuse<SDPatternOperator operator>
: PatFrag<(ops node:$A, node:$B),
(operator node:$A, node:$B), [{
return N->hasOneUse();
}]>;
def and_oneuse : binop_oneuse<and>;
def mul_oneuse : binop_oneuse<mul>;

It may be something worth extracting into a common tablegen file. We have quite a few uses of hasOneUse() in the backends. Could be in a separate patch.

return N->hasOneUse();
}]>;

multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
def rrr:
NVPTXInst<(outs Reg:$dst),
(ins Reg:$a, Reg:$b, Reg:$c),
Ptx # " \t$dst, $a, $b, $c;",
[(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>;

def rir:
NVPTXInst<(outs Reg:$dst),
(ins Reg:$a, Imm:$b, Reg:$c),
Ptx # " \t$dst, $a, $b, $c;",
[(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>;
def rri:
NVPTXInst<(outs Reg:$dst),
(ins Reg:$a, Reg:$b, Imm:$c),
Ptx # " \t$dst, $a, $b, $c;",
[(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>;
def rii:
NVPTXInst<(outs Reg:$dst),
(ins Reg:$a, Imm:$b, Imm:$c),
Ptx # " \t$dst, $a, $b, $c;",
[(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>;
}

let Predicates = [hasOptEnabled] in {
defm MAD16 : MAD<"mad.lo.s16", i16, Int16Regs, i16imm>;
defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
}

def INEG16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
Expand Down
55 changes: 55 additions & 0 deletions llvm/test/CodeGen/NVPTX/combine-mad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,58 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
%add = add i32 %c, %sel
ret i32 %add
}

declare i32 @use(i32 %0, i32 %1)

define i32 @test_mad_multi_use(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_mad_multi_use(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<8>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_mad_multi_use_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_mad_multi_use_param_1];
; CHECK-NEXT: mul.lo.s32 %r3, %r1, %r2;
; CHECK-NEXT: ld.param.u32 %r4, [test_mad_multi_use_param_2];
; CHECK-NEXT: add.s32 %r5, %r3, %r4;
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .b32 param0;
; CHECK-NEXT: st.param.b32 [param0], %r3;
; CHECK-NEXT: .param .b32 param1;
; CHECK-NEXT: st.param.b32 [param1], %r5;
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: call.uni (retval0),
; CHECK-NEXT: use,
; CHECK-NEXT: (
; CHECK-NEXT: param0,
; CHECK-NEXT: param1
; CHECK-NEXT: );
; CHECK-NEXT: ld.param.b32 %r6, [retval0];
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
%mul = mul i32 %a, %b
%add = add i32 %mul, %c
%res = call i32 @use(i32 %mul, i32 %add)
ret i32 %res
}

;; This case relies on mad x 1 y => add x y, previously we emit:
;; mad.lo.s32 %r3, %r1, 1, %r2;
define i32 @test_mad_fold(i32 %x) {
; CHECK-LABEL: test_mad_fold(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_param_0];
; CHECK-NEXT: mul.hi.s32 %r2, %r1, -2147221471;
; CHECK-NEXT: add.s32 %r3, %r2, %r1;
; CHECK-NEXT: shr.u32 %r4, %r3, 31;
; CHECK-NEXT: shr.s32 %r5, %r3, 12;
; CHECK-NEXT: add.s32 %r6, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
%div = sdiv i32 %x, 8191
ret i32 %div
}
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
; CHECK-NOT: __local_depot

; CHECK-32: ld.param.u32 %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
; CHECK-32-NEXT: mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
; CHECK-32-NEXT: add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
; CHECK-32-NEXT: and.b32 %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
; CHECK-32-NEXT: alloca.u32 %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
; CHECK-32-NEXT: cvta.local.u32 %r[[ALLOCA]], %r[[ALLOCA]];
Expand Down
Loading
Loading