@@ -651,85 +651,77 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
651
651
assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
652
652
" Only expand double or int64 scalars or vectors" );
653
653
654
- unsigned ExtractNum = 2 ;
655
- if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
656
- assert (VT->getNumElements () == 2 &&
654
+ // Determine if we're dealing with a vector or scalar
655
+ bool IsVector = isa<FixedVectorType>(BufferTy);
656
+ if (IsVector) {
657
+ assert (cast<FixedVectorType>(BufferTy)->getNumElements () == 2 &&
657
658
" TypedBufferStore vector must be size 2" );
658
- ExtractNum = 4 ;
659
659
}
660
+
661
+ // Create the appropriate vector type for the result
662
+ Type *Int32Ty = Builder.getInt32Ty ();
663
+ Type *ResultTy = VectorType::get (Int32Ty, IsVector ? 4 : 2 , false );
664
+ Value *Val = PoisonValue::get (ResultTy);
665
+
666
+ // Split the 64-bit values into 32-bit components
660
667
if (IsDouble) {
661
- Type *SplitElementTy = Builder.getInt32Ty ();
662
- if (ExtractNum == 4 )
668
+ // Handle double type(s)
669
+ Type *SplitElementTy = Int32Ty;
670
+ if (IsVector)
663
671
SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
664
672
665
- // Handle double type(s) - keep original behavior
666
673
auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
667
674
Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
668
675
{Orig->getOperand (2 )});
669
- // create our vector
670
676
Value *LowBits = Builder.CreateExtractValue (Split, 0 );
671
677
Value *HighBits = Builder.CreateExtractValue (Split, 1 );
672
- Value *Val;
673
- if (ExtractNum == 2 ) {
674
- Val = PoisonValue::get (VectorType::get (Builder.getInt32Ty (), 2 , false ));
678
+
679
+ if (IsVector) {
680
+ // For vector doubles, use shuffle to create the final vector
681
+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
682
+ } else {
683
+ // For scalar doubles, insert the elements
675
684
Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
676
685
Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
677
- } else
678
- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
679
-
680
- Builder.CreateIntrinsic (Builder.getVoidTy (),
681
- Intrinsic::dx_resource_store_typedbuffer,
682
- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
686
+ }
683
687
} else {
684
688
// Handle int64 type(s)
685
689
Value *InputVal = Orig->getOperand (2 );
686
- Value *Val;
687
690
688
- if (ExtractNum == 4 ) {
691
+ if (IsVector ) {
689
692
// Handle vector of int64
690
- Type *Int32x4Ty = VectorType::get (Builder.getInt32Ty (), 4 , false );
691
- Val = PoisonValue::get (Int32x4Ty);
692
-
693
693
for (unsigned I = 0 ; I < 2 ; ++I) {
694
694
// Extract each int64 element
695
695
Value *Int64Val =
696
696
Builder.CreateExtractElement (InputVal, Builder.getInt32 (I));
697
697
698
- // Get low 32 bits by truncating to i32
699
- Value *LowBits = Builder.CreateTrunc (Int64Val, Builder.getInt32Ty ());
700
-
701
- // Get high 32 bits by shifting right by 32 and truncating
698
+ // Split into low and high 32-bit parts
699
+ Value *LowBits = Builder.CreateTrunc (Int64Val, Int32Ty);
702
700
Value *ShiftedVal = Builder.CreateLShr (Int64Val, Builder.getInt64 (32 ));
703
- Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder. getInt32Ty () );
701
+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Int32Ty );
704
702
705
- // Insert into our final vector
703
+ // Insert into result vector
706
704
Val =
707
705
Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (I * 2 ));
708
706
Val = Builder.CreateInsertElement (Val, HighBits,
709
707
Builder.getInt32 (I * 2 + 1 ));
710
708
}
711
709
} else {
712
710
// Handle scalar int64
713
- Type *Int32x2Ty = VectorType::get (Builder.getInt32Ty (), 2 , false );
714
- Val = PoisonValue::get (Int32x2Ty);
715
-
716
- // Get low 32 bits by truncating to i32
717
- Value *LowBits = Builder.CreateTrunc (InputVal, Builder.getInt32Ty ());
718
-
719
- // Get high 32 bits by shifting right by 32 and truncating
711
+ Value *LowBits = Builder.CreateTrunc (InputVal, Int32Ty);
720
712
Value *ShiftedVal = Builder.CreateLShr (InputVal, Builder.getInt64 (32 ));
721
- Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder. getInt32Ty () );
713
+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Int32Ty );
722
714
723
- // Insert into our final vector
724
715
Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
725
716
Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
726
717
}
727
-
728
- Builder.CreateIntrinsic (Builder.getVoidTy (),
729
- Intrinsic::dx_resource_store_typedbuffer,
730
- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
731
718
}
732
719
720
+ // Create the final intrinsic call
721
+ Builder.CreateIntrinsic (Builder.getVoidTy (),
722
+ Intrinsic::dx_resource_store_typedbuffer,
723
+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
724
+
733
725
Orig->eraseFromParent ();
734
726
return true ;
735
727
}
0 commit comments