Skip to content

Commit 9a0201c

Browse files
committed
[NVPTX] Fixup rotate lowering correctness
1 parent cc2684e commit 9a0201c

File tree

4 files changed

+303
-369
lines changed

4 files changed

+303
-369
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -594,20 +594,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
594594
setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
595595
setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
596596

597-
// TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs
598-
// that don't have h/w rotation we lower them to multi-instruction assembly.
599-
// See ROT*_sw in NVPTXIntrInfo.td
600-
setOperationAction(ISD::ROTL, MVT::i64, Legal);
601-
setOperationAction(ISD::ROTR, MVT::i64, Legal);
602-
setOperationAction(ISD::ROTL, MVT::i32, Legal);
603-
setOperationAction(ISD::ROTR, MVT::i32, Legal);
604-
605-
setOperationAction(ISD::ROTL, MVT::i16, Expand);
606-
setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
607-
setOperationAction(ISD::ROTR, MVT::i16, Expand);
608-
setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
609-
setOperationAction(ISD::ROTL, MVT::i8, Expand);
610-
setOperationAction(ISD::ROTR, MVT::i8, Expand);
597+
setOperationAction({ISD::ROTL, ISD::ROTR},
598+
{MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
599+
Expand);
600+
601+
if (STI.hasHWROT32())
602+
setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
603+
611604
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
612605

613606
setOperationAction(ISD::BR_JT, MVT::Other, Custom);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 24 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -3491,222 +3491,40 @@ def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))),
34913491
(CVT_u32_u16 Int16Regs:$a, CvtNONE)>;
34923492

34933493
//
3494-
// Rotate: Use ptx shf instruction if available.
3494+
// Funnel-Shift
34953495
//
34963496

34973497
// Create SDNodes so they can be used in the DAG code, e.g.
34983498
// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts)
3499-
def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
3500-
def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
3499+
def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
3500+
def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
35013501

35023502
// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so
35033503
// no side effects.
35043504
let hasSideEffects = false in {
3505+
multiclass ShfInst<string mode, SDNode op> {
3506+
def _i
3507+
: NVPTXInst<(outs Int32Regs:$dst),
3508+
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3509+
"shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;",
3510+
[(set Int32Regs:$dst,
3511+
(op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>,
3512+
Requires<[hasHWROT32]>;
35053513

3506-
def SHF_L_CLAMP_B32_REG :
3507-
NVPTXInst<(outs Int32Regs:$dst),
3508-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3509-
"shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;",
3510-
[(set Int32Regs:$dst,
3511-
(FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3512-
Requires<[hasHWROT32]>;
3513-
3514-
def SHF_R_CLAMP_B32_REG :
3515-
NVPTXInst<(outs Int32Regs:$dst),
3516-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3517-
"shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;",
3518-
[(set Int32Regs:$dst,
3519-
(FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3520-
Requires<[hasHWROT32]>;
3521-
3522-
def SHF_L_WRAP_B32_IMM
3523-
: NVPTXInst<(outs Int32Regs:$dst),
3524-
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3525-
"shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3526-
Requires<[hasHWROT32]>;
3527-
3528-
def SHF_L_WRAP_B32_REG
3529-
: NVPTXInst<(outs Int32Regs:$dst),
3530-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3531-
"shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3532-
Requires<[hasHWROT32]>;
3533-
3534-
def SHF_R_WRAP_B32_IMM
3535-
: NVPTXInst<(outs Int32Regs:$dst),
3536-
(ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
3537-
"shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3538-
Requires<[hasHWROT32]>;
3539-
3540-
def SHF_R_WRAP_B32_REG
3541-
: NVPTXInst<(outs Int32Regs:$dst),
3542-
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3543-
"shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
3544-
Requires<[hasHWROT32]>;
3545-
}
3546-
3547-
// 32 bit r2 = rotl r1, n
3548-
// =>
3549-
// r2 = shf.l r1, r1, n
3550-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
3551-
(SHF_L_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, imm:$amt)>,
3552-
Requires<[hasHWROT32]>;
3553-
3554-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)),
3555-
(SHF_L_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, Int32Regs:$amt)>,
3556-
Requires<[hasHWROT32]>;
3557-
3558-
// 32 bit r2 = rotr r1, n
3559-
// =>
3560-
// r2 = shf.r r1, r1, n
3561-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
3562-
(SHF_R_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, imm:$amt)>,
3563-
Requires<[hasHWROT32]>;
3564-
3565-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)),
3566-
(SHF_R_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, Int32Regs:$amt)>,
3567-
Requires<[hasHWROT32]>;
3568-
3569-
// HW version of rotate 64
3570-
def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
3571-
(V2I32toI64
3572-
(SHF_L_WRAP_B32_IMM (I64toI32H Int64Regs:$src),
3573-
(I64toI32L Int64Regs:$src), imm:$amt),
3574-
(SHF_L_WRAP_B32_IMM (I64toI32L Int64Regs:$src),
3575-
(I64toI32H Int64Regs:$src), imm:$amt))>,
3576-
Requires<[hasHWROT32]>;
3577-
3578-
def : Pat<(rotl Int64Regs:$src, (i32 Int32Regs:$amt)),
3579-
(V2I32toI64
3580-
(SHF_L_WRAP_B32_REG (I64toI32H Int64Regs:$src),
3581-
(I64toI32L Int64Regs:$src), Int32Regs:$amt),
3582-
(SHF_L_WRAP_B32_REG (I64toI32L Int64Regs:$src),
3583-
(I64toI32H Int64Regs:$src), Int32Regs:$amt))>,
3584-
Requires<[hasHWROT32]>;
3585-
3586-
def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
3587-
(V2I32toI64
3588-
(SHF_R_WRAP_B32_IMM (I64toI32L Int64Regs:$src),
3589-
(I64toI32H Int64Regs:$src), imm:$amt),
3590-
(SHF_R_WRAP_B32_IMM (I64toI32H Int64Regs:$src),
3591-
(I64toI32L Int64Regs:$src), imm:$amt))>,
3592-
Requires<[hasHWROT32]>;
3593-
3594-
def : Pat<(rotr Int64Regs:$src, (i32 Int32Regs:$amt)),
3595-
(V2I32toI64
3596-
(SHF_R_WRAP_B32_REG (I64toI32L Int64Regs:$src),
3597-
(I64toI32H Int64Regs:$src), Int32Regs:$amt),
3598-
(SHF_R_WRAP_B32_REG (I64toI32H Int64Regs:$src),
3599-
(I64toI32L Int64Regs:$src), Int32Regs:$amt))>,
3600-
Requires<[hasHWROT32]>;
3601-
3602-
// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1.
3603-
def ROT32imm_sw :
3604-
NVPTXInst<(outs Int32Regs:$dst),
3605-
(ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
3606-
"{{\n\t"
3607-
".reg .b32 %lhs;\n\t"
3608-
".reg .b32 %rhs;\n\t"
3609-
"shl.b32 \t%lhs, $src, $amt1;\n\t"
3610-
"shr.b32 \t%rhs, $src, $amt2;\n\t"
3611-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3612-
"}}",
3613-
[]>;
3614-
3615-
def SUB_FRM_32 : SDNodeXForm<imm, [{
3616-
return CurDAG->getTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32);
3617-
}]>;
3618-
3619-
def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
3620-
(ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
3621-
Requires<[noHWROT32]>;
3622-
def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
3623-
(ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
3624-
Requires<[noHWROT32]>;
3625-
3626-
// 32-bit software rotate left by register.
3627-
def ROTL32reg_sw :
3628-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
3629-
"{{\n\t"
3630-
".reg .b32 %lhs;\n\t"
3631-
".reg .b32 %rhs;\n\t"
3632-
".reg .b32 %amt2;\n\t"
3633-
"shl.b32 \t%lhs, $src, $amt;\n\t"
3634-
"sub.s32 \t%amt2, 32, $amt;\n\t"
3635-
"shr.b32 \t%rhs, $src, %amt2;\n\t"
3636-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3637-
"}}",
3638-
[(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
3639-
Requires<[noHWROT32]>;
3640-
3641-
// 32-bit software rotate right by register.
3642-
def ROTR32reg_sw :
3643-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
3644-
"{{\n\t"
3645-
".reg .b32 %lhs;\n\t"
3646-
".reg .b32 %rhs;\n\t"
3647-
".reg .b32 %amt2;\n\t"
3648-
"shr.b32 \t%lhs, $src, $amt;\n\t"
3649-
"sub.s32 \t%amt2, 32, $amt;\n\t"
3650-
"shl.b32 \t%rhs, $src, %amt2;\n\t"
3651-
"add.u32 \t$dst, %lhs, %rhs;\n\t"
3652-
"}}",
3653-
[(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
3654-
Requires<[noHWROT32]>;
3655-
3656-
// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1.
3657-
def ROT64imm_sw :
3658-
NVPTXInst<(outs Int64Regs:$dst),
3659-
(ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2),
3660-
"{{\n\t"
3661-
".reg .b64 %lhs;\n\t"
3662-
".reg .b64 %rhs;\n\t"
3663-
"shl.b64 \t%lhs, $src, $amt1;\n\t"
3664-
"shr.b64 \t%rhs, $src, $amt2;\n\t"
3665-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3666-
"}}",
3667-
[]>;
3668-
3669-
def SUB_FRM_64 : SDNodeXForm<imm, [{
3670-
return CurDAG->getTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32);
3671-
}]>;
3514+
def _r
3515+
: NVPTXInst<(outs Int32Regs:$dst),
3516+
(ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
3517+
"shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;",
3518+
[(set Int32Regs:$dst,
3519+
(op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
3520+
Requires<[hasHWROT32]>;
3521+
}
36723522

3673-
def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
3674-
(ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>,
3675-
Requires<[noHWROT32]>;
3676-
def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
3677-
(ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>,
3678-
Requires<[noHWROT32]>;
3679-
3680-
// 64-bit software rotate left by register.
3681-
def ROTL64reg_sw :
3682-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
3683-
"{{\n\t"
3684-
".reg .b64 %lhs;\n\t"
3685-
".reg .b64 %rhs;\n\t"
3686-
".reg .u32 %amt2;\n\t"
3687-
"and.b32 \t%amt2, $amt, 63;\n\t"
3688-
"shl.b64 \t%lhs, $src, %amt2;\n\t"
3689-
"sub.u32 \t%amt2, 64, %amt2;\n\t"
3690-
"shr.b64 \t%rhs, $src, %amt2;\n\t"
3691-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3692-
"}}",
3693-
[(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>,
3694-
Requires<[noHWROT32]>;
3695-
3696-
def ROTR64reg_sw :
3697-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
3698-
"{{\n\t"
3699-
".reg .b64 %lhs;\n\t"
3700-
".reg .b64 %rhs;\n\t"
3701-
".reg .u32 %amt2;\n\t"
3702-
"and.b32 \t%amt2, $amt, 63;\n\t"
3703-
"shr.b64 \t%lhs, $src, %amt2;\n\t"
3704-
"sub.u32 \t%amt2, 64, %amt2;\n\t"
3705-
"shl.b64 \t%rhs, $src, %amt2;\n\t"
3706-
"add.u64 \t$dst, %lhs, %rhs;\n\t"
3707-
"}}",
3708-
[(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>,
3709-
Requires<[noHWROT32]>;
3523+
defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>;
3524+
defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>;
3525+
defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>;
3526+
defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>;
3527+
}
37103528

37113529
// Count leading zeros
37123530
let hasSideEffects = false in {

0 commit comments

Comments
 (0)