Skip to content

Commit 5f6b058

Browse files
authored
[DirectX] Match DXC when storing RWBuffer<float> (#129911)
Update the lowering of `llvm.dx.resource.store.typedbuffer` to match DXC and repeat the first element in cases where we are storing fewer than 4 elements. Fixes #128110
1 parent 1493f42 commit 5f6b058

File tree

4 files changed

+59
-28
lines changed

4 files changed

+59
-28
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def int_dx_resource_load_typedbuffer
3434
: DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
3535
[llvm_any_ty, llvm_i32_ty], [IntrReadMem]>;
3636
def int_dx_resource_store_typedbuffer
37-
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty],
37+
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_any_ty],
3838
[IntrWriteMem]>;
3939
def int_dx_resource_load_rawbuffer
4040
: DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -649,16 +649,13 @@ class OpLowerer {
649649

650650
uint64_t NumElements =
651651
DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
652-
Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
652+
Value *Mask =
653+
ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U);
653654

654655
// TODO: check that we only have vector or scalar...
655-
if (!IsRaw && NumElements != 4)
656-
return make_error<StringError>(
657-
"typedBufferStore data must be a vector of 4 elements",
658-
inconvertibleErrorCode());
659-
else if (NumElements > 4)
656+
if (NumElements > 4)
660657
return make_error<StringError>(
661-
"rawBufferStore data must have at most 4 elements",
658+
"Buffer store data must have at most 4 elements",
662659
inconvertibleErrorCode());
663660

664661
std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
@@ -687,10 +684,13 @@ class OpLowerer {
687684
if (DataElements[I] == nullptr)
688685
DataElements[I] =
689686
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
690-
// For any elements beyond the length of the vector, fill up with undef.
687+
688+
// For any elements beyond the length of the vector, we should fill it up
689+
// with undef - however, for typed buffers we repeat the first element to
690+
// match DXC.
691691
for (int I = NumElements, E = 4; I < E; ++I)
692692
if (DataElements[I] == nullptr)
693-
DataElements[I] = UndefValue::get(ScalarTy);
693+
DataElements[I] = IsRaw ? UndefValue::get(ScalarTy) : DataElements[0];
694694

695695
dxil::OpCode Op = OpCode::BufferStore;
696696
SmallVector<Value *, 9> Args{

llvm/test/CodeGen/DirectX/BufferStore-errors.ll

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
55

66
; CHECK: error:
77
; CHECK-SAME: in function storetoomany
8-
; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
8+
; CHECK-SAME: Buffer store data must have at most 4 elements
99
define void @storetoomany(<5 x float> %data, i32 %index) "hlsl.export" {
1010
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
1111
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -18,20 +18,4 @@ define void @storetoomany(<5 x float> %data, i32 %index) "hlsl.export" {
1818
ret void
1919
}
2020

21-
; CHECK: error:
22-
; CHECK-SAME: in function storetoofew
23-
; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
24-
define void @storetoofew(<3 x i32> %data, i32 %index) "hlsl.export" {
25-
%buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0)
26-
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4i32_1_0_0(
27-
i32 0, i32 0, i32 1, i32 0, i1 false)
28-
29-
call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(
30-
target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
31-
i32 %index, <3 x i32> %data)
32-
33-
ret void
34-
}
35-
3621
declare void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4f32_1_0_0t.v5f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0), i32, <5 x float>)
37-
declare void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(target("dx.TypedBuffer", <4 x i32>, 1, 0, 0), i32, <3 x i32>)

llvm/test/CodeGen/DirectX/BufferStore.ll

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
target triple = "dxil-pc-shadermodel6.6-compute"
44

5-
define void @storefloat(<4 x float> %data, i32 %index) {
5+
; CHECK-LABEL: define void @storefloats
6+
define void @storefloats(<4 x float> %data, i32 %index) {
67

78
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
89
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
@@ -25,6 +26,49 @@ define void @storefloat(<4 x float> %data, i32 %index) {
2526
ret void
2627
}
2728

29+
; CHECK-LABEL: define void @storeonefloat
30+
define void @storeonefloat(float %data, i32 %index) {
31+
32+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
33+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
34+
%buffer = call target("dx.TypedBuffer", float, 1, 0, 0)
35+
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0_0(
36+
i32 0, i32 0, i32 1, i32 0, i1 false)
37+
38+
; The temporary casts should all have been cleaned up
39+
; CHECK-NOT: %dx.resource.casthandle
40+
41+
; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float %data, float %data, float %data, float %data, i8 15){{$}}
42+
call void @llvm.dx.resource.store.typedbuffer(
43+
target("dx.TypedBuffer", float, 1, 0, 0) %buffer,
44+
i32 %index, float %data)
45+
46+
ret void
47+
}
48+
49+
; CHECK-LABEL: define void @storetwofloat
50+
define void @storetwofloat(<2 x float> %data, i32 %index) {
51+
52+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
53+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
54+
%buffer = call target("dx.TypedBuffer", <2 x float>, 1, 0, 0)
55+
@llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2f32_1_0_0(
56+
i32 0, i32 0, i32 1, i32 0, i1 false)
57+
58+
; The temporary casts should all have been cleaned up
59+
; CHECK-NOT: %dx.resource.casthandle
60+
61+
; CHECK: [[DATA0_0:%.*]] = extractelement <2 x float> %data, i32 0
62+
; CHECK: [[DATA0_1:%.*]] = extractelement <2 x float> %data, i32 1
63+
; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_0]], float [[DATA0_0]], i8 15){{$}}
64+
call void @llvm.dx.resource.store.typedbuffer(
65+
target("dx.TypedBuffer", <2 x float>, 1, 0, 0) %buffer,
66+
i32 %index, <2 x float> %data)
67+
68+
ret void
69+
}
70+
71+
; CHECK-LABEL: define void @storeint
2872
define void @storeint(<4 x i32> %data, i32 %index) {
2973

3074
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -45,6 +89,7 @@ define void @storeint(<4 x i32> %data, i32 %index) {
4589
ret void
4690
}
4791

92+
; CHECK-LABEL: define void @storehalf
4893
define void @storehalf(<4 x half> %data, i32 %index) {
4994

5095
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -68,6 +113,7 @@ define void @storehalf(<4 x half> %data, i32 %index) {
68113
ret void
69114
}
70115

116+
; CHECK-LABEL: define void @storei16
71117
define void @storei16(<4 x i16> %data, i32 %index) {
72118

73119
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -91,6 +137,7 @@ define void @storei16(<4 x i16> %data, i32 %index) {
91137
ret void
92138
}
93139

140+
; CHECK-LABEL: define void @store_scalarized_floats
94141
define void @store_scalarized_floats(float %data0, float %data1, float %data2, float %data3, i32 %index) {
95142

96143
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,

0 commit comments

Comments
 (0)