Skip to content

Commit b967616

Browse files
committed
[NVPTX] Fixup rotate lowering correctness
1 parent 0e3ba7a commit b967616

File tree

4 files changed

+303
-369
lines changed

4 files changed

+303
-369
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -594,20 +594,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
594594
setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
595595
setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
596596

597-
// TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs
598-
// that don't have h/w rotation we lower them to multi-instruction assembly.
599-
// See ROT*_sw in NVPTXIntrInfo.td
600-
setOperationAction(ISD::ROTL, MVT::i64, Legal);
601-
setOperationAction(ISD::ROTR, MVT::i64, Legal);
602-
setOperationAction(ISD::ROTL, MVT::i32, Legal);
603-
setOperationAction(ISD::ROTR, MVT::i32, Legal);
604-
605-
setOperationAction(ISD::ROTL, MVT::i16, Expand);
606-
setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
607-
setOperationAction(ISD::ROTR, MVT::i16, Expand);
608-
setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
609-
setOperationAction(ISD::ROTL, MVT::i8, Expand);
610-
setOperationAction(ISD::ROTR, MVT::i8, Expand);
597+
setOperationAction({ISD::ROTL, ISD::ROTR},
598+
{MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
599+
Expand);
600+
601+
if (STI.hasHWROT32())
602+
setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
603+
611604
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
612605

613606
setOperationAction(ISD::BR_JT, MVT::Other, Custom);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 24 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -3497,222 +3497,40 @@ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))),
34973497
(CVT_u32_u16 Int16Regs:$a, CvtNONE)>;
34983498

34993499
//
3500-
// Rotate: Use ptx shf instruction if available.
3500+
// Funnel-Shift
35013501
//
35023502

35033503
// Create SDNodes so they can be used in the DAG code, e.g.
35043504
// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts)
3505-
def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
3506-
def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
3505+
def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
3506+
def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
35073507

35083508
// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so
35093509
// no side effects.
35103510
let hasSideEffects = false in {
3511+
multiclass ShfInst<string mode, SDNode op> {
3512+
def _i
3513+
: NVPTXInst<(outs Int32Regs:$dst),
3514+
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3515+
"shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;",
3516+
[(set Int32Regs:$dst,
3517+
(op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>,
3518+
Requires<[hasHWROT32]>;
35113519

3512-
def SHF_L_CLAMP_B32_REG :
3513-
NVPTXInst<(outs Int32Regs:$dst),
3514-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3515-
"shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;",
3516-
[(set Int32Regs:$dst,
3517-
(FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3518-
Requires<[hasHWROT32]>;
3519-
3520-
def SHF_R_CLAMP_B32_REG :
3521-
NVPTXInst<(outs Int32Regs:$dst),
3522-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3523-
"shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;",
3524-
[(set Int32Regs:$dst,
3525-
(FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3526-
Requires<[hasHWROT32]>;
3527-
3528-
def SHF_L_WRAP_B32_IMM
3529-
: NVPTXInst<(outs Int32Regs:$dst),
3530-
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3531-
"shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3532-
Requires<[hasHWROT32]>;
3533-
3534-
def SHF_L_WRAP_B32_REG
3535-
: NVPTXInst<(outs Int32Regs:$dst),
3536-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3537-
"shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3538-
Requires<[hasHWROT32]>;
3539-
3540-
def SHF_R_WRAP_B32_IMM
3541-
: NVPTXInst<(outs Int32Regs:$dst),
3542-
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3543-
"shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3544-
Requires<[hasHWROT32]>;
3545-
3546-
def SHF_R_WRAP_B32_REG
3547-
: NVPTXInst<(outs Int32Regs:$dst),
3548-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3549-
"shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3550-
Requires<[hasHWROT32]>;
3551-
}
3552-
3553-
// 32 bit r2 = rotl r1, n
3554-
// =>
3555-
// r2 = shf.l r1, r1, n
3556-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
3557-
(SHF_L_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, imm:$amt)>,
3558-
Requires<[hasHWROT32]>;
3559-
3560-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)),
3561-
(SHF_L_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, Int32Regs:$amt)>,
3562-
Requires<[hasHWROT32]>;
3563-
3564-
// 32 bit r2 = rotr r1, n
3565-
// =>
3566-
// r2 = shf.r r1, r1, n
3567-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
3568-
(SHF_R_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, imm:$amt)>,
3569-
Requires<[hasHWROT32]>;
3570-
3571-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)),
3572-
(SHF_R_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, Int32Regs:$amt)>,
3573-
Requires<[hasHWROT32]>;
3574-
3575-
// HW version of rotate 64
3576-
def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
3577-
(V2I32toI64
3578-
(SHF_L_WRAP_B32_IMM (I64toI32H Int64Regs:$src),
3579-
(I64toI32L Int64Regs:$src), imm:$amt),
3580-
(SHF_L_WRAP_B32_IMM (I64toI32L Int64Regs:$src),
3581-
(I64toI32H Int64Regs:$src), imm:$amt))>,
3582-
Requires<[hasHWROT32]>;
3583-
3584-
def : Pat<(rotl Int64Regs:$src, (i32 Int32Regs:$amt)),
3585-
(V2I32toI64
3586-
(SHF_L_WRAP_B32_REG (I64toI32H Int64Regs:$src),
3587-
(I64toI32L Int64Regs:$src), Int32Regs:$amt),
3588-
(SHF_L_WRAP_B32_REG (I64toI32L Int64Regs:$src),
3589-
(I64toI32H Int64Regs:$src), Int32Regs:$amt))>,
3590-
Requires<[hasHWROT32]>;
3591-
3592-
def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
3593-
(V2I32toI64
3594-
(SHF_R_WRAP_B32_IMM (I64toI32L Int64Regs:$src),
3595-
(I64toI32H Int64Regs:$src), imm:$amt),
3596-
(SHF_R_WRAP_B32_IMM (I64toI32H Int64Regs:$src),
3597-
(I64toI32L Int64Regs:$src), imm:$amt))>,
3598-
Requires<[hasHWROT32]>;
3599-
3600-
def : Pat<(rotr Int64Regs:$src, (i32 Int32Regs:$amt)),
3601-
(V2I32toI64
3602-
(SHF_R_WRAP_B32_REG (I64toI32L Int64Regs:$src),
3603-
(I64toI32H Int64Regs:$src), Int32Regs:$amt),
3604-
(SHF_R_WRAP_B32_REG (I64toI32H Int64Regs:$src),
3605-
(I64toI32L Int64Regs:$src), Int32Regs:$amt))>,
3606-
Requires<[hasHWROT32]>;
3607-
3608-
// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1.
3609-
def ROT32imm_sw :
3610-
NVPTXInst<(outs Int32Regs:$dst),
3611-
(ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
3612-
"{{\n\t"
3613-
".reg .b32 %lhs;\n\t"
3614-
".reg .b32 %rhs;\n\t"
3615-
"shl.b32 \t%lhs, $src, $amt1;\n\t"
3616-
"shr.b32 \t%rhs, $src, $amt2;\n\t"
3617-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3618-
"}}",
3619-
[]>;
3620-
3621-
def SUB_FRM_32 : SDNodeXForm<imm, [{
3622-
return CurDAG->getTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32);
3623-
}]>;
3624-
3625-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
3626-
(ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
3627-
Requires<[noHWROT32]>;
3628-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
3629-
(ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
3630-
Requires<[noHWROT32]>;
3631-
3632-
// 32-bit software rotate left by register.
3633-
def ROTL32reg_sw :
3634-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
3635-
"{{\n\t"
3636-
".reg .b32 %lhs;\n\t"
3637-
".reg .b32 %rhs;\n\t"
3638-
".reg .b32 %amt2;\n\t"
3639-
"shl.b32 \t%lhs, $src, $amt;\n\t"
3640-
"sub.s32 \t%amt2, 32, $amt;\n\t"
3641-
"shr.b32 \t%rhs, $src, %amt2;\n\t"
3642-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3643-
"}}",
3644-
[(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
3645-
Requires<[noHWROT32]>;
3646-
3647-
// 32-bit software rotate right by register.
3648-
def ROTR32reg_sw :
3649-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
3650-
"{{\n\t"
3651-
".reg .b32 %lhs;\n\t"
3652-
".reg .b32 %rhs;\n\t"
3653-
".reg .b32 %amt2;\n\t"
3654-
"shr.b32 \t%lhs, $src, $amt;\n\t"
3655-
"sub.s32 \t%amt2, 32, $amt;\n\t"
3656-
"shl.b32 \t%rhs, $src, %amt2;\n\t"
3657-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3658-
"}}",
3659-
[(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
3660-
Requires<[noHWROT32]>;
3661-
3662-
// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1.
3663-
def ROT64imm_sw :
3664-
NVPTXInst<(outs Int64Regs:$dst),
3665-
(ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2),
3666-
"{{\n\t"
3667-
".reg .b64 %lhs;\n\t"
3668-
".reg .b64 %rhs;\n\t"
3669-
"shl.b64 \t%lhs, $src, $amt1;\n\t"
3670-
"shr.b64 \t%rhs, $src, $amt2;\n\t"
3671-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3672-
"}}",
3673-
[]>;
3674-
3675-
def SUB_FRM_64 : SDNodeXForm<imm, [{
3676-
return CurDAG->getTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32);
3677-
}]>;
3520+
def _r
3521+
: NVPTXInst<(outs Int32Regs:$dst),
3522+
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3523+
"shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;",
3524+
[(set Int32Regs:$dst,
3525+
(op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3526+
Requires<[hasHWROT32]>;
3527+
}
36783528

3679-
def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
3680-
(ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>,
3681-
Requires<[noHWROT32]>;
3682-
def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
3683-
(ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>,
3684-
Requires<[noHWROT32]>;
3685-
3686-
// 64-bit software rotate left by register.
3687-
def ROTL64reg_sw :
3688-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
3689-
"{{\n\t"
3690-
".reg .b64 %lhs;\n\t"
3691-
".reg .b64 %rhs;\n\t"
3692-
".reg .u32 %amt2;\n\t"
3693-
"and.b32 \t%amt2, $amt, 63;\n\t"
3694-
"shl.b64 \t%lhs, $src, %amt2;\n\t"
3695-
"sub.u32 \t%amt2, 64, %amt2;\n\t"
3696-
"shr.b64 \t%rhs, $src, %amt2;\n\t"
3697-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3698-
"}}",
3699-
[(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>,
3700-
Requires<[noHWROT32]>;
3701-
3702-
def ROTR64reg_sw :
3703-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
3704-
"{{\n\t"
3705-
".reg .b64 %lhs;\n\t"
3706-
".reg .b64 %rhs;\n\t"
3707-
".reg .u32 %amt2;\n\t"
3708-
"and.b32 \t%amt2, $amt, 63;\n\t"
3709-
"shr.b64 \t%lhs, $src, %amt2;\n\t"
3710-
"sub.u32 \t%amt2, 64, %amt2;\n\t"
3711-
"shl.b64 \t%rhs, $src, %amt2;\n\t"
3712-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3713-
"}}",
3714-
[(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>,
3715-
Requires<[noHWROT32]>;
3529+
defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>;
3530+
defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>;
3531+
defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>;
3532+
defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>;
3533+
}
37163534

37173535
// Count leading zeros
37183536
let hasSideEffects = false in {

0 commit comments

Comments
 (0)