Skip to content

Commit 2c88ac9

Browse files
authored
[DirectX] Clean up extra vectors when lowering to buffer store (#116721)
DXILOpLowering runs after scalarization but `@llvm.dx.typedbuffer.store` takes a vector, so the argument is usually an artifact. Avoid creating a vector just to extract elements from it immediately.
1 parent 0aa7892 commit 2c88ac9

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -568,23 +568,47 @@ class OpLowerer {
568568
return make_error<StringError>(
569569
"typedBufferStore data must be a vector of 4 elements",
570570
inconvertibleErrorCode());
571-
Value *Data0 =
572-
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0));
573-
Value *Data1 =
574-
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1));
575-
Value *Data2 =
576-
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2));
577-
Value *Data3 =
578-
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3));
579-
580-
std::array<Value *, 8> Args{Handle, Index0, Index1, Data0,
581-
Data1, Data2, Data3, Mask};
571+
572+
// Since we're post-scalarizer, we likely have a vector that's constructed
573+
// solely for the argument of the store. If so, just use the scalar values
574+
// from before they're inserted into the temporary.
575+
std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
576+
auto *IEI = dyn_cast<InsertElementInst>(Data);
577+
while (IEI) {
578+
auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
579+
if (!IndexOp)
580+
break;
581+
size_t IndexVal = IndexOp->getZExtValue();
582+
assert(IndexVal < 4 && "Too many elements for buffer store");
583+
DataElements[IndexVal] = IEI->getOperand(1);
584+
IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
585+
}
586+
587+
// If for some reason we weren't able to forward the arguments from the
588+
// scalarizer artifact, then we need to actually extract elements from the
589+
// vector.
590+
for (int I = 0, E = 4; I != E; ++I)
591+
if (DataElements[I] == nullptr)
592+
DataElements[I] =
593+
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
594+
595+
std::array<Value *, 8> Args{
596+
Handle, Index0, Index1, DataElements[0],
597+
DataElements[1], DataElements[2], DataElements[3], Mask};
582598
Expected<CallInst *> OpCall =
583599
OpBuilder.tryCreateOp(OpCode::BufferStore, Args, CI->getName());
584600
if (Error E = OpCall.takeError())
585601
return E;
586602

587603
CI->eraseFromParent();
604+
// Clean up any leftover `insertelement`s
605+
IEI = dyn_cast<InsertElementInst>(Data);
606+
while (IEI && IEI->use_empty()) {
607+
InsertElementInst *Tmp = IEI;
608+
IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
609+
Tmp->eraseFromParent();
610+
}
611+
588612
return Error::success();
589613
});
590614
}

llvm/test/CodeGen/DirectX/BufferStore.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,27 @@ define void @storei16(<4 x i16> %data, i32 %index) {
9090

9191
ret void
9292
}
93+
94+
define void @store_scalarized_floats(float %data0, float %data1, float %data2, float %data3, i32 %index) {
95+
96+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
97+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
98+
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
99+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
100+
i32 0, i32 0, i32 1, i32 0, i1 false)
101+
102+
; We shouldn't end up with any inserts/extracts.
103+
; CHECK-NOT: insertelement
104+
; CHECK-NOT: extractelement
105+
106+
; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float %data0, float %data1, float %data2, float %data3, i8 15)
107+
%vec.upto0 = insertelement <4 x float> poison, float %data0, i64 0
108+
%vec.upto1 = insertelement <4 x float> %vec.upto0, float %data1, i64 1
109+
%vec.upto2 = insertelement <4 x float> %vec.upto1, float %data2, i64 2
110+
%vec = insertelement <4 x float> %vec.upto2, float %data3, i64 3
111+
call void @llvm.dx.typedBufferStore(
112+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
113+
i32 %index, <4 x float> %vec)
114+
115+
ret void
116+
}

0 commit comments

Comments
 (0)