Skip to content

Commit 2630d72

Browse files
authored
[HLSL] Support vector swizzles on scalars (#67700)
HLSL supports vector swizzles on scalars by implicitly converting the scalar to a single-element vector. This syntax is a convienent way to initialize vectors based on filling a scalar value. There are two parts of this change. The first part in the Lexer splits numeric constant tokens when a `.x` or `.r` suffix is encountered. This splitting is a bit hacky but allows the numeric constant to be parsed separately from the vector element expression. There is an ambiguity here with the `r` suffix used by fixed point types, however fixed point types aren't supported in HLSL so this should not cause any exposable problems (a separate issue has been filed to track validating language options for HLSL: #67689). The second part of this change is in Sema::LookupMemberExpr. For HLSL, if the base type is a scalar, we implicit cast the scalar to a one-element vector then call back to perform the vector lookup. Fixes #56658 and #67511
1 parent 8391bb3 commit 2630d72

File tree

7 files changed

+387
-5
lines changed

7 files changed

+387
-5
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,14 @@ RValue CodeGenFunction::EmitLoadOfExtVectorElementLValue(LValue LV) {
21582158
llvm::Value *Vec = Builder.CreateLoad(LV.getExtVectorAddress(),
21592159
LV.isVolatileQualified());
21602160

2161+
// HLSL allows treating scalars as one-element vectors. Converting the scalar
2162+
// IR value to a vector here allows the rest of codegen to behave as normal.
2163+
if (getLangOpts().HLSL && !Vec->getType()->isVectorTy()) {
2164+
llvm::Type *DstTy = llvm::FixedVectorType::get(Vec->getType(), 1);
2165+
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
2166+
Vec = Builder.CreateInsertElement(DstTy, Vec, Zero, "cast.splat");
2167+
}
2168+
21612169
const llvm::Constant *Elts = LV.getExtVectorElts();
21622170

21632171
// If the result of the expression is a non-vector type, we must be extracting
@@ -2427,10 +2435,20 @@ void CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,
24272435

24282436
void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
24292437
LValue Dst) {
2438+
// HLSL allows storing to scalar values through ExtVector component LValues.
2439+
// To support this we need to handle the case where the destination address is
2440+
// a scalar.
2441+
Address DstAddr = Dst.getExtVectorAddress();
2442+
if (!DstAddr.getElementType()->isVectorTy()) {
2443+
assert(!Dst.getType()->isVectorType() &&
2444+
"this should only occur for non-vector l-values");
2445+
Builder.CreateStore(Src.getScalarVal(), DstAddr, Dst.isVolatileQualified());
2446+
return;
2447+
}
2448+
24302449
// This access turns into a read/modify/write of the vector. Load the input
24312450
// value now.
2432-
llvm::Value *Vec = Builder.CreateLoad(Dst.getExtVectorAddress(),
2433-
Dst.isVolatileQualified());
2451+
llvm::Value *Vec = Builder.CreateLoad(DstAddr, Dst.isVolatileQualified());
24342452
const llvm::Constant *Elts = Dst.getExtVectorElts();
24352453

24362454
llvm::Value *SrcVal = Src.getScalarVal();
@@ -2478,7 +2496,8 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
24782496
llvm_unreachable("unexpected shorten vector length");
24792497
}
24802498
} else {
2481-
// If the Src is a scalar (not a vector) it must be updating one element.
2499+
// If the Src is a scalar (not a vector), and the target is a vector it must
2500+
// be updating one element.
24822501
unsigned InIdx = getAccessedFieldNo(0, Elts);
24832502
llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
24842503
Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
@@ -4879,7 +4898,6 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
48794898
case CK_IntegralToPointer:
48804899
case CK_PointerToIntegral:
48814900
case CK_PointerToBoolean:
4882-
case CK_VectorSplat:
48834901
case CK_IntegralCast:
48844902
case CK_BooleanToSignedIntegral:
48854903
case CK_IntegralToBoolean:
@@ -5044,6 +5062,13 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
50445062
}
50455063
case CK_ZeroToOCLOpaqueType:
50465064
llvm_unreachable("NULL to OpenCL opaque type lvalue cast is not valid");
5065+
5066+
case CK_VectorSplat: {
5067+
// LValue results of vector splats are only supported in HLSL.
5068+
if (!getLangOpts().HLSL)
5069+
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
5070+
return EmitLValue(E->getSubExpr());
5071+
}
50475072
}
50485073

50495074
llvm_unreachable("Unhandled lvalue cast kind?");

clang/lib/Lex/Lexer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,6 +1990,10 @@ bool Lexer::LexNumericConstant(Token &Result, const char *CurPtr) {
19901990
while (isPreprocessingNumberBody(C)) {
19911991
CurPtr = ConsumeChar(CurPtr, Size, Result);
19921992
PrevCh = C;
1993+
if (LangOpts.HLSL && C == '.' && (*CurPtr == 'x' || *CurPtr == 'r')) {
1994+
CurPtr -= Size;
1995+
break;
1996+
}
19931997
C = getCharAndSize(CurPtr, Size);
19941998
}
19951999

clang/lib/Lex/LiteralSupport.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,11 @@ NumericLiteralParser::NumericLiteralParser(StringRef TokSpelling,
930930
// and FP constants (specifically, the 'pp-number' regex), and assumes that
931931
// the byte at "*end" is both valid and not part of the regex. Because of
932932
// this, it doesn't have to check for 'overscan' in various places.
933-
if (isPreprocessingNumberBody(*ThisTokEnd)) {
933+
// Note: For HLSL, the end token is allowed to be '.' which would be in the
934+
// 'pp-number' regex. This is required to support vector swizzles on numeric
935+
// constants (i.e. 1.xx or 1.5f.rrr).
936+
if (isPreprocessingNumberBody(*ThisTokEnd) &&
937+
!(LangOpts.HLSL && *ThisTokEnd == '.')) {
934938
Diags.Report(TokLoc, diag::err_lexing_numeric);
935939
hadError = true;
936940
return;

clang/lib/Sema/SemaExprMember.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,16 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
17141714
ObjCImpDecl, HasTemplateArgs, TemplateKWLoc);
17151715
}
17161716

1717+
// HLSL supports implicit conversion of scalar types to single element vector
1718+
// rvalues in member expressions.
1719+
if (S.getLangOpts().HLSL && BaseType->isScalarType()) {
1720+
QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
1721+
BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
1722+
BaseExpr.get()->getValueKind());
1723+
return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS, ObjCImpDecl,
1724+
HasTemplateArgs, TemplateKWLoc);
1725+
}
1726+
17171727
S.Diag(OpLoc, diag::err_typecheck_member_reference_struct_union)
17181728
<< BaseType << BaseExpr.get()->getSourceRange() << MemberLoc;
17191729

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
3+
// RUN: -o - | FileCheck %s
4+
5+
// CHECK-LABEL: ToTwoInts
6+
// CHECK: [[splat:%.*]] = insertelement <1 x i32> poison, i32 {{.*}}, i64 0
7+
// CHECK: [[vec2:%.*]] = shufflevector <1 x i32> [[splat]], <1 x i32> poison, <2 x i32> zeroinitializer
8+
// CHECK: ret <2 x i32> [[vec2]]
9+
int2 ToTwoInts(int V){
10+
return V.xx;
11+
}
12+
13+
// CHECK-LABEL: ToFourFloats
14+
// [[splat:%.*]] = insertelement <1 x float> poison, float {{.*}}, i64 0
15+
// [[vec4:%.*]] = shufflevector <1 x float> [[splat]], <1 x float> poison, <4 x i32> zeroinitializer
16+
// ret <4 x float> [[vec4]]
17+
float4 ToFourFloats(float V){
18+
return V.rrrr;
19+
}
20+
21+
// CHECK-LABEL: FillOne
22+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i32>, align 4
23+
// CHECK: store <1 x i32> <i32 1>, ptr [[vec1Ptr]], align 4
24+
// CHECK: [[vec1:%.*]] = load <1 x i32>, ptr [[vec1Ptr]], align 4
25+
// CHECK: [[vec2:%.*]] = shufflevector <1 x i32> [[vec1]], <1 x i32> poison, <2 x i32> zeroinitializer
26+
// CHECK: ret <2 x i32> [[vec2]]
27+
int2 FillOne(){
28+
return 1.xx;
29+
}
30+
31+
// CHECK-LABEL: FillOneUnsigned
32+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i32>, align 4
33+
// CHECK: store <1 x i32> <i32 1>, ptr [[vec1Ptr]], align 4
34+
// CHECK: [[vec1:%.*]] = load <1 x i32>, ptr [[vec1Ptr]], align 4
35+
// CHECK: [[vec3:%.*]] = shufflevector <1 x i32> [[vec1]], <1 x i32> poison, <3 x i32> zeroinitializer
36+
// CHECK: ret <3 x i32> [[vec3]]
37+
uint3 FillOneUnsigned(){
38+
return 1u.xxx;
39+
}
40+
41+
// CHECK-LABEL: FillOneUnsignedLong
42+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i64>, align 8
43+
// CHECK: store <1 x i64> <i64 1>, ptr [[vec1Ptr]], align 8
44+
// CHECK: [[vec1:%.*]] = load <1 x i64>, ptr [[vec1Ptr]], align 8
45+
// CHECK: [[vec4:%.*]] = shufflevector <1 x i64> [[vec1]], <1 x i64> poison, <4 x i32> zeroinitializer
46+
// CHECK: ret <4 x i64> [[vec4]]
47+
vector<uint64_t,4> FillOneUnsignedLong(){
48+
return 1ul.xxxx;
49+
}
50+
51+
// CHECK-LABEL: FillTwoPointFive
52+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x double>, align 8
53+
// CHECK: store <1 x double> <double 2.500000e+00>, ptr [[vec1Ptr]], align 8
54+
// CHECK: [[vec1:%.*]] = load <1 x double>, ptr [[vec1Ptr]], align 8
55+
// CHECK: [[vec2:%.*]] = shufflevector <1 x double> [[vec1]], <1 x double> poison, <2 x i32> zeroinitializer
56+
// CHECK: ret <2 x double> [[vec2]]
57+
double2 FillTwoPointFive(){
58+
return 2.5.rr;
59+
}
60+
61+
// CHECK-LABEL: FillOneHalf
62+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x double>, align 8
63+
// CHECK: store <1 x double> <double 5.000000e-01>, ptr [[vec1Ptr]], align 8
64+
// CHECK: [[vec1:%.*]] = load <1 x double>, ptr [[vec1Ptr]], align 8
65+
// CHECK: [[vec3:%.*]] = shufflevector <1 x double> [[vec1]], <1 x double> poison, <3 x i32> zeroinitializer
66+
// CHECK: ret <3 x double> [[vec3]]
67+
double3 FillOneHalf(){
68+
return .5.rrr;
69+
}
70+
71+
// CHECK-LABEL: FillTwoPointFiveFloat
72+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x float>, align 4
73+
// CHECK: store <1 x float> <float 2.500000e+00>, ptr [[vec1Ptr]], align 4
74+
// CHECK: [[vec1:%.*]] = load <1 x float>, ptr [[vec1Ptr]], align 4
75+
// CHECK: [[vec4:%.*]] = shufflevector <1 x float> [[vec1]], <1 x float> poison, <4 x i32> zeroinitializer
76+
// CHECK: ret <4 x float> [[vec4]]
77+
float4 FillTwoPointFiveFloat(){
78+
return 2.5f.rrrr;
79+
}
80+
81+
// The initial codegen for this case is correct but a bit odd. The IR optimizer
82+
// cleans this up very nicely.
83+
84+
// CHECK-LABEL: FillOneHalfFloat
85+
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x float>, align 4
86+
// CHECK: store <1 x float> <float 5.000000e-01>, ptr [[vec1Ptr]], align 4
87+
// CHECK: [[vec1:%.*]] = load <1 x float>, ptr [[vec1Ptr]], align 4
88+
// CHECK: [[vec1Ret:%.*]] = shufflevector <1 x float> [[vec1]], <1 x float> undef, <1 x i32> zeroinitializer
89+
// CHECK: ret <1 x float> [[vec1Ret]]
90+
vector<float, 1> FillOneHalfFloat(){
91+
return .5f.r;
92+
}
93+
94+
// The initial codegen for this case is correct but a bit odd. The IR optimizer
95+
// cleans this up very nicely.
96+
97+
// CHECK-LABEL: HowManyFloats
98+
// CHECK: [[VAddr:%.*]] = alloca float, align 4
99+
// CHECK: [[vec2Ptr:%.*]] = alloca <2 x float>, align 8
100+
// CHECK: [[VVal:%.*]] = load float, ptr [[VAddr]], align 4
101+
// CHECK: [[splat:%.*]] = insertelement <1 x float> poison, float [[VVal]], i64 0
102+
// CHECK: [[vec2:%.*]] = shufflevector <1 x float> [[splat]], <1 x float> poison, <2 x i32> zeroinitializer
103+
// CHECK: store <2 x float> [[vec2]], ptr [[vec2Ptr]], align 8
104+
// CHECK: [[vec2:%.*]] = load <2 x float>, ptr [[vec2Ptr]], align 8
105+
// CHECK: [[vec2Res:%.*]] = shufflevector <2 x float> [[vec2]], <2 x float> poison, <2 x i32> zeroinitializer
106+
// CHECK: ret <2 x float> [[vec2Res]]
107+
float2 HowManyFloats(float V) {
108+
return V.rr.rr;
109+
}
110+
111+
// This codegen is gnarly because `1.` is a double, so this creates double
112+
// vectors that need to be truncated down to floats. The optimizer cleans this
113+
// up nicely too.
114+
115+
// CHECK-LABEL: AllRighty
116+
// CHECK: [[XTmp:%.*]] = alloca <1 x double>, align 8
117+
// CHECK: [[YTmp:%.*]] = alloca <1 x double>, align 8
118+
// CHECK: [[ZTmp:%.*]] = alloca <1 x double>, align 8
119+
120+
// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[XTmp]], align 8
121+
// CHECK: [[XVec:%.*]] = load <1 x double>, ptr [[XTmp]], align 8
122+
// CHECK: [[XVec3:%.*]] = shufflevector <1 x double> [[XVec]], <1 x double> poison, <3 x i32> zeroinitializer
123+
// CHECK: [[XVal:%.*]] = extractelement <3 x double> [[XVec3]], i32 0
124+
// CHECK: [[XValF:%.*]] = fptrunc double [[XVal]] to float
125+
// CHECK: [[Vec3F1:%.*]] = insertelement <3 x float> undef, float [[XValF]], i32 0
126+
127+
// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[YTmp]], align 8
128+
// CHECK: [[YVec:%.*]] = load <1 x double>, ptr [[YTmp]], align 8
129+
// CHECK: [[YVec3:%.*]] = shufflevector <1 x double> [[YVec]], <1 x double> poison, <3 x i32> zeroinitializer
130+
// CHECK: [[YVal:%.*]] = extractelement <3 x double> [[YVec3]], i32 1
131+
// CHECK: [[YValF:%.*]] = fptrunc double [[YVal]] to float
132+
// CHECK: [[Vec3F2:%.*]] = insertelement <3 x float> [[Vec3F1]], float [[YValF]], i32 1
133+
134+
// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[ZTmp]], align 8
135+
// CHECK: [[ZVec:%.*]] = load <1 x double>, ptr [[ZTmp]], align 8
136+
// CHECK: [[ZVec3:%.*]] = shufflevector <1 x double> [[ZVec]], <1 x double> poison, <3 x i32> zeroinitializer
137+
// CHECK: [[ZVal:%.*]] = extractelement <3 x double> [[ZVec3]], i32 2
138+
// CHECK: [[ZValF:%.*]] = fptrunc double [[ZVal]] to float
139+
// CHECK: [[Vec3F3:%.*]] = insertelement <3 x float> [[Vec3F2]], float [[ZValF]], i32 2
140+
141+
// ret <3 x float> [[Vec3F3]]
142+
float3 AllRighty() {
143+
return 1..rrr;
144+
}
145+
146+
// CHECK-LABEL: AssignInt
147+
// CHECK: [[VAddr:%.*]] = alloca i32, align 4
148+
// CHECK: [[XAddr:%.*]] = alloca i32, align 4
149+
150+
// Load V into a vector, then extract V out and store it to X.
151+
// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
152+
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
153+
// CHECK: [[VExtVal:%.*]] = extractelement <1 x i32> [[Splat]], i32 0
154+
// CHECK: store i32 [[VExtVal]], ptr [[XAddr]], align 4
155+
156+
// Load V into two separate vectors, then add the extracted X components.
157+
// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
158+
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
159+
// CHECK: [[LHS:%.*]] = extractelement <1 x i32> [[Splat]], i32 0
160+
161+
// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
162+
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
163+
// CHECK: [[RHS:%.*]] = extractelement <1 x i32> [[Splat]], i32 0
164+
165+
// CHECK: [[Sum:%.*]] = add nsw i32 [[LHS]], [[RHS]]
166+
// CHECK: store i32 [[Sum]], ptr [[XAddr]], align 4
167+
// CHECK: [[X:%.*]] = load i32, ptr [[XAddr]], align 4
168+
// CHECK: ret i32 [[X]]
169+
170+
int AssignInt(int V){
171+
int X = V.x;
172+
X.x = V.x + V.x;
173+
return X;
174+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl -finclude-default-header -verify %s
2+
3+
int2 ToTwoInts(int V) {
4+
return V.xy; // expected-error{{vector component access exceeds type 'int __attribute__((ext_vector_type(1)))' (vector of 1 'int' value)}}
5+
}
6+
7+
float2 ToTwoFloats(float V) {
8+
return V.rg; // expected-error{{vector component access exceeds type 'float __attribute__((ext_vector_type(1)))' (vector of 1 'float' value)}}
9+
}
10+
11+
int4 SomeNonsense(int V) {
12+
return V.poop; // expected-error{{illegal vector component name 'p'}}
13+
}
14+
15+
float2 WhatIsHappening(float V) {
16+
return V.; // expected-error{{expected unqualified-id}}
17+
}
18+
19+
// These cases produce no error.
20+
21+
float2 HowManyFloats(float V) {
22+
return V.rr.rr;
23+
}
24+
25+
int64_t4 HooBoy() {
26+
return 4l.xxxx;
27+
}
28+
29+
float3 AllRighty() {
30+
return 1..rrr;
31+
}

0 commit comments

Comments
 (0)