Skip to content

[HLSL] Raise Diag for Invalid CounterDirection #137697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,14 @@ ModulePass *createDXILResourceTypeWrapperPassPass();
//===----------------------------------------------------------------------===//

class DXILResourceMap {
using CallMapTy = DenseMap<CallInst *, unsigned>;

SmallVector<dxil::ResourceInfo> Infos;
DenseMap<CallInst *, unsigned> CallMap;
CallMapTy CallMap;
unsigned FirstUAV = 0;
unsigned FirstCBuffer = 0;
unsigned FirstSampler = 0;
bool HasInvalidDirection = false;

/// Populate all the resource instance data.
void populate(Module &M, DXILResourceTypeMap &DRTM);
Expand Down Expand Up @@ -532,6 +535,23 @@ class DXILResourceMap {
return make_range(sampler_begin(), sampler_end());
}

struct call_iterator
: iterator_adaptor_base<call_iterator, CallMapTy::iterator> {
call_iterator() = default;
call_iterator(CallMapTy::iterator Iter)
: call_iterator::iterator_adaptor_base(std::move(Iter)) {}

CallInst *operator*() const { return I->first; }
};

call_iterator call_begin() { return call_iterator(CallMap.begin()); }
call_iterator call_end() { return call_iterator(CallMap.end()); }
iterator_range<call_iterator> calls() {
return make_range(call_begin(), call_end());
}

bool hasInvalidCounterDirection() const { return HasInvalidDirection; }

void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
const DataLayout &DL) const;

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Analysis/DXILResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,10 @@ void DXILResourceMap::populateCounterDirections(Module &M) {
for (ResourceInfo *RBInfo : RBInfos) {
if (RBInfo->CounterDirection == ResourceCounterDirection::Unknown)
RBInfo->CounterDirection = Direction;
else if (RBInfo->CounterDirection != Direction)
else if (RBInfo->CounterDirection != Direction) {
RBInfo->CounterDirection = ResourceCounterDirection::Invalid;
HasInvalidDirection = true;
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_llvm_target(DirectXCodeGen
DXILIntrinsicExpansion.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
DXILPostOptimizationValidation.cpp
DXILPrepare.cpp
DXILPrettyPrinter.cpp
DXILResourceAccess.cpp
Expand Down
102 changes: 102 additions & 0 deletions llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===- DXILPostOptimizationValidation.cpp - Opt DXIL validation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILPostOptimizationValidation.h"
#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"

#define DEBUG_TYPE "dxil-post-optimization-validation"

using namespace llvm;
using namespace llvm::dxil;

namespace {

static void reportInvalidDirection(Module &M, DXILResourceMap &DRM) {
for (const auto &UAV : DRM.uavs()) {
if (UAV.CounterDirection != ResourceCounterDirection::Invalid)
continue;

CallInst *ResourceHandle = nullptr;
for (CallInst *MaybeHandle : DRM.calls()) {
if (*DRM.find(MaybeHandle) == UAV) {
ResourceHandle = MaybeHandle;
break;
}
}

StringRef Message = "RWStructuredBuffers may increment or decrement their "
"counters, but not both.";
for (const auto &U : ResourceHandle->users()) {
const CallInst *CI = dyn_cast<CallInst>(U);
if (!CI && CI->getIntrinsicID() != Intrinsic::dx_resource_updatecounter)
continue;

M.getContext().diagnose(DiagnosticInfoGenericWithLoc(
Message, *CI->getFunction(), CI->getDebugLoc()));
}
}
}

} // namespace

PreservedAnalyses
DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);

if (DRM.hasInvalidCounterDirection())
reportInvalidDirection(M, DRM);

return PreservedAnalyses::all();
}

namespace {
class DXILPostOptimizationValidationLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override {
DXILResourceMap &DRM =
getAnalysis<DXILResourceWrapperPass>().getResourceMap();

if (DRM.hasInvalidCounterDirection())
reportInvalidDirection(M, DRM);

return false;
}
StringRef getPassName() const override {
return "DXIL Post Optimization Validation";
}
DXILPostOptimizationValidationLegacy() : ModulePass(ID) {}

static char ID; // Pass identification.
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
AU.addRequired<DXILResourceWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
}
};
char DXILPostOptimizationValidationLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS_BEGIN(DXILPostOptimizationValidationLegacy, DEBUG_TYPE,
"DXIL Post Optimization Validation", false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
INITIALIZE_PASS_END(DXILPostOptimizationValidationLegacy, DEBUG_TYPE,
"DXIL Post Optimization Validation", false, false)

ModulePass *llvm::createDXILPostOptimizationValidationLegacyPass() {
return new DXILPostOptimizationValidationLegacy();
}
29 changes: 29 additions & 0 deletions llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===- DXILPostOptimizationValidation.h - Opt DXIL Validations -*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file Pass for validating IR after optimizations are applied and before
// lowering to DXIL.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
#define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILPostOptimizationValidation
: public PassInfoMixin<DXILPostOptimizationValidation> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
6 changes: 6 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ ModulePass *createDXILPrettyPrinterLegacyPass(raw_ostream &OS);
/// Initializer for DXILPrettyPrinter.
void initializeDXILPrettyPrinterLegacyPass(PassRegistry &);

/// Initializer for DXILPostOptimizationValidation.
void initializeDXILPostOptimizationValidationLegacyPass(PassRegistry &);

/// Pass to lowering LLVM intrinsic call to DXIL op function call.
ModulePass *createDXILPostOptimizationValidationLegacyPass();

/// Initializer for dxil::ShaderFlagsAnalysisWrapper pass.
void initializeShaderFlagsAnalysisWrapperPass(PassRegistry &);

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion())
MODULE_PASS("dxil-op-lower", DXILOpLowering())
MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs()))
MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
MODULE_PASS("dxil-post-optimization-validation", DXILPostOptimizationValidation())
// TODO: rename to print<foo> after NPM switch
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbgs()))
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DXILLegalizePass.h"
#include "DXILOpLowering.h"
#include "DXILPostOptimizationValidation.h"
#include "DXILPrettyPrinter.h"
#include "DXILResourceAccess.h"
#include "DXILRootSignature.h"
Expand Down Expand Up @@ -63,6 +64,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeDXILOpLoweringLegacyPass(*PR);
initializeDXILResourceAccessLegacyPass(*PR);
initializeDXILTranslateMetadataLegacyPass(*PR);
initializeDXILPostOptimizationValidationLegacyPass(*PR);
initializeShaderFlagsAnalysisWrapperPass(*PR);
initializeRootSignatureAnalysisWrapperPass(*PR);
initializeDXILFinalizeLinkageLegacyPass(*PR);
Expand Down Expand Up @@ -110,6 +112,7 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createDXILForwardHandleAccessesLegacyPass());
addPass(createDXILLegalizeLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
addPass(createDXILPostOptimizationValidationLegacyPass());
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILPrepareModulePass());
}
Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/DirectX/llc-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
; CHECK-NEXT: DXIL Module Metadata analysis
; CHECK-NEXT: DXIL Shader Flag Analysis
; CHECK-NEXT: DXIL Translate Metadata
; CHECK-NEXT: DXIL Post Optimization Validation
; CHECK-NEXT: DXIL Op Lowering
; CHECK-NEXT: DXIL Prepare Module

Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/resource_counter_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -passes='dxil-post-optimization-validation' -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
; CHECK: RWStructuredBuffers may increment or decrement their counters, but not both.

define void @inc_and_dec() {
entry:
%handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
ret void
}
Loading