Skip to content

Commit c5b11a7

Browse files
authored
[NVPTX] support immediate values in st.param instructions (#91523)
Add support for generating `st.param` instructions with direct use of immediates. This eliminates the need for a `mov` instruction prior to the `st.param` resulting in more concise emitted PTX.
1 parent 58c7785 commit c5b11a7

File tree

3 files changed

+2199
-64
lines changed

3 files changed

+2199
-64
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 135 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,100 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
21822182
return true;
21832183
}
21842184

2185+
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2186+
#define getOpcV2H(ty, opKind0, opKind1) \
2187+
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
2188+
2189+
#define getOpcV2H1(ty, opKind0, isImm1) \
2190+
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
2191+
2192+
#define getOpcodeForVectorStParamV2(ty, isimm) \
2193+
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
2194+
2195+
#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
2196+
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
2197+
2198+
#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
2199+
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
2200+
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)
2201+
2202+
#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
2203+
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
2204+
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
2205+
2206+
#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
2207+
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
2208+
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
2209+
2210+
#define getOpcodeForVectorStParamV4(ty, isimm) \
2211+
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
2212+
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
2213+
2214+
#define getOpcodeForVectorStParam(n, ty, isimm) \
2215+
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
2216+
: getOpcodeForVectorStParamV4(ty, isimm)
2217+
2218+
static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
2219+
unsigned NumElts,
2220+
MVT::SimpleValueType MemTy,
2221+
SelectionDAG *CurDAG, SDLoc DL) {
2222+
// Determine which inputs are registers and immediates make new operators
2223+
// with constant values
2224+
SmallVector<bool, 4> IsImm(NumElts, false);
2225+
for (unsigned i = 0; i < NumElts; i++) {
2226+
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
2227+
if (IsImm[i]) {
2228+
SDValue Imm = Ops[i];
2229+
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2230+
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2231+
const ConstantFP *CF = ConstImm->getConstantFPValue();
2232+
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2233+
} else {
2234+
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2235+
const ConstantInt *CI = ConstImm->getConstantIntValue();
2236+
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2237+
}
2238+
Ops[i] = Imm;
2239+
}
2240+
}
2241+
2242+
// Get opcode for MemTy, size, and register/immediate operand ordering
2243+
switch (MemTy) {
2244+
case MVT::i8:
2245+
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
2246+
case MVT::i16:
2247+
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
2248+
case MVT::i32:
2249+
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
2250+
case MVT::i64:
2251+
assert(NumElts == 2 && "MVT too large for NumElts > 2");
2252+
return getOpcodeForVectorStParamV2(I64, IsImm);
2253+
case MVT::f32:
2254+
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
2255+
case MVT::f64:
2256+
assert(NumElts == 2 && "MVT too large for NumElts > 2");
2257+
return getOpcodeForVectorStParamV2(F64, IsImm);
2258+
2259+
// These cases don't support immediates, just use the all register version
2260+
// and generate moves.
2261+
case MVT::i1:
2262+
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
2263+
: NVPTX::StoreParamV4I8_rrrr;
2264+
case MVT::f16:
2265+
case MVT::bf16:
2266+
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
2267+
: NVPTX::StoreParamV4I16_rrrr;
2268+
case MVT::v2f16:
2269+
case MVT::v2bf16:
2270+
case MVT::v2i16:
2271+
case MVT::v4i8:
2272+
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
2273+
: NVPTX::StoreParamV4I32_rrrr;
2274+
default:
2275+
llvm_unreachable("Cannot select st.param for unknown MemTy");
2276+
}
2277+
}
2278+
21852279
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
21862280
SDLoc DL(N);
21872281
SDValue Chain = N->getOperand(0);
@@ -2193,10 +2287,10 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
21932287
SDValue Glue = N->getOperand(N->getNumOperands() - 1);
21942288

21952289
// How many elements do we have?
2196-
unsigned NumElts = 1;
2290+
unsigned NumElts;
21972291
switch (N->getOpcode()) {
21982292
default:
2199-
return false;
2293+
llvm_unreachable("Unexpected opcode");
22002294
case NVPTXISD::StoreParamU32:
22012295
case NVPTXISD::StoreParamS32:
22022296
case NVPTXISD::StoreParam:
@@ -2222,54 +2316,69 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
22222316
// Determine target opcode
22232317
// If we have an i1, use an 8-bit store. The lowering code in
22242318
// NVPTXISelLowering will have already emitted an upcast.
2225-
std::optional<unsigned> Opcode = 0;
2319+
std::optional<unsigned> Opcode;
22262320
switch (N->getOpcode()) {
22272321
default:
22282322
switch (NumElts) {
22292323
default:
2230-
return false;
2231-
case 1:
2232-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2233-
NVPTX::StoreParamI8, NVPTX::StoreParamI16,
2234-
NVPTX::StoreParamI32, NVPTX::StoreParamI64,
2235-
NVPTX::StoreParamF32, NVPTX::StoreParamF64);
2236-
if (Opcode == NVPTX::StoreParamI8) {
2324+
llvm_unreachable("Unexpected NumElts");
2325+
case 1: {
2326+
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2327+
SDValue Imm = Ops[0];
2328+
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
2329+
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
2330+
// Convert immediate to target constant
2331+
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2332+
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2333+
const ConstantFP *CF = ConstImm->getConstantFPValue();
2334+
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2335+
} else {
2336+
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2337+
const ConstantInt *CI = ConstImm->getConstantIntValue();
2338+
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2339+
}
2340+
Ops[0] = Imm;
2341+
// Use immediate version of store param
2342+
Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
2343+
NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
2344+
NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
2345+
NVPTX::StoreParamF64_i);
2346+
} else
2347+
Opcode =
2348+
pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2349+
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
2350+
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
2351+
NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
2352+
if (Opcode == NVPTX::StoreParamI8_r) {
22372353
// Fine tune the opcode depending on the size of the operand.
22382354
// This helps to avoid creating redundant COPY instructions in
22392355
// InstrEmitter::AddRegisterOperand().
22402356
switch (Ops[0].getSimpleValueType().SimpleTy) {
22412357
default:
22422358
break;
22432359
case MVT::i32:
2244-
Opcode = NVPTX::StoreParamI8TruncI32;
2360+
Opcode = NVPTX::StoreParamI8TruncI32_r;
22452361
break;
22462362
case MVT::i64:
2247-
Opcode = NVPTX::StoreParamI8TruncI64;
2363+
Opcode = NVPTX::StoreParamI8TruncI64_r;
22482364
break;
22492365
}
22502366
}
22512367
break;
2368+
}
22522369
case 2:
2253-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2254-
NVPTX::StoreParamV2I8, NVPTX::StoreParamV2I16,
2255-
NVPTX::StoreParamV2I32, NVPTX::StoreParamV2I64,
2256-
NVPTX::StoreParamV2F32, NVPTX::StoreParamV2F64);
2257-
break;
2258-
case 4:
2259-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2260-
NVPTX::StoreParamV4I8, NVPTX::StoreParamV4I16,
2261-
NVPTX::StoreParamV4I32, std::nullopt,
2262-
NVPTX::StoreParamV4F32, std::nullopt);
2370+
case 4: {
2371+
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2372+
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
22632373
break;
22642374
}
2265-
if (!Opcode)
2266-
return false;
2375+
}
22672376
break;
22682377
// Special case: if we have a sign-extend/zero-extend node, insert the
22692378
// conversion instruction first, and use that as the value operand to
22702379
// the selected StoreParam node.
22712380
case NVPTXISD::StoreParamU32: {
2272-
Opcode = NVPTX::StoreParamI32;
2381+
Opcode = NVPTX::StoreParamI32_r;
22732382
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
22742383
MVT::i32);
22752384
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
@@ -2278,7 +2387,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
22782387
break;
22792388
}
22802389
case NVPTXISD::StoreParamS32: {
2281-
Opcode = NVPTX::StoreParamI32;
2390+
Opcode = NVPTX::StoreParamI32_r;
22822391
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
22832392
MVT::i32);
22842393
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,25 +2637,46 @@ class LoadParamRegInst<NVPTXRegClass regclass, string opstr> :
26372637
[(set regclass:$dst, (LoadParam (i32 0), (i32 imm:$b)))]>;
26382638

26392639
let mayStore = true in {
2640-
class StoreParamInst<NVPTXRegClass regclass, string opstr> :
2641-
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b),
2642-
!strconcat("st.param", opstr, " \t[param$a+$b], $val;"),
2643-
[]>;
26442640

2645-
class StoreParamV2Inst<NVPTXRegClass regclass, string opstr> :
2646-
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2,
2647-
i32imm:$a, i32imm:$b),
2648-
!strconcat("st.param.v2", opstr,
2649-
" \t[param$a+$b], {{$val, $val2}};"),
2650-
[]>;
2641+
multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> {
2642+
foreach op = [IMMType, regclass] in
2643+
if !or(support_imm, !isa<NVPTXRegClass>(op)) then
2644+
def _ # !if(!isa<NVPTXRegClass>(op), "r", "i")
2645+
: NVPTXInst<(outs),
2646+
(ins op:$val, i32imm:$a, i32imm:$b),
2647+
"st.param" # opstr # " \t[param$a+$b], $val;",
2648+
[]>;
2649+
}
26512650

2652-
class StoreParamV4Inst<NVPTXRegClass regclass, string opstr> :
2653-
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, regclass:$val3,
2654-
regclass:$val4, i32imm:$a,
2655-
i32imm:$b),
2656-
!strconcat("st.param.v4", opstr,
2657-
" \t[param$a+$b], {{$val, $val2, $val3, $val4}};"),
2658-
[]>;
2651+
multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
2652+
foreach op1 = [IMMType, regclass] in
2653+
foreach op2 = [IMMType, regclass] in
2654+
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
2655+
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
2656+
: NVPTXInst<(outs),
2657+
(ins op1:$val1, op2:$val2,
2658+
i32imm:$a, i32imm:$b),
2659+
"st.param.v2" # opstr # " \t[param$a+$b], {{$val1, $val2}};",
2660+
[]>;
2661+
}
2662+
2663+
multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
2664+
foreach op1 = [IMMType, regclass] in
2665+
foreach op2 = [IMMType, regclass] in
2666+
foreach op3 = [IMMType, regclass] in
2667+
foreach op4 = [IMMType, regclass] in
2668+
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
2669+
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
2670+
# !if(!isa<NVPTXRegClass>(op3), "r", "i")
2671+
# !if(!isa<NVPTXRegClass>(op4), "r", "i")
2672+
2673+
: NVPTXInst<(outs),
2674+
(ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
2675+
i32imm:$a, i32imm:$b),
2676+
"st.param.v4" # opstr #
2677+
" \t[param$a+$b], {{$val1, $val2, $val3, $val4}};",
2678+
[]>;
2679+
}
26592680

26602681
class StoreRetvalInst<NVPTXRegClass regclass, string opstr> :
26612682
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a),
@@ -2735,27 +2756,30 @@ def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
27352756
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
27362757
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
27372758

2738-
def StoreParamI64 : StoreParamInst<Int64Regs, ".b64">;
2739-
def StoreParamI32 : StoreParamInst<Int32Regs, ".b32">;
2740-
2741-
def StoreParamI16 : StoreParamInst<Int16Regs, ".b16">;
2742-
def StoreParamI8 : StoreParamInst<Int16Regs, ".b8">;
2743-
def StoreParamI8TruncI32 : StoreParamInst<Int32Regs, ".b8">;
2744-
def StoreParamI8TruncI64 : StoreParamInst<Int64Regs, ".b8">;
2745-
def StoreParamV2I64 : StoreParamV2Inst<Int64Regs, ".b64">;
2746-
def StoreParamV2I32 : StoreParamV2Inst<Int32Regs, ".b32">;
2747-
def StoreParamV2I16 : StoreParamV2Inst<Int16Regs, ".b16">;
2748-
def StoreParamV2I8 : StoreParamV2Inst<Int16Regs, ".b8">;
2749-
2750-
def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">;
2751-
def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">;
2752-
def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">;
2753-
2754-
def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
2755-
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
2756-
def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">;
2757-
def StoreParamV2F64 : StoreParamV2Inst<Float64Regs, ".f64">;
2758-
def StoreParamV4F32 : StoreParamV4Inst<Float32Regs, ".f32">;
2759+
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
2760+
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
2761+
defm StoreParamI16 : StoreParamInst<Int16Regs, i16imm, ".b16">;
2762+
defm StoreParamI8 : StoreParamInst<Int16Regs, i8imm, ".b8">;
2763+
2764+
defm StoreParamI8TruncI32 : StoreParamInst<Int32Regs, i8imm, ".b8", /* support_imm */ false>;
2765+
defm StoreParamI8TruncI64 : StoreParamInst<Int64Regs, i8imm, ".b8", /* support_imm */ false>;
2766+
2767+
defm StoreParamV2I64 : StoreParamV2Inst<Int64Regs, i64imm, ".b64">;
2768+
defm StoreParamV2I32 : StoreParamV2Inst<Int32Regs, i32imm, ".b32">;
2769+
defm StoreParamV2I16 : StoreParamV2Inst<Int16Regs, i16imm, ".b16">;
2770+
defm StoreParamV2I8 : StoreParamV2Inst<Int16Regs, i8imm, ".b8">;
2771+
2772+
defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
2773+
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
2774+
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
2775+
2776+
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
2777+
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
2778+
2779+
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
2780+
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
2781+
2782+
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
27592783

27602784
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
27612785
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;

0 commit comments

Comments
 (0)