Skip to content

[HLSL] Move where ZExt happens in 'EmitStoreThroughExtVectorComponentLValue' to handle bug with hlsl boolean vector swizzles #140627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025

Conversation

spall
Copy link
Contributor

@spall spall commented May 19, 2025

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts in the case the Destination Scalar Type is larger than the Source Scalar Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:

bool4 b = true.xxxx;
b.xyz = false.xxx;

Leading to a bad shuffle vector.

Closes #140564

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels May 19, 2025
@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-codegen

Author: Sarah Spall (spall)

Changes

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts in the case the Destination Scalar Type is larger than the Source Scalar Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:

bool4 b = true.xxxx;
b.xyz = false.xxx;

Leading to a bad shuffle vector.

Closes #140564


Full diff: https://github.com/llvm/llvm-project/pull/140627.diff

2 Files Affected:

  • (modified) clang/lib/CodeGen/CGExpr.cpp (+9-13)
  • (modified) clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl (+22-3)
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 37a5678aa61d5..7580a490a66c1 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2694,14 +2694,20 @@ void CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,
 
 void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
                                                                LValue Dst) {
+  llvm::Value *SrcVal = Src.getScalarVal();
+  Address DstAddr = Dst.getExtVectorAddress();
+  if (DstAddr.getElementType()->getScalarSizeInBits() >
+      SrcVal->getType()->getScalarSizeInBits())
+    SrcVal = Builder.CreateZExt(
+        SrcVal, convertTypeForLoadStore(Dst.getType(), SrcVal->getType()));
+
   // HLSL allows storing to scalar values through ExtVector component LValues.
   // To support this we need to handle the case where the destination address is
   // a scalar.
-  Address DstAddr = Dst.getExtVectorAddress();
   if (!DstAddr.getElementType()->isVectorTy()) {
     assert(!Dst.getType()->isVectorType() &&
            "this should only occur for non-vector l-values");
-    Builder.CreateStore(Src.getScalarVal(), DstAddr, Dst.isVolatileQualified());
+    Builder.CreateStore(SrcVal, DstAddr, Dst.isVolatileQualified());
     return;
   }
 
@@ -2722,11 +2728,6 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         Mask[getAccessedFieldNo(i, Elts)] = i;
 
-      llvm::Value *SrcVal = Src.getScalarVal();
-      if (VecTy->getScalarSizeInBits() >
-          SrcVal->getType()->getScalarSizeInBits())
-        SrcVal = Builder.CreateZExt(SrcVal, VecTy);
-
       Vec = Builder.CreateShuffleVector(SrcVal, Mask);
     } else if (NumDstElts > NumSrcElts) {
       // Extended the source vector to the same length and then shuffle it
@@ -2737,8 +2738,7 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         ExtMask.push_back(i);
       ExtMask.resize(NumDstElts, -1);
-      llvm::Value *ExtSrcVal =
-          Builder.CreateShuffleVector(Src.getScalarVal(), ExtMask);
+      llvm::Value *ExtSrcVal = Builder.CreateShuffleVector(SrcVal, ExtMask);
       // build identity
       SmallVector<int, 4> Mask;
       for (unsigned i = 0; i != NumDstElts; ++i)
@@ -2764,10 +2764,6 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
     unsigned InIdx = getAccessedFieldNo(0, Elts);
     llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
 
-    llvm::Value *SrcVal = Src.getScalarVal();
-    if (VecTy->getScalarSizeInBits() > SrcVal->getType()->getScalarSizeInBits())
-      SrcVal = Builder.CreateZExt(SrcVal, VecTy->getScalarType());
-
     Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
   }
 
diff --git a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
index 96e17046ee934..8a3958ad8fd04 100644
--- a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
@@ -233,7 +233,8 @@ int AssignInt(int V){
 
 // CHECK: lor.end:
 // CHECK-NEXT: [[H:%.*]] = phi i1 [ true, %entry ], [ [[G]], %lor.rhs ]
-// CHECK-NEXT: store i1 [[H]], ptr [[XAddr]], align 4
+// CHECK-NEXT: [[J:%.*]] = zext i1 %9 to i32
+// CHECK-NEXT: store i32 [[J]], ptr [[XAddr]], align 4
 // CHECK-NEXT: [[I:%.*]] = load i32, ptr [[XAddr]], align 4
 // CHECK-NEXT: [[LoadV:%.*]] = trunc i32 [[I]] to i1
 // CHECK-NEXT: ret i1 [[LoadV]]
@@ -257,8 +258,8 @@ bool AssignBool(bool V) {
 // CHECK-NEXT: store <2 x i32> [[A]], ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = load i32, ptr [[VAddr]], align 4
 // CHECK-NEXT: [[LV1:%.*]] = trunc i32 [[B]] to i1
-// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[D:%.*]] = zext i1 [[LV1]] to i32
+// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[C]], i32 [[D]], i32 1
 // CHECK-NEXT: store <2 x i32> [[E]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -275,8 +276,8 @@ void AssignBool2(bool V) {
 // CHECK-NEXT: store <2 x i32> splat (i32 1), ptr [[X]], align 8
 // CHECK-NEXT: [[Z:%.*]] = load <2 x i32>, ptr [[VAddr]], align 8
 // CHECK-NEXT: [[LV:%.*]] = trunc <2 x i32> [[Z]] to <2 x i1>
-// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = zext <2 x i1> [[LV]] to <2 x i32>
+// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i32> [[B]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
 // CHECK-NEXT: store <2 x i32> [[C]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -302,3 +303,21 @@ bool2 AccessBools() {
   bool4 X = true.xxxx;
   return X.zw;
 }
+
+// CHECK-LABEL: define void {{.*}}BoolSizeMismatch{{.*}}
+// CHECK: [[B:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[Tmp:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[B]], align 16
+// CHECK-NEXT: store <1 x i32> zeroinitializer, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L0:%.*]] = load <1 x i32>, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L1:%.*]] = shufflevector <1 x i32> [[L0]], <1 x i32> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TruncV:%.*]] = trunc <3 x i32> [[L1]] to <3 x i1>
+// CHECK-NEXT: [[L2:%.*]] = zext <3 x i1> [[TruncV]] to <3 x i32>
+// CHECK-NEXT: [[L3:%.*]] = load <4 x i32>, ptr [[B]], align 16
+// CHECK-NEXT: [[L4:%.*]] = shufflevector <3 x i32> [[L2]], <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+// CHECK-NEXT: [[L5:%.*]] = shufflevector <4 x i32> [[L3]], <4 x i32> [[L4]], <4 x i32> <i32 4, i32 5, i32 6, i32 3>
+// CHECK-NEXT: store <4 x i32> [[L5]], ptr [[B]], align 16
+void BoolSizeMismatch() {
+  bool4 B = {true,true,true,true};
+  B.xyz = false.xxx;
+}

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts in the case the Destination Scalar Type is larger than the Source Scalar Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:

bool4 b = true.xxxx;
b.xyz = false.xxx;

Leading to a bad shuffle vector.

Closes #140564


Full diff: https://github.com/llvm/llvm-project/pull/140627.diff

2 Files Affected:

  • (modified) clang/lib/CodeGen/CGExpr.cpp (+9-13)
  • (modified) clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl (+22-3)
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 37a5678aa61d5..7580a490a66c1 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2694,14 +2694,20 @@ void CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,
 
 void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
                                                                LValue Dst) {
+  llvm::Value *SrcVal = Src.getScalarVal();
+  Address DstAddr = Dst.getExtVectorAddress();
+  if (DstAddr.getElementType()->getScalarSizeInBits() >
+      SrcVal->getType()->getScalarSizeInBits())
+    SrcVal = Builder.CreateZExt(
+        SrcVal, convertTypeForLoadStore(Dst.getType(), SrcVal->getType()));
+
   // HLSL allows storing to scalar values through ExtVector component LValues.
   // To support this we need to handle the case where the destination address is
   // a scalar.
-  Address DstAddr = Dst.getExtVectorAddress();
   if (!DstAddr.getElementType()->isVectorTy()) {
     assert(!Dst.getType()->isVectorType() &&
            "this should only occur for non-vector l-values");
-    Builder.CreateStore(Src.getScalarVal(), DstAddr, Dst.isVolatileQualified());
+    Builder.CreateStore(SrcVal, DstAddr, Dst.isVolatileQualified());
     return;
   }
 
@@ -2722,11 +2728,6 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         Mask[getAccessedFieldNo(i, Elts)] = i;
 
-      llvm::Value *SrcVal = Src.getScalarVal();
-      if (VecTy->getScalarSizeInBits() >
-          SrcVal->getType()->getScalarSizeInBits())
-        SrcVal = Builder.CreateZExt(SrcVal, VecTy);
-
       Vec = Builder.CreateShuffleVector(SrcVal, Mask);
     } else if (NumDstElts > NumSrcElts) {
       // Extended the source vector to the same length and then shuffle it
@@ -2737,8 +2738,7 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         ExtMask.push_back(i);
       ExtMask.resize(NumDstElts, -1);
-      llvm::Value *ExtSrcVal =
-          Builder.CreateShuffleVector(Src.getScalarVal(), ExtMask);
+      llvm::Value *ExtSrcVal = Builder.CreateShuffleVector(SrcVal, ExtMask);
       // build identity
       SmallVector<int, 4> Mask;
       for (unsigned i = 0; i != NumDstElts; ++i)
@@ -2764,10 +2764,6 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
     unsigned InIdx = getAccessedFieldNo(0, Elts);
     llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
 
-    llvm::Value *SrcVal = Src.getScalarVal();
-    if (VecTy->getScalarSizeInBits() > SrcVal->getType()->getScalarSizeInBits())
-      SrcVal = Builder.CreateZExt(SrcVal, VecTy->getScalarType());
-
     Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
   }
 
diff --git a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
index 96e17046ee934..8a3958ad8fd04 100644
--- a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
@@ -233,7 +233,8 @@ int AssignInt(int V){
 
 // CHECK: lor.end:
 // CHECK-NEXT: [[H:%.*]] = phi i1 [ true, %entry ], [ [[G]], %lor.rhs ]
-// CHECK-NEXT: store i1 [[H]], ptr [[XAddr]], align 4
+// CHECK-NEXT: [[J:%.*]] = zext i1 %9 to i32
+// CHECK-NEXT: store i32 [[J]], ptr [[XAddr]], align 4
 // CHECK-NEXT: [[I:%.*]] = load i32, ptr [[XAddr]], align 4
 // CHECK-NEXT: [[LoadV:%.*]] = trunc i32 [[I]] to i1
 // CHECK-NEXT: ret i1 [[LoadV]]
@@ -257,8 +258,8 @@ bool AssignBool(bool V) {
 // CHECK-NEXT: store <2 x i32> [[A]], ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = load i32, ptr [[VAddr]], align 4
 // CHECK-NEXT: [[LV1:%.*]] = trunc i32 [[B]] to i1
-// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[D:%.*]] = zext i1 [[LV1]] to i32
+// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[C]], i32 [[D]], i32 1
 // CHECK-NEXT: store <2 x i32> [[E]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -275,8 +276,8 @@ void AssignBool2(bool V) {
 // CHECK-NEXT: store <2 x i32> splat (i32 1), ptr [[X]], align 8
 // CHECK-NEXT: [[Z:%.*]] = load <2 x i32>, ptr [[VAddr]], align 8
 // CHECK-NEXT: [[LV:%.*]] = trunc <2 x i32> [[Z]] to <2 x i1>
-// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = zext <2 x i1> [[LV]] to <2 x i32>
+// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i32> [[B]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
 // CHECK-NEXT: store <2 x i32> [[C]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -302,3 +303,21 @@ bool2 AccessBools() {
   bool4 X = true.xxxx;
   return X.zw;
 }
+
+// CHECK-LABEL: define void {{.*}}BoolSizeMismatch{{.*}}
+// CHECK: [[B:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[Tmp:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[B]], align 16
+// CHECK-NEXT: store <1 x i32> zeroinitializer, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L0:%.*]] = load <1 x i32>, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L1:%.*]] = shufflevector <1 x i32> [[L0]], <1 x i32> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TruncV:%.*]] = trunc <3 x i32> [[L1]] to <3 x i1>
+// CHECK-NEXT: [[L2:%.*]] = zext <3 x i1> [[TruncV]] to <3 x i32>
+// CHECK-NEXT: [[L3:%.*]] = load <4 x i32>, ptr [[B]], align 16
+// CHECK-NEXT: [[L4:%.*]] = shufflevector <3 x i32> [[L2]], <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+// CHECK-NEXT: [[L5:%.*]] = shufflevector <4 x i32> [[L3]], <4 x i32> [[L4]], <4 x i32> <i32 4, i32 5, i32 6, i32 3>
+// CHECK-NEXT: store <4 x i32> [[L5]], ptr [[B]], align 16
+void BoolSizeMismatch() {
+  bool4 B = {true,true,true,true};
+  B.xyz = false.xxx;
+}

@spall spall merged commit 5999988 into llvm:main May 20, 2025
15 checks passed
kostasalv pushed a commit to kostasalv/llvm-project that referenced this pull request May 21, 2025
…LValue' to handle bug with hlsl boolean vector swizzles (llvm#140627)

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts
in the case the Destination Scalar Type is larger than the Source Scalar
Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:
```
bool4 b = true.xxxx;
b.xyz = false.xxx;
```
Leading to a bad shuffle vector. 

Closes llvm#140564
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…LValue' to handle bug with hlsl boolean vector swizzles (llvm#140627)

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts
in the case the Destination Scalar Type is larger than the Source Scalar
Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:
```
bool4 b = true.xxxx;
b.xyz = false.xxx;
```
Leading to a bad shuffle vector. 

Closes llvm#140564
ajaden-codes pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 6, 2025
…LValue' to handle bug with hlsl boolean vector swizzles (llvm#140627)

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts
in the case the Destination Scalar Type is larger than the Source Scalar
Type, to the top of the function, to ensure each condition is handled.

The previous code missed this case:
```
bool4 b = true.xxxx;
b.xyz = false.xxx;
```
Leading to a bad shuffle vector. 

Closes llvm#140564
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] Boolean vector not being converted to in memory representation <n x i32> from <n x i1>
4 participants