Skip to content

[NVPTX][NFC] Minor cleanup in NVPTXInstrInfo.td #138006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
switch (MI.getOpcode()) {
case NVPTX::CallUniPrintCallRetInst1:
case NVPTX::CallArgBeginInst:
case NVPTX::CallArgI32imm:
case NVPTX::CallArgParam:
case NVPTX::LastCallArgI32imm:
case NVPTX::LastCallArgParam:
case NVPTX::CallArgEndInst1:
return true;
Expand Down
124 changes: 40 additions & 84 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1330,58 +1330,46 @@ def FDIV32ri_prec :
// FMA
//

multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
multiclass FMA<string OpcStr, RegTyInfo t, list<Predicate> Preds = []> {
defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
def rrr : NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
asmstr,
[(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
Requires<[Pred]>;
def rri : NVPTXInst<(outs RC:$dst),
(ins RC:$a, RC:$b, ImmCls:$c),
asmstr,
[(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
Requires<[Pred]>;
def rir : NVPTXInst<(outs RC:$dst),
(ins RC:$a, ImmCls:$b, RC:$c),
asmstr,
[(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
Requires<[Pred]>;
def rii : NVPTXInst<(outs RC:$dst),
(ins RC:$a, ImmCls:$b, ImmCls:$c),
asmstr,
[(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
Requires<[Pred]>;
def iir : NVPTXInst<(outs RC:$dst),
(ins ImmCls:$a, ImmCls:$b, RC:$c),
asmstr,
[(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
Requires<[Pred]>;

}

multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
Requires<[useFP16Math, Pred]>;
}

multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
Requires<[hasBF16Math, Pred]>;
[(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
Requires<Preds>;

if t.SupportsImm then {
def rri : NVPTXInst<(outs t.RC:$dst),
(ins t.RC:$a, t.RC:$b, t.Imm:$c),
asmstr,
[(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
Requires<Preds>;
def rir : NVPTXInst<(outs t.RC:$dst),
(ins t.RC:$a, t.Imm:$b, t.RC:$c),
asmstr,
[(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
Requires<Preds>;
def rii : NVPTXInst<(outs t.RC:$dst),
(ins t.RC:$a, t.Imm:$b, t.Imm:$c),
asmstr,
[(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
Requires<Preds>;
def iir : NVPTXInst<(outs t.RC:$dst),
(ins t.Imm:$a, t.Imm:$b, t.RC:$c),
asmstr,
[(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
Requires<Preds>;
}
}

defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
defm FMA16_ftz : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>;
defm FMA16 : FMA<"fma.rn.f16", F16RT, [useFP16Math]>;
defm FMA16x2_ftz : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>;
defm FMA16x2 : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>;
defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
defm FMA64 : FMA<"fma.rn.f64", F64RT>;

// sin/cos

Expand Down Expand Up @@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
Requires<[doF32FTZ]>;
def : Pat<(i1 (OpNode f32:$a, f32:$b)),
(SETP_f32rr $a, $b, Mode)>;
def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)),
def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
(SETP_f32ri $a, fpimm:$b, ModeFTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
Expand Down Expand Up @@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
Expand Down Expand Up @@ -2352,42 +2340,10 @@ def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
def CallArgEndInst0 : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>;
def RETURNInst : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>;

class CallArgInst<NVPTXRegClass regclass> :
NVPTXInst<(outs), (ins regclass:$a), "$a, ",
[(CallArg (i32 0), regclass:$a)]>;

class CallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$a), "$a, ",
[(CallArg (i32 0), vt:$a)]>;

class LastCallArgInst<NVPTXRegClass regclass> :
NVPTXInst<(outs), (ins regclass:$a), "$a",
[(LastCallArg (i32 0), regclass:$a)]>;
class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$a), "$a",
[(LastCallArg (i32 0), vt:$a)]>;

def CallArgI64 : CallArgInst<Int64Regs>;
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
def CallArgF64 : CallArgInst<Float64Regs>;
def CallArgF32 : CallArgInst<Float32Regs>;

def LastCallArgI64 : LastCallArgInst<Int64Regs>;
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
def LastCallArgF32 : LastCallArgInst<Float32Regs>;

def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ",
[(CallArg (i32 0), (i32 imm:$a))]>;
def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a",
[(LastCallArg (i32 0), (i32 imm:$a))]>;

def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
[(CallArg (i32 1), (i32 imm:$a))]>;
[(CallArg 1, imm:$a)]>;
def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
[(LastCallArg (i32 1), (i32 imm:$a))]>;
[(LastCallArg 1, imm:$a)]>;

def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
[(CallVoid (Wrapper tglobaladdr:$addr))]>;
Expand Down