Skip to content

[DirectX] Legalize i8 allocas #137399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 101 additions & 12 deletions llvm/lib/Target/DirectX/DXILLegalizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
Expand All @@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
ToRemove.push_back(FI);
}

static void fixI8TruncUseChain(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {
static void fixI8UseChain(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {

auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
Type *InstrType = IntegerType::get(I.getContext(), 32);

for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
if (ReplacedValues.count(Op) &&
ReplacedValues[Op]->getType()->isIntegerTy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this check is required so that we don't replace the type of a store

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes specificslly Load/Store had ptr types and I wanted to exclude ptrs. from the replacement value tracking and it didn't make sense to consider any other types like fp.

InstrType = ReplacedValues[Op]->getType();
}

Expand Down Expand Up @@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
}
}

if (auto *Store = dyn_cast<StoreInst>(&I)) {
if (!Store->getValueOperand()->getType()->isIntegerTy(8))
return;
SmallVector<Value *> NewOperands;
ProcessOperands(NewOperands);
Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
ReplacedValues[Store] = NewStore;
ToRemove.push_back(Store);
return;
}

if (auto *Load = dyn_cast<LoadInst>(&I)) {
if (!I.getType()->isIntegerTy(8))
return;
SmallVector<Value *> NewOperands;
ProcessOperands(NewOperands);
Type *ElementType = NewOperands[0]->getType();
if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
ElementType = AI->getAllocatedType();
LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
ReplacedValues[Load] = NewLoad;
ToRemove.push_back(Load);
return;
}

if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
if (!I.getType()->isIntegerTy(8))
return;
Expand All @@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
Value *NewInst =
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
if (OBO->hasNoSignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
if (OBO->hasNoUnsignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can this ever fail? Maybe we can just assert it? If it does fail should we be early exiting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if arg 0 and arg 1 are both constants Int/fp then you can't dyn_cast back to a BinaryOperator. The Value gets converted to a Constant. So we can't assert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting to OverflowingBinaryOperator here anyway? It's in the Operator class heirarchy that abstracts over instructions and constant exprs.

It is pretty surprising / confusing that BinaryOperator is not in this heirarchy, but I guess that naming is a bit of a historical accident

Copy link
Member Author

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting to OverflowingBinaryOperator here anyway

OverflowingBinaryOperator does not have the setters that we need that are public. Thats why we are casting back to BinaryOperator. But since we can't cast to BinaryOperator we can't set setHasNoSignedWrap or setHasNoUnsignedWrap

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I hadn't realized that setHasNoUnsignedWrap is private in OverflowingBinaryOperator and got confused. Thanks for double checking.

if (NewBO && OBO->hasNoSignedWrap())
NewBO->setHasNoSignedWrap();
if (NewBO && OBO->hasNoUnsignedWrap())
NewBO->setHasNoUnsignedWrap();
}
ReplacedValues[BO] = NewInst;
ToRemove.push_back(BO);
return;
}

if (auto *Sel = dyn_cast<SelectInst>(&I)) {
if (!I.getType()->isIntegerTy(8))
return;
SmallVector<Value *> NewOperands;
ProcessOperands(NewOperands);
Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
NewOperands[2]);
ReplacedValues[Sel] = NewInst;
ToRemove.push_back(Sel);
return;
}

if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
return;
Expand All @@ -105,13 +145,61 @@ static void fixI8TruncUseChain(Instruction &I,
}

if (auto *Cast = dyn_cast<CastInst>(&I)) {
if (Cast->getSrcTy()->isIntegerTy(8)) {
ToRemove.push_back(Cast);
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
if (!Cast->getSrcTy()->isIntegerTy(8))
return;

ToRemove.push_back(Cast);
auto *Replacement = ReplacedValues[Cast->getOperand(0)];
if (Cast->getType() == Replacement->getType()) {
Cast->replaceAllUsesWith(Replacement);
return;
}
Value *AdjustedCast = nullptr;
if (Cast->getOpcode() == Instruction::ZExt)
AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
if (Cast->getOpcode() == Instruction::SExt)
AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());

if (AdjustedCast)
Cast->replaceAllUsesWith(AdjustedCast);
}
}

static void upcastI8AllocasAndUses(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {
auto *AI = dyn_cast<AllocaInst>(&I);
if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
return;

Type *SmallestType = nullptr;

// Gather all cast targets
for (User *U : AI->users()) {
auto *Load = dyn_cast<LoadInst>(U);
if (!Load)
continue;
for (User *LU : Load->users()) {
auto *Cast = dyn_cast<CastInst>(LU);
if (!Cast)
continue;
Type *Ty = Cast->getType();
if (!SmallestType ||
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
SmallestType = Ty;
}
}

if (!SmallestType)
return; // no valid casts found

// Replace alloca
IRBuilder<> Builder(AI);
auto *NewAlloca = Builder.CreateAlloca(SmallestType);
ReplacedValues[AI] = NewAlloca;
ToRemove.push_back(AI);
}

static void
downcastI64toI32InsertExtractElements(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
Expand Down Expand Up @@ -178,7 +266,8 @@ class DXILLegalizationPipeline {
LegalizationPipeline;

void initializeLegalizationPipeline() {
LegalizationPipeline.push_back(fixI8TruncUseChain);
LegalizationPipeline.push_back(upcastI8AllocasAndUses);
LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
}
Expand Down
99 changes: 99 additions & 0 deletions llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define void @const_i8_store() {
; CHECK-LABEL: define void @const_i8_store() {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [1 x i32], align 4
%i = alloca i8, align 4
store i8 1, ptr %i
%i8.load = load i8, ptr %i
%z = zext i8 %i8.load to i32
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
store i32 %z, ptr %gep, align 4
ret void
}

define void @const_add_i8_store() {
; CHECK-LABEL: define void @const_add_i8_store() {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [1 x i32], align 4
%i = alloca i8, align 4
%add_i8 = add nsw i8 3, 1
store i8 %add_i8, ptr %i
%i8.load = load i8, ptr %i
%z = zext i8 %i8.load to i32
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
store i32 %z, ptr %gep, align 4
ret void
}

define void @var_i8_store(i1 %cmp.i8) {
; CHECK-LABEL: define void @var_i8_store(
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [1 x i32], align 4
%i = alloca i8, align 4
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
store i8 %select.i8, ptr %i
%i8.load = load i8, ptr %i
%z = zext i8 %i8.load to i32
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
store i32 %z, ptr %gep, align 4
ret void
}

define void @conflicting_cast(i1 %cmp.i8) {
; CHECK-LABEL: define void @conflicting_cast(
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2
; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [2 x i32], align 4
%i = alloca i8, align 4
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
store i8 %select.i8, ptr %i
%i8.load = load i8, ptr %i
%z = zext i8 %i8.load to i16
%gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
store i16 %z, ptr %gep1, align 2
%gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
store i16 %z, ptr %gep2, align 2
%z2 = zext i8 %i8.load to i32
%gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
store i32 %z2, ptr %gep3, align 4
ret void
}
Loading