Skip to content

Commit 5dbd877

Browse files
authored
[DirectX] add support for i64 buffer load/stores (#145047)
fixes #140321 Specifically it fixes ` error: Cannot create BufferLoad operation: Invalid overload type` https://hlsl.godbolt.org/z/dTq4q7o58 but no new DML shaders are building. This change now exposes #144747. The change does two things it adds i64 support for intrinsic expansion for the `dx_resource_load_typedbuffer`, and `dx_resource_store_typedbuffer` intrinsics. It also lets loaded typedbuffers crash more gracefully because of ` auto *EVI = cast<ExtractValueInst>(U);` is now a `dyn_cast` and `llvm_unreachable`.
1 parent 23f1ba3 commit 5dbd877

File tree

4 files changed

+177
-43
lines changed

4 files changed

+177
-43
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 90 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/IR/PassManager.h"
2626
#include "llvm/IR/Type.h"
2727
#include "llvm/Pass.h"
28+
#include "llvm/Support/Casting.h"
2829
#include "llvm/Support/ErrorHandling.h"
2930
#include "llvm/Support/MathExtras.h"
3031

@@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
7071
case Intrinsic::vector_reduce_add:
7172
case Intrinsic::vector_reduce_fadd:
7273
return true;
73-
case Intrinsic::dx_resource_load_typedbuffer:
74-
// We need to handle doubles and vector of doubles.
75-
return F.getReturnType()
76-
->getStructElementType(0)
77-
->getScalarType()
78-
->isDoubleTy();
79-
case Intrinsic::dx_resource_store_typedbuffer:
80-
// We need to handle doubles and vector of doubles.
81-
return F.getFunctionType()->getParamType(2)->getScalarType()->isDoubleTy();
74+
case Intrinsic::dx_resource_load_typedbuffer: {
75+
// We need to handle i64, doubles, and vectors of them.
76+
Type *ScalarTy =
77+
F.getReturnType()->getStructElementType(0)->getScalarType();
78+
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
79+
}
80+
case Intrinsic::dx_resource_store_typedbuffer: {
81+
// We need to handle i64 and doubles and vectors of i64 and doubles.
82+
Type *ScalarTy = F.getFunctionType()->getParamType(2)->getScalarType();
83+
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
84+
}
8285
}
8386
return false;
8487
}
@@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
545548
IRBuilder<> Builder(Orig);
546549

547550
Type *BufferTy = Orig->getType()->getStructElementType(0);
548-
assert(BufferTy->getScalarType()->isDoubleTy() &&
549-
"Only expand double or double2");
551+
Type *ScalarTy = BufferTy->getScalarType();
552+
bool IsDouble = ScalarTy->isDoubleTy();
553+
assert(IsDouble || ScalarTy->isIntegerTy(64) &&
554+
"Only expand double or int64 scalars or vectors");
550555

551556
unsigned ExtractNum = 2;
552557
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
553558
assert(VT->getNumElements() == 2 &&
554-
"TypedBufferLoad double vector has wrong size");
559+
"TypedBufferLoad vector must be size 2");
555560
ExtractNum = 4;
556561
}
557562

@@ -570,22 +575,42 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
570575
ExtractElements.push_back(
571576
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
572577

573-
// combine into double(s)
578+
// combine into double(s) or int64(s)
574579
Value *Result = PoisonValue::get(BufferTy);
575580
for (unsigned I = 0; I < ExtractNum; I += 2) {
576-
Value *Dbl =
577-
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
578-
{ExtractElements[I], ExtractElements[I + 1]});
581+
Value *Combined = nullptr;
582+
if (IsDouble)
583+
// For doubles, use dx_asdouble intrinsic
584+
Combined =
585+
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
586+
{ExtractElements[I], ExtractElements[I + 1]});
587+
else {
588+
// For int64, manually combine two int32s
589+
// First, zero-extend both values to i64
590+
Value *Lo = Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
591+
Value *Hi =
592+
Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
593+
// Shift the high bits left by 32 bits
594+
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
595+
// OR the high and low bits together
596+
Combined = Builder.CreateOr(Lo, ShiftedHi);
597+
}
598+
579599
if (ExtractNum == 4)
580-
Result =
581-
Builder.CreateInsertElement(Result, Dbl, Builder.getInt32(I / 2));
600+
Result = Builder.CreateInsertElement(Result, Combined,
601+
Builder.getInt32(I / 2));
582602
else
583-
Result = Dbl;
603+
Result = Combined;
584604
}
585605

586606
Value *CheckBit = nullptr;
587607
for (User *U : make_early_inc_range(Orig->users())) {
588-
auto *EVI = cast<ExtractValueInst>(U);
608+
// If it's not a ExtractValueInst, we don't know how to
609+
// handle it
610+
auto *EVI = dyn_cast<ExtractValueInst>(U);
611+
if (!EVI)
612+
llvm_unreachable("Unexpected user of typedbufferload");
613+
589614
ArrayRef<unsigned> Indices = EVI->getIndices();
590615
assert(Indices.size() == 1);
591616

@@ -609,38 +634,61 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
609634
IRBuilder<> Builder(Orig);
610635

611636
Type *BufferTy = Orig->getFunctionType()->getParamType(2);
612-
assert(BufferTy->getScalarType()->isDoubleTy() &&
613-
"Only expand double or double2");
614-
615-
unsigned ExtractNum = 2;
616-
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
617-
assert(VT->getNumElements() == 2 &&
618-
"TypedBufferStore double vector has wrong size");
619-
ExtractNum = 4;
637+
Type *ScalarTy = BufferTy->getScalarType();
638+
bool IsDouble = ScalarTy->isDoubleTy();
639+
assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
640+
"Only expand double or int64 scalars or vectors");
641+
642+
// Determine if we're dealing with a vector or scalar
643+
bool IsVector = isa<FixedVectorType>(BufferTy);
644+
if (IsVector) {
645+
assert(cast<FixedVectorType>(BufferTy)->getNumElements() == 2 &&
646+
"TypedBufferStore vector must be size 2");
620647
}
621648

622-
Type *SplitElementTy = Builder.getInt32Ty();
623-
if (ExtractNum == 4)
649+
// Create the appropriate vector type for the result
650+
Type *Int32Ty = Builder.getInt32Ty();
651+
Type *ResultTy = VectorType::get(Int32Ty, IsVector ? 4 : 2, false);
652+
Value *Val = PoisonValue::get(ResultTy);
653+
654+
Type *SplitElementTy = Int32Ty;
655+
if (IsVector)
624656
SplitElementTy = VectorType::get(SplitElementTy, 2, false);
625657

626-
// split our double(s)
627-
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
628-
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
629-
Orig->getOperand(2));
630-
// create our vector
631-
Value *LowBits = Builder.CreateExtractValue(Split, 0);
632-
Value *HighBits = Builder.CreateExtractValue(Split, 1);
633-
Value *Val;
634-
if (ExtractNum == 2) {
635-
Val = PoisonValue::get(VectorType::get(SplitElementTy, 2, false));
658+
Value *LowBits = nullptr;
659+
Value *HighBits = nullptr;
660+
// Split the 64-bit values into 32-bit components
661+
if (IsDouble) {
662+
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
663+
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
664+
{Orig->getOperand(2)});
665+
LowBits = Builder.CreateExtractValue(Split, 0);
666+
HighBits = Builder.CreateExtractValue(Split, 1);
667+
} else {
668+
// Handle int64 type(s)
669+
Value *InputVal = Orig->getOperand(2);
670+
Constant *ShiftAmt = Builder.getInt64(32);
671+
if (IsVector)
672+
ShiftAmt = ConstantVector::getSplat(ElementCount::getFixed(2), ShiftAmt);
673+
674+
// Split into low and high 32-bit parts
675+
LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
676+
Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt);
677+
HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy);
678+
}
679+
680+
if (IsVector) {
681+
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
682+
} else {
636683
Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
637684
Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
638-
} else
639-
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
685+
}
640686

687+
// Create the final intrinsic call
641688
Builder.CreateIntrinsic(Builder.getVoidTy(),
642689
Intrinsic::dx_resource_store_typedbuffer,
643690
{Orig->getOperand(0), Orig->getOperand(1), Val});
691+
644692
Orig->eraseFromParent();
645693
return true;
646694
}

llvm/test/CodeGen/DirectX/BufferLoadDouble.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ define void @loadf64WithCheckBit() {
8888
; CHECK-NOT: extractvalue { double, i1 }
8989
%cb = extractvalue {double, i1} %load0, 1
9090
ret void
91-
}
91+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s
3+
4+
target triple = "dxil-pc-shadermodel6.2-compute"
5+
6+
define void @loadi64() {
7+
; CHECK-LABEL: define void @loadi64() {
8+
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
9+
; CHECK-NEXT: [[TMP1:%.*]] = call { <2 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v2i32.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0)
10+
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <2 x i32>, i1 } [[TMP1]], 0
11+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
12+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
13+
; CHECK-NEXT: [[TMP5:%.*]] = zext i32 [[TMP3]] to i64
14+
; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64
15+
; CHECK-NEXT: [[TMP7:%.*]] = shl i64 [[TMP6]], 32
16+
; CHECK-NEXT: [[TMP8:%.*]] = or i64 [[TMP5]], [[TMP7]]
17+
; CHECK-NEXT: ret void
18+
;
19+
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
20+
%result = call { i64, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0)
21+
ret void
22+
}
23+
24+
define void @loadv2i64() {
25+
; CHECK-LABEL: define void @loadv2i64() {
26+
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
27+
; CHECK-NEXT: [[TMP1:%.*]] = call { <4 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v4i32.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0)
28+
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <4 x i32>, i1 } [[TMP1]], 0
29+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0
30+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[TMP2]], i32 1
31+
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i32> [[TMP2]], i32 2
32+
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[TMP2]], i32 3
33+
; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP3]] to i64
34+
; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP4]] to i64
35+
; CHECK-NEXT: [[TMP9:%.*]] = shl i64 [[TMP8]], 32
36+
; CHECK-NEXT: [[TMP10:%.*]] = or i64 [[TMP7]], [[TMP9]]
37+
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> poison, i64 [[TMP10]], i32 0
38+
; CHECK-NEXT: [[TMP12:%.*]] = zext i32 [[TMP5]] to i64
39+
; CHECK-NEXT: [[TMP13:%.*]] = zext i32 [[TMP6]] to i64
40+
; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP13]], 32
41+
; CHECK-NEXT: [[TMP15:%.*]] = or i64 [[TMP12]], [[TMP14]]
42+
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <2 x i64> [[TMP11]], i64 [[TMP15]], i32 1
43+
; CHECK-NEXT: ret void
44+
;
45+
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
46+
%result = call { <2 x i64>, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0)
47+
ret void
48+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s
3+
4+
target triple = "dxil-pc-shadermodel6.6-compute"
5+
6+
define void @storei64(i64 %0) {
7+
; CHECK-LABEL: define void @storei64(
8+
; CHECK-SAME: i64 [[TMP0:%.*]]) {
9+
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
10+
; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP0]] to i32
11+
; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP0]], 32
12+
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[TMP3]] to i32
13+
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP2]], i32 0
14+
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP4]], i32 1
15+
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP6]])
16+
; CHECK-NEXT: ret void
17+
;
18+
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
19+
call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0,i64 %0)
20+
ret void
21+
}
22+
23+
24+
define void @storev2i64(<2 x i64> %0) {
25+
; CHECK-LABEL: define void @storev2i64(
26+
; CHECK-SAME: <2 x i64> [[TMP0:%.*]]) {
27+
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
28+
; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[TMP0]] to <2 x i32>
29+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <2 x i64> [[TMP0]], splat (i64 32)
30+
; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP3]] to <2 x i32>
31+
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP4]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
32+
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP13]])
33+
; CHECK-NEXT: ret void
34+
;
35+
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
36+
call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0, <2 x i64> %0)
37+
ret void
38+
}

0 commit comments

Comments
 (0)