Skip to content

[DirectX] Legalize i8 allocas #137399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged

[DirectX] Legalize i8 allocas #137399

merged 3 commits into from
Apr 29, 2025

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Apr 25, 2025

fixes #137202

investingating i8 allocas I came to find some missing instructions from out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the casts.

After finding the casts I chose to pick the smallest cast as the cast to transform to. That would then let me preserve the larger casts that come later

@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2025

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes

fixes #137202

investingating i8 allocas I came to find some missing instructions from out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the casts.

After finding the casts I chose to pick the smallest cast as the cast to transform to. That would then let me preserve the larger casts that come later


Full diff: https://github.com/llvm/llvm-project/pull/137399.diff

2 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILLegalizePass.cpp (+102-12)
  • (added) llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll (+99)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index b62ff4c52f70c..b7b209fcecbc9 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -12,6 +12,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
   ToRemove.push_back(FI);
 }
 
-static void fixI8TruncUseChain(Instruction &I,
-                               SmallVectorImpl<Instruction *> &ToRemove,
-                               DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8UseChain(Instruction &I,
+                          SmallVectorImpl<Instruction *> &ToRemove,
+                          DenseMap<Value *, Value *> &ReplacedValues) {
 
   auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
     Type *InstrType = IntegerType::get(I.getContext(), 32);
 
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
-      if (ReplacedValues.count(Op))
+      if (ReplacedValues.count(Op) &&
+          ReplacedValues[Op]->getType()->isIntegerTy())
         InstrType = ReplacedValues[Op]->getType();
     }
 
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
     }
   }
 
+  if (auto *Store = dyn_cast<StoreInst>(&I)) {
+    if (!Store->getValueOperand()->getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
+    ReplacedValues[Store] = NewStore;
+    ToRemove.push_back(Store);
+    return;
+  }
+
+  if (auto *Load = dyn_cast<LoadInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Type *ElementType = NewOperands[0]->getType();
+    if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
+      ElementType = AI->getAllocatedType();
+    LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
+    ReplacedValues[Load] = NewLoad;
+    ToRemove.push_back(Load);
+    return;
+  }
+
   if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
     if (!I.getType()->isIntegerTy(8))
       return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
     Value *NewInst =
         Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
-      if (OBO->hasNoSignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
-      if (OBO->hasNoUnsignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
+      auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
+      if (NewBO && OBO->hasNoSignedWrap())
+        NewBO->setHasNoSignedWrap();
+      if (NewBO && OBO->hasNoUnsignedWrap())
+        NewBO->setHasNoUnsignedWrap();
     }
     ReplacedValues[BO] = NewInst;
     ToRemove.push_back(BO);
     return;
   }
 
+  if (auto *Sel = dyn_cast<SelectInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
+                                          NewOperands[2]);
+    ReplacedValues[Sel] = NewInst;
+    ToRemove.push_back(Sel);
+    return;
+  }
+
   if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
     if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
       return;
@@ -105,13 +145,62 @@ static void fixI8TruncUseChain(Instruction &I,
   }
 
   if (auto *Cast = dyn_cast<CastInst>(&I)) {
-    if (Cast->getSrcTy()->isIntegerTy(8)) {
-      ToRemove.push_back(Cast);
-      Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
+    if (!Cast->getSrcTy()->isIntegerTy(8))
+      return;
+    
+    ToRemove.push_back(Cast);
+    auto* Replacement =ReplacedValues[Cast->getOperand(0)];
+    if (Cast->getType() == Replacement->getType()) {
+      Cast->replaceAllUsesWith(Replacement);
+      return;
     }
+    Value* AdjustedCast = nullptr;
+    if (Cast->getOpcode() == Instruction::ZExt)
+      AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
+    if (Cast->getOpcode() == Instruction::SExt)
+      AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
+  
+    if(AdjustedCast)
+      Cast->replaceAllUsesWith(AdjustedCast);
   }
 }
 
+static void upcastI8AllocasAndUses(Instruction &I,
+                                   SmallVectorImpl<Instruction *> &ToRemove,
+                                   DenseMap<Value *, Value *> &ReplacedValues) {
+  auto *AI = dyn_cast<AllocaInst>(&I);
+  if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
+    return;
+
+  Type *SmallestType = nullptr;
+
+  // Gather all cast targets
+  for (User *U : AI->users()) {
+    auto *Load = dyn_cast<LoadInst>(U);
+    if (!Load)
+      continue;
+    for (User *LU : Load->users()) {
+      auto *Cast = dyn_cast<CastInst>(LU);
+      if (!Cast)
+        continue;
+      Type *Ty = Cast->getType();
+      if (!SmallestType ||
+          Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
+        SmallestType = Ty;
+    }
+  }
+
+  if (!SmallestType)
+    return; // no valid casts found
+
+  // Replace alloca
+  IRBuilder<> Builder(AI);
+  auto *NewAlloca =
+      Builder.CreateAlloca(SmallestType);
+  ReplacedValues[AI] = NewAlloca;
+  ToRemove.push_back(AI);
+}
+
 static void
 downcastI64toI32InsertExtractElements(Instruction &I,
                                       SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +267,8 @@ class DXILLegalizationPipeline {
       LegalizationPipeline;
 
   void initializeLegalizationPipeline() {
-    LegalizationPipeline.push_back(fixI8TruncUseChain);
+    LegalizationPipeline.push_back(upcastI8AllocasAndUses);
+    LegalizationPipeline.push_back(fixI8UseChain);
     LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
     LegalizationPipeline.push_back(legalizeFreeze);
   }
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
new file mode 100644
index 0000000000000..529a69fca5d34
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @const_i8_store() {
+; CHECK-LABEL: define void @const_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 1, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  store i8 1, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @const_add_i8_store() {
+; CHECK-LABEL: define void @const_add_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 4, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %add_i8 = add nsw i8 3, 1
+  store i8 %add_i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @var_i8_store(i1 %cmp.i8) {
+; CHECK-LABEL: define void @var_i8_store(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP3]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @conflicting_cast(i1 %cmp.i8) {
+; CHECK-LABEL: define void @conflicting_cast(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i16, align 2
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP1]], align 2
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP2]], align 2
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i32 [[TMP4]], ptr [[GEP3]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [2 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i16
+  %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
+  store i16 %z, ptr %gep1, align 2
+  %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
+  store i16 %z, ptr %gep2, align 2
+  %z2 = zext i8 %i8.load to i32
+  %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
+  store i32 %z2, ptr %gep3, align 4
+  ret void
+}

Copy link

github-actions bot commented Apr 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl force-pushed the bugfix/issue-137202 branch from f09d663 to 6e3a11d Compare April 28, 2025 20:04
@farzonl farzonl self-assigned this Apr 29, 2025

auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
Type *InstrType = IntegerType::get(I.getContext(), 32);

for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
if (ReplacedValues.count(Op) &&
ReplacedValues[Op]->getType()->isIntegerTy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this check is required so that we don't replace the type of a store

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes specificslly Load/Store had ptr types and I wanted to exclude ptrs. from the replacement value tracking and it didn't make sense to consider any other types like fp.

cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
if (OBO->hasNoUnsignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can this ever fail? Maybe we can just assert it? If it does fail should we be early exiting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if arg 0 and arg 1 are both constants Int/fp then you can't dyn_cast back to a BinaryOperator. The Value gets converted to a Constant. So we can't assert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting to OverflowingBinaryOperator here anyway? It's in the Operator class heirarchy that abstracts over instructions and constant exprs.

It is pretty surprising / confusing that BinaryOperator is not in this heirarchy, but I guess that naming is a bit of a historical accident

Copy link
Member Author

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting to OverflowingBinaryOperator here anyway

OverflowingBinaryOperator does not have the setters that we need that are public. Thats why we are casting back to BinaryOperator. But since we can't cast to BinaryOperator we can't set setHasNoSignedWrap or setHasNoUnsignedWrap

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I hadn't realized that setHasNoUnsignedWrap is private in OverflowingBinaryOperator and got confused. Thanks for double checking.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, but the logic around casting binary operators is a bit fishy.

cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
if (OBO->hasNoUnsignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting to OverflowingBinaryOperator here anyway? It's in the Operator class heirarchy that abstracts over instructions and constant exprs.

It is pretty surprising / confusing that BinaryOperator is not in this heirarchy, but I guess that naming is a bit of a historical accident

cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
if (OBO->hasNoUnsignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I hadn't realized that setHasNoUnsignedWrap is private in OverflowingBinaryOperator and got confused. Thanks for double checking.

@farzonl farzonl merged commit d3d35ad into llvm:main Apr 29, 2025
10 of 12 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
fixes llvm#137202

investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the
casts.

After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
fixes llvm#137202

investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the
casts.

After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
fixes llvm#137202

investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the
casts.

After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
fixes llvm#137202

investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the
casts.

After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
fixes llvm#137202

investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the
casts.

After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
@damyanp damyanp removed this from HLSL Support Jun 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DirectX] Legalize i8 allocas
4 participants