25
25
#include " llvm/IR/PassManager.h"
26
26
#include " llvm/IR/Type.h"
27
27
#include " llvm/Pass.h"
28
+ #include " llvm/Support/Casting.h"
28
29
#include " llvm/Support/ErrorHandling.h"
29
30
#include " llvm/Support/MathExtras.h"
30
31
@@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
70
71
case Intrinsic::vector_reduce_add:
71
72
case Intrinsic::vector_reduce_fadd:
72
73
return true ;
73
- case Intrinsic::dx_resource_load_typedbuffer:
74
- // We need to handle doubles and vector of doubles.
75
- return F.getReturnType ()
76
- ->getStructElementType (0 )
77
- ->getScalarType ()
78
- ->isDoubleTy ();
79
- case Intrinsic::dx_resource_store_typedbuffer:
80
- // We need to handle doubles and vector of doubles.
81
- return F.getFunctionType ()->getParamType (2 )->getScalarType ()->isDoubleTy ();
74
+ case Intrinsic::dx_resource_load_typedbuffer: {
75
+ // We need to handle i64, doubles, and vectors of them.
76
+ Type *ScalarTy =
77
+ F.getReturnType ()->getStructElementType (0 )->getScalarType ();
78
+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
79
+ }
80
+ case Intrinsic::dx_resource_store_typedbuffer: {
81
+ // We need to handle i64 and doubles and vectors of i64 and doubles.
82
+ Type *ScalarTy = F.getFunctionType ()->getParamType (2 )->getScalarType ();
83
+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
84
+ }
82
85
}
83
86
return false ;
84
87
}
@@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
545
548
IRBuilder<> Builder (Orig);
546
549
547
550
Type *BufferTy = Orig->getType ()->getStructElementType (0 );
548
- assert (BufferTy->getScalarType ()->isDoubleTy () &&
549
- " Only expand double or double2" );
551
+ Type *ScalarTy = BufferTy->getScalarType ();
552
+ bool IsDouble = ScalarTy->isDoubleTy ();
553
+ assert (IsDouble || ScalarTy->isIntegerTy (64 ) &&
554
+ " Only expand double or int64 scalars or vectors" );
550
555
551
556
unsigned ExtractNum = 2 ;
552
557
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
553
558
assert (VT->getNumElements () == 2 &&
554
- " TypedBufferLoad double vector has wrong size" );
559
+ " TypedBufferLoad vector must be size 2 " );
555
560
ExtractNum = 4 ;
556
561
}
557
562
@@ -570,22 +575,42 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
570
575
ExtractElements.push_back (
571
576
Builder.CreateExtractElement (Extract, Builder.getInt32 (I)));
572
577
573
- // combine into double(s)
578
+ // combine into double(s) or int64(s)
574
579
Value *Result = PoisonValue::get (BufferTy);
575
580
for (unsigned I = 0 ; I < ExtractNum; I += 2 ) {
576
- Value *Dbl =
577
- Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
578
- {ExtractElements[I], ExtractElements[I + 1 ]});
581
+ Value *Combined = nullptr ;
582
+ if (IsDouble)
583
+ // For doubles, use dx_asdouble intrinsic
584
+ Combined =
585
+ Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
586
+ {ExtractElements[I], ExtractElements[I + 1 ]});
587
+ else {
588
+ // For int64, manually combine two int32s
589
+ // First, zero-extend both values to i64
590
+ Value *Lo = Builder.CreateZExt (ExtractElements[I], Builder.getInt64Ty ());
591
+ Value *Hi =
592
+ Builder.CreateZExt (ExtractElements[I + 1 ], Builder.getInt64Ty ());
593
+ // Shift the high bits left by 32 bits
594
+ Value *ShiftedHi = Builder.CreateShl (Hi, Builder.getInt64 (32 ));
595
+ // OR the high and low bits together
596
+ Combined = Builder.CreateOr (Lo, ShiftedHi);
597
+ }
598
+
579
599
if (ExtractNum == 4 )
580
- Result =
581
- Builder. CreateInsertElement (Result, Dbl, Builder.getInt32 (I / 2 ));
600
+ Result = Builder. CreateInsertElement (Result, Combined,
601
+ Builder.getInt32 (I / 2 ));
582
602
else
583
- Result = Dbl ;
603
+ Result = Combined ;
584
604
}
585
605
586
606
Value *CheckBit = nullptr ;
587
607
for (User *U : make_early_inc_range (Orig->users ())) {
588
- auto *EVI = cast<ExtractValueInst>(U);
608
+ // If it's not a ExtractValueInst, we don't know how to
609
+ // handle it
610
+ auto *EVI = dyn_cast<ExtractValueInst>(U);
611
+ if (!EVI)
612
+ llvm_unreachable (" Unexpected user of typedbufferload" );
613
+
589
614
ArrayRef<unsigned > Indices = EVI->getIndices ();
590
615
assert (Indices.size () == 1 );
591
616
@@ -609,38 +634,61 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
609
634
IRBuilder<> Builder (Orig);
610
635
611
636
Type *BufferTy = Orig->getFunctionType ()->getParamType (2 );
612
- assert (BufferTy->getScalarType ()->isDoubleTy () &&
613
- " Only expand double or double2" );
614
-
615
- unsigned ExtractNum = 2 ;
616
- if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
617
- assert (VT->getNumElements () == 2 &&
618
- " TypedBufferStore double vector has wrong size" );
619
- ExtractNum = 4 ;
637
+ Type *ScalarTy = BufferTy->getScalarType ();
638
+ bool IsDouble = ScalarTy->isDoubleTy ();
639
+ assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
640
+ " Only expand double or int64 scalars or vectors" );
641
+
642
+ // Determine if we're dealing with a vector or scalar
643
+ bool IsVector = isa<FixedVectorType>(BufferTy);
644
+ if (IsVector) {
645
+ assert (cast<FixedVectorType>(BufferTy)->getNumElements () == 2 &&
646
+ " TypedBufferStore vector must be size 2" );
620
647
}
621
648
622
- Type *SplitElementTy = Builder.getInt32Ty ();
623
- if (ExtractNum == 4 )
649
+ // Create the appropriate vector type for the result
650
+ Type *Int32Ty = Builder.getInt32Ty ();
651
+ Type *ResultTy = VectorType::get (Int32Ty, IsVector ? 4 : 2 , false );
652
+ Value *Val = PoisonValue::get (ResultTy);
653
+
654
+ Type *SplitElementTy = Int32Ty;
655
+ if (IsVector)
624
656
SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
625
657
626
- // split our double(s)
627
- auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
628
- Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
629
- Orig->getOperand (2 ));
630
- // create our vector
631
- Value *LowBits = Builder.CreateExtractValue (Split, 0 );
632
- Value *HighBits = Builder.CreateExtractValue (Split, 1 );
633
- Value *Val;
634
- if (ExtractNum == 2 ) {
635
- Val = PoisonValue::get (VectorType::get (SplitElementTy, 2 , false ));
658
+ Value *LowBits = nullptr ;
659
+ Value *HighBits = nullptr ;
660
+ // Split the 64-bit values into 32-bit components
661
+ if (IsDouble) {
662
+ auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
663
+ Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
664
+ {Orig->getOperand (2 )});
665
+ LowBits = Builder.CreateExtractValue (Split, 0 );
666
+ HighBits = Builder.CreateExtractValue (Split, 1 );
667
+ } else {
668
+ // Handle int64 type(s)
669
+ Value *InputVal = Orig->getOperand (2 );
670
+ Constant *ShiftAmt = Builder.getInt64 (32 );
671
+ if (IsVector)
672
+ ShiftAmt = ConstantVector::getSplat (ElementCount::getFixed (2 ), ShiftAmt);
673
+
674
+ // Split into low and high 32-bit parts
675
+ LowBits = Builder.CreateTrunc (InputVal, SplitElementTy);
676
+ Value *ShiftedVal = Builder.CreateLShr (InputVal, ShiftAmt);
677
+ HighBits = Builder.CreateTrunc (ShiftedVal, SplitElementTy);
678
+ }
679
+
680
+ if (IsVector) {
681
+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
682
+ } else {
636
683
Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
637
684
Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
638
- } else
639
- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
685
+ }
640
686
687
+ // Create the final intrinsic call
641
688
Builder.CreateIntrinsic (Builder.getVoidTy (),
642
689
Intrinsic::dx_resource_store_typedbuffer,
643
690
{Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
691
+
644
692
Orig->eraseFromParent ();
645
693
return true ;
646
694
}
0 commit comments