25
25
#include " llvm/IR/PassManager.h"
26
26
#include " llvm/IR/Type.h"
27
27
#include " llvm/Pass.h"
28
+ #include " llvm/Support/Casting.h"
28
29
#include " llvm/Support/ErrorHandling.h"
29
30
#include " llvm/Support/MathExtras.h"
30
31
@@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
70
71
case Intrinsic::vector_reduce_add:
71
72
case Intrinsic::vector_reduce_fadd:
72
73
return true ;
73
- case Intrinsic::dx_resource_load_typedbuffer:
74
- // We need to handle doubles and vector of doubles.
75
- return F.getReturnType ()
76
- ->getStructElementType (0 )
77
- ->getScalarType ()
78
- ->isDoubleTy ();
79
- case Intrinsic::dx_resource_store_typedbuffer:
80
- // We need to handle doubles and vector of doubles.
81
- return F.getFunctionType ()->getParamType (2 )->getScalarType ()->isDoubleTy ();
74
+ case Intrinsic::dx_resource_load_typedbuffer: {
75
+ // We need to handle i64, doubles, and vectors of them.
76
+ Type *ScalarTy =
77
+ F.getReturnType ()->getStructElementType (0 )->getScalarType ();
78
+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
79
+ }
80
+ case Intrinsic::dx_resource_store_typedbuffer: {
81
+ // We need to handle i64 and doubles and vectors of i64 and doubles.
82
+ Type *ScalarTy = F.getFunctionType ()->getParamType (2 )->getScalarType ();
83
+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
84
+ }
82
85
}
83
86
return false ;
84
87
}
@@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
545
548
IRBuilder<> Builder (Orig);
546
549
547
550
Type *BufferTy = Orig->getType ()->getStructElementType (0 );
548
- assert (BufferTy->getScalarType ()->isDoubleTy () &&
549
- " Only expand double or double2" );
551
+ Type *ScalarTy = BufferTy->getScalarType ();
552
+ bool IsDouble = ScalarTy->isDoubleTy ();
553
+ assert (IsDouble || ScalarTy->isIntegerTy (64 ) &&
554
+ " Only expand double or int64 scalars or vectors" );
550
555
551
556
unsigned ExtractNum = 2 ;
552
557
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
553
558
assert (VT->getNumElements () == 2 &&
554
- " TypedBufferLoad double vector has wrong size" );
559
+ " TypedBufferLoad vector must be size 2 " );
555
560
ExtractNum = 4 ;
556
561
}
557
562
@@ -570,22 +575,54 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
570
575
ExtractElements.push_back (
571
576
Builder.CreateExtractElement (Extract, Builder.getInt32 (I)));
572
577
573
- // combine into double(s)
578
+ // combine into double(s) or int64(s)
574
579
Value *Result = PoisonValue::get (BufferTy);
575
580
for (unsigned I = 0 ; I < ExtractNum; I += 2 ) {
576
- Value *Dbl =
577
- Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
578
- {ExtractElements[I], ExtractElements[I + 1 ]});
581
+ Value *Combined = nullptr ;
582
+ if (IsDouble) {
583
+ // For doubles, use dx_asdouble intrinsic
584
+ Combined =
585
+ Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
586
+ {ExtractElements[I], ExtractElements[I + 1 ]});
587
+ } else {
588
+ // For int64, manually combine two int32s
589
+ // First, zero-extend both values to i64
590
+ Value *Lo = Builder.CreateZExt (ExtractElements[I], Builder.getInt64Ty ());
591
+ Value *Hi =
592
+ Builder.CreateZExt (ExtractElements[I + 1 ], Builder.getInt64Ty ());
593
+ // Shift the high bits left by 32 bits
594
+ Value *ShiftedHi = Builder.CreateShl (Hi, Builder.getInt64 (32 ));
595
+ // OR the high and low bits together
596
+ Combined = Builder.CreateOr (Lo, ShiftedHi);
597
+ }
598
+
579
599
if (ExtractNum == 4 )
580
- Result =
581
- Builder. CreateInsertElement (Result, Dbl, Builder.getInt32 (I / 2 ));
600
+ Result = Builder. CreateInsertElement (Result, Combined,
601
+ Builder.getInt32 (I / 2 ));
582
602
else
583
- Result = Dbl ;
603
+ Result = Combined ;
584
604
}
585
605
586
606
Value *CheckBit = nullptr ;
587
607
for (User *U : make_early_inc_range (Orig->users ())) {
588
- auto *EVI = cast<ExtractValueInst>(U);
608
+ if (auto *Ret = dyn_cast<ReturnInst>(U)) {
609
+ // For return instructions, we need to handle the case where the function
610
+ // is directly returning the result of the call
611
+ Type *RetTy = Ret->getFunction ()->getReturnType ();
612
+ Value *StructRet = PoisonValue::get (RetTy);
613
+ StructRet = Builder.CreateInsertValue (StructRet, Result, {0 });
614
+ Value *CheckBitForRet = Builder.CreateExtractValue (Load, {1 });
615
+ StructRet = Builder.CreateInsertValue (StructRet, CheckBitForRet, {1 });
616
+ Ret->setOperand (0 , StructRet);
617
+ continue ;
618
+ }
619
+ auto *EVI = dyn_cast<ExtractValueInst>(U);
620
+ if (!EVI) {
621
+ // If it's not a ReturnInst or ExtractValueInst, we don't know how to
622
+ // handle it
623
+ llvm_unreachable (" Unexpected user of typedbufferload" );
624
+ }
625
+
589
626
ArrayRef<unsigned > Indices = EVI->getIndices ();
590
627
assert (Indices.size () == 1 );
591
628
@@ -609,38 +646,90 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
609
646
IRBuilder<> Builder (Orig);
610
647
611
648
Type *BufferTy = Orig->getFunctionType ()->getParamType (2 );
612
- assert (BufferTy->getScalarType ()->isDoubleTy () &&
613
- " Only expand double or double2" );
649
+ Type *ScalarTy = BufferTy->getScalarType ();
650
+ bool IsDouble = ScalarTy->isDoubleTy ();
651
+ assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
652
+ " Only expand double or int64 scalars or vectors" );
614
653
615
654
unsigned ExtractNum = 2 ;
616
655
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
617
656
assert (VT->getNumElements () == 2 &&
618
- " TypedBufferStore double vector has wrong size" );
657
+ " TypedBufferStore vector must be size 2 " );
619
658
ExtractNum = 4 ;
620
659
}
660
+ if (IsDouble) {
661
+ Type *SplitElementTy = Builder.getInt32Ty ();
662
+ if (ExtractNum == 4 )
663
+ SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
664
+
665
+ // Handle double type(s) - keep original behavior
666
+ auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
667
+ Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
668
+ {Orig->getOperand (2 )});
669
+ // create our vector
670
+ Value *LowBits = Builder.CreateExtractValue (Split, 0 );
671
+ Value *HighBits = Builder.CreateExtractValue (Split, 1 );
672
+ Value *Val;
673
+ if (ExtractNum == 2 ) {
674
+ Val = PoisonValue::get (VectorType::get (Builder.getInt32Ty (), 2 , false ));
675
+ Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
676
+ Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
677
+ } else
678
+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
679
+
680
+ Builder.CreateIntrinsic (Builder.getVoidTy (),
681
+ Intrinsic::dx_resource_store_typedbuffer,
682
+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
683
+ } else {
684
+ // Handle int64 type(s)
685
+ Value *InputVal = Orig->getOperand (2 );
686
+ Value *Val;
687
+
688
+ if (ExtractNum == 4 ) {
689
+ // Handle vector of int64
690
+ Type *Int32x4Ty = VectorType::get (Builder.getInt32Ty (), 4 , false );
691
+ Val = PoisonValue::get (Int32x4Ty);
692
+
693
+ for (unsigned I = 0 ; I < 2 ; ++I) {
694
+ // Extract each int64 element
695
+ Value *Int64Val =
696
+ Builder.CreateExtractElement (InputVal, Builder.getInt32 (I));
697
+
698
+ // Get low 32 bits by truncating to i32
699
+ Value *LowBits = Builder.CreateTrunc (Int64Val, Builder.getInt32Ty ());
700
+
701
+ // Get high 32 bits by shifting right by 32 and truncating
702
+ Value *ShiftedVal = Builder.CreateLShr (Int64Val, Builder.getInt64 (32 ));
703
+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder.getInt32Ty ());
704
+
705
+ // Insert into our final vector
706
+ Val =
707
+ Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (I * 2 ));
708
+ Val = Builder.CreateInsertElement (Val, HighBits,
709
+ Builder.getInt32 (I * 2 + 1 ));
710
+ }
711
+ } else {
712
+ // Handle scalar int64
713
+ Type *Int32x2Ty = VectorType::get (Builder.getInt32Ty (), 2 , false );
714
+ Val = PoisonValue::get (Int32x2Ty);
715
+
716
+ // Get low 32 bits by truncating to i32
717
+ Value *LowBits = Builder.CreateTrunc (InputVal, Builder.getInt32Ty ());
718
+
719
+ // Get high 32 bits by shifting right by 32 and truncating
720
+ Value *ShiftedVal = Builder.CreateLShr (InputVal, Builder.getInt64 (32 ));
721
+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder.getInt32Ty ());
722
+
723
+ // Insert into our final vector
724
+ Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
725
+ Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
726
+ }
727
+
728
+ Builder.CreateIntrinsic (Builder.getVoidTy (),
729
+ Intrinsic::dx_resource_store_typedbuffer,
730
+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
731
+ }
621
732
622
- Type *SplitElementTy = Builder.getInt32Ty ();
623
- if (ExtractNum == 4 )
624
- SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
625
-
626
- // split our double(s)
627
- auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
628
- Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
629
- Orig->getOperand (2 ));
630
- // create our vector
631
- Value *LowBits = Builder.CreateExtractValue (Split, 0 );
632
- Value *HighBits = Builder.CreateExtractValue (Split, 1 );
633
- Value *Val;
634
- if (ExtractNum == 2 ) {
635
- Val = PoisonValue::get (VectorType::get (SplitElementTy, 2 , false ));
636
- Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
637
- Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
638
- } else
639
- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
640
-
641
- Builder.CreateIntrinsic (Builder.getVoidTy (),
642
- Intrinsic::dx_resource_store_typedbuffer,
643
- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
644
733
Orig->eraseFromParent ();
645
734
return true ;
646
735
}
0 commit comments