@@ -2183,25 +2183,29 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
2183
2183
}
2184
2184
2185
2185
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2186
- #define getOpcV2H (ty, op0, op1 ) NVPTX::StoreParamV2##ty##_##op0##op1
2186
+ #define getOpcV2H (ty, opKind0, opKind1 ) \
2187
+ NVPTX::StoreParamV2##ty##_##opKind0##opKind1
2187
2188
2188
- #define getOpcV2H1 (ty, op0, op1 ) \
2189
- (op1 ) ? getOpcV2H(ty, op0 , i) : getOpcV2H(ty, op0 , r)
2189
+ #define getOpcV2H1 (ty, opKind0, isImm1 ) \
2190
+ (isImm1 ) ? getOpcV2H(ty, opKind0 , i) : getOpcV2H(ty, opKind0 , r)
2190
2191
2191
2192
#define getOpcodeForVectorStParamV2 (ty, isimm ) \
2192
2193
(isimm[0 ]) ? getOpcV2H1(ty, i, isimm[1 ]) : getOpcV2H1(ty, r, isimm[1 ])
2193
2194
2194
- #define getOpcV4H (ty, op0, op1, op2, op3 ) \
2195
- NVPTX::StoreParamV4##ty##_##op0##op1##op2##op3
2195
+ #define getOpcV4H (ty, opKind0, opKind1, opKind2, opKind3 ) \
2196
+ NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
2196
2197
2197
- #define getOpcV4H3 (ty, op0, op1, op2, op3 ) \
2198
- (op3) ? getOpcV4H(ty, op0, op1, op2, i) : getOpcV4H(ty, op0, op1, op2, r)
2198
+ #define getOpcV4H3 (ty, opKind0, opKind1, opKind2, isImm3 ) \
2199
+ (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
2200
+ : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
2199
2201
2200
- #define getOpcV4H2 (ty, op0, op1, op2, op3 ) \
2201
- (op2) ? getOpcV4H3(ty, op0, op1, i, op3) : getOpcV4H3(ty, op0, op1, r, op3)
2202
+ #define getOpcV4H2 (ty, opKind0, opKind1, isImm2, isImm3 ) \
2203
+ (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
2204
+ : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
2202
2205
2203
- #define getOpcV4H1 (ty, op0, op1, op2, op3 ) \
2204
- (op1) ? getOpcV4H2(ty, op0, i, op2, op3) : getOpcV4H2(ty, op0, r, op2, op3)
2206
+ #define getOpcV4H1 (ty, opKind0, isImm1, isImm2, isImm3 ) \
2207
+ (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
2208
+ : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
2205
2209
2206
2210
#define getOpcodeForVectorStParamV4 (ty, isimm ) \
2207
2211
(isimm[0 ]) ? getOpcV4H1(ty, i, isimm[1 ], isimm[2 ], isimm[3 ]) \
@@ -2211,10 +2215,10 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
2211
2215
(n == 2 ) ? getOpcodeForVectorStParamV2(ty, isimm) \
2212
2216
: getOpcodeForVectorStParamV4(ty, isimm)
2213
2217
2214
- static std::optional< unsigned >
2215
- pickOpcodeForVectorStParam (SmallVector<SDValue, 8 > &Ops, unsigned NumElts,
2216
- MVT::SimpleValueType MemTy, SelectionDAG *CurDAG ,
2217
- SDLoc DL) {
2218
+ static unsigned pickOpcodeForVectorStParam (SmallVector<SDValue, 8 > &Ops,
2219
+ unsigned NumElts,
2220
+ MVT::SimpleValueType MemTy,
2221
+ SelectionDAG *CurDAG, SDLoc DL) {
2218
2222
// Determine which inputs are registers and immediates make new operators
2219
2223
// with constant values
2220
2224
SmallVector<bool , 4 > IsImm (NumElts, false );
@@ -2244,19 +2248,31 @@ pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops, unsigned NumElts,
2244
2248
case MVT::i32 :
2245
2249
return getOpcodeForVectorStParam (NumElts, I32, IsImm);
2246
2250
case MVT::i64 :
2247
- if (NumElts == 4 )
2248
- return std::nullopt;
2251
+ assert (NumElts == 2 && " MVT too large for NumElts > 2" );
2249
2252
return getOpcodeForVectorStParamV2 (I64, IsImm);
2250
2253
case MVT::f32 :
2251
2254
return getOpcodeForVectorStParam (NumElts, F32, IsImm);
2252
2255
case MVT::f64 :
2253
- if (NumElts == 4 )
2254
- return std::nullopt;
2256
+ assert (NumElts == 2 && " MVT too large for NumElts > 2" );
2255
2257
return getOpcodeForVectorStParamV2 (F64, IsImm);
2258
+
2259
+ // These cases don't support immediates, just use the all register version
2260
+ // and generate moves.
2261
+ case MVT::i1:
2262
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I8_rr
2263
+ : NVPTX::StoreParamV4I8_rrrr;
2256
2264
case MVT::f16 :
2265
+ case MVT::bf16 :
2266
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I16_rr
2267
+ : NVPTX::StoreParamV4I16_rrrr;
2257
2268
case MVT::v2f16:
2269
+ case MVT::v2bf16:
2270
+ case MVT::v2i16:
2271
+ case MVT::v4i8:
2272
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I32_rr
2273
+ : NVPTX::StoreParamV4I32_rrrr;
2258
2274
default :
2259
- return std::nullopt ;
2275
+ llvm_unreachable ( " Cannot select st.param for unknown MemTy " ) ;
2260
2276
}
2261
2277
}
2262
2278
@@ -2271,10 +2287,10 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2271
2287
SDValue Glue = N->getOperand (N->getNumOperands () - 1 );
2272
2288
2273
2289
// How many elements do we have?
2274
- unsigned NumElts = 1 ;
2290
+ unsigned NumElts;
2275
2291
switch (N->getOpcode ()) {
2276
2292
default :
2277
- return false ;
2293
+ llvm_unreachable ( " Unexpected opcode " ) ;
2278
2294
case NVPTXISD::StoreParamU32:
2279
2295
case NVPTXISD::StoreParamS32:
2280
2296
case NVPTXISD::StoreParam:
@@ -2300,12 +2316,12 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2300
2316
// Determine target opcode
2301
2317
// If we have an i1, use an 8-bit store. The lowering code in
2302
2318
// NVPTXISelLowering will have already emitted an upcast.
2303
- std::optional<unsigned > Opcode = 0 ;
2319
+ std::optional<unsigned > Opcode;
2304
2320
switch (N->getOpcode ()) {
2305
2321
default :
2306
2322
switch (NumElts) {
2307
2323
default :
2308
- return false ;
2324
+ llvm_unreachable ( " Unexpected NumElts " ) ;
2309
2325
case 1 : {
2310
2326
MVT::SimpleValueType MemTy = Mem->getMemoryVT ().getSimpleVT ().SimpleTy ;
2311
2327
SDValue Imm = Ops[0 ];
@@ -2357,8 +2373,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2357
2373
break ;
2358
2374
}
2359
2375
}
2360
- if (!Opcode)
2361
- return false ;
2362
2376
break ;
2363
2377
// Special case: if we have a sign-extend/zero-extend node, insert the
2364
2378
// conversion instruction first, and use that as the value operand to
0 commit comments