Skip to content

[HLSL] Support vector swizzles on scalars #67700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2035,6 +2035,14 @@ RValue CodeGenFunction::EmitLoadOfExtVectorElementLValue(LValue LV) {
llvm::Value *Vec = Builder.CreateLoad(LV.getExtVectorAddress(),
LV.isVolatileQualified());

// HLSL allows treating scalars as one-element vectors. Converting the scalar
// IR value to a vector here allows the rest of codegen to behave as normal.
if (getLangOpts().HLSL && !Vec->getType()->isVectorTy()) {
llvm::Type *DstTy = llvm::FixedVectorType::get(Vec->getType(), 1);
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
Vec = Builder.CreateInsertElement(DstTy, Vec, Zero, "cast.splat");
}

const llvm::Constant *Elts = LV.getExtVectorElts();

// If the result of the expression is a non-vector type, we must be extracting
Expand Down Expand Up @@ -2304,10 +2312,20 @@ void CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,

void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
LValue Dst) {
// HLSL allows storing to scalar values through ExtVector component LValues.
// To support this we need to handle the case where the destination address is
// a scalar.
Address DstAddr = Dst.getExtVectorAddress();
if (!DstAddr.getElementType()->isVectorTy()) {
assert(!Dst.getType()->isVectorType() &&
"this should only occur for non-vector l-values");
Builder.CreateStore(Src.getScalarVal(), DstAddr, Dst.isVolatileQualified());
return;
}

// This access turns into a read/modify/write of the vector. Load the input
// value now.
llvm::Value *Vec = Builder.CreateLoad(Dst.getExtVectorAddress(),
Dst.isVolatileQualified());
llvm::Value *Vec = Builder.CreateLoad(DstAddr, Dst.isVolatileQualified());
const llvm::Constant *Elts = Dst.getExtVectorElts();

llvm::Value *SrcVal = Src.getScalarVal();
Expand Down Expand Up @@ -2355,7 +2373,8 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
llvm_unreachable("unexpected shorten vector length");
}
} else {
// If the Src is a scalar (not a vector) it must be updating one element.
// If the Src is a scalar (not a vector), and the target is a vector it must
// be updating one element.
unsigned InIdx = getAccessedFieldNo(0, Elts);
llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
Expand Down Expand Up @@ -4734,7 +4753,6 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_IntegralToPointer:
case CK_PointerToIntegral:
case CK_PointerToBoolean:
case CK_VectorSplat:
case CK_IntegralCast:
case CK_BooleanToSignedIntegral:
case CK_IntegralToBoolean:
Expand Down Expand Up @@ -4899,6 +4917,13 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
}
case CK_ZeroToOCLOpaqueType:
llvm_unreachable("NULL to OpenCL opaque type lvalue cast is not valid");

case CK_VectorSplat: {
// LValue results of vector splats are only supported in HLSL.
if (!getLangOpts().HLSL)
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
return EmitLValue(E->getSubExpr());
}
}

llvm_unreachable("Unhandled lvalue cast kind?");
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Lex/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,10 @@ bool Lexer::LexNumericConstant(Token &Result, const char *CurPtr) {
while (isPreprocessingNumberBody(C)) {
CurPtr = ConsumeChar(CurPtr, Size, Result);
PrevCh = C;
if (LangOpts.HLSL && C == '.' && (*CurPtr == 'x' || *CurPtr == 'r')) {
CurPtr -= Size;
break;
}
C = getCharAndSize(CurPtr, Size);
}

Expand Down
6 changes: 5 additions & 1 deletion clang/lib/Lex/LiteralSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,11 @@ NumericLiteralParser::NumericLiteralParser(StringRef TokSpelling,
// and FP constants (specifically, the 'pp-number' regex), and assumes that
// the byte at "*end" is both valid and not part of the regex. Because of
// this, it doesn't have to check for 'overscan' in various places.
if (isPreprocessingNumberBody(*ThisTokEnd)) {
// Note: For HLSL, the end token is allowed to be '.' which would be in the
// 'pp-number' regex. This is required to support vector swizzles on numeric
// constants (i.e. 1.xx or 1.5f.rrr).
if (isPreprocessingNumberBody(*ThisTokEnd) &&
!(LangOpts.HLSL && *ThisTokEnd == '.')) {
Diags.Report(TokLoc, diag::err_lexing_numeric);
hadError = true;
return;
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/Sema/SemaExprMember.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,16 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
ObjCImpDecl, HasTemplateArgs, TemplateKWLoc);
}

// HLSL supports implicit conversion of scalar types to single element vector
// rvalues in member expressions.
if (S.getLangOpts().HLSL && BaseType->isScalarType()) {
QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
BaseExpr.get()->getValueKind());
return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS, ObjCImpDecl,
HasTemplateArgs, TemplateKWLoc);
}

S.Diag(OpLoc, diag::err_typecheck_member_reference_struct_union)
<< BaseType << BaseExpr.get()->getSourceRange() << MemberLoc;

Expand Down
174 changes: 174 additions & 0 deletions clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s

// CHECK-LABEL: ToTwoInts
// CHECK: [[splat:%.*]] = insertelement <1 x i32> poison, i32 {{.*}}, i64 0
// CHECK: [[vec2:%.*]] = shufflevector <1 x i32> [[splat]], <1 x i32> poison, <2 x i32> zeroinitializer
// CHECK: ret <2 x i32> [[vec2]]
int2 ToTwoInts(int V){
return V.xx;
}

// CHECK-LABEL: ToFourFloats
// [[splat:%.*]] = insertelement <1 x float> poison, float {{.*}}, i64 0
// [[vec4:%.*]] = shufflevector <1 x float> [[splat]], <1 x float> poison, <4 x i32> zeroinitializer
// ret <4 x float> [[vec4]]
float4 ToFourFloats(float V){
return V.rrrr;
}

// CHECK-LABEL: FillOne
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i32>, align 4
// CHECK: store <1 x i32> <i32 1>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec1:%.*]] = load <1 x i32>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec2:%.*]] = shufflevector <1 x i32> [[vec1]], <1 x i32> poison, <2 x i32> zeroinitializer
// CHECK: ret <2 x i32> [[vec2]]
int2 FillOne(){
return 1.xx;
}

// CHECK-LABEL: FillOneUnsigned
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i32>, align 4
// CHECK: store <1 x i32> <i32 1>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec1:%.*]] = load <1 x i32>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec3:%.*]] = shufflevector <1 x i32> [[vec1]], <1 x i32> poison, <3 x i32> zeroinitializer
// CHECK: ret <3 x i32> [[vec3]]
uint3 FillOneUnsigned(){
return 1u.xxx;
}

// CHECK-LABEL: FillOneUnsignedLong
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x i64>, align 8
// CHECK: store <1 x i64> <i64 1>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec1:%.*]] = load <1 x i64>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec4:%.*]] = shufflevector <1 x i64> [[vec1]], <1 x i64> poison, <4 x i32> zeroinitializer
// CHECK: ret <4 x i64> [[vec4]]
vector<uint64_t,4> FillOneUnsignedLong(){
return 1ul.xxxx;
}

// CHECK-LABEL: FillTwoPointFive
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x double>, align 8
// CHECK: store <1 x double> <double 2.500000e+00>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec1:%.*]] = load <1 x double>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec2:%.*]] = shufflevector <1 x double> [[vec1]], <1 x double> poison, <2 x i32> zeroinitializer
// CHECK: ret <2 x double> [[vec2]]
double2 FillTwoPointFive(){
return 2.5.rr;
}

// CHECK-LABEL: FillOneHalf
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x double>, align 8
// CHECK: store <1 x double> <double 5.000000e-01>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec1:%.*]] = load <1 x double>, ptr [[vec1Ptr]], align 8
// CHECK: [[vec3:%.*]] = shufflevector <1 x double> [[vec1]], <1 x double> poison, <3 x i32> zeroinitializer
// CHECK: ret <3 x double> [[vec3]]
double3 FillOneHalf(){
return .5.rrr;
}

// CHECK-LABEL: FillTwoPointFiveFloat
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x float>, align 4
// CHECK: store <1 x float> <float 2.500000e+00>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec1:%.*]] = load <1 x float>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec4:%.*]] = shufflevector <1 x float> [[vec1]], <1 x float> poison, <4 x i32> zeroinitializer
// CHECK: ret <4 x float> [[vec4]]
float4 FillTwoPointFiveFloat(){
return 2.5f.rrrr;
}

// The initial codegen for this case is correct but a bit odd. The IR optimizer
// cleans this up very nicely.

// CHECK-LABEL: FillOneHalfFloat
// CHECK: [[vec1Ptr:%.*]] = alloca <1 x float>, align 4
// CHECK: store <1 x float> <float 5.000000e-01>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec1:%.*]] = load <1 x float>, ptr [[vec1Ptr]], align 4
// CHECK: [[vec1Ret:%.*]] = shufflevector <1 x float> [[vec1]], <1 x float> undef, <1 x i32> zeroinitializer
// CHECK: ret <1 x float> [[vec1Ret]]
vector<float, 1> FillOneHalfFloat(){
return .5f.r;
}

// The initial codegen for this case is correct but a bit odd. The IR optimizer
// cleans this up very nicely.

// CHECK-LABEL: HowManyFloats
// CHECK: [[VAddr:%.*]] = alloca float, align 4
// CHECK: [[vec2Ptr:%.*]] = alloca <2 x float>, align 8
// CHECK: [[VVal:%.*]] = load float, ptr [[VAddr]], align 4
// CHECK: [[splat:%.*]] = insertelement <1 x float> poison, float [[VVal]], i64 0
// CHECK: [[vec2:%.*]] = shufflevector <1 x float> [[splat]], <1 x float> poison, <2 x i32> zeroinitializer
// CHECK: store <2 x float> [[vec2]], ptr [[vec2Ptr]], align 8
// CHECK: [[vec2:%.*]] = load <2 x float>, ptr [[vec2Ptr]], align 8
// CHECK: [[vec2Res:%.*]] = shufflevector <2 x float> [[vec2]], <2 x float> poison, <2 x i32> zeroinitializer
// CHECK: ret <2 x float> [[vec2Res]]
float2 HowManyFloats(float V) {
return V.rr.rr;
}

// This codegen is gnarly because `1.` is a double, so this creates double
// vectors that need to be truncated down to floats. The optimizer cleans this
// up nicely too.

// CHECK-LABEL: AllRighty
// CHECK: [[XTmp:%.*]] = alloca <1 x double>, align 8
// CHECK: [[YTmp:%.*]] = alloca <1 x double>, align 8
// CHECK: [[ZTmp:%.*]] = alloca <1 x double>, align 8

// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[XTmp]], align 8
// CHECK: [[XVec:%.*]] = load <1 x double>, ptr [[XTmp]], align 8
// CHECK: [[XVec3:%.*]] = shufflevector <1 x double> [[XVec]], <1 x double> poison, <3 x i32> zeroinitializer
// CHECK: [[XVal:%.*]] = extractelement <3 x double> [[XVec3]], i32 0
// CHECK: [[XValF:%.*]] = fptrunc double [[XVal]] to float
// CHECK: [[Vec3F1:%.*]] = insertelement <3 x float> undef, float [[XValF]], i32 0

// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[YTmp]], align 8
// CHECK: [[YVec:%.*]] = load <1 x double>, ptr [[YTmp]], align 8
// CHECK: [[YVec3:%.*]] = shufflevector <1 x double> [[YVec]], <1 x double> poison, <3 x i32> zeroinitializer
// CHECK: [[YVal:%.*]] = extractelement <3 x double> [[YVec3]], i32 1
// CHECK: [[YValF:%.*]] = fptrunc double [[YVal]] to float
// CHECK: [[Vec3F2:%.*]] = insertelement <3 x float> [[Vec3F1]], float [[YValF]], i32 1

// CHECK: store <1 x double> <double 1.000000e+00>, ptr [[ZTmp]], align 8
// CHECK: [[ZVec:%.*]] = load <1 x double>, ptr [[ZTmp]], align 8
// CHECK: [[ZVec3:%.*]] = shufflevector <1 x double> [[ZVec]], <1 x double> poison, <3 x i32> zeroinitializer
// CHECK: [[ZVal:%.*]] = extractelement <3 x double> [[ZVec3]], i32 2
// CHECK: [[ZValF:%.*]] = fptrunc double [[ZVal]] to float
// CHECK: [[Vec3F3:%.*]] = insertelement <3 x float> [[Vec3F2]], float [[ZValF]], i32 2

// ret <3 x float> [[Vec3F3]]
float3 AllRighty() {
return 1..rrr;
}

// CHECK-LABEL: AssignInt
// CHECK: [[VAddr:%.*]] = alloca i32, align 4
// CHECK: [[XAddr:%.*]] = alloca i32, align 4

// Load V into a vector, then extract V out and store it to X.
// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
// CHECK: [[VExtVal:%.*]] = extractelement <1 x i32> [[Splat]], i32 0
// CHECK: store i32 [[VExtVal]], ptr [[XAddr]], align 4

// Load V into two separate vectors, then add the extracted X components.
// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
// CHECK: [[LHS:%.*]] = extractelement <1 x i32> [[Splat]], i32 0

// CHECK: [[V:%.*]] = load i32, ptr [[VAddr]], align 4
// CHECK: [[Splat:%.*]] = insertelement <1 x i32> poison, i32 [[V]], i64 0
// CHECK: [[RHS:%.*]] = extractelement <1 x i32> [[Splat]], i32 0

// CHECK: [[Sum:%.*]] = add nsw i32 [[LHS]], [[RHS]]
// CHECK: store i32 [[Sum]], ptr [[XAddr]], align 4
// CHECK: [[X:%.*]] = load i32, ptr [[XAddr]], align 4
// CHECK: ret i32 [[X]]

int AssignInt(int V){
int X = V.x;
X.x = V.x + V.x;
return X;
}
31 changes: 31 additions & 0 deletions clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl -finclude-default-header -verify %s

int2 ToTwoInts(int V) {
return V.xy; // expected-error{{vector component access exceeds type 'int __attribute__((ext_vector_type(1)))' (vector of 1 'int' value)}}
}

float2 ToTwoFloats(float V) {
return V.rg; // expected-error{{vector component access exceeds type 'float __attribute__((ext_vector_type(1)))' (vector of 1 'float' value)}}
}

int4 SomeNonsense(int V) {
return V.poop; // expected-error{{illegal vector component name 'p'}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like another test that validates we catch multiple levels of dots. e.g.,

float2 HowManyFloats(float V) {
  return V.rr.rr;
}

or dots not followed by anything useful:

float2 WhatIsHappening(float V) {
  return V.;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really sad thing is that the first case there seems to be valid on both our reference compilers:

DXC on Compiler Explorer
FXC on Shader Playground

We may need to support that. I'll add tests either way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add CodeGen tests that show we do... whatever that is.

}

float2 WhatIsHappening(float V) {
return V.; // expected-error{{expected unqualified-id}}
}

// These cases produce no error.

float2 HowManyFloats(float V) {
return V.rr.rr;
}

int64_t4 HooBoy() {
return 4l.xxxx;
}

float3 AllRighty() {
return 1..rrr;
}
Loading