Skip to content

[HLSL] Support vector swizzles on scalars #67700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 29, 2023

Conversation

llvm-beanz
Copy link
Collaborator

HLSL supports vector swizzles on scalars by implicitly converting the scalar to a single-element vector. This syntax is a convienent way to initialize vectors based on filling a scalar value.

There are two parts of this change. The first part in the Lexer splits numeric constant tokens when a .x or .r suffix is encountered. This splitting is a bit hacky but allows the numeric constant to be parsed separately from the vector element expression. There is an ambiguity here with the r suffix used by fixed point types, however fixed point types aren't supported in HLSL so this should not cause any exposable problems (a separate issue has been filed to track validating language options for HLSL: #67689).

The second part of this change is in Sema::LookupMemberExpr. For HLSL, if the base type is a scalar, we implicit cast the scalar to a one-element vector then call back to perform the vector lookup.

Fixes #56658 and #67511

HLSL supports vector swizzles on scalars by implicitly converting the
scalar to a single-element vector. This syntax is a convienent way to
initialize vectors based on filling a scalar value.

There are two parts of this change. The first part in the Lexer splits
numeric constant tokens when a `.x` or `.r` suffix is encountered. This
splitting is a bit hacky but allows the numeric constant to be parsed
separately from the vector element expression. There is an ambiguity
here with the `r` suffix used by fixed point types, however fixed point
types aren't supported in HLSL so this should not cause any exposable
problems (a separate issue has been filed to track validating language
options for HLSL: llvm#67689).

The second part of this change is in Sema::LookupMemberExpr. For HLSL,
if the base type is a scalar, we implicit cast the scalar to a
one-element vector then call back to perform the vector lookup.

Fixes llvm#56658 and llvm#67511
@llvm-beanz llvm-beanz added the HLSL HLSL Language Support label Sep 28, 2023
@llvm-beanz llvm-beanz self-assigned this Sep 28, 2023
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels Sep 28, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2023

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-hlsl

Changes

HLSL supports vector swizzles on scalars by implicitly converting the scalar to a single-element vector. This syntax is a convienent way to initialize vectors based on filling a scalar value.

There are two parts of this change. The first part in the Lexer splits numeric constant tokens when a .x or .r suffix is encountered. This splitting is a bit hacky but allows the numeric constant to be parsed separately from the vector element expression. There is an ambiguity here with the r suffix used by fixed point types, however fixed point types aren't supported in HLSL so this should not cause any exposable problems (a separate issue has been filed to track validating language options for HLSL: #67689).

The second part of this change is in Sema::LookupMemberExpr. For HLSL, if the base type is a scalar, we implicit cast the scalar to a one-element vector then call back to perform the vector lookup.

Fixes #56658 and #67511


Full diff: https://github.com/llvm/llvm-project/pull/67700.diff

5 Files Affected:

  • (modified) clang/lib/Lex/Lexer.cpp (+4)
  • (modified) clang/lib/Lex/LiteralSupport.cpp (+5-1)
  • (modified) clang/lib/Sema/SemaExprMember.cpp (+10)
  • (added) clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl (+14)
  • (added) clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl (+78)
diff --git a/clang/lib/Lex/Lexer.cpp b/clang/lib/Lex/Lexer.cpp
index 37c3e4175d4736e..65588191cfd8fdb 100644
--- a/clang/lib/Lex/Lexer.cpp
+++ b/clang/lib/Lex/Lexer.cpp
@@ -1950,6 +1950,10 @@ bool Lexer::LexNumericConstant(Token &Result, const char *CurPtr) {
   while (isPreprocessingNumberBody(C)) {
     CurPtr = ConsumeChar(CurPtr, Size, Result);
     PrevCh = C;
+    if (LangOpts.HLSL && C == '.' && (*CurPtr == 'x' || *CurPtr == 'r')) {
+      CurPtr--;
+      break;
+    }
     C = getCharAndSize(CurPtr, Size);
   }
 
diff --git a/clang/lib/Lex/LiteralSupport.cpp b/clang/lib/Lex/LiteralSupport.cpp
index 2de307883b97ce7..0a78638f680511d 100644
--- a/clang/lib/Lex/LiteralSupport.cpp
+++ b/clang/lib/Lex/LiteralSupport.cpp
@@ -930,7 +930,11 @@ NumericLiteralParser::NumericLiteralParser(StringRef TokSpelling,
   // and FP constants (specifically, the 'pp-number' regex), and assumes that
   // the byte at "*end" is both valid and not part of the regex.  Because of
   // this, it doesn't have to check for 'overscan' in various places.
-  if (isPreprocessingNumberBody(*ThisTokEnd)) {
+  // Note: For HLSL, the end token is allowed to be '.' which would be in the
+  // 'pp-number' regex. This is required to support vector swizzles on numeric
+  // constants (i.e. 1.xx or 1.5f.rrr).
+  if (isPreprocessingNumberBody(*ThisTokEnd) &&
+      !(LangOpts.HLSL && *ThisTokEnd == '.')) {
     Diags.Report(TokLoc, diag::err_lexing_numeric);
     hadError = true;
     return;
diff --git a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp
index fe92215ae46776f..2de5579df27153d 100644
--- a/clang/lib/Sema/SemaExprMember.cpp
+++ b/clang/lib/Sema/SemaExprMember.cpp
@@ -1687,6 +1687,16 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
                             ObjCImpDecl, HasTemplateArgs, TemplateKWLoc);
   }
 
+  // HLSL supports implicit conversion of scalar types to single element vector
+  // rvalues in member expressions.
+  if (S.getLangOpts().HLSL && BaseType->isScalarType()) {
+    QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
+    BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
+                                   BaseExpr.get()->getValueKind());
+    return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS,
+                                ObjCImpDecl, HasTemplateArgs, TemplateKWLoc);
+  }
+
   S.Diag(OpLoc, diag::err_typecheck_member_reference_struct_union)
     << BaseType << BaseExpr.get()->getSourceRange() << MemberLoc;
 
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
new file mode 100644
index 000000000000000..56c6ab537261769
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library  -x hlsl -finclude-default-header -verify %s
+
+
+int2 ToTwoInts(int V) {
+  return V.xy; // expected-error{{vector component access exceeds type 'int __attribute__((ext_vector_type(1)))' (vector of 1 'int' value)}}
+}
+
+float2 ToTwoFloats(float V) {
+  return V.rg; // expected-error{{vector component access exceeds type 'float __attribute__((ext_vector_type(1)))' (vector of 1 'float' value)}}
+}
+
+int4 SomeNonsense(int V) {
+  return V.poop; // expected-error{{illegal vector component name 'p'}}
+}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl
new file mode 100644
index 000000000000000..b25d73499fdba0f
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl
@@ -0,0 +1,78 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library  -x hlsl \
+// RUN:   -finclude-default-header -ast-dump %s | FileCheck %s
+
+
+// CHECK: ExtVectorElementExpr {{.*}} 'int __attribute__((ext_vector_type(2)))' xx
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int __attribute__((ext_vector_type(1)))' lvalue <VectorSplat>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} 'V' 'int'
+
+int2 ToTwoInts(int V){
+  return V.xx;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'float __attribute__((ext_vector_type(4)))' rrrr
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float __attribute__((ext_vector_type(1)))' lvalue <VectorSplat>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float' lvalue ParmVar {{.*}} 'V' 'float'
+
+
+float4 ToThreeFloats(float V){
+  return V.rrrr;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'int __attribute__((ext_vector_type(2)))' xx
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 1
+
+int2 FillOne(){
+  return 1.xx;
+}
+
+
+// CHECK: ExtVectorElementExpr {{.*}} 'unsigned int __attribute__((ext_vector_type(3)))' xxx
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'unsigned int' 1
+
+uint3 FillOneUnsigned(){
+  return 1u.xxx;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'unsigned long __attribute__((ext_vector_type(4)))' xxxx
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned long __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'unsigned long' 1
+
+vector<uint64_t,4> FillOneUnsignedLong(){
+  return 1ul.xxxx;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'double __attribute__((ext_vector_type(2)))' rr
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'double __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: FloatingLiteral {{.*}} 'double' 2.500000e+00
+
+double2 FillTwoPointFive(){
+  return 2.5.rr;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'double __attribute__((ext_vector_type(3)))' rrr
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'double __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: FloatingLiteral {{.*}} 'double' 5.000000e-01
+
+double3 FillOneHalf(){
+  return .5.rrr;
+}
+
+// CHECK: ExtVectorElementExpr {{.*}} 'float __attribute__((ext_vector_type(4)))' rrrr
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 2.500000e+00
+
+float4 FillTwoPointFiveFloat(){
+  return 2.5f.rrrr;
+}
+
+// CHECK: InitListExpr {{.*}} 'vector<float, 1>':'float __attribute__((ext_vector_type(1)))'
+// CHECK-NEXT: ExtVectorElementExpr {{.*}} 'float' r
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float __attribute__((ext_vector_type(1)))' <VectorSplat>
+// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 5.000000e-01
+
+vector<float, 1> FillOneHalfFloat(){
+  return .5f.r;
+}

@github-actions
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 6bbccc0bcb36689507ba98ef338527d43334c7e7 2fa183da3991c0bc4da6163609331d198f4a37af -- clang/lib/Lex/Lexer.cpp clang/lib/Lex/LiteralSupport.cpp clang/lib/Sema/SemaExprMember.cpp
View the diff from clang-format here.
diff --git a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp
index 2de5579df271..96db63739f97 100644
--- a/clang/lib/Sema/SemaExprMember.cpp
+++ b/clang/lib/Sema/SemaExprMember.cpp
@@ -1693,8 +1693,8 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
     QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
     BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
                                    BaseExpr.get()->getValueKind());
-    return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS,
-                                ObjCImpDecl, HasTemplateArgs, TemplateKWLoc);
+    return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS, ObjCImpDecl,
+                            HasTemplateArgs, TemplateKWLoc);
   }
 
   S.Diag(OpLoc, diag::err_typecheck_member_reference_struct_union)

@@ -1950,6 +1950,10 @@ bool Lexer::LexNumericConstant(Token &Result, const char *CurPtr) {
while (isPreprocessingNumberBody(C)) {
CurPtr = ConsumeChar(CurPtr, Size, Result);
PrevCh = C;
if (LangOpts.HLSL && C == '.' && (*CurPtr == 'x' || *CurPtr == 'r')) {
CurPtr--;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd need to decrement by Size here in case there's an escaped newline or something:

  return 1\
.xx;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

}

int4 SomeNonsense(int V) {
return V.poop; // expected-error{{illegal vector component name 'p'}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like another test that validates we catch multiple levels of dots. e.g.,

float2 HowManyFloats(float V) {
  return V.rr.rr;
}

or dots not followed by anything useful:

float2 WhatIsHappening(float V) {
  return V.;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really sad thing is that the first case there seems to be valid on both our reference compilers:

DXC on Compiler Explorer
FXC on Shader Playground

We may need to support that. I'll add tests either way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add CodeGen tests that show we do... whatever that is.

Comment on lines +35 to +37
uint3 FillOneUnsigned(){
return 1u.xxx;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with:

auto HooBoy() {
  return 4wb.xxxx;
}

or with a float that has a trailing period followed by this Very Special™ suffix?

float3 err() {
  return 1..rrr;
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HooBoy fails because wb isn't a supported suffix in C++ mode (and all HLSL versions are C++-based). Also, even the most experimental HLSL version is only C++-11 based so the auto return type will always error. I've added a different integer test case:

int64_t4 HooBoy() {
  return 4l.xxxx;
}

err() works, which is also sadly expected despite the terrible syntax:
FXC on shader playground

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a codegen test for err()

@cor3ntin
Copy link
Contributor

This is the documentation i found.
Can you confirm the intent is only to support .x??? and r??? ?
It alludes to more options.
Maybe we need a isHLSLSwizzleStart function to avoid comparing to 'x' and 'r' in multiple places.

https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx9-graphics-reference-asm-vs-registers-modifiers-source-swizzling
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx9-graphics-reference-asm-ps-registers-modifiers-source-register-swizzling

I don't suppose there is a specification / grammar for these things?

@llvm-beanz
Copy link
Collaborator Author

This is the documentation i found. Can you confirm the intent is only to support .x??? and r??? ? It alludes to more options. Maybe we need a isHLSLSwizzleStart function to avoid comparing to 'x' and 'r' in multiple places.

The documentation for this is bad. Single element vectors can only have x and r components, so it should never be anything else. Adding a helper is probably worth doing just for code simplicity and expressivity. We also don't support the r suffix, so there shouldn't be any syntactic ambiguity with things like 1.r being interpreted as fixed point.

https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx9-graphics-reference-asm-vs-registers-modifiers-source-swizzling https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx9-graphics-reference-asm-ps-registers-modifiers-source-register-swizzling

There is also this page:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math

It doesn't really add much.

I don't suppose there is a specification / grammar for these things?

Working on it. I've been working to draft a language specification over here. I haven't yet gotten to the vector grammar. I'm unfortunately chasing down too many paths at once trying to both get enough syntax supported that we can generate real programs to flesh out codegen, and trying to get the language semantics documented and implemented in a way that conforms with the reference compilers.

I'm digging myself out from the Dev Meeting, but I'll work on getting an update to this PR posted this week.

Thanks for the great feedback!

../clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
}

int4 SomeNonsense(int V) {
return V.poop; // expected-error{{illegal vector component name 'p'}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add CodeGen tests that show we do... whatever that is.

// CHECK-NEXT: DeclRefExpr {{.*}} 'float' lvalue ParmVar {{.*}} 'V' 'float'


float4 ToThreeFloats(float V){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this to four floats?

Comment on lines +35 to +37
uint3 FillOneUnsigned(){
return 1u.xxx;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a codegen test for err()

This change adds CodeGen support for writing to scalar l-values through
ext vector component l-value expressions. This is something only HLSL is
crazy enough to support.

../clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
../clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
../clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl
@llvmbot llvmbot added the clang:codegen IR generation bugs: mangling, exceptions, etc. label Oct 27, 2023
@llvm-beanz
Copy link
Collaborator Author

@cor3ntin, I've in parallel been working on fleshing out the immediately relevant bits of our language spec. This PR describes the pp-number and vector-literal grammars roughly correctly to this change.

I think I've addressed all the other feedback on the PR, please let me know if there's anything else you think I should change.

@llvm-beanz
Copy link
Collaborator Author

Friendly ping @AaronBallman & @cor3ntin.

1 similar comment
@llvm-beanz
Copy link
Collaborator Author

Friendly ping @AaronBallman & @cor3ntin.

Copy link
Contributor

@cor3ntin cor3ntin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy with the lexing/sema changes.
Please give @AaronBallman the opportunity to look at the code gen bits.

Copy link
Collaborator

@AaronBallman AaronBallman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@llvm-beanz llvm-beanz merged commit 2630d72 into llvm:main Nov 29, 2023
llvm-beanz added a commit that referenced this pull request Nov 29, 2023
This fixes the test to handle the changes in the AST printer.
../clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzles.hlsl
@llvm-beanz llvm-beanz mentioned this pull request Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

HLSL Scalar Constant Swizzles
5 participants