Skip to content

[NVPTX] support immediate values in st.param instructions #91523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 135 additions & 26 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,100 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
return true;
}

// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
#define getOpcV2H(ty, opKind0, opKind1) \
NVPTX::StoreParamV2##ty##_##opKind0##opKind1

#define getOpcV2H1(ty, opKind0, isImm1) \
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)

#define getOpcodeForVectorStParamV2(ty, isimm) \
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])

#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3

#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)

#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)

#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)

#define getOpcodeForVectorStParamV4(ty, isimm) \
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])

#define getOpcodeForVectorStParam(n, ty, isimm) \
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
: getOpcodeForVectorStParamV4(ty, isimm)

static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
unsigned NumElts,
MVT::SimpleValueType MemTy,
SelectionDAG *CurDAG, SDLoc DL) {
// Determine which inputs are registers and immediates make new operators
// with constant values
SmallVector<bool, 4> IsImm(NumElts, false);
for (unsigned i = 0; i < NumElts; i++) {
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
if (IsImm[i]) {
SDValue Imm = Ops[i];
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
const ConstantFP *CF = ConstImm->getConstantFPValue();
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
} else {
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
const ConstantInt *CI = ConstImm->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
}
Ops[i] = Imm;
}
}

// Get opcode for MemTy, size, and register/immediate operand ordering
switch (MemTy) {
case MVT::i8:
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
case MVT::i16:
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
case MVT::i32:
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
case MVT::i64:
assert(NumElts == 2 && "MVT too large for NumElts > 2");
return getOpcodeForVectorStParamV2(I64, IsImm);
case MVT::f32:
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
case MVT::f64:
assert(NumElts == 2 && "MVT too large for NumElts > 2");
return getOpcodeForVectorStParamV2(F64, IsImm);

// These cases don't support immediates, just use the all register version
// and generate moves.
case MVT::i1:
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
: NVPTX::StoreParamV4I8_rrrr;
case MVT::f16:
case MVT::bf16:
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
: NVPTX::StoreParamV4I16_rrrr;
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
: NVPTX::StoreParamV4I32_rrrr;
default:
llvm_unreachable("Cannot select st.param for unknown MemTy");
}
}

bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
Expand All @@ -2193,10 +2287,10 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
SDValue Glue = N->getOperand(N->getNumOperands() - 1);

// How many elements do we have?
unsigned NumElts = 1;
unsigned NumElts;
switch (N->getOpcode()) {
default:
return false;
llvm_unreachable("Unexpected opcode");
case NVPTXISD::StoreParamU32:
case NVPTXISD::StoreParamS32:
case NVPTXISD::StoreParam:
Expand All @@ -2222,54 +2316,69 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
// Determine target opcode
// If we have an i1, use an 8-bit store. The lowering code in
// NVPTXISelLowering will have already emitted an upcast.
std::optional<unsigned> Opcode = 0;
std::optional<unsigned> Opcode;
switch (N->getOpcode()) {
default:
switch (NumElts) {
default:
return false;
case 1:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreParamI8, NVPTX::StoreParamI16,
NVPTX::StoreParamI32, NVPTX::StoreParamI64,
NVPTX::StoreParamF32, NVPTX::StoreParamF64);
if (Opcode == NVPTX::StoreParamI8) {
llvm_unreachable("Unexpected NumElts");
case 1: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
SDValue Imm = Ops[0];
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
// Convert immediate to target constant
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
const ConstantFP *CF = ConstImm->getConstantFPValue();
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
} else {
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
const ConstantInt *CI = ConstImm->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
}
Ops[0] = Imm;
// Use immediate version of store param
Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
NVPTX::StoreParamF64_i);
} else
Opcode =
pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
if (Opcode == NVPTX::StoreParamI8_r) {
// Fine tune the opcode depending on the size of the operand.
// This helps to avoid creating redundant COPY instructions in
// InstrEmitter::AddRegisterOperand().
switch (Ops[0].getSimpleValueType().SimpleTy) {
default:
break;
case MVT::i32:
Opcode = NVPTX::StoreParamI8TruncI32;
Opcode = NVPTX::StoreParamI8TruncI32_r;
break;
case MVT::i64:
Opcode = NVPTX::StoreParamI8TruncI64;
Opcode = NVPTX::StoreParamI8TruncI64_r;
break;
}
}
break;
}
case 2:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreParamV2I8, NVPTX::StoreParamV2I16,
NVPTX::StoreParamV2I32, NVPTX::StoreParamV2I64,
NVPTX::StoreParamV2F32, NVPTX::StoreParamV2F64);
break;
case 4:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreParamV4I8, NVPTX::StoreParamV4I16,
NVPTX::StoreParamV4I32, std::nullopt,
NVPTX::StoreParamV4F32, std::nullopt);
case 4: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
break;
}
if (!Opcode)
return false;
}
break;
// Special case: if we have a sign-extend/zero-extend node, insert the
// conversion instruction first, and use that as the value operand to
// the selected StoreParam node.
case NVPTXISD::StoreParamU32: {
Opcode = NVPTX::StoreParamI32;
Opcode = NVPTX::StoreParamI32_r;
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
MVT::i32);
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
Expand All @@ -2278,7 +2387,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
break;
}
case NVPTXISD::StoreParamS32: {
Opcode = NVPTX::StoreParamI32;
Opcode = NVPTX::StoreParamI32_r;
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
MVT::i32);
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
Expand Down
100 changes: 62 additions & 38 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2637,25 +2637,46 @@ class LoadParamRegInst<NVPTXRegClass regclass, string opstr> :
[(set regclass:$dst, (LoadParam (i32 0), (i32 imm:$b)))]>;

let mayStore = true in {
class StoreParamInst<NVPTXRegClass regclass, string opstr> :
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b),
!strconcat("st.param", opstr, " \t[param$a+$b], $val;"),
[]>;

class StoreParamV2Inst<NVPTXRegClass regclass, string opstr> :
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2,
i32imm:$a, i32imm:$b),
!strconcat("st.param.v2", opstr,
" \t[param$a+$b], {{$val, $val2}};"),
[]>;
multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I like it even better.

foreach op = [IMMType, regclass] in
if !or(support_imm, !isa<NVPTXRegClass>(op)) then
def _ # !if(!isa<NVPTXRegClass>(op), "r", "i")
: NVPTXInst<(outs),
(ins op:$val, i32imm:$a, i32imm:$b),
"st.param" # opstr # " \t[param$a+$b], $val;",
[]>;
}

class StoreParamV4Inst<NVPTXRegClass regclass, string opstr> :
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, regclass:$val3,
regclass:$val4, i32imm:$a,
i32imm:$b),
!strconcat("st.param.v4", opstr,
" \t[param$a+$b], {{$val, $val2, $val3, $val4}};"),
[]>;
multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
foreach op1 = [IMMType, regclass] in
foreach op2 = [IMMType, regclass] in
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
: NVPTXInst<(outs),
(ins op1:$val1, op2:$val2,
i32imm:$a, i32imm:$b),
"st.param.v2" # opstr # " \t[param$a+$b], {{$val1, $val2}};",
[]>;
}

multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
foreach op1 = [IMMType, regclass] in
foreach op2 = [IMMType, regclass] in
foreach op3 = [IMMType, regclass] in
foreach op4 = [IMMType, regclass] in
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
# !if(!isa<NVPTXRegClass>(op3), "r", "i")
# !if(!isa<NVPTXRegClass>(op4), "r", "i")

: NVPTXInst<(outs),
(ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
i32imm:$a, i32imm:$b),
"st.param.v4" # opstr #
" \t[param$a+$b], {{$val1, $val2, $val3, $val4}};",
[]>;
}

class StoreRetvalInst<NVPTXRegClass regclass, string opstr> :
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a),
Expand Down Expand Up @@ -2735,27 +2756,30 @@ def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;

def StoreParamI64 : StoreParamInst<Int64Regs, ".b64">;
def StoreParamI32 : StoreParamInst<Int32Regs, ".b32">;

def StoreParamI16 : StoreParamInst<Int16Regs, ".b16">;
def StoreParamI8 : StoreParamInst<Int16Regs, ".b8">;
def StoreParamI8TruncI32 : StoreParamInst<Int32Regs, ".b8">;
def StoreParamI8TruncI64 : StoreParamInst<Int64Regs, ".b8">;
def StoreParamV2I64 : StoreParamV2Inst<Int64Regs, ".b64">;
def StoreParamV2I32 : StoreParamV2Inst<Int32Regs, ".b32">;
def StoreParamV2I16 : StoreParamV2Inst<Int16Regs, ".b16">;
def StoreParamV2I8 : StoreParamV2Inst<Int16Regs, ".b8">;

def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">;
def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">;
def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">;

def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">;
def StoreParamV2F64 : StoreParamV2Inst<Float64Regs, ".f64">;
def StoreParamV4F32 : StoreParamV4Inst<Float32Regs, ".f32">;
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
defm StoreParamI16 : StoreParamInst<Int16Regs, i16imm, ".b16">;
defm StoreParamI8 : StoreParamInst<Int16Regs, i8imm, ".b8">;

defm StoreParamI8TruncI32 : StoreParamInst<Int32Regs, i8imm, ".b8", /* support_imm */ false>;
defm StoreParamI8TruncI64 : StoreParamInst<Int64Regs, i8imm, ".b8", /* support_imm */ false>;

defm StoreParamV2I64 : StoreParamV2Inst<Int64Regs, i64imm, ".b64">;
defm StoreParamV2I32 : StoreParamV2Inst<Int32Regs, i32imm, ".b32">;
defm StoreParamV2I16 : StoreParamV2Inst<Int16Regs, i16imm, ".b16">;
defm StoreParamV2I8 : StoreParamV2Inst<Int16Regs, i8imm, ".b8">;

defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;

defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;

defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;

defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;

def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
Expand Down
Loading
Loading