@@ -2182,6 +2182,100 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
2182
2182
return true ;
2183
2183
}
2184
2184
2185
+ // Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2186
+ #define getOpcV2H (ty, opKind0, opKind1 ) \
2187
+ NVPTX::StoreParamV2##ty##_##opKind0##opKind1
2188
+
2189
+ #define getOpcV2H1 (ty, opKind0, isImm1 ) \
2190
+ (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
2191
+
2192
+ #define getOpcodeForVectorStParamV2 (ty, isimm ) \
2193
+ (isimm[0 ]) ? getOpcV2H1(ty, i, isimm[1 ]) : getOpcV2H1(ty, r, isimm[1 ])
2194
+
2195
+ #define getOpcV4H (ty, opKind0, opKind1, opKind2, opKind3 ) \
2196
+ NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
2197
+
2198
+ #define getOpcV4H3 (ty, opKind0, opKind1, opKind2, isImm3 ) \
2199
+ (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
2200
+ : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
2201
+
2202
+ #define getOpcV4H2 (ty, opKind0, opKind1, isImm2, isImm3 ) \
2203
+ (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
2204
+ : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
2205
+
2206
+ #define getOpcV4H1 (ty, opKind0, isImm1, isImm2, isImm3 ) \
2207
+ (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
2208
+ : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
2209
+
2210
+ #define getOpcodeForVectorStParamV4 (ty, isimm ) \
2211
+ (isimm[0 ]) ? getOpcV4H1(ty, i, isimm[1 ], isimm[2 ], isimm[3 ]) \
2212
+ : getOpcV4H1(ty, r, isimm[1 ], isimm[2 ], isimm[3 ])
2213
+
2214
+ #define getOpcodeForVectorStParam (n, ty, isimm ) \
2215
+ (n == 2 ) ? getOpcodeForVectorStParamV2(ty, isimm) \
2216
+ : getOpcodeForVectorStParamV4(ty, isimm)
2217
+
2218
+ static unsigned pickOpcodeForVectorStParam (SmallVector<SDValue, 8 > &Ops,
2219
+ unsigned NumElts,
2220
+ MVT::SimpleValueType MemTy,
2221
+ SelectionDAG *CurDAG, SDLoc DL) {
2222
+ // Determine which inputs are registers and immediates make new operators
2223
+ // with constant values
2224
+ SmallVector<bool , 4 > IsImm (NumElts, false );
2225
+ for (unsigned i = 0 ; i < NumElts; i++) {
2226
+ IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
2227
+ if (IsImm[i]) {
2228
+ SDValue Imm = Ops[i];
2229
+ if (MemTy == MVT::f32 || MemTy == MVT::f64 ) {
2230
+ const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2231
+ const ConstantFP *CF = ConstImm->getConstantFPValue ();
2232
+ Imm = CurDAG->getTargetConstantFP (*CF, DL, Imm->getValueType (0 ));
2233
+ } else {
2234
+ const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2235
+ const ConstantInt *CI = ConstImm->getConstantIntValue ();
2236
+ Imm = CurDAG->getTargetConstant (*CI, DL, Imm->getValueType (0 ));
2237
+ }
2238
+ Ops[i] = Imm;
2239
+ }
2240
+ }
2241
+
2242
+ // Get opcode for MemTy, size, and register/immediate operand ordering
2243
+ switch (MemTy) {
2244
+ case MVT::i8 :
2245
+ return getOpcodeForVectorStParam (NumElts, I8, IsImm);
2246
+ case MVT::i16 :
2247
+ return getOpcodeForVectorStParam (NumElts, I16, IsImm);
2248
+ case MVT::i32 :
2249
+ return getOpcodeForVectorStParam (NumElts, I32, IsImm);
2250
+ case MVT::i64 :
2251
+ assert (NumElts == 2 && " MVT too large for NumElts > 2" );
2252
+ return getOpcodeForVectorStParamV2 (I64, IsImm);
2253
+ case MVT::f32 :
2254
+ return getOpcodeForVectorStParam (NumElts, F32, IsImm);
2255
+ case MVT::f64 :
2256
+ assert (NumElts == 2 && " MVT too large for NumElts > 2" );
2257
+ return getOpcodeForVectorStParamV2 (F64, IsImm);
2258
+
2259
+ // These cases don't support immediates, just use the all register version
2260
+ // and generate moves.
2261
+ case MVT::i1:
2262
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I8_rr
2263
+ : NVPTX::StoreParamV4I8_rrrr;
2264
+ case MVT::f16 :
2265
+ case MVT::bf16 :
2266
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I16_rr
2267
+ : NVPTX::StoreParamV4I16_rrrr;
2268
+ case MVT::v2f16:
2269
+ case MVT::v2bf16:
2270
+ case MVT::v2i16:
2271
+ case MVT::v4i8:
2272
+ return (NumElts == 2 ) ? NVPTX::StoreParamV2I32_rr
2273
+ : NVPTX::StoreParamV4I32_rrrr;
2274
+ default :
2275
+ llvm_unreachable (" Cannot select st.param for unknown MemTy" );
2276
+ }
2277
+ }
2278
+
2185
2279
bool NVPTXDAGToDAGISel::tryStoreParam (SDNode *N) {
2186
2280
SDLoc DL (N);
2187
2281
SDValue Chain = N->getOperand (0 );
@@ -2193,10 +2287,10 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2193
2287
SDValue Glue = N->getOperand (N->getNumOperands () - 1 );
2194
2288
2195
2289
// How many elements do we have?
2196
- unsigned NumElts = 1 ;
2290
+ unsigned NumElts;
2197
2291
switch (N->getOpcode ()) {
2198
2292
default :
2199
- return false ;
2293
+ llvm_unreachable ( " Unexpected opcode " ) ;
2200
2294
case NVPTXISD::StoreParamU32:
2201
2295
case NVPTXISD::StoreParamS32:
2202
2296
case NVPTXISD::StoreParam:
@@ -2222,54 +2316,69 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2222
2316
// Determine target opcode
2223
2317
// If we have an i1, use an 8-bit store. The lowering code in
2224
2318
// NVPTXISelLowering will have already emitted an upcast.
2225
- std::optional<unsigned > Opcode = 0 ;
2319
+ std::optional<unsigned > Opcode;
2226
2320
switch (N->getOpcode ()) {
2227
2321
default :
2228
2322
switch (NumElts) {
2229
2323
default :
2230
- return false ;
2231
- case 1 :
2232
- Opcode = pickOpcodeForVT (Mem->getMemoryVT ().getSimpleVT ().SimpleTy ,
2233
- NVPTX::StoreParamI8, NVPTX::StoreParamI16,
2234
- NVPTX::StoreParamI32, NVPTX::StoreParamI64,
2235
- NVPTX::StoreParamF32, NVPTX::StoreParamF64);
2236
- if (Opcode == NVPTX::StoreParamI8) {
2324
+ llvm_unreachable (" Unexpected NumElts" );
2325
+ case 1 : {
2326
+ MVT::SimpleValueType MemTy = Mem->getMemoryVT ().getSimpleVT ().SimpleTy ;
2327
+ SDValue Imm = Ops[0 ];
2328
+ if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
2329
+ (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
2330
+ // Convert immediate to target constant
2331
+ if (MemTy == MVT::f32 || MemTy == MVT::f64 ) {
2332
+ const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2333
+ const ConstantFP *CF = ConstImm->getConstantFPValue ();
2334
+ Imm = CurDAG->getTargetConstantFP (*CF, DL, Imm->getValueType (0 ));
2335
+ } else {
2336
+ const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2337
+ const ConstantInt *CI = ConstImm->getConstantIntValue ();
2338
+ Imm = CurDAG->getTargetConstant (*CI, DL, Imm->getValueType (0 ));
2339
+ }
2340
+ Ops[0 ] = Imm;
2341
+ // Use immediate version of store param
2342
+ Opcode = pickOpcodeForVT (MemTy, NVPTX::StoreParamI8_i,
2343
+ NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
2344
+ NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
2345
+ NVPTX::StoreParamF64_i);
2346
+ } else
2347
+ Opcode =
2348
+ pickOpcodeForVT (Mem->getMemoryVT ().getSimpleVT ().SimpleTy ,
2349
+ NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
2350
+ NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
2351
+ NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
2352
+ if (Opcode == NVPTX::StoreParamI8_r) {
2237
2353
// Fine tune the opcode depending on the size of the operand.
2238
2354
// This helps to avoid creating redundant COPY instructions in
2239
2355
// InstrEmitter::AddRegisterOperand().
2240
2356
switch (Ops[0 ].getSimpleValueType ().SimpleTy ) {
2241
2357
default :
2242
2358
break ;
2243
2359
case MVT::i32 :
2244
- Opcode = NVPTX::StoreParamI8TruncI32 ;
2360
+ Opcode = NVPTX::StoreParamI8TruncI32_r ;
2245
2361
break ;
2246
2362
case MVT::i64 :
2247
- Opcode = NVPTX::StoreParamI8TruncI64 ;
2363
+ Opcode = NVPTX::StoreParamI8TruncI64_r ;
2248
2364
break ;
2249
2365
}
2250
2366
}
2251
2367
break ;
2368
+ }
2252
2369
case 2 :
2253
- Opcode = pickOpcodeForVT (Mem->getMemoryVT ().getSimpleVT ().SimpleTy ,
2254
- NVPTX::StoreParamV2I8, NVPTX::StoreParamV2I16,
2255
- NVPTX::StoreParamV2I32, NVPTX::StoreParamV2I64,
2256
- NVPTX::StoreParamV2F32, NVPTX::StoreParamV2F64);
2257
- break ;
2258
- case 4 :
2259
- Opcode = pickOpcodeForVT (Mem->getMemoryVT ().getSimpleVT ().SimpleTy ,
2260
- NVPTX::StoreParamV4I8, NVPTX::StoreParamV4I16,
2261
- NVPTX::StoreParamV4I32, std::nullopt,
2262
- NVPTX::StoreParamV4F32, std::nullopt);
2370
+ case 4 : {
2371
+ MVT::SimpleValueType MemTy = Mem->getMemoryVT ().getSimpleVT ().SimpleTy ;
2372
+ Opcode = pickOpcodeForVectorStParam (Ops, NumElts, MemTy, CurDAG, DL);
2263
2373
break ;
2264
2374
}
2265
- if (!Opcode)
2266
- return false ;
2375
+ }
2267
2376
break ;
2268
2377
// Special case: if we have a sign-extend/zero-extend node, insert the
2269
2378
// conversion instruction first, and use that as the value operand to
2270
2379
// the selected StoreParam node.
2271
2380
case NVPTXISD::StoreParamU32: {
2272
- Opcode = NVPTX::StoreParamI32 ;
2381
+ Opcode = NVPTX::StoreParamI32_r ;
2273
2382
SDValue CvtNone = CurDAG->getTargetConstant (NVPTX::PTXCvtMode::NONE, DL,
2274
2383
MVT::i32 );
2275
2384
SDNode *Cvt = CurDAG->getMachineNode (NVPTX::CVT_u32_u16, DL,
@@ -2278,7 +2387,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2278
2387
break ;
2279
2388
}
2280
2389
case NVPTXISD::StoreParamS32: {
2281
- Opcode = NVPTX::StoreParamI32 ;
2390
+ Opcode = NVPTX::StoreParamI32_r ;
2282
2391
SDValue CvtNone = CurDAG->getTargetConstant (NVPTX::PTXCvtMode::NONE, DL,
2283
2392
MVT::i32 );
2284
2393
SDNode *Cvt = CurDAG->getMachineNode (NVPTX::CVT_s32_s16, DL,
0 commit comments