Skip to content

Commit a62b464

Browse files
committed
[DirectX] Lower @llvm.dx.typedBufferLoad to DXIL ops
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`. There's some complexity here due to translating from a vector return type to a named struct and trying to avoid excessive IR coming out of that. Note that this change includes a bit of a hack in how it deals with `getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal with operation overloads to generate a table directly rather than proxy through the OverloadKind enum, but that's left for a later change here. Pull Request: llvm#104252
1 parent 787fc81 commit a62b464

File tree

7 files changed

+252
-9
lines changed

7 files changed

+252
-9
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def int_dx_handle_fromBinding
3030
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
3131
[IntrNoMem]>;
3232

33+
def int_dx_typedBufferLoad
34+
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
35+
3336
// Cast between target extension handle types and dxil-style opaque handles
3437
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
3538

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
4040
def HalfTy : DXILOpParamType;
4141
def FloatTy : DXILOpParamType;
4242
def DoubleTy : DXILOpParamType;
43-
def ResRetTy : DXILOpParamType;
43+
def ResRetHalfTy : DXILOpParamType;
44+
def ResRetFloatTy : DXILOpParamType;
45+
def ResRetInt16Ty : DXILOpParamType;
46+
def ResRetInt32Ty : DXILOpParamType;
4447
def HandleTy : DXILOpParamType;
4548
def ResBindTy : DXILOpParamType;
4649
def ResPropsTy : DXILOpParamType;
@@ -693,6 +696,17 @@ def CreateHandle : DXILOp<57, createHandle> {
693696
let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
694697
}
695698

699+
def BufferLoad : DXILOp<68, bufferLoad> {
700+
let Doc = "reads from a TypedBuffer";
701+
// Handle, Coord0, Coord1
702+
let arguments = [HandleTy, Int32Ty, Int32Ty];
703+
let result = OverloadTy;
704+
let overloads =
705+
[Overloads<DXIL1_0,
706+
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
707+
let stages = [Stages<DXIL1_0, [all_stages]>];
708+
}
709+
696710
def ThreadId : DXILOp<93, threadId> {
697711
let Doc = "Reads the thread ID";
698712
let LLVMIntrinsic = int_dx_thread_id;

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,12 @@ static OverloadKind getOverloadKind(Type *Ty) {
120120
}
121121
case Type::PointerTyID:
122122
return OverloadKind::UserDefineType;
123-
case Type::StructTyID:
124-
return OverloadKind::ObjectType;
123+
case Type::StructTyID: {
124+
// TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
125+
// how we're handling overloads and remove the `OverloadKind` proxy enum.
126+
StructType *ST = cast<StructType>(Ty);
127+
return getOverloadKind(ST->getElementType(0));
128+
}
125129
default:
126130
return OverloadKind::UNDEFINED;
127131
}
@@ -194,10 +198,11 @@ static StructType *getOrCreateStructType(StringRef Name,
194198
return StructType::create(Ctx, EltTys, Name);
195199
}
196200

197-
static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
198-
OverloadKind Kind = getOverloadKind(OverloadTy);
201+
static StructType *getResRetType(Type *ElementTy) {
202+
LLVMContext &Ctx = ElementTy->getContext();
203+
OverloadKind Kind = getOverloadKind(ElementTy);
199204
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
200-
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
205+
Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
201206
Type::getInt32Ty(Ctx)};
202207
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
203208
}
@@ -247,8 +252,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
247252
return Type::getInt64Ty(Ctx);
248253
case OpParamType::OverloadTy:
249254
return OverloadTy;
250-
case OpParamType::ResRetTy:
251-
return getResRetType(OverloadTy, Ctx);
255+
case OpParamType::ResRetHalfTy:
256+
return getResRetType(Type::getHalfTy(Ctx));
257+
case OpParamType::ResRetFloatTy:
258+
return getResRetType(Type::getFloatTy(Ctx));
259+
case OpParamType::ResRetInt16Ty:
260+
return getResRetType(Type::getInt16Ty(Ctx));
261+
case OpParamType::ResRetInt32Ty:
262+
return getResRetType(Type::getInt32Ty(Ctx));
252263
case OpParamType::HandleTy:
253264
return getHandleType(Ctx);
254265
case OpParamType::ResBindTy:
@@ -390,6 +401,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
390401
return makeOpError(OpCode, "Wrong number of arguments");
391402
OverloadTy = Args[ArgIndex]->getType();
392403
}
404+
393405
FunctionType *DXILOpFT =
394406
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
395407

@@ -450,6 +462,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
450462
return *Result;
451463
}
452464

465+
StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
466+
return ::getResRetType(ElementTy);
467+
}
468+
453469
StructType *DXILOpBuilder::getHandleType() {
454470
return ::getHandleType(IRB.getContext());
455471
}

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class DXILOpBuilder {
4646
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
4747
Type *RetTy = nullptr);
4848

49+
/// Get a `%dx.types.ResRet` type with the given element type.
50+
StructType *getResRetType(Type *ElementTy);
4951
/// Get the `%dx.types.Handle` type.
5052
StructType *getHandleType();
5153

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,70 @@ class OpLowerer {
259259
lowerToBindAndAnnotateHandle(F);
260260
}
261261

262+
void lowerTypedBufferLoad(Function &F) {
263+
IRBuilder<> &IRB = OpBuilder.getIRB();
264+
Type *Int32Ty = IRB.getInt32Ty();
265+
266+
replaceFunction(F, [&](CallInst *CI) -> Error {
267+
IRB.SetInsertPoint(CI);
268+
269+
Value *Handle =
270+
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
271+
Value *Index0 = CI->getArgOperand(1);
272+
Value *Index1 = UndefValue::get(Int32Ty);
273+
274+
Type *OldRetTy = CI->getType();
275+
Type *NewRetTy = OpBuilder.getResRetType(OldRetTy->getScalarType());
276+
277+
std::array<Value *, 3> Args{Handle, Index0, Index1};
278+
Expected<CallInst *> OpCall =
279+
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
280+
if (Error E = OpCall.takeError())
281+
return E;
282+
283+
// For scalars, we extract the first element of the ResRet value.
284+
if (!isa<FixedVectorType>(OldRetTy)) {
285+
Value *EVI = IRB.CreateExtractValue(*OpCall, 0);
286+
CI->replaceAllUsesWith(EVI);
287+
CI->eraseFromParent();
288+
return Error::success();
289+
}
290+
291+
std::array<Value *, 4> Extracts = {};
292+
293+
// We've switched the return type from a vector to a struct, but at this
294+
// point most vectors have probably already been scalarized. Try to
295+
// forward arguments directly rather than inserting into and immediately
296+
// extracting from a vector.
297+
for (Use &U : make_early_inc_range(CI->uses()))
298+
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
299+
if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
300+
size_t IndexVal = IndexOp->getZExtValue();
301+
assert(IndexVal < 4 && "Index into buffer load out of range");
302+
if (!Extracts[IndexVal])
303+
Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
304+
EEI->replaceAllUsesWith(Extracts[IndexVal]);
305+
EEI->eraseFromParent();
306+
}
307+
308+
unsigned N = cast<FixedVectorType>(OldRetTy)->getNumElements();
309+
// If there are still uses then we need to create a vector.
310+
if (!CI->use_empty()) {
311+
for (int I = 0, E = N; I != E; ++I)
312+
if (!Extracts[I])
313+
Extracts[I] = IRB.CreateExtractValue(*OpCall, I);
314+
315+
Value *Vec = UndefValue::get(OldRetTy);
316+
for (int I = 0, E = N; I != E; ++I)
317+
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
318+
CI->replaceAllUsesWith(Vec);
319+
}
320+
321+
CI->eraseFromParent();
322+
return Error::success();
323+
});
324+
}
325+
262326
bool lowerIntrinsics() {
263327
bool Updated = false;
264328

@@ -276,6 +340,10 @@ class OpLowerer {
276340
#include "DXILOperation.inc"
277341
case Intrinsic::dx_handle_fromBinding:
278342
lowerHandleFromBinding(F);
343+
break;
344+
case Intrinsic::dx_typedBufferLoad:
345+
lowerTypedBufferLoad(F);
346+
break;
279347
}
280348
Updated = true;
281349
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
; RUN: opt -S -dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
declare void @scalar_user(float)
6+
declare void @vector_user(<4 x float>)
7+
8+
define void @loadv4f32() {
9+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
10+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
11+
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
12+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
13+
i32 0, i32 0, i32 1, i32 0, i1 false)
14+
15+
; The temporary casts should all have been cleaned up
16+
; CHECK-NOT: %dx.cast_handle
17+
18+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
19+
%data0 = call <4 x float> @llvm.dx.typedBufferLoad(
20+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
21+
22+
; The extract order depends on the users, so don't enforce that here.
23+
; CHECK-DAG: [[VAL0_0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
24+
%data0_0 = extractelement <4 x float> %data0, i32 0
25+
; CHECK-DAG: [[VAL0_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
26+
%data0_2 = extractelement <4 x float> %data0, i32 2
27+
28+
; If all of the uses are extracts, we skip creating a vector
29+
; CHECK-NOT: insertelement
30+
; CHECK-DAG: call void @scalar_user(float [[VAL0_0]])
31+
; CHECK-DAG: call void @scalar_user(float [[VAL0_2]])
32+
call void @scalar_user(float %data0_0)
33+
call void @scalar_user(float %data0_2)
34+
35+
; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
36+
%data4 = call <4 x float> @llvm.dx.typedBufferLoad(
37+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
38+
39+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
40+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
41+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
42+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
43+
; CHECK: insertelement <4 x float> undef
44+
; CHECK: insertelement <4 x float>
45+
; CHECK: insertelement <4 x float>
46+
; CHECK: insertelement <4 x float>
47+
call void @vector_user(<4 x float> %data4)
48+
49+
; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
50+
%data12 = call <4 x float> @llvm.dx.typedBufferLoad(
51+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
52+
53+
; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
54+
%data12_3 = extractelement <4 x float> %data12, i32 3
55+
56+
; If there are a mix of users we need the vector, but extracts are direct
57+
; CHECK: call void @scalar_user(float [[DATA12_3]])
58+
call void @scalar_user(float %data12_3)
59+
call void @vector_user(<4 x float> %data12)
60+
61+
ret void
62+
}
63+
64+
define void @loadf32() {
65+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
66+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
67+
%buffer = call target("dx.TypedBuffer", float, 0, 0, 0)
68+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_0_0_0(
69+
i32 0, i32 0, i32 1, i32 0, i1 false)
70+
71+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
72+
%data0 = call float @llvm.dx.typedBufferLoad(
73+
target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 0)
74+
75+
; CHECK: [[VAL0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
76+
; CHECK: call void @scalar_user(float [[VAL0]])
77+
call void @scalar_user(float %data0)
78+
79+
ret void
80+
}
81+
82+
define void @loadv2f32() {
83+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
84+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
85+
%buffer = call target("dx.TypedBuffer", <2 x float>, 0, 0, 0)
86+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v2f32_0_0_0(
87+
i32 0, i32 0, i32 1, i32 0, i1 false)
88+
89+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
90+
%data0 = call <2 x float> @llvm.dx.typedBufferLoad(
91+
target("dx.TypedBuffer", <2 x float>, 0, 0, 0) %buffer, i32 0)
92+
93+
ret void
94+
}
95+
96+
define void @loadv4i32() {
97+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
98+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
99+
%buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
100+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
101+
i32 0, i32 0, i32 1, i32 0, i1 false)
102+
103+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
104+
%data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
105+
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
106+
107+
ret void
108+
}
109+
110+
define void @loadv4f16() {
111+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
112+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
113+
%buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
114+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
115+
i32 0, i32 0, i32 1, i32 0, i1 false)
116+
117+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
118+
%data0 = call <4 x half> @llvm.dx.typedBufferLoad(
119+
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
120+
121+
ret void
122+
}
123+
124+
define void @loadv4i16() {
125+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
126+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
127+
%buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
128+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
129+
i32 0, i32 0, i32 1, i32 0, i1 false)
130+
131+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
132+
%data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
133+
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
134+
135+
ret void
136+
}

llvm/utils/TableGen/DXILEmitter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
187187
.Case("Int8Ty", "OverloadKind::I8")
188188
.Case("Int16Ty", "OverloadKind::I16")
189189
.Case("Int32Ty", "OverloadKind::I32")
190-
.Case("Int64Ty", "OverloadKind::I64");
190+
.Case("Int64Ty", "OverloadKind::I64")
191+
.Case("ResRetHalfTy", "OverloadKind::HALF")
192+
.Case("ResRetFloatTy", "OverloadKind::FLOAT")
193+
.Case("ResRetInt16Ty", "OverloadKind::I16")
194+
.Case("ResRetInt32Ty", "OverloadKind::I32");
191195
}
192196

193197
/// Return a string representation of valid overload information denoted

0 commit comments

Comments
 (0)