Skip to content

Commit 4bd5a28

Browse files
committed
[NVPTX] support immediate values in st.param instructions
1 parent 58c7785 commit 4bd5a28

File tree

3 files changed

+768
-54
lines changed

3 files changed

+768
-54
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 113 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,84 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
21822182
return true;
21832183
}
21842184

2185+
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2186+
#define getOpcV2H(ty, op0, op1) NVPTX::StoreParamV2##ty##_##op0##op1
2187+
2188+
#define getOpcV2H1(ty, op0, op1) \
2189+
(op1) ? getOpcV2H(ty, op0, i) : getOpcV2H(ty, op0, r)
2190+
2191+
#define getOpcodeForVectorStParamV2(ty, isimm) \
2192+
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
2193+
2194+
#define getOpcV4H(ty, op0, op1, op2, op3) \
2195+
NVPTX::StoreParamV4##ty##_##op0##op1##op2##op3
2196+
2197+
#define getOpcV4H3(ty, op0, op1, op2, op3) \
2198+
(op3) ? getOpcV4H(ty, op0, op1, op2, i) : getOpcV4H(ty, op0, op1, op2, r)
2199+
2200+
#define getOpcV4H2(ty, op0, op1, op2, op3) \
2201+
(op2) ? getOpcV4H3(ty, op0, op1, i, op3) : getOpcV4H3(ty, op0, op1, r, op3)
2202+
2203+
#define getOpcV4H1(ty, op0, op1, op2, op3) \
2204+
(op1) ? getOpcV4H2(ty, op0, i, op2, op3) : getOpcV4H2(ty, op0, r, op2, op3)
2205+
2206+
#define getOpcodeForVectorStParamV4(ty, isimm) \
2207+
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
2208+
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
2209+
2210+
#define getOpcodeForVectorStParam(n, ty, isimm) \
2211+
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
2212+
: getOpcodeForVectorStParamV4(ty, isimm)
2213+
2214+
static std::optional<unsigned>
2215+
pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops, unsigned NumElts,
2216+
MVT::SimpleValueType MemTy, SelectionDAG *CurDAG,
2217+
SDLoc DL) {
2218+
// Determine which inputs are registers and immediates make new operators
2219+
// with constant values
2220+
SmallVector<bool, 4> IsImm(NumElts, false);
2221+
for (unsigned i = 0; i < NumElts; i++) {
2222+
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
2223+
if (IsImm[i]) {
2224+
SDValue Imm = Ops[i];
2225+
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2226+
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2227+
const ConstantFP *CF = ConstImm->getConstantFPValue();
2228+
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2229+
} else {
2230+
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2231+
const ConstantInt *CI = ConstImm->getConstantIntValue();
2232+
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2233+
}
2234+
Ops[i] = Imm;
2235+
}
2236+
}
2237+
2238+
// Get opcode for MemTy, size, and register/immediate operand ordering
2239+
switch (MemTy) {
2240+
case MVT::i8:
2241+
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
2242+
case MVT::i16:
2243+
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
2244+
case MVT::i32:
2245+
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
2246+
case MVT::i64:
2247+
if (NumElts == 4)
2248+
return std::nullopt;
2249+
return getOpcodeForVectorStParamV2(I64, IsImm);
2250+
case MVT::f32:
2251+
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
2252+
case MVT::f64:
2253+
if (NumElts == 4)
2254+
return std::nullopt;
2255+
return getOpcodeForVectorStParamV2(F64, IsImm);
2256+
case MVT::f16:
2257+
case MVT::v2f16:
2258+
default:
2259+
return std::nullopt;
2260+
}
2261+
}
2262+
21852263
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
21862264
SDLoc DL(N);
21872265
SDValue Chain = N->getOperand(0);
@@ -2228,12 +2306,34 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
22282306
switch (NumElts) {
22292307
default:
22302308
return false;
2231-
case 1:
2232-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2233-
NVPTX::StoreParamI8, NVPTX::StoreParamI16,
2234-
NVPTX::StoreParamI32, NVPTX::StoreParamI64,
2235-
NVPTX::StoreParamF32, NVPTX::StoreParamF64);
2236-
if (Opcode == NVPTX::StoreParamI8) {
2309+
case 1: {
2310+
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2311+
SDValue Imm = Ops[0];
2312+
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
2313+
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
2314+
// Convert immediate to target constant
2315+
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2316+
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2317+
const ConstantFP *CF = ConstImm->getConstantFPValue();
2318+
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2319+
} else {
2320+
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2321+
const ConstantInt *CI = ConstImm->getConstantIntValue();
2322+
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2323+
}
2324+
Ops[0] = Imm;
2325+
// Use immediate version of store param
2326+
Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
2327+
NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
2328+
NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
2329+
NVPTX::StoreParamF64_i);
2330+
} else
2331+
Opcode =
2332+
pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2333+
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
2334+
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
2335+
NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
2336+
if (Opcode == NVPTX::StoreParamI8_r) {
22372337
// Fine tune the opcode depending on the size of the operand.
22382338
// This helps to avoid creating redundant COPY instructions in
22392339
// InstrEmitter::AddRegisterOperand().
@@ -2249,27 +2349,22 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
22492349
}
22502350
}
22512351
break;
2352+
}
22522353
case 2:
2253-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2254-
NVPTX::StoreParamV2I8, NVPTX::StoreParamV2I16,
2255-
NVPTX::StoreParamV2I32, NVPTX::StoreParamV2I64,
2256-
NVPTX::StoreParamV2F32, NVPTX::StoreParamV2F64);
2257-
break;
2258-
case 4:
2259-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2260-
NVPTX::StoreParamV4I8, NVPTX::StoreParamV4I16,
2261-
NVPTX::StoreParamV4I32, std::nullopt,
2262-
NVPTX::StoreParamV4F32, std::nullopt);
2354+
case 4: {
2355+
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2356+
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
22632357
break;
22642358
}
2359+
}
22652360
if (!Opcode)
22662361
return false;
22672362
break;
22682363
// Special case: if we have a sign-extend/zero-extend node, insert the
22692364
// conversion instruction first, and use that as the value operand to
22702365
// the selected StoreParam node.
22712366
case NVPTXISD::StoreParamU32: {
2272-
Opcode = NVPTX::StoreParamI32;
2367+
Opcode = NVPTX::StoreParamI32_r;
22732368
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
22742369
MVT::i32);
22752370
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
@@ -2278,7 +2373,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
22782373
break;
22792374
}
22802375
case NVPTXISD::StoreParamS32: {
2281-
Opcode = NVPTX::StoreParamI32;
2376+
Opcode = NVPTX::StoreParamI32_r;
22822377
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
22832378
MVT::i32);
22842379
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,25 +2637,49 @@ class LoadParamRegInst<NVPTXRegClass regclass, string opstr> :
26372637
[(set regclass:$dst, (LoadParam (i32 0), (i32 imm:$b)))]>;
26382638

26392639
let mayStore = true in {
2640-
class StoreParamInst<NVPTXRegClass regclass, string opstr> :
2640+
class StoreParamInstReg<NVPTXRegClass regclass, string opstr> :
26412641
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b),
2642-
!strconcat("st.param", opstr, " \t[param$a+$b], $val;"),
2642+
"st.param" # opstr # " \t[param$a+$b], $val;",
26432643
[]>;
26442644

2645-
class StoreParamV2Inst<NVPTXRegClass regclass, string opstr> :
2646-
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2,
2647-
i32imm:$a, i32imm:$b),
2648-
!strconcat("st.param.v2", opstr,
2649-
" \t[param$a+$b], {{$val, $val2}};"),
2650-
[]>;
2645+
multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
2646+
def _r: StoreParamInstReg<regclass, opstr>;
26512647

2652-
class StoreParamV4Inst<NVPTXRegClass regclass, string opstr> :
2653-
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, regclass:$val3,
2654-
regclass:$val4, i32imm:$a,
2655-
i32imm:$b),
2656-
!strconcat("st.param.v4", opstr,
2657-
" \t[param$a+$b], {{$val, $val2, $val3, $val4}};"),
2658-
[]>;
2648+
def _i:
2649+
NVPTXInst<(outs), (ins IMMType:$val, i32imm:$a, i32imm:$b),
2650+
"st.param" # opstr # " \t[param$a+$b], $val;",
2651+
[]>;
2652+
}
2653+
2654+
multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
2655+
foreach op1 = [IMMType, regclass] in
2656+
foreach op2 = [IMMType, regclass] in
2657+
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
2658+
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
2659+
: NVPTXInst<(outs),
2660+
(ins op1:$val1, op2:$val2,
2661+
i32imm:$a, i32imm:$b),
2662+
"st.param.v2" # opstr # " \t[param$a+$b], {{$val1, $val2}};",
2663+
[]>;
2664+
}
2665+
2666+
multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
2667+
foreach op1 = [IMMType, regclass] in
2668+
foreach op2 = [IMMType, regclass] in
2669+
foreach op3 = [IMMType, regclass] in
2670+
foreach op4 = [IMMType, regclass] in
2671+
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
2672+
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
2673+
# !if(!isa<NVPTXRegClass>(op3), "r", "i")
2674+
# !if(!isa<NVPTXRegClass>(op4), "r", "i")
2675+
2676+
: NVPTXInst<(outs),
2677+
(ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
2678+
i32imm:$a, i32imm:$b),
2679+
"st.param.v4" # opstr #
2680+
" \t[param$a+$b], {{$val1, $val2, $val3, $val4}};",
2681+
[]>;
2682+
}
26592683

26602684
class StoreRetvalInst<NVPTXRegClass regclass, string opstr> :
26612685
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a),
@@ -2735,27 +2759,30 @@ def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
27352759
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
27362760
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
27372761

2738-
def StoreParamI64 : StoreParamInst<Int64Regs, ".b64">;
2739-
def StoreParamI32 : StoreParamInst<Int32Regs, ".b32">;
2740-
2741-
def StoreParamI16 : StoreParamInst<Int16Regs, ".b16">;
2742-
def StoreParamI8 : StoreParamInst<Int16Regs, ".b8">;
2743-
def StoreParamI8TruncI32 : StoreParamInst<Int32Regs, ".b8">;
2744-
def StoreParamI8TruncI64 : StoreParamInst<Int64Regs, ".b8">;
2745-
def StoreParamV2I64 : StoreParamV2Inst<Int64Regs, ".b64">;
2746-
def StoreParamV2I32 : StoreParamV2Inst<Int32Regs, ".b32">;
2747-
def StoreParamV2I16 : StoreParamV2Inst<Int16Regs, ".b16">;
2748-
def StoreParamV2I8 : StoreParamV2Inst<Int16Regs, ".b8">;
2749-
2750-
def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">;
2751-
def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">;
2752-
def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">;
2753-
2754-
def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
2755-
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
2756-
def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">;
2757-
def StoreParamV2F64 : StoreParamV2Inst<Float64Regs, ".f64">;
2758-
def StoreParamV4F32 : StoreParamV4Inst<Float32Regs, ".f32">;
2762+
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
2763+
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
2764+
defm StoreParamI16 : StoreParamInst<Int16Regs, i16imm, ".b16">;
2765+
defm StoreParamI8 : StoreParamInst<Int16Regs, i8imm, ".b8">;
2766+
2767+
def StoreParamI8TruncI32 : StoreParamInstReg<Int32Regs, ".b8">;
2768+
def StoreParamI8TruncI64 : StoreParamInstReg<Int64Regs, ".b8">;
2769+
2770+
defm StoreParamV2I64 : StoreParamV2Inst<Int64Regs, i64imm, ".b64">;
2771+
defm StoreParamV2I32 : StoreParamV2Inst<Int32Regs, i32imm, ".b32">;
2772+
defm StoreParamV2I16 : StoreParamV2Inst<Int16Regs, i16imm, ".b16">;
2773+
defm StoreParamV2I8 : StoreParamV2Inst<Int16Regs, i8imm, ".b8">;
2774+
2775+
defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
2776+
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
2777+
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
2778+
2779+
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
2780+
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
2781+
2782+
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
2783+
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
2784+
2785+
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
27592786

27602787
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
27612788
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;

0 commit comments

Comments
 (0)