@@ -253,6 +253,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) {
253
253
case NVPTXISD::Suld3DV4I32Trap:
254
254
ResNode = SelectSurfaceIntrinsic (N);
255
255
break ;
256
+ case ISD::AND:
257
+ case ISD::SRA:
258
+ case ISD::SRL:
259
+ // Try to select BFE
260
+ ResNode = SelectBFE (N);
261
+ break ;
256
262
case ISD::ADDRSPACECAST:
257
263
ResNode = SelectAddrSpaceCast (N);
258
264
break ;
@@ -2959,6 +2965,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) {
2959
2965
return Ret;
2960
2966
}
2961
2967
2968
+ // / SelectBFE - Look for instruction sequences that can be made more efficient
2969
+ // / by using the 'bfe' (bit-field extract) PTX instruction
2970
+ SDNode *NVPTXDAGToDAGISel::SelectBFE (SDNode *N) {
2971
+ SDValue LHS = N->getOperand (0 );
2972
+ SDValue RHS = N->getOperand (1 );
2973
+ SDValue Len;
2974
+ SDValue Start;
2975
+ SDValue Val;
2976
+ bool IsSigned = false ;
2977
+
2978
+ if (N->getOpcode () == ISD::AND) {
2979
+ // Canonicalize the operands
2980
+ // We want 'and %val, %mask'
2981
+ if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
2982
+ std::swap (LHS, RHS);
2983
+ }
2984
+
2985
+ ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
2986
+ if (!Mask) {
2987
+ // We need a constant mask on the RHS of the AND
2988
+ return NULL ;
2989
+ }
2990
+
2991
+ // Extract the mask bits
2992
+ uint64_t MaskVal = Mask->getZExtValue ();
2993
+ if (!isMask_64 (MaskVal)) {
2994
+ // We *could* handle shifted masks here, but doing so would require an
2995
+ // 'and' operation to fix up the low-order bits so we would trade
2996
+ // shr+and for bfe+and, which has the same throughput
2997
+ return NULL ;
2998
+ }
2999
+
3000
+ // How many bits are in our mask?
3001
+ uint64_t NumBits = CountTrailingOnes_64 (MaskVal);
3002
+ Len = CurDAG->getTargetConstant (NumBits, MVT::i32 );
3003
+
3004
+ if (LHS.getOpcode () == ISD::SRL || LHS.getOpcode () == ISD::SRA) {
3005
+ // We have a 'srl/and' pair, extract the effective start bit and length
3006
+ Val = LHS.getNode ()->getOperand (0 );
3007
+ Start = LHS.getNode ()->getOperand (1 );
3008
+ ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
3009
+ if (StartConst) {
3010
+ uint64_t StartVal = StartConst->getZExtValue ();
3011
+ // How many "good" bits do we have left? "good" is defined here as bits
3012
+ // that exist in the original value, not shifted in.
3013
+ uint64_t GoodBits = Start.getValueType ().getSizeInBits () - StartVal;
3014
+ if (NumBits > GoodBits) {
3015
+ // Do not handle the case where bits have been shifted in. In theory
3016
+ // we could handle this, but the cost is likely higher than just
3017
+ // emitting the srl/and pair.
3018
+ return NULL ;
3019
+ }
3020
+ Start = CurDAG->getTargetConstant (StartVal, MVT::i32 );
3021
+ } else {
3022
+ // Do not handle the case where the shift amount (can be zero if no srl
3023
+ // was found) is not constant. We could handle this case, but it would
3024
+ // require run-time logic that would be more expensive than just
3025
+ // emitting the srl/and pair.
3026
+ return NULL ;
3027
+ }
3028
+ } else {
3029
+ // Do not handle the case where the LHS of the and is not a shift. While
3030
+ // it would be trivial to handle this case, it would just transform
3031
+ // 'and' -> 'bfe', but 'and' has higher-throughput.
3032
+ return NULL ;
3033
+ }
3034
+ } else if (N->getOpcode () == ISD::SRL || N->getOpcode () == ISD::SRA) {
3035
+ if (LHS->getOpcode () == ISD::AND) {
3036
+ ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
3037
+ if (!ShiftCnst) {
3038
+ // Shift amount must be constant
3039
+ return NULL ;
3040
+ }
3041
+
3042
+ uint64_t ShiftAmt = ShiftCnst->getZExtValue ();
3043
+
3044
+ SDValue AndLHS = LHS->getOperand (0 );
3045
+ SDValue AndRHS = LHS->getOperand (1 );
3046
+
3047
+ // Canonicalize the AND to have the mask on the RHS
3048
+ if (isa<ConstantSDNode>(AndLHS)) {
3049
+ std::swap (AndLHS, AndRHS);
3050
+ }
3051
+
3052
+ ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
3053
+ if (!MaskCnst) {
3054
+ // Mask must be constant
3055
+ return NULL ;
3056
+ }
3057
+
3058
+ uint64_t MaskVal = MaskCnst->getZExtValue ();
3059
+ uint64_t NumZeros;
3060
+ uint64_t NumBits;
3061
+ if (isMask_64 (MaskVal)) {
3062
+ NumZeros = 0 ;
3063
+ // The number of bits in the result bitfield will be the number of
3064
+ // trailing ones (the AND) minus the number of bits we shift off
3065
+ NumBits = CountTrailingOnes_64 (MaskVal) - ShiftAmt;
3066
+ } else if (isShiftedMask_64 (MaskVal)) {
3067
+ NumZeros = countTrailingZeros (MaskVal);
3068
+ unsigned NumOnes = CountTrailingOnes_64 (MaskVal >> NumZeros);
3069
+ // The number of bits in the result bitfield will be the number of
3070
+ // trailing zeros plus the number of set bits in the mask minus the
3071
+ // number of bits we shift off
3072
+ NumBits = NumZeros + NumOnes - ShiftAmt;
3073
+ } else {
3074
+ // This is not a mask we can handle
3075
+ return NULL ;
3076
+ }
3077
+
3078
+ if (ShiftAmt < NumZeros) {
3079
+ // Handling this case would require extra logic that would make this
3080
+ // transformation non-profitable
3081
+ return NULL ;
3082
+ }
3083
+
3084
+ Val = AndLHS;
3085
+ Start = CurDAG->getTargetConstant (ShiftAmt, MVT::i32 );
3086
+ Len = CurDAG->getTargetConstant (NumBits, MVT::i32 );
3087
+ } else if (LHS->getOpcode () == ISD::SHL) {
3088
+ // Here, we have a pattern like:
3089
+ //
3090
+ // (sra (shl val, NN), MM)
3091
+ // or
3092
+ // (srl (shl val, NN), MM)
3093
+ //
3094
+ // If MM >= NN, we can efficiently optimize this with bfe
3095
+ Val = LHS->getOperand (0 );
3096
+
3097
+ SDValue ShlRHS = LHS->getOperand (1 );
3098
+ ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
3099
+ if (!ShlCnst) {
3100
+ // Shift amount must be constant
3101
+ return NULL ;
3102
+ }
3103
+ uint64_t InnerShiftAmt = ShlCnst->getZExtValue ();
3104
+
3105
+ SDValue ShrRHS = RHS;
3106
+ ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
3107
+ if (!ShrCnst) {
3108
+ // Shift amount must be constant
3109
+ return NULL ;
3110
+ }
3111
+ uint64_t OuterShiftAmt = ShrCnst->getZExtValue ();
3112
+
3113
+ // To avoid extra codegen and be profitable, we need Outer >= Inner
3114
+ if (OuterShiftAmt < InnerShiftAmt) {
3115
+ return NULL ;
3116
+ }
3117
+
3118
+ // If the outer shift is more than the type size, we have no bitfield to
3119
+ // extract (since we also check that the inner shift is <= the outer shift
3120
+ // then this also implies that the inner shift is < the type size)
3121
+ if (OuterShiftAmt >= Val.getValueType ().getSizeInBits ()) {
3122
+ return NULL ;
3123
+ }
3124
+
3125
+ Start =
3126
+ CurDAG->getTargetConstant (OuterShiftAmt - InnerShiftAmt, MVT::i32 );
3127
+ Len =
3128
+ CurDAG->getTargetConstant (Val.getValueType ().getSizeInBits () -
3129
+ OuterShiftAmt, MVT::i32 );
3130
+
3131
+ if (N->getOpcode () == ISD::SRA) {
3132
+ // If we have a arithmetic right shift, we need to use the signed bfe
3133
+ // variant
3134
+ IsSigned = true ;
3135
+ }
3136
+ } else {
3137
+ // No can do...
3138
+ return NULL ;
3139
+ }
3140
+ } else {
3141
+ // No can do...
3142
+ return NULL ;
3143
+ }
3144
+
3145
+
3146
+ unsigned Opc;
3147
+ // For the BFE operations we form here from "and" and "srl", always use the
3148
+ // unsigned variants.
3149
+ if (Val.getValueType () == MVT::i32 ) {
3150
+ if (IsSigned) {
3151
+ Opc = NVPTX::BFE_S32rii;
3152
+ } else {
3153
+ Opc = NVPTX::BFE_U32rii;
3154
+ }
3155
+ } else if (Val.getValueType () == MVT::i64 ) {
3156
+ if (IsSigned) {
3157
+ Opc = NVPTX::BFE_S64rii;
3158
+ } else {
3159
+ Opc = NVPTX::BFE_U64rii;
3160
+ }
3161
+ } else {
3162
+ // We cannot handle this type
3163
+ return NULL ;
3164
+ }
3165
+
3166
+ SDValue Ops[] = {
3167
+ Val, Start, Len
3168
+ };
3169
+
3170
+ SDNode *Ret =
3171
+ CurDAG->getMachineNode (Opc, SDLoc (N), N->getVTList (), Ops);
3172
+
3173
+ return Ret;
3174
+ }
3175
+
2962
3176
// SelectDirectAddr - Match a direct address for DAG.
2963
3177
// A direct address could be a globaladdress or externalsymbol.
2964
3178
bool NVPTXDAGToDAGISel::SelectDirectAddr (SDValue N, SDValue &Address) {
0 commit comments