Skip to content

[NFC] Decompose function input types #10264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 68 additions & 13 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2297,10 +2297,44 @@ getSILFunctionLanguage(SILFunctionTypeRepresentation rep) {
class AnyFunctionType : public TypeBase {
const Type Input;
const Type Output;

const unsigned NumParams;

public:
using Representation = FunctionTypeRepresentation;

class Param {
public:
explicit Param(Type t) : Ty(t), Label(Identifier()), Flags() {}
explicit Param(const TupleTypeElt &tte)
: Ty(tte.getType()), Label(tte.getName()),
Flags(tte.getParameterFlags()) {}

private:
/// The type of the parameter. For a variadic parameter, this is the
/// element type.
Type Ty;

// The label associated with the parameter, if any.
Identifier Label;

/// Parameter specific flags.
ParameterTypeFlags Flags = {};

public:
Type getType() const { return Ty; }

Identifier getLabel() const { return Label; }

/// Whether the parameter is varargs
bool isVariadic() const { return Flags.isVariadic(); }

/// Whether the parameter is marked '@autoclosure'
bool isAutoClosure() const { return Flags.isAutoClosure(); }

/// Whether the parameter is marked '@escaping'
bool isEscaping() const { return Flags.isEscaping(); }
};

/// \brief A class which abstracts out some details necessary for
/// making a call.
class ExtInfo {
Expand Down Expand Up @@ -2442,16 +2476,18 @@ class AnyFunctionType : public TypeBase {
protected:
AnyFunctionType(TypeKind Kind, const ASTContext *CanTypeContext,
Type Input, Type Output, RecursiveTypeProperties properties,
const ExtInfo &Info)
: TypeBase(Kind, CanTypeContext, properties), Input(Input), Output(Output) {
unsigned NumParams, const ExtInfo &Info)
: TypeBase(Kind, CanTypeContext, properties), Input(Input), Output(Output),
NumParams(NumParams) {
AnyFunctionTypeBits.ExtInfo = Info.Bits;
}

public:

Type getInput() const { return Input; }
Type getResult() const { return Output; }

ArrayRef<AnyFunctionType::Param> getParams() const;
unsigned getNumParams() const { return NumParams; }

ExtInfo getExtInfo() const {
return ExtInfo(AnyFunctionTypeBits.ExtInfo);
}
Expand Down Expand Up @@ -2501,22 +2537,32 @@ END_CAN_TYPE_WRAPPER(AnyFunctionType, Type)
///
/// For example:
/// let x : (Float, Int) -> Int
class FunctionType : public AnyFunctionType {
class FunctionType final : public AnyFunctionType,
private llvm::TrailingObjects<FunctionType, AnyFunctionType::Param> {
friend TrailingObjects;

public:
/// 'Constructor' Factory Function
static FunctionType *get(Type Input, Type Result) {
return get(Input, Result, ExtInfo());
}

static FunctionType *get(Type Input, Type Result, const ExtInfo &Info);



// Retrieve the input parameters of this function type.
ArrayRef<AnyFunctionType::Param> getParams() const {
return {getTrailingObjects<AnyFunctionType::Param>(), getNumParams()};
}

// Implement isa/cast/dyncast/etc.
static bool classof(const TypeBase *T) {
return T->getKind() == TypeKind::Function;
}

private:
FunctionType(Type Input, Type Result,
FunctionType(ArrayRef<AnyFunctionType::Param> params,
Type Input, Type Result,
RecursiveTypeProperties properties,
const ExtInfo &Info);
};
Expand Down Expand Up @@ -2587,25 +2633,34 @@ std::string getParamListAsString(ArrayRef<CallArgParam> parameters);
/// on those parameters and dependent member types thereof. The input and
/// output types of the generic function can be expressed in terms of those
/// generic parameters.
class GenericFunctionType : public AnyFunctionType,
public llvm::FoldingSetNode
{
class GenericFunctionType final : public AnyFunctionType,
public llvm::FoldingSetNode,
private llvm::TrailingObjects<GenericFunctionType, AnyFunctionType::Param> {
friend TrailingObjects;

GenericSignature *Signature;

/// Construct a new generic function type.
GenericFunctionType(GenericSignature *sig,
ArrayRef<AnyFunctionType::Param> params,
Type input,
Type result,
const ExtInfo &info,
const ASTContext *ctx,
RecursiveTypeProperties properties);

public:
/// Create a new generic function type.
static GenericFunctionType *get(GenericSignature *sig,
Type input,
Type result,
const ExtInfo &info);


// Retrieve the input parameters of this function type.
ArrayRef<AnyFunctionType::Param> getParams() const {
return {getTrailingObjects<AnyFunctionType::Param>(), getNumParams()};
}

/// Retrieve the generic signature of this function type.
GenericSignature *getGenericSignature() const {
return Signature;
Expand Down
71 changes: 58 additions & 13 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,17 @@ getGenericFunctionRecursiveProperties(Type Input, Type Result) {
return properties;
}

ArrayRef<AnyFunctionType::Param> AnyFunctionType::getParams() const {
switch (getKind()) {
case TypeKind::Function:
return cast<FunctionType>(this)->getParams();
case TypeKind::GenericFunction:
return cast<GenericFunctionType>(this)->getParams();
default:
llvm_unreachable("Undefined function type");
}
}

AnyFunctionType *AnyFunctionType::withExtInfo(ExtInfo info) const {
if (isa<FunctionType>(this))
return FunctionType::get(getInput(), getResult(), info);
Expand All @@ -3150,6 +3161,30 @@ AnyFunctionType *AnyFunctionType::withExtInfo(ExtInfo info) const {
llvm_unreachable("unhandled function type");
}

static SmallVector<AnyFunctionType::Param, 4> decomposeInputType(Type type) {
SmallVector<AnyFunctionType::Param, 4> result;
switch (type->getKind()) {
case TypeKind::Tuple: {
auto tupleTy = cast<TupleType>(type.getPointer());
for (auto &elt : tupleTy->getElements()) {
AnyFunctionType::Param param(elt);
result.push_back(param);
}
return result;
}

case TypeKind::Paren: {
auto ty = cast<ParenType>(type.getPointer())->getUnderlyingType();
result.push_back(AnyFunctionType::Param(ty));
return result;
}

default:
result.push_back(AnyFunctionType::Param(type));
return result;
}
}

FunctionType *FunctionType::get(Type Input, Type Result,
const ExtInfo &Info) {
auto properties = getFunctionRecursiveProperties(Input, Result);
Expand All @@ -3161,21 +3196,28 @@ FunctionType *FunctionType::get(Type Input, Type Result,
FunctionType *&Entry
= C.Impl.getArena(arena).FunctionTypes[{Input, {Result, attrKey} }];
if (Entry) return Entry;

return Entry = new (C, arena) FunctionType(Input, Result,
properties,
Info);

auto params = decomposeInputType(Input);
void *mem = C.Allocate(sizeof(FunctionType) +
sizeof(AnyFunctionType::Param) * params.size(),
alignof(FunctionType));
return Entry = new (mem) FunctionType(params, Input, Result,
properties, Info);
}

// If the input and result types are canonical, then so is the result.
FunctionType::FunctionType(Type input, Type output,
FunctionType::FunctionType(ArrayRef<AnyFunctionType::Param> params,
Type input, Type output,
RecursiveTypeProperties properties,
const ExtInfo &Info)
: AnyFunctionType(TypeKind::Function,
(input->isCanonical() && output->isCanonical())
? &input->getASTContext()
: nullptr,
input, output, properties, Info) {}
input, output, properties, params.size(), Info) {
std::uninitialized_copy(params.begin(), params.end(),
getTrailingObjects<AnyFunctionType::Param>());
}

void GenericFunctionType::Profile(llvm::FoldingSetNodeID &ID,
GenericSignature *sig,
Expand Down Expand Up @@ -3221,13 +3263,14 @@ GenericFunctionType::get(GenericSignature *sig,
= ctx.Impl.GenericFunctionTypes.FindNodeOrInsertPos(id, insertPos)) {
return result;
}

// Allocate storage for the object.
void *mem = ctx.Allocate(sizeof(GenericFunctionType),

auto params = decomposeInputType(input);
void *mem = ctx.Allocate(sizeof(GenericFunctionType) +
sizeof(AnyFunctionType::Param) * params.size(),
alignof(GenericFunctionType));

auto properties = getGenericFunctionRecursiveProperties(input, output);
auto result = new (mem) GenericFunctionType(sig, input, output, info,
auto result = new (mem) GenericFunctionType(sig, params, input, output, info,
isCanonical ? &ctx : nullptr,
properties);

Expand All @@ -3237,15 +3280,17 @@ GenericFunctionType::get(GenericSignature *sig,

GenericFunctionType::GenericFunctionType(
GenericSignature *sig,
ArrayRef<AnyFunctionType::Param> params,
Type input,
Type result,
const ExtInfo &info,
const ASTContext *ctx,
RecursiveTypeProperties properties)
: AnyFunctionType(TypeKind::GenericFunction, ctx, input, result,
properties, info),
Signature(sig)
{}
properties, params.size(), info), Signature(sig) {
std::uninitialized_copy(params.begin(), params.end(),
getTrailingObjects<AnyFunctionType::Param>());
}

GenericTypeParamType *GenericTypeParamType::get(unsigned depth, unsigned index,
const ASTContext &ctx) {
Expand Down