Skip to content

[DirectX] Don't limit visitGetElementPtrInst to global ptrs #144959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2025

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Jun 19, 2025

fixes #144608

  • there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst.

  • The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.

fixes llvm#144608
- there is a getPointerOperandIndex function so we don't need to iterate
  the operands trying to find the pointer. This resulted in a small
  cleanup to visitStoreInst and visitLoadInst.

- The meat of this change was in visitGetElementPtrInst to account for
  allocas and not bail when we don't find a global.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes

fixes #144608

  • there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst.

  • The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.


Full diff: https://github.com/llvm/llvm-project/pull/144959.diff

2 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (+55-49)
  • (modified) llvm/test/CodeGen/DirectX/scalarize-alloca.ll (+17-2)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 06708cec00cec..61c5301ed5051 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -14,11 +14,13 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/ReplaceConstant.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 
@@ -127,71 +129,75 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
 }
 
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
-  unsigned NumOperands = LI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = LI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(LI.getIterator());
-      IRBuilder<> Builder(&LI);
-      LoadInst *NewLoad =
-          Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
-      NewLoad->setAlignment(LI.getAlign());
-      LI.replaceAllUsesWith(NewLoad);
-      LI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      LI.setOperand(I, NewGlobal);
+  Value *PtrOperand = LI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(LI.getIterator());
+    IRBuilder<> Builder(&LI);
+    LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+    NewLoad->setAlignment(LI.getAlign());
+    LI.replaceAllUsesWith(NewLoad);
+    LI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
   return false;
 }
 
 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
-  unsigned NumOperands = SI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = SI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(SI.getIterator());
-      IRBuilder<> Builder(&SI);
-      StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
-      NewStore->setAlignment(SI.getAlign());
-      SI.replaceAllUsesWith(NewStore);
-      SI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      SI.setOperand(I, NewGlobal);
+
+  Value *PtrOperand = SI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(SI.getIterator());
+    IRBuilder<> Builder(&SI);
+    StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+    NewStore->setAlignment(SI.getAlign());
+    SI.replaceAllUsesWith(NewStore);
+    SI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
+
   return false;
 }
 
 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-
-  unsigned NumOperands = GEPI.getNumOperands();
-  GlobalVariable *NewGlobal = nullptr;
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = GEPI.getOperand(I);
-    NewGlobal = lookupReplacementGlobal(CurrOpperand);
-    if (NewGlobal)
-      break;
+  Value *PtrOperand = GEPI.getPointerOperand();
+  Type *OrigGEPType = GEPI.getPointerOperandType();
+  Type *NewGEPType = OrigGEPType;
+  bool NeedsTransform = false;
+
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
+    NewGEPType = NewGlobal->getValueType();
+    PtrOperand = NewGlobal;
+    NeedsTransform = true;
+  } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+    Type *AllocatedType = Alloca->getAllocatedType();
+    // OrigGEPType might just be a pointer lets make sure
+    // to add the allocated type so we have a size
+    if (AllocatedType != OrigGEPType) {
+      NewGEPType = AllocatedType;
+      NeedsTransform = true;
+    }
   }
-  if (!NewGlobal)
+
+  // Note: We bail if this isn't a gep touched via alloca or global
+  // transformations
+  if (!NeedsTransform)
     return false;
 
   IRBuilder<> Builder(&GEPI);
   SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
 
-  Value *NewGEP =
-      Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
-                        GEPI.getName(), GEPI.getNoWrapFlags());
+  Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
+                                    GEPI.getName(), GEPI.getNoWrapFlags());
   GEPI.replaceAllUsesWith(NewGEP);
   GEPI.eraseFromParent();
   return true;
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index 4829f3a31791f..b589136d6965c 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -1,10 +1,25 @@
-; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
-; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
 
 ; CHECK-LABEL: alloca_2d__vec_test
 define void @alloca_2d__vec_test() local_unnamed_addr #2 {
   ; SCHECK:  alloca [2 x [4 x i32]], align 16
   ; FCHECK:  alloca [8 x i32], align 16
+  ; CHECK: ret void
   %1 = alloca [2 x <4 x i32>], align 16
   ret void
 }
+
+; CHECK-LABEL: alloca_2d_gep_test
+define void @alloca_2d_gep_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [2 x [2 x i32]], align 16
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [4 x i32], align 16
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; CHECK: ret void
+  %1 = alloca [2 x <2 x i32>], align 16
+  %2 = tail call i32 @llvm.dx.thread.id(i32 0)
+  %3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
+  ret void
+}

@farzonl farzonl merged commit 2a4207e into llvm:main Jun 20, 2025
10 checks passed
Jaddyen pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 23, 2025
)

fixes llvm#144608
- there is a getPointerOperandIndex function so we don't need to iterate
the operands trying to find the pointer. This resulted in a small
cleanup to visitStoreInst and visitLoadInst.

- The meat of this change was in visitGetElementPtrInst to account for
allocas and not bail when we don't find a global.
@damyanp damyanp removed this from HLSL Support Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DirectX] DXIL Data Scalarization is not scalarizing GEP for an array of vectors in function parameter
4 participants