Skip to content

Commit ca7a4f1

Browse files
[NVPTX] Add isel patterns for bit-field extract (bfe)
llvm-svn: 211932
1 parent 10c2596 commit ca7a4f1

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) {
253253
case NVPTXISD::Suld3DV4I32Trap:
254254
ResNode = SelectSurfaceIntrinsic(N);
255255
break;
256+
case ISD::AND:
257+
case ISD::SRA:
258+
case ISD::SRL:
259+
// Try to select BFE
260+
ResNode = SelectBFE(N);
261+
break;
256262
case ISD::ADDRSPACECAST:
257263
ResNode = SelectAddrSpaceCast(N);
258264
break;
@@ -2959,6 +2965,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) {
29592965
return Ret;
29602966
}
29612967

2968+
/// SelectBFE - Look for instruction sequences that can be made more efficient
2969+
/// by using the 'bfe' (bit-field extract) PTX instruction
2970+
SDNode *NVPTXDAGToDAGISel::SelectBFE(SDNode *N) {
2971+
SDValue LHS = N->getOperand(0);
2972+
SDValue RHS = N->getOperand(1);
2973+
SDValue Len;
2974+
SDValue Start;
2975+
SDValue Val;
2976+
bool IsSigned = false;
2977+
2978+
if (N->getOpcode() == ISD::AND) {
2979+
// Canonicalize the operands
2980+
// We want 'and %val, %mask'
2981+
if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
2982+
std::swap(LHS, RHS);
2983+
}
2984+
2985+
ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
2986+
if (!Mask) {
2987+
// We need a constant mask on the RHS of the AND
2988+
return NULL;
2989+
}
2990+
2991+
// Extract the mask bits
2992+
uint64_t MaskVal = Mask->getZExtValue();
2993+
if (!isMask_64(MaskVal)) {
2994+
// We *could* handle shifted masks here, but doing so would require an
2995+
// 'and' operation to fix up the low-order bits so we would trade
2996+
// shr+and for bfe+and, which has the same throughput
2997+
return NULL;
2998+
}
2999+
3000+
// How many bits are in our mask?
3001+
uint64_t NumBits = CountTrailingOnes_64(MaskVal);
3002+
Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
3003+
3004+
if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) {
3005+
// We have a 'srl/and' pair, extract the effective start bit and length
3006+
Val = LHS.getNode()->getOperand(0);
3007+
Start = LHS.getNode()->getOperand(1);
3008+
ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
3009+
if (StartConst) {
3010+
uint64_t StartVal = StartConst->getZExtValue();
3011+
// How many "good" bits do we have left? "good" is defined here as bits
3012+
// that exist in the original value, not shifted in.
3013+
uint64_t GoodBits = Start.getValueType().getSizeInBits() - StartVal;
3014+
if (NumBits > GoodBits) {
3015+
// Do not handle the case where bits have been shifted in. In theory
3016+
// we could handle this, but the cost is likely higher than just
3017+
// emitting the srl/and pair.
3018+
return NULL;
3019+
}
3020+
Start = CurDAG->getTargetConstant(StartVal, MVT::i32);
3021+
} else {
3022+
// Do not handle the case where the shift amount (can be zero if no srl
3023+
// was found) is not constant. We could handle this case, but it would
3024+
// require run-time logic that would be more expensive than just
3025+
// emitting the srl/and pair.
3026+
return NULL;
3027+
}
3028+
} else {
3029+
// Do not handle the case where the LHS of the and is not a shift. While
3030+
// it would be trivial to handle this case, it would just transform
3031+
// 'and' -> 'bfe', but 'and' has higher-throughput.
3032+
return NULL;
3033+
}
3034+
} else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) {
3035+
if (LHS->getOpcode() == ISD::AND) {
3036+
ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
3037+
if (!ShiftCnst) {
3038+
// Shift amount must be constant
3039+
return NULL;
3040+
}
3041+
3042+
uint64_t ShiftAmt = ShiftCnst->getZExtValue();
3043+
3044+
SDValue AndLHS = LHS->getOperand(0);
3045+
SDValue AndRHS = LHS->getOperand(1);
3046+
3047+
// Canonicalize the AND to have the mask on the RHS
3048+
if (isa<ConstantSDNode>(AndLHS)) {
3049+
std::swap(AndLHS, AndRHS);
3050+
}
3051+
3052+
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
3053+
if (!MaskCnst) {
3054+
// Mask must be constant
3055+
return NULL;
3056+
}
3057+
3058+
uint64_t MaskVal = MaskCnst->getZExtValue();
3059+
uint64_t NumZeros;
3060+
uint64_t NumBits;
3061+
if (isMask_64(MaskVal)) {
3062+
NumZeros = 0;
3063+
// The number of bits in the result bitfield will be the number of
3064+
// trailing ones (the AND) minus the number of bits we shift off
3065+
NumBits = CountTrailingOnes_64(MaskVal) - ShiftAmt;
3066+
} else if (isShiftedMask_64(MaskVal)) {
3067+
NumZeros = countTrailingZeros(MaskVal);
3068+
unsigned NumOnes = CountTrailingOnes_64(MaskVal >> NumZeros);
3069+
// The number of bits in the result bitfield will be the number of
3070+
// trailing zeros plus the number of set bits in the mask minus the
3071+
// number of bits we shift off
3072+
NumBits = NumZeros + NumOnes - ShiftAmt;
3073+
} else {
3074+
// This is not a mask we can handle
3075+
return NULL;
3076+
}
3077+
3078+
if (ShiftAmt < NumZeros) {
3079+
// Handling this case would require extra logic that would make this
3080+
// transformation non-profitable
3081+
return NULL;
3082+
}
3083+
3084+
Val = AndLHS;
3085+
Start = CurDAG->getTargetConstant(ShiftAmt, MVT::i32);
3086+
Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
3087+
} else if (LHS->getOpcode() == ISD::SHL) {
3088+
// Here, we have a pattern like:
3089+
//
3090+
// (sra (shl val, NN), MM)
3091+
// or
3092+
// (srl (shl val, NN), MM)
3093+
//
3094+
// If MM >= NN, we can efficiently optimize this with bfe
3095+
Val = LHS->getOperand(0);
3096+
3097+
SDValue ShlRHS = LHS->getOperand(1);
3098+
ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
3099+
if (!ShlCnst) {
3100+
// Shift amount must be constant
3101+
return NULL;
3102+
}
3103+
uint64_t InnerShiftAmt = ShlCnst->getZExtValue();
3104+
3105+
SDValue ShrRHS = RHS;
3106+
ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
3107+
if (!ShrCnst) {
3108+
// Shift amount must be constant
3109+
return NULL;
3110+
}
3111+
uint64_t OuterShiftAmt = ShrCnst->getZExtValue();
3112+
3113+
// To avoid extra codegen and be profitable, we need Outer >= Inner
3114+
if (OuterShiftAmt < InnerShiftAmt) {
3115+
return NULL;
3116+
}
3117+
3118+
// If the outer shift is more than the type size, we have no bitfield to
3119+
// extract (since we also check that the inner shift is <= the outer shift
3120+
// then this also implies that the inner shift is < the type size)
3121+
if (OuterShiftAmt >= Val.getValueType().getSizeInBits()) {
3122+
return NULL;
3123+
}
3124+
3125+
Start =
3126+
CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, MVT::i32);
3127+
Len =
3128+
CurDAG->getTargetConstant(Val.getValueType().getSizeInBits() -
3129+
OuterShiftAmt, MVT::i32);
3130+
3131+
if (N->getOpcode() == ISD::SRA) {
3132+
// If we have a arithmetic right shift, we need to use the signed bfe
3133+
// variant
3134+
IsSigned = true;
3135+
}
3136+
} else {
3137+
// No can do...
3138+
return NULL;
3139+
}
3140+
} else {
3141+
// No can do...
3142+
return NULL;
3143+
}
3144+
3145+
3146+
unsigned Opc;
3147+
// For the BFE operations we form here from "and" and "srl", always use the
3148+
// unsigned variants.
3149+
if (Val.getValueType() == MVT::i32) {
3150+
if (IsSigned) {
3151+
Opc = NVPTX::BFE_S32rii;
3152+
} else {
3153+
Opc = NVPTX::BFE_U32rii;
3154+
}
3155+
} else if (Val.getValueType() == MVT::i64) {
3156+
if (IsSigned) {
3157+
Opc = NVPTX::BFE_S64rii;
3158+
} else {
3159+
Opc = NVPTX::BFE_U64rii;
3160+
}
3161+
} else {
3162+
// We cannot handle this type
3163+
return NULL;
3164+
}
3165+
3166+
SDValue Ops[] = {
3167+
Val, Start, Len
3168+
};
3169+
3170+
SDNode *Ret =
3171+
CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);
3172+
3173+
return Ret;
3174+
}
3175+
29623176
// SelectDirectAddr - Match a direct address for DAG.
29633177
// A direct address could be a globaladdress or externalsymbol.
29643178
bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7171
SDNode *SelectAddrSpaceCast(SDNode *N);
7272
SDNode *SelectTextureIntrinsic(SDNode *N);
7373
SDNode *SelectSurfaceIntrinsic(SDNode *N);
74+
SDNode *SelectBFE(SDNode *N);
7475

7576
inline SDValue getI32Imm(unsigned Imm) {
7677
return CurDAG->getTargetConstant(Imm, MVT::i32);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,29 @@ def ROTR64reg_sw : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src,
11791179
!strconcat("}}", ""))))))))),
11801180
[(set Int64Regs:$dst, (rotr Int64Regs:$src, Int32Regs:$amt))]>;
11811181

1182+
// BFE - bit-field extract
1183+
1184+
multiclass BFE<string TyStr, RegisterClass RC> {
1185+
// BFE supports both 32-bit and 64-bit values, but the start and length
1186+
// operands are always 32-bit
1187+
def rrr
1188+
: NVPTXInst<(outs RC:$d),
1189+
(ins RC:$a, Int32Regs:$b, Int32Regs:$c),
1190+
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1191+
def rri
1192+
: NVPTXInst<(outs RC:$d),
1193+
(ins RC:$a, Int32Regs:$b, i32imm:$c),
1194+
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1195+
def rii
1196+
: NVPTXInst<(outs RC:$d),
1197+
(ins RC:$a, i32imm:$b, i32imm:$c),
1198+
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1199+
}
1200+
1201+
defm BFE_S32 : BFE<"s32", Int32Regs>;
1202+
defm BFE_U32 : BFE<"u32", Int32Regs>;
1203+
defm BFE_S64 : BFE<"s64", Int64Regs>;
1204+
defm BFE_U64 : BFE<"u64", Int64Regs>;
11821205

11831206
//-----------------------------------
11841207
// General Comparison

llvm/test/CodeGen/NVPTX/bfe.ll

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
2+
3+
4+
; CHECK: bfe0
5+
define i32 @bfe0(i32 %a) {
6+
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 4, 4
7+
; CHECK-NOT: shr
8+
; CHECK-NOT: and
9+
%val0 = ashr i32 %a, 4
10+
%val1 = and i32 %val0, 15
11+
ret i32 %val1
12+
}
13+
14+
; CHECK: bfe1
15+
define i32 @bfe1(i32 %a) {
16+
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 3, 3
17+
; CHECK-NOT: shr
18+
; CHECK-NOT: and
19+
%val0 = ashr i32 %a, 3
20+
%val1 = and i32 %val0, 7
21+
ret i32 %val1
22+
}
23+
24+
; CHECK: bfe2
25+
define i32 @bfe2(i32 %a) {
26+
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 5, 3
27+
; CHECK-NOT: shr
28+
; CHECK-NOT: and
29+
%val0 = ashr i32 %a, 5
30+
%val1 = and i32 %val0, 7
31+
ret i32 %val1
32+
}

0 commit comments

Comments
 (0)