Skip to content

[DirectX] Implement the DXILCBufferAccess pass #134571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions llvm/include/llvm/Frontend/HLSL/CBuffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- CBuffer.h - HLSL constant buffer handling ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains utilities to work with constant buffers in HLSL.
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_FRONTEND_HLSL_CBUFFER_H
#define LLVM_FRONTEND_HLSL_CBUFFER_H

#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include <optional>

namespace llvm {
class Module;
class GlobalVariable;
class NamedMDNode;

namespace hlsl {

struct CBufferMember {
GlobalVariable *GV;
size_t Offset;

CBufferMember(GlobalVariable *GV, size_t Offset) : GV(GV), Offset(Offset) {}
};

struct CBufferMapping {
GlobalVariable *Handle;
SmallVector<CBufferMember> Members;

CBufferMapping(GlobalVariable *Handle) : Handle(Handle) {}
};

class CBufferMetadata {
NamedMDNode *MD;
SmallVector<CBufferMapping> Mappings;

CBufferMetadata(NamedMDNode *MD) : MD(MD) {}

public:
static std::optional<CBufferMetadata> get(Module &M);

using iterator = SmallVector<CBufferMapping>::iterator;
iterator begin() { return Mappings.begin(); }
iterator end() { return Mappings.end(); }

void eraseFromModule();
};

APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
ArrayType *Ty);

} // namespace hlsl
} // namespace llvm

#endif // LLVM_FRONTEND_HLSL_CBUFFER_H
71 changes: 71 additions & 0 deletions llvm/lib/Frontend/HLSL/CBuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- CBuffer.cpp - HLSL constant buffer handling ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/Frontend/HLSL/HLSLResource.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"

using namespace llvm;
using namespace llvm::hlsl;

static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");

auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");

// Skip the "size" parameter.
size_t ParamIndex = Index + 1;
assert(LayoutTy->getNumIntParameters() > ParamIndex &&
"Not enough parameters");

return LayoutTy->getIntParameter(ParamIndex);
}

std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
if (!CBufMD)
return std::nullopt;

std::optional<CBufferMetadata> Result({CBufMD});

for (const MDNode *MD : CBufMD->operands()) {
assert(MD->getNumOperands() && "Invalid cbuffer metadata");

auto *Handle = cast<GlobalVariable>(
cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);

for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
Metadata *OpMD = MD->getOperand(I);
// Some members may be null if they've been optimized out.
if (!OpMD)
continue;
auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
}
}

return Result;
}

void CBufferMetadata::eraseFromModule() {
// Remove the cbs named metadata
MD->eraseFromParent();
}

APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
ArrayType *Ty) {
int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
int64_t RoundUp = alignTo(TypeSize, Align(CBufferRowSizeInBytes));
return Offset.udiv(TypeSize) * RoundUp;
}
1 change: 1 addition & 0 deletions llvm/lib/Frontend/HLSL/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_llvm_component_library(LLVMFrontendHLSL
CBuffer.cpp
HLSLResource.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
DirectXTargetMachine.cpp
DirectXTargetTransformInfo.cpp
DXContainerGlobals.cpp
DXILCBufferAccess.cpp
DXILDataScalarization.cpp
DXILFinalizeLinkage.cpp
DXILFlattenArrays.cpp
Expand Down
210 changes: 210 additions & 0 deletions llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right place for this transformation? I expect that we would want this pass to run for all backends. We would definitely want it for SPIR-V. Could we move it into an HLSL directory as in #134260?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking here is that the generic logic that's helpful for all backends should belong in Frontend/HLSL/CBuffer.h, but the pass itself is fairly DirectX specific. While all of the logic to figure out offsets and memory layout is necessary for all targets, the details of what this should transform into aren't necessarily compatible.

For example, the DirectX backend is constrained in that it needs to access the cbuffers via an operation that loads a single 16-byte row, so we need to lower to this series of dx.cbuffer.load.cbufferrow operations and then piece together the data we actually want to load. I can't imagine that this would be the best way to represent this for SPIR-V, where all we really care about is where the object ended up in memory and how it's padded but can use normal load operations from there.

So trying to put all of this in a generic pass that's aware of the various backends and their different target intrinsics feels like it would be wrong.

All that said, I'm not sure if the balance between what's in the pass itself and what's done in lib/Frontend is correct here - we may want to move more stuff to there when we implement a similar change for SPIR-V.

Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILCBufferAccess.h"
#include "DirectX.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/Frontend/HLSL/HLSLResource.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/Local.h"

#define DEBUG_TYPE "dxil-cbuffer-access"
using namespace llvm;

namespace {
/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
struct CBufferRowIntrin {
Intrinsic::ID IID;
Type *RetTy;
unsigned int EltSize;
unsigned int NumElts;

CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
assert(Ty == Ty->getScalarType() && "Expected scalar type");

switch (DL.getTypeSizeInBits(Ty)) {
case 16:
IID = Intrinsic::dx_resource_load_cbufferrow_8;
RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
EltSize = 2;
NumElts = 8;
break;
case 32:
IID = Intrinsic::dx_resource_load_cbufferrow_4;
RetTy = StructType::get(Ty, Ty, Ty, Ty);
EltSize = 4;
NumElts = 4;
break;
case 64:
IID = Intrinsic::dx_resource_load_cbufferrow_2;
RetTy = StructType::get(Ty, Ty);
EltSize = 8;
NumElts = 2;
break;
default:
llvm_unreachable("Only 16, 32, and 64 bit types supported");
}
}
};
} // namespace

static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
const DataLayout &DL) {
// Since we should always have a constant offset, we should only ever have a
// single GEP of indirection from the Global.
assert(GEP->getPointerOperand() == Global &&
"Indirect access to resource handle");

APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
(void)Success;
assert(Success && "Offsets into cbuffer globals must be constant");

if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);

return ConstantOffset.getZExtValue();
}

/// Replace access via cbuffer global with a load from the cbuffer handle
/// itself.
static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
GlobalVariable *HandleGV, size_t BaseOffset,
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
const DataLayout &DL = HandleGV->getDataLayout();

size_t Offset = BaseOffset;
if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
Offset += getOffsetForCBufferGEP(GEP, Global, DL);
else if (LI->getPointerOperand() != Global)
llvm_unreachable("Load instruction doesn't reference cbuffer global");

IRBuilder<> Builder(LI);
auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
HandleGV->getName());

Type *Ty = LI->getType();
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
// The cbuffer consists of some number of 16-byte rows.
unsigned int CurrentRow = Offset / hlsl::CBufferRowSizeInBytes;
unsigned int CurrentIndex =
(Offset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;

auto *CBufLoad = Builder.CreateIntrinsic(
Intrin.RetTy, Intrin.IID,
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
LI->getName());
auto *Elt =
Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());

Value *Result = nullptr;
unsigned int Remaining =
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
if (Remaining == 0) {
// We only have a single element, so we're done.
Result = Elt;

// However, if we loaded a <1 x T>, then we need to adjust the type here.
if (auto *VT = dyn_cast<FixedVectorType>(LI->getType())) {
assert(VT->getNumElements() == 1 && "Can't have multiple elements here");
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
Builder.getInt32(0));
}
} else {
// Walk each element and extract it, wrapping to new rows as needed.
SmallVector<Value *> Extracts{Elt};
while (Remaining--) {
CurrentIndex %= Intrin.NumElts;

if (CurrentIndex == 0)
CBufLoad = Builder.CreateIntrinsic(
Intrin.RetTy, Intrin.IID,
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
nullptr, LI->getName());

Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
LI->getName()));
}

// Finally, we build up the original loaded value.
Result = PoisonValue::get(Ty);
for (int I = 0, E = Extracts.size(); I < E; ++I)
Result =
Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
}

LI->replaceAllUsesWith(Result);
DeadInsts.push_back(LI);
}

static void replaceAccessesWithHandle(GlobalVariable *Global,
GlobalVariable *HandleGV,
size_t BaseOffset) {
SmallVector<WeakTrackingVH> DeadInsts;

SmallVector<User *> ToProcess{Global->users()};
while (!ToProcess.empty()) {
User *Cur = ToProcess.pop_back_val();

// If we have a load instruction, replace the access.
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
continue;
}

// Otherwise, walk users looking for a load...
ToProcess.append(Cur->user_begin(), Cur->user_end());
}
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
}

static bool replaceCBufferAccesses(Module &M) {
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
if (!CBufMD)
return false;

for (const hlsl::CBufferMapping &Mapping : *CBufMD)
for (const hlsl::CBufferMember &Member : Mapping.Members) {
replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
Member.GV->removeFromParent();
}

CBufMD->eraseFromModule();
return true;
}

PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
PreservedAnalyses PA;
bool Changed = replaceCBufferAccesses(M);

if (!Changed)
return PreservedAnalyses::all();
return PA;
}

namespace {
class DXILCBufferAccessLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
StringRef getPassName() const override { return "DXIL CBuffer Access"; }
DXILCBufferAccessLegacy() : ModulePass(ID) {}

static char ID; // Pass identification.
};
char DXILCBufferAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
false, false)

ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
return new DXILCBufferAccessLegacy();
}
28 changes: 28 additions & 0 deletions llvm/lib/Target/DirectX/DXILCBufferAccess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- DXILCBufferAccess.h - Translate CBuffer Loads ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file Pass for replacing loads from cbuffers in the cbuffer address space to
// cbuffer load intrinsics.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
#define LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILCBufferAccess : public PassInfoMixin<DXILCBufferAccess> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
Loading
Loading