Skip to content

Commit d01d3dd

Browse files
committed
[HLSL] Implement array temporary support
In HLSL function parameters are passed by value, including array parameters. This change introduces a new AST node to represent array temporary expressions. They behave as lvalues to temporary arrays and decay to pointers for overload resolution and code generation. The behavior of HLSL function calls is documented in the [draft language specification](https://microsoft.github.io/hlsl-specs/specs/hlsl.pdf) under the Expr.Post.Call heading. Additionally the design of this implementation approach is documented in [Clang's documentation](https://clang.llvm.org/docs/HLSL/FunctionCalls.html)
1 parent 6e4930c commit d01d3dd

File tree

20 files changed

+213
-1
lines changed

20 files changed

+213
-1
lines changed

clang/include/clang/AST/Expr.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6651,6 +6651,44 @@ class RecoveryExpr final : public Expr,
66516651
friend class ASTStmtWriter;
66526652
};
66536653

6654+
/// HLSLArrayTemporaryExpr - In HLSL, default parameter passing is by value
6655+
/// including for arrays. This AST node represents a materialized temporary of a
6656+
/// constant size arrray.
6657+
class HLSLArrayTemporaryExpr : public Expr {
6658+
Expr *SourceExpr;
6659+
6660+
HLSLArrayTemporaryExpr(Expr *S)
6661+
: Expr(HLSLArrayTemporaryExprClass, S->getType(), VK_LValue, OK_Ordinary),
6662+
SourceExpr(S) {}
6663+
6664+
HLSLArrayTemporaryExpr(EmptyShell Empty)
6665+
: Expr(HLSLArrayTemporaryExprClass, Empty), SourceExpr(nullptr) {}
6666+
6667+
public:
6668+
static HLSLArrayTemporaryExpr *Create(const ASTContext &Ctx, Expr *S);
6669+
static HLSLArrayTemporaryExpr *CreateEmpty(const ASTContext &Ctx);
6670+
6671+
const Expr *getSourceExpr() const { return SourceExpr; }
6672+
Expr *getSourceExpr() { return SourceExpr; }
6673+
void setSourceExpr(Expr *S) { SourceExpr = S; }
6674+
6675+
SourceLocation getBeginLoc() const { return SourceExpr->getBeginLoc(); }
6676+
6677+
SourceLocation getEndLoc() const { return SourceExpr->getEndLoc(); }
6678+
6679+
static bool classof(const Stmt *T) {
6680+
return T->getStmtClass() == HLSLArrayTemporaryExprClass;
6681+
}
6682+
6683+
// Iterators
6684+
child_range children() {
6685+
return child_range(child_iterator(), child_iterator());
6686+
}
6687+
const_child_range children() const {
6688+
return const_child_range(const_child_iterator(), const_child_iterator());
6689+
}
6690+
};
6691+
66546692
} // end namespace clang
66556693

66566694
#endif // LLVM_CLANG_AST_EXPR_H

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3171,6 +3171,8 @@ DEF_TRAVERSE_STMT(OMPTargetParallelGenericLoopDirective,
31713171
DEF_TRAVERSE_STMT(OMPErrorDirective,
31723172
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
31733173

3174+
DEF_TRAVERSE_STMT(HLSLArrayTemporaryExpr, {})
3175+
31743176
// OpenMP clauses.
31753177
template <typename Derived>
31763178
bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) {

clang/include/clang/Basic/StmtNodes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,6 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>;
295295
def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
296296
def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
297297
def OMPErrorDirective : StmtNode<OMPExecutableDirective>;
298+
299+
// HLSL Extensions
300+
def HLSLArrayTemporaryExpr : StmtNode<Expr>;

clang/lib/AST/Expr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
35693569
case ConceptSpecializationExprClass:
35703570
case RequiresExprClass:
35713571
case SYCLUniqueStableNameExprClass:
3572+
case HLSLArrayTemporaryExprClass:
35723573
// These never have a side-effect.
35733574
return false;
35743575

@@ -5227,3 +5228,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
52275228
alignof(OMPIteratorExpr));
52285229
return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
52295230
}
5231+
5232+
HLSLArrayTemporaryExpr *
5233+
HLSLArrayTemporaryExpr::Create(const ASTContext &Ctx, Expr *Base) {
5234+
return new (Ctx) HLSLArrayTemporaryExpr(Base);
5235+
}
5236+
5237+
HLSLArrayTemporaryExpr *
5238+
HLSLArrayTemporaryExpr::CreateEmpty(const ASTContext &Ctx) {
5239+
return new (Ctx) HLSLArrayTemporaryExpr(EmptyShell());
5240+
}

clang/lib/AST/ExprClassification.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
148148
case Expr::OMPArraySectionExprClass:
149149
case Expr::OMPArrayShapingExprClass:
150150
case Expr::OMPIteratorExprClass:
151+
case Expr::HLSLArrayTemporaryExprClass:
151152
return Cl::CL_LValue;
152153

153154
// C99 6.5.2.5p5 says that compound literals are lvalues.

clang/lib/AST/ExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16044,6 +16044,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
1604416044
case Expr::CoyieldExprClass:
1604516045
case Expr::SYCLUniqueStableNameExprClass:
1604616046
case Expr::CXXParenListInitExprClass:
16047+
case Expr::HLSLArrayTemporaryExprClass:
1604716048
return ICEDiag(IK_NotICE, E->getBeginLoc());
1604816049

1604916050
case Expr::InitListExprClass: {

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4701,6 +4701,10 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
47014701
E = cast<ConstantExpr>(E)->getSubExpr();
47024702
goto recurse;
47034703

4704+
case Expr::HLSLArrayTemporaryExprClass:
4705+
E = cast<HLSLArrayTemporaryExpr>(E)->getSourceExpr();
4706+
goto recurse;
4707+
47044708
// FIXME: invent manglings for all these.
47054709
case Expr::BlockExprClass:
47064710
case Expr::ChooseExprClass:

clang/lib/AST/StmtPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,6 +2749,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
27492749
OS << ")";
27502750
}
27512751

2752+
void StmtPrinter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *Node) {
2753+
PrintExpr(Node->getSourceExpr());
2754+
}
2755+
27522756
//===----------------------------------------------------------------------===//
27532757
// Stmt method implementations
27542758
//===----------------------------------------------------------------------===//

clang/lib/AST/StmtProfile.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,15 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
24332433
}
24342434
}
24352435

2436+
//===----------------------------------------------------------------------===//
2437+
// HLSL AST Nodes
2438+
//===----------------------------------------------------------------------===//
2439+
2440+
void StmtProfiler::VisitHLSLArrayTemporaryExpr(
2441+
const HLSLArrayTemporaryExpr *S) {
2442+
VisitExpr(S);
2443+
}
2444+
24362445
void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
24372446
bool Canonical, bool ProfileLambdaExpr) const {
24382447
StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,7 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
15901590
case Expr::CXXUuidofExprClass:
15911591
return EmitCXXUuidofLValue(cast<CXXUuidofExpr>(E));
15921592
case Expr::LambdaExprClass:
1593+
case Expr::HLSLArrayTemporaryExprClass:
15931594
return EmitAggExprToLValue(E);
15941595

15951596
case Expr::ExprWithCleanupsClass: {

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
235235
RValue Res = CGF.EmitAtomicExpr(E);
236236
EmitFinalDestCopy(E->getType(), Res);
237237
}
238+
239+
void VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E);
238240
};
239241
} // end anonymous namespace.
240242

@@ -1923,6 +1925,10 @@ void AggExprEmitter::VisitDesignatedInitUpdateExpr(DesignatedInitUpdateExpr *E)
19231925
VisitInitListExpr(E->getUpdater());
19241926
}
19251927

1928+
void AggExprEmitter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E) {
1929+
Visit(E->getSourceExpr());
1930+
}
1931+
19261932
//===----------------------------------------------------------------------===//
19271933
// Entry Points into this File
19281934
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaExceptionSpec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
14141414
case Expr::SourceLocExprClass:
14151415
case Expr::ConceptSpecializationExprClass:
14161416
case Expr::RequiresExprClass:
1417+
case Expr::HLSLArrayTemporaryExprClass:
14171418
// These expressions can never throw.
14181419
return CT_Cannot;
14191420

clang/lib/Sema/SemaInit.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10524,6 +10524,11 @@ Sema::PerformCopyInitialization(const InitializedEntity &Entity,
1052410524
Expr *InitE = Init.get();
1052510525
assert(InitE && "No initialization expression?");
1052610526

10527+
if (LangOpts.HLSL)
10528+
if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType()))
10529+
if (AdjTy->getOriginalType()->isConstantArrayType())
10530+
InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE);
10531+
1052710532
if (EqualLoc.isInvalid())
1052810533
EqualLoc = InitE->getBeginLoc();
1052910534

clang/lib/Sema/TreeTransform.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15461,6 +15461,19 @@ TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) {
1546115461
return getSema().ActOnCapturedRegionEnd(Body.get());
1546215462
}
1546315463

15464+
template <typename Derived>
15465+
ExprResult TreeTransform<Derived>::TransformHLSLArrayTemporaryExpr(
15466+
HLSLArrayTemporaryExpr *E) {
15467+
ExprResult SrcExpr = getDerived().TransformExpr(E->getSourceExpr());
15468+
if (SrcExpr.isInvalid())
15469+
return ExprError();
15470+
15471+
if (!getDerived().AlwaysRebuild() && SrcExpr.get() == E->getSourceExpr())
15472+
return E;
15473+
15474+
return HLSLArrayTemporaryExpr::Create(getSema().Context, SrcExpr.get());
15475+
}
15476+
1546415477
} // end namespace clang
1546515478

1546615479
#endif // LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H

clang/lib/Serialization/ASTReaderStmt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,14 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
27762776
VisitOMPLoopDirective(D);
27772777
}
27782778

2779+
//===----------------------------------------------------------------------===//
2780+
// HLSL AST Nodes
2781+
//===----------------------------------------------------------------------===//
2782+
2783+
void ASTStmtReader::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
2784+
VisitExpr(S);
2785+
}
2786+
27792787
//===----------------------------------------------------------------------===//
27802788
// ASTReader Implementation
27812789
//===----------------------------------------------------------------------===//

clang/lib/Serialization/ASTWriterStmt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,14 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
28252825
Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE;
28262826
}
28272827

2828+
//===----------------------------------------------------------------------===//
2829+
// HLSL AST Nodes
2830+
//===----------------------------------------------------------------------===//
2831+
2832+
void ASTStmtWriter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
2833+
VisitExpr(S);
2834+
}
2835+
28282836
//===----------------------------------------------------------------------===//
28292837
// ASTWriter Implementation
28302838
//===----------------------------------------------------------------------===//

clang/lib/StaticAnalyzer/Core/ExprEngine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1821,7 +1821,8 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
18211821
case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
18221822
case Stmt::CapturedStmtClass:
18231823
case Stmt::OMPUnrollDirectiveClass:
1824-
case Stmt::OMPMetaDirectiveClass: {
1824+
case Stmt::OMPMetaDirectiveClass:
1825+
case Stmt::HLSLArrayTemporaryExprClass: {
18251826
const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
18261827
Engine.addAbortedBlock(node, currBldrCtx->getBlock());
18271828
break;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | Filecheck %s
2+
3+
void fn(float x[2]) { }
4+
5+
// CHECK-LABEL: define void {{.*}}call{{.*}}
6+
// CHECK: [[Arr:%.*]] = alloca [2 x float]
7+
// CHECK: [[Tmp:%.*]] = alloca [2 x float]
8+
// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 8, i1 false)
9+
// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 8, i1 false)
10+
// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x float], ptr [[Tmp]], i32 0, i32 0
11+
// CHECK: call void {{.*}}fn{{.*}}(ptr noundef [[Decay]])
12+
void call() {
13+
float Arr[2] = {0, 0};
14+
fn(Arr);
15+
}
16+
17+
struct Obj {
18+
float V;
19+
int X;
20+
};
21+
22+
void fn2(Obj O[4]) { }
23+
24+
// CHECK-LABEL: define void {{.*}}call2{{.*}}
25+
// CHECK: [[Arr:%.*]] = alloca [4 x %struct.Obj]
26+
// CHECK: [[Tmp:%.*]] = alloca [4 x %struct.Obj]
27+
// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 32, i1 false)
28+
// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 32, i1 false)
29+
// CHECK: [[Decay:%.*]] = getelementptr inbounds [4 x %struct.Obj], ptr [[Tmp]], i32 0, i32 0
30+
// CHECK: call void {{.*}}fn2{{.*}}(ptr noundef [[Decay]])
31+
void call2() {
32+
Obj Arr[4] = {};
33+
fn2(Arr);
34+
}
35+
36+
37+
void fn3(float x[2][2]) { }
38+
39+
// CHECK-LABEL: define void {{.*}}call3{{.*}}
40+
// CHECK: [[Arr:%.*]] = alloca [2 x [2 x float]]
41+
// CHECK: [[Tmp:%.*]] = alloca [2 x [2 x float]]
42+
// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Arr]], ptr align 4 {{.*}}, i32 16, i1 false)
43+
// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 16, i1 false)
44+
// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x [2 x float]], ptr [[Tmp]], i32 0, i32 0
45+
// CHECK: call void {{.*}}fn3{{.*}}(ptr noundef [[Decay]])
46+
void call3() {
47+
float Arr[2][2] = {{0, 0}, {1,1}};
48+
fn3(Arr);
49+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump %s | Filecheck %s
2+
3+
void fn(float x[2]) { }
4+
5+
// CHECK: CallExpr {{.*}} 'void'
6+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float *)' <FunctionToPointerDecay>
7+
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float *)' lvalue Function {{.*}} 'fn' 'void (float *)'
8+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float *' <ArrayToPointerDecay>
9+
// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2]' lvalue
10+
11+
void call() {
12+
float Arr[2] = {0, 0};
13+
fn(Arr);
14+
}
15+
16+
struct Obj {
17+
float V;
18+
int X;
19+
};
20+
21+
void fn2(Obj O[4]) { }
22+
23+
// CHECK: CallExpr {{.*}} 'void'
24+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(Obj *)' <FunctionToPointerDecay>
25+
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (Obj *)' lvalue Function {{.*}} 'fn2' 'void (Obj *)'
26+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Obj *' <ArrayToPointerDecay>
27+
// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'Obj[4]' lvalue
28+
29+
void call2() {
30+
Obj Arr[4] = {};
31+
fn2(Arr);
32+
}
33+
34+
35+
void fn3(float x[2][2]) { }
36+
37+
// CHECK: CallExpr {{.*}} 'void'
38+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float (*)[2])' <FunctionToPointerDecay>
39+
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float (*)[2])' lvalue Function {{.*}} 'fn3' 'void (float (*)[2])'
40+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float (*)[2]' <ArrayToPointerDecay>
41+
// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2][2]' lvalue
42+
43+
void call3() {
44+
float Arr[2][2] = {{0, 0}, {1,1}};
45+
fn3(Arr);
46+
}

clang/tools/libclang/CXCursor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
335335
case Stmt::ObjCSubscriptRefExprClass:
336336
case Stmt::RecoveryExprClass:
337337
case Stmt::SYCLUniqueStableNameExprClass:
338+
case Stmt::HLSLArrayTemporaryExprClass:
338339
K = CXCursor_UnexposedExpr;
339340
break;
340341

0 commit comments

Comments
 (0)