@@ -265,16 +265,50 @@ class OpLowerer {
265
265
266
266
// / Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
267
267
// / Since we expect to be post-scalarization, make an effort to avoid vectors.
268
- Error replaceResRetUses (CallInst *Intrin, CallInst *Op) {
268
+ Error replaceResRetUses (CallInst *Intrin, CallInst *Op, bool HasCheckBit ) {
269
269
IRBuilder<> &IRB = OpBuilder.getIRB ();
270
270
271
+ Instruction *OldResult = Intrin;
271
272
Type *OldTy = Intrin->getType ();
272
273
274
+ if (HasCheckBit) {
275
+ auto *ST = cast<StructType>(OldTy);
276
+
277
+ Value *CheckOp = nullptr ;
278
+ Type *Int32Ty = IRB.getInt32Ty ();
279
+ for (Use &U : make_early_inc_range (OldResult->uses ())) {
280
+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser ())) {
281
+ ArrayRef<unsigned > Indices = EVI->getIndices ();
282
+ assert (Indices.size () == 1 );
283
+ // We're only interested in uses of the check bit for now.
284
+ if (Indices[0 ] != 1 )
285
+ continue ;
286
+ if (!CheckOp) {
287
+ Value *NewEVI = IRB.CreateExtractValue (Op, 4 );
288
+ Expected<CallInst *> OpCall = OpBuilder.tryCreateOp (
289
+ OpCode::CheckAccessFullyMapped, {NewEVI}, Int32Ty);
290
+ if (Error E = OpCall.takeError ())
291
+ return E;
292
+ CheckOp = *OpCall;
293
+ }
294
+ EVI->replaceAllUsesWith (CheckOp);
295
+ EVI->eraseFromParent ();
296
+ }
297
+ }
298
+
299
+ OldResult = cast<Instruction>(IRB.CreateExtractValue (Op, 0 ));
300
+ OldTy = ST->getElementType (0 );
301
+ }
302
+
273
303
// For scalars, we just extract the first element.
274
304
if (!isa<FixedVectorType>(OldTy)) {
275
305
Value *EVI = IRB.CreateExtractValue (Op, 0 );
276
- Intrin->replaceAllUsesWith (EVI);
277
- Intrin->eraseFromParent ();
306
+ OldResult->replaceAllUsesWith (EVI);
307
+ OldResult->eraseFromParent ();
308
+ if (OldResult != Intrin) {
309
+ assert (Intrin->use_empty () && " Intrinsic still has uses?" );
310
+ Intrin->eraseFromParent ();
311
+ }
278
312
return Error::success ();
279
313
}
280
314
@@ -283,7 +317,7 @@ class OpLowerer {
283
317
284
318
// The users of the operation should all be scalarized, so we attempt to
285
319
// replace the extractelements with extractvalues directly.
286
- for (Use &U : make_early_inc_range (Intrin ->uses ())) {
320
+ for (Use &U : make_early_inc_range (OldResult ->uses ())) {
287
321
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser ())) {
288
322
if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand ())) {
289
323
size_t IndexVal = IndexOp->getZExtValue ();
@@ -331,22 +365,27 @@ class OpLowerer {
331
365
// If we still have uses, then we're not fully scalarized and need to
332
366
// recreate the vector. This should only happen for things like exported
333
367
// functions from libraries.
334
- if (!Intrin ->use_empty ()) {
368
+ if (!OldResult ->use_empty ()) {
335
369
for (int I = 0 , E = N; I != E; ++I)
336
370
if (!Extracts[I])
337
371
Extracts[I] = IRB.CreateExtractValue (Op, I);
338
372
339
373
Value *Vec = UndefValue::get (OldTy);
340
374
for (int I = 0 , E = N; I != E; ++I)
341
375
Vec = IRB.CreateInsertElement (Vec, Extracts[I], I);
342
- Intrin->replaceAllUsesWith (Vec);
376
+ OldResult->replaceAllUsesWith (Vec);
377
+ }
378
+
379
+ OldResult->eraseFromParent ();
380
+ if (OldResult != Intrin) {
381
+ assert (Intrin->use_empty () && " Intrinsic still has uses?" );
382
+ Intrin->eraseFromParent ();
343
383
}
344
384
345
- Intrin->eraseFromParent ();
346
385
return Error::success ();
347
386
}
348
387
349
- [[nodiscard]] bool lowerTypedBufferLoad (Function &F) {
388
+ [[nodiscard]] bool lowerTypedBufferLoad (Function &F, bool HasCheckBit ) {
350
389
IRBuilder<> &IRB = OpBuilder.getIRB ();
351
390
Type *Int32Ty = IRB.getInt32Ty ();
352
391
@@ -358,14 +397,17 @@ class OpLowerer {
358
397
Value *Index0 = CI->getArgOperand (1 );
359
398
Value *Index1 = UndefValue::get (Int32Ty);
360
399
361
- Type *NewRetTy = OpBuilder.getResRetType (CI->getType ()->getScalarType ());
400
+ Type *OldTy = CI->getType ();
401
+ if (HasCheckBit)
402
+ OldTy = cast<StructType>(OldTy)->getElementType (0 );
403
+ Type *NewRetTy = OpBuilder.getResRetType (OldTy->getScalarType ());
362
404
363
405
std::array<Value *, 3 > Args{Handle, Index0, Index1};
364
406
Expected<CallInst *> OpCall =
365
407
OpBuilder.tryCreateOp (OpCode::BufferLoad, Args, NewRetTy);
366
408
if (Error E = OpCall.takeError ())
367
409
return E;
368
- if (Error E = replaceResRetUses (CI, *OpCall))
410
+ if (Error E = replaceResRetUses (CI, *OpCall, HasCheckBit ))
369
411
return E;
370
412
371
413
return Error::success ();
@@ -434,7 +476,10 @@ class OpLowerer {
434
476
HasErrors |= lowerHandleFromBinding (F);
435
477
break ;
436
478
case Intrinsic::dx_typedBufferLoad:
437
- HasErrors |= lowerTypedBufferLoad (F);
479
+ HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ false );
480
+ break ;
481
+ case Intrinsic::dx_typedBufferLoad_checkbit:
482
+ HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ true );
438
483
break ;
439
484
case Intrinsic::dx_typedBufferStore:
440
485
HasErrors |= lowerTypedBufferStore (F);
0 commit comments