Skip to content

Commit adb0d8d

Browse files
authored
[Clang] Distinguish expanding-pack-in-place cases for SubstTemplateTypeParmTypes (#114220)
In 50e5411, we preserved the pack substitution index within SubstTemplateTypeParmType nodes and performed in-place expansions of packs such that type constraints on a lambda that serve as a pattern of a fold expression could be evaluated if the type constraints contain any packs that are expanded by the fold expression. However, we made an incorrect assumption of the condition under which in-place expansion should occur. For example, a SizeOfPackExpr case relies on SubstTemplateTypeParmType nodes being transformed to SubstTemplateTypeParmPackTypes rather than expanding them immediately in place. This fixes that by adding a flag to SubstTemplateTypeParmType to discriminate such in-place expansion situations. Fixes #113518
1 parent 9f79615 commit adb0d8d

File tree

9 files changed

+88
-22
lines changed

9 files changed

+88
-22
lines changed

clang/include/clang/AST/ASTContext.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1747,7 +1747,9 @@ class ASTContext : public RefCountedBase<ASTContext> {
17471747
QualType
17481748
getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
17491749
unsigned Index,
1750-
std::optional<unsigned> PackIndex) const;
1750+
std::optional<unsigned> PackIndex,
1751+
SubstTemplateTypeParmTypeFlag Flag =
1752+
SubstTemplateTypeParmTypeFlag::None) const;
17511753
QualType getSubstTemplateTypeParmPackType(Decl *AssociatedDecl,
17521754
unsigned Index, bool Final,
17531755
const TemplateArgument &ArgPack);

clang/include/clang/AST/PropertiesBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def Selector : PropertyType;
137137
def SourceLocation : PropertyType;
138138
def StmtRef : RefPropertyType<"Stmt"> { let ConstWhenWriting = 1; }
139139
def ExprRef : SubclassPropertyType<"Expr", StmtRef>;
140+
def SubstTemplateTypeParmTypeFlag : EnumPropertyType;
140141
def TemplateArgument : PropertyType;
141142
def TemplateArgumentKind : EnumPropertyType<"TemplateArgument::ArgKind">;
142143
def TemplateName : DefaultValuePropertyType;

clang/include/clang/AST/Type.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,15 @@ enum class AutoTypeKeyword {
18021802
GNUAutoType
18031803
};
18041804

1805+
enum class SubstTemplateTypeParmTypeFlag {
1806+
None,
1807+
1808+
/// Whether to expand the pack using the stored PackIndex in place. This is
1809+
/// useful for e.g. substituting into an atomic constraint expression, where
1810+
/// that expression is part of an unexpanded pack.
1811+
ExpandPacksInPlace,
1812+
};
1813+
18051814
enum class ArraySizeModifier;
18061815
enum class ElaboratedTypeKeyword;
18071816
enum class VectorKind;
@@ -2171,6 +2180,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
21712180
LLVM_PREFERRED_TYPE(bool)
21722181
unsigned HasNonCanonicalUnderlyingType : 1;
21732182

2183+
LLVM_PREFERRED_TYPE(SubstTemplateTypeParmTypeFlag)
2184+
unsigned SubstitutionFlag : 1;
2185+
21742186
// The index of the template parameter this substitution represents.
21752187
unsigned Index : 15;
21762188

@@ -6387,7 +6399,8 @@ class SubstTemplateTypeParmType final
63876399
Decl *AssociatedDecl;
63886400

63896401
SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
6390-
unsigned Index, std::optional<unsigned> PackIndex);
6402+
unsigned Index, std::optional<unsigned> PackIndex,
6403+
SubstTemplateTypeParmTypeFlag Flag);
63916404

63926405
public:
63936406
/// Gets the type that was substituted for the template
@@ -6416,21 +6429,31 @@ class SubstTemplateTypeParmType final
64166429
return SubstTemplateTypeParmTypeBits.PackIndex - 1;
64176430
}
64186431

6432+
SubstTemplateTypeParmTypeFlag getSubstitutionFlag() const {
6433+
return static_cast<SubstTemplateTypeParmTypeFlag>(
6434+
SubstTemplateTypeParmTypeBits.SubstitutionFlag);
6435+
}
6436+
64196437
bool isSugared() const { return true; }
64206438
QualType desugar() const { return getReplacementType(); }
64216439

64226440
void Profile(llvm::FoldingSetNodeID &ID) {
64236441
Profile(ID, getReplacementType(), getAssociatedDecl(), getIndex(),
6424-
getPackIndex());
6442+
getPackIndex(), getSubstitutionFlag());
64256443
}
64266444

64276445
static void Profile(llvm::FoldingSetNodeID &ID, QualType Replacement,
64286446
const Decl *AssociatedDecl, unsigned Index,
6429-
std::optional<unsigned> PackIndex) {
6447+
std::optional<unsigned> PackIndex,
6448+
SubstTemplateTypeParmTypeFlag Flag) {
64306449
Replacement.Profile(ID);
64316450
ID.AddPointer(AssociatedDecl);
64326451
ID.AddInteger(Index);
64336452
ID.AddInteger(PackIndex ? *PackIndex - 1 : 0);
6453+
ID.AddInteger(llvm::to_underlying(Flag));
6454+
assert((Flag != SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace ||
6455+
PackIndex) &&
6456+
"ExpandPacksInPlace needs a valid PackIndex");
64346457
}
64356458

64366459
static bool classof(const Type *T) {

clang/include/clang/AST/TypeProperties.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,11 +824,14 @@ let Class = SubstTemplateTypeParmType in {
824824
def : Property<"PackIndex", Optional<UInt32>> {
825825
let Read = [{ node->getPackIndex() }];
826826
}
827+
def : Property<"SubstitutionFlag", SubstTemplateTypeParmTypeFlag> {
828+
let Read = [{ node->getSubstitutionFlag() }];
829+
}
827830

828831
// The call to getCanonicalType here existed in ASTReader.cpp, too.
829832
def : Creator<[{
830833
return ctx.getSubstTemplateTypeParmType(
831-
replacementType, associatedDecl, Index, PackIndex);
834+
replacementType, associatedDecl, Index, PackIndex, SubstitutionFlag);
832835
}]>;
833836
}
834837

clang/lib/AST/ASTContext.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5303,10 +5303,11 @@ QualType ASTContext::getHLSLAttributedResourceType(
53035303
/// Retrieve a substitution-result type.
53045304
QualType ASTContext::getSubstTemplateTypeParmType(
53055305
QualType Replacement, Decl *AssociatedDecl, unsigned Index,
5306-
std::optional<unsigned> PackIndex) const {
5306+
std::optional<unsigned> PackIndex,
5307+
SubstTemplateTypeParmTypeFlag Flag) const {
53075308
llvm::FoldingSetNodeID ID;
53085309
SubstTemplateTypeParmType::Profile(ID, Replacement, AssociatedDecl, Index,
5309-
PackIndex);
5310+
PackIndex, Flag);
53105311
void *InsertPos = nullptr;
53115312
SubstTemplateTypeParmType *SubstParm =
53125313
SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
@@ -5316,7 +5317,7 @@ QualType ASTContext::getSubstTemplateTypeParmType(
53165317
!Replacement.isCanonical()),
53175318
alignof(SubstTemplateTypeParmType));
53185319
SubstParm = new (Mem) SubstTemplateTypeParmType(Replacement, AssociatedDecl,
5319-
Index, PackIndex);
5320+
Index, PackIndex, Flag);
53205321
Types.push_back(SubstParm);
53215322
SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
53225323
}

clang/lib/AST/ASTImporter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,8 +1628,8 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
16281628
return ToReplacementTypeOrErr.takeError();
16291629

16301630
return Importer.getToContext().getSubstTemplateTypeParmType(
1631-
*ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(),
1632-
T->getPackIndex());
1631+
*ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(), T->getPackIndex(),
1632+
T->getSubstitutionFlag());
16331633
}
16341634

16351635
ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(

clang/lib/AST/Type.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4219,7 +4219,7 @@ static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
42194219

42204220
SubstTemplateTypeParmType::SubstTemplateTypeParmType(
42214221
QualType Replacement, Decl *AssociatedDecl, unsigned Index,
4222-
std::optional<unsigned> PackIndex)
4222+
std::optional<unsigned> PackIndex, SubstTemplateTypeParmTypeFlag Flag)
42234223
: Type(SubstTemplateTypeParm, Replacement.getCanonicalType(),
42244224
Replacement->getDependence()),
42254225
AssociatedDecl(AssociatedDecl) {
@@ -4230,6 +4230,10 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(
42304230

42314231
SubstTemplateTypeParmTypeBits.Index = Index;
42324232
SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
4233+
SubstTemplateTypeParmTypeBits.SubstitutionFlag = llvm::to_underlying(Flag);
4234+
assert((Flag != SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace ||
4235+
PackIndex) &&
4236+
"ExpandPacksInPlace needs a valid PackIndex");
42334237
assert(AssociatedDecl != nullptr);
42344238
}
42354239

clang/lib/Sema/SemaTemplateInstantiate.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,14 +1661,17 @@ namespace {
16611661
QualType
16621662
TransformSubstTemplateTypeParmType(TypeLocBuilder &TLB,
16631663
SubstTemplateTypeParmTypeLoc TL) {
1664-
if (SemaRef.CodeSynthesisContexts.back().Kind !=
1665-
Sema::CodeSynthesisContext::ConstraintSubstitution)
1664+
const SubstTemplateTypeParmType *Type = TL.getTypePtr();
1665+
if (Type->getSubstitutionFlag() !=
1666+
SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace)
16661667
return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
16671668

1668-
auto PackIndex = TL.getTypePtr()->getPackIndex();
1669-
std::optional<Sema::ArgumentPackSubstitutionIndexRAII> SubstIndex;
1670-
if (SemaRef.ArgumentPackSubstitutionIndex == -1 && PackIndex)
1671-
SubstIndex.emplace(SemaRef, *PackIndex);
1669+
assert(Type->getPackIndex());
1670+
TemplateArgument TA = TemplateArgs(
1671+
Type->getReplacedParameter()->getDepth(), Type->getIndex());
1672+
assert(*Type->getPackIndex() + 1 <= TA.pack_size());
1673+
Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(
1674+
SemaRef, TA.pack_size() - 1 - *Type->getPackIndex());
16721675

16731676
return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
16741677
}
@@ -3147,7 +3150,11 @@ struct ExpandPackedTypeConstraints
31473150

31483151
using inherited = TreeTransform<ExpandPackedTypeConstraints>;
31493152

3150-
ExpandPackedTypeConstraints(Sema &SemaRef) : inherited(SemaRef) {}
3153+
const MultiLevelTemplateArgumentList &TemplateArgs;
3154+
3155+
ExpandPackedTypeConstraints(
3156+
Sema &SemaRef, const MultiLevelTemplateArgumentList &TemplateArgs)
3157+
: inherited(SemaRef), TemplateArgs(TemplateArgs) {}
31513158

31523159
using inherited::TransformTemplateTypeParmType;
31533160

@@ -3163,9 +3170,15 @@ struct ExpandPackedTypeConstraints
31633170

31643171
assert(SemaRef.ArgumentPackSubstitutionIndex != -1);
31653172

3173+
TemplateArgument Arg = TemplateArgs(T->getDepth(), T->getIndex());
3174+
3175+
std::optional<unsigned> PackIndex;
3176+
if (Arg.getKind() == TemplateArgument::Pack)
3177+
PackIndex = Arg.pack_size() - 1 - SemaRef.ArgumentPackSubstitutionIndex;
3178+
31663179
QualType Result = SemaRef.Context.getSubstTemplateTypeParmType(
3167-
TL.getType(), T->getDecl(), T->getIndex(),
3168-
SemaRef.ArgumentPackSubstitutionIndex);
3180+
TL.getType(), T->getDecl(), T->getIndex(), PackIndex,
3181+
SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace);
31693182
SubstTemplateTypeParmTypeLoc NewTL =
31703183
TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
31713184
NewTL.setNameLoc(TL.getNameLoc());
@@ -3224,8 +3237,8 @@ bool Sema::SubstTypeConstraint(
32243237
TemplateArgumentListInfo InstArgs;
32253238
InstArgs.setLAngleLoc(TemplArgInfo->LAngleLoc);
32263239
InstArgs.setRAngleLoc(TemplArgInfo->RAngleLoc);
3227-
if (ExpandPackedTypeConstraints(*this).SubstTemplateArguments(
3228-
TemplArgInfo->arguments(), InstArgs))
3240+
if (ExpandPackedTypeConstraints(*this, TemplateArgs)
3241+
.SubstTemplateArguments(TemplArgInfo->arguments(), InstArgs))
32293242
return true;
32303243

32313244
// The type of the original parameter.

clang/test/SemaCXX/cxx20-ctad-type-alias.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,22 @@ template <typename V> using Alias = S<V>;
494494
Alias A(42);
495495

496496
} // namespace GH111508
497+
498+
namespace GH113518 {
499+
500+
template <class T, unsigned N> struct array {
501+
T value[N];
502+
};
503+
504+
template <typename Tp, typename... Up>
505+
array(Tp, Up...) -> array<Tp, 1 + sizeof...(Up)>;
506+
507+
template <typename T> struct ArrayType {
508+
template <unsigned size> using Array = array<T, size>;
509+
};
510+
511+
template <ArrayType<int>::Array array> void test() {}
512+
513+
void foo() { test<{1, 2, 3}>(); }
514+
515+
} // namespace GH113518

0 commit comments

Comments
 (0)