Skip to content

[SE-0042][AST/Sema/SILGen] Flattening the function type of unapplied method references #3836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ class alignas(8) Expr {
< (1 << NumCheckedCastKindBits),
"unable to fit a CheckedCastKind in the given number of bits");

class FunctionConversionExprBitfields {
friend class FunctionConversionExpr;
unsigned : NumImplicitConversionExprBits;
unsigned Flattening : 1;
};
enum { NumFunctionConversionExprBits = NumImplicitConversionExprBits + 1 };
static_assert(NumFunctionConversionExprBits <= 32, "fits in an unsigned");

class CollectionUpcastConversionExprBitfields {
friend class CollectionUpcastConversionExpr;
unsigned : NumExprBits;
Expand Down Expand Up @@ -413,6 +421,7 @@ class alignas(8) Expr {
ApplyExprBitfields ApplyExprBits;
CallExprBitfields CallExprBits;
CheckedCastExprBitfields CheckedCastExprBits;
FunctionConversionExprBitfields FunctionConversionExprBits;
CollectionUpcastConversionExprBitfields CollectionUpcastConversionExprBits;
TupleShuffleExprBitfields TupleShuffleExprBits;
ObjCSelectorExprBitfields ObjCSelectorExprBits;
Expand Down Expand Up @@ -2761,8 +2770,22 @@ class UnresolvedTypeConversionExpr : public ImplicitConversionExpr {
class FunctionConversionExpr : public ImplicitConversionExpr {
public:
FunctionConversionExpr(Expr *subExpr, Type type)
: ImplicitConversionExpr(ExprKind::FunctionConversion, subExpr, type) {}

: ImplicitConversionExpr(ExprKind::FunctionConversion, subExpr, type) {
FunctionConversionExprBits.Flattening = false;
}

/// Set whether this function conversion flattens an unapplied member
/// function.
void setFlattening() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to make this its own type of Expr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought about this as well, but there's exactly one place where we set this flag …

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so that place can make the new Expr type instead. These really are not function conversions...

FunctionConversionExprBits.Flattening = true;
}

/// Returns whether this function conversion flattens an unapplied member
/// function.
bool isFlattening() const {
return FunctionConversionExprBits.Flattening;
}

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::FunctionConversion;
}
Expand Down
10 changes: 10 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,16 @@ class AnyFunctionType : public TypeBase {
return getExtInfo().throws();
}

unsigned getCurryLevel() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use the decl's getNumParameterLists() instead. If a function is declared as returning a function, we do not want to flatten it's type (eg func foo() -> () -> ())

unsigned Level = 0;
const AnyFunctionType *function = this;
while ((function = function->getResult()->getAs<AnyFunctionType>()))
++Level;
return Level;
}

AnyFunctionType *getUncurriedFunction();

/// Returns a new function type exactly like this one but with the ExtInfo
/// replaced.
AnyFunctionType *withExtInfo(ExtInfo info) const;
Expand Down
26 changes: 26 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3131,6 +3131,32 @@ AnyFunctionType *AnyFunctionType::withExtInfo(ExtInfo info) const {
llvm_unreachable("unhandled function type");
}

AnyFunctionType *AnyFunctionType::getUncurriedFunction() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic for this already exists in SIL TypeLowering.cpp -- opportunity to share code here?

assert(getCurryLevel() > 0 && "nothing to uncurry");

auto innerFunction = getResult()->castTo<AnyFunctionType>();
SmallVector<TupleTypeElt, 4> params{getInput()->getDesugaredType()};

if (auto tuple = dyn_cast<TupleType>(innerFunction->getInput().getPointer()))
params.append(tuple->getElements().begin(), tuple->getElements().end());
else
params.push_back(innerFunction->getInput()->getDesugaredType());

auto inputType = TupleType::get(params, getASTContext());
auto extInfo =
innerFunction->getExtInfo().withRepresentation(getRepresentation());

if (auto generic = getAs<GenericFunctionType>())
return GenericFunctionType::get(generic->getGenericSignature(), inputType,
innerFunction->getResult(), extInfo);

if (auto poly = getAs<PolymorphicFunctionType>())
return PolymorphicFunctionType::get(inputType, innerFunction->getResult(),
&poly->getGenericParams(), extInfo);

return FunctionType::get(inputType, innerFunction->getResult(), extInfo);
}

FunctionType *FunctionType::get(Type Input, Type Result,
const ExtInfo &Info) {
auto properties = getFunctionRecursiveProperties(Input, Result);
Expand Down
5 changes: 4 additions & 1 deletion lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,10 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
OS << ')';
}
void visitFunctionConversionExpr(FunctionConversionExpr *E) {
printCommon(E, "function_conversion_expr") << '\n';
printCommon(E, "function_conversion_expr");
if (E->isFlattening())
OS << " flattening";
OS << '\n';
printRec(E->getSubExpr());
OS << ')';
}
Expand Down
7 changes: 7 additions & 0 deletions lib/SILGen/SILGenApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,13 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
}

void visitFunctionConversionExpr(FunctionConversionExpr *e) {
// If this is a flattening function conversion, emit the expression
// directly.
if (e->isFlattening()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine you end up with both a flattening and a conversion here -- either you should handle both cases or make it a separate Expr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, both cases are handled in buildThunkBody at the same time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you should do is have a FlatteningExpr that wraps a DeclRefExpr. It would be handled just like a DeclRefExpr, except the underlying decl would be emitted with a different uncurry level.

visitExpr(e);
return;
}

// FIXME: Check whether this function conversion requires us to build a
// thunk.
visit(e->getSubExpr());
Expand Down
3 changes: 2 additions & 1 deletion lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,8 @@ RValue RValueEmitter::visitFunctionConversionExpr(FunctionConversionExpr *e,
result = convertFunctionRepresentation(SGF, e, result, srcRepTy, srcTy);

if (srcTy != destTy)
result = SGF.emitTransformedValue(e, result, srcTy, destTy);
result =
SGF.emitTransformedValue(e, result, srcTy, destTy, e->isFlattening());

if (destTy != destRepTy)
result = convertFunctionRepresentation(SGF, e, result, destTy, destRepTy);
Expand Down
2 changes: 2 additions & 0 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,7 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
ManagedValue emitTransformedValue(SILLocation loc, ManagedValue input,
CanType inputType,
CanType outputType,
bool isFlattening = false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a way to emit a SILDeclRef with the right uncurry level -- I can dig into it if you want, let me know

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should probably just make a new sibling function to create a function-flattening thunk rather than complicate the core value-transformation API like this.

SGFContext ctx = SGFContext());

/// Most general form of the above.
Expand All @@ -1468,6 +1469,7 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
CanType inputSubstType,
AbstractionPattern outputOrigType,
CanType outputSubstType,
bool isFlattening = false,
SGFContext ctx = SGFContext());
RValue emitTransformedValue(SILLocation loc, RValue &&input,
AbstractionPattern inputOrigType,
Expand Down
101 changes: 75 additions & 26 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ namespace {
private:
SILGenFunction &SGF;
SILLocation Loc;
bool Flattening;

public:
Transform(SILGenFunction &SGF, SILLocation loc) : SGF(SGF), Loc(loc) {}
Transform(SILGenFunction &SGF, SILLocation loc, bool flattening = false)
: SGF(SGF), Loc(loc), Flattening(flattening) {}
virtual ~Transform() = default;

/// Transform an arbitrary value.
Expand Down Expand Up @@ -1252,7 +1254,7 @@ namespace {
return SGF.emitTransformedValue(Loc, input,
inputOrigType, inputSubstType,
outputOrigType, outputSubstType,
context);
/*isFlattening=*/false, context);
}

/// Force the given result into the given initialization.
Expand Down Expand Up @@ -2177,7 +2179,7 @@ void ResultPlanner::execute(ArrayRef<SILValue> innerDirectResults,
Gen.emitTransformedValue(Loc, innerResult,
op.InnerOrigType, op.InnerSubstType,
op.OuterOrigType, op.OuterSubstType,
outerResultCtxt);
/*isFlattening=*/false, outerResultCtxt);

// If the outer is indirect, force it into the context.
if (outerIsIndirect) {
Expand Down Expand Up @@ -2269,26 +2271,72 @@ void ResultPlanner::execute(ArrayRef<SILValue> innerDirectResults,
/// \param inputSubstType Formal AST type of function value being thunked
/// \param outputOrigType Abstraction pattern of the thunk
/// \param outputSubstType Formal AST type of the thunk
static void buildThunkBody(SILGenFunction &gen, SILLocation loc,
AbstractionPattern inputOrigType,
CanAnyFunctionType inputSubstType,
AbstractionPattern outputOrigType,
CanAnyFunctionType outputSubstType) {
static void buildThunkBody(
SILGenFunction &gen, SILLocation loc, bool isFlattening,
AbstractionPattern inputOrigType, CanAnyFunctionType inputSubstType,
AbstractionPattern outputOrigType, CanAnyFunctionType outputSubstType) {
PrettyStackTraceSILFunction stackTrace("emitting reabstraction thunk in",
&gen.F);
auto thunkType = gen.F.getLoweredFunctionType();

FullExpr scope(gen.Cleanups, CleanupLocation::get(loc));

SmallVector<ManagedValue, 8> params;
SmallVector<ManagedValue, 8> paramsBuffer;
// TODO: Could accept +0 arguments here when forwardFunctionArguments/
// emitApply can.
gen.collectThunkParams(loc, params, /*allowPlusZero*/ false);

ManagedValue fnValue = params.pop_back_val();
gen.collectThunkParams(loc, paramsBuffer, /*allowPlusZero*/ false);
ManagedValue fnValue = paramsBuffer.pop_back_val();
auto fnType = fnValue.getType().castTo<SILFunctionType>();
assert(!fnType->isPolymorphic());
auto argTypes = fnType->getParameters();

ArrayRef<ManagedValue> params(paramsBuffer);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, see the emitCurryThunk stuff - we can already do this

if (isFlattening) {
// Flatten an instance function type.
assert(
inputSubstType->getCurryLevel() - outputSubstType->getCurryLevel() == 1 &&
"Invalid (un)currying");

SmallVector<SILValue, 1> selfArg;
forwardFunctionArguments(gen, loc, fnType, params.front(), selfArg);
auto inner = gen.emitApplyWithRethrow(loc, fnValue.forward(gen),
/*substFnType*/ fnValue.getType(),
/*substitutions*/ {}, selfArg);

// For the next steps update the variables by dropping the already applied
// first parameter (`self`)
params = params.slice(1);
fnValue = ManagedValue::forUnmanaged(inner);
fnType = fnValue.getType().castTo<SILFunctionType>();
argTypes = fnType->getParameters();
inputOrigType = inputOrigType.getFunctionResultType();
inputSubstType = cast<AnyFunctionType>(inputSubstType.getResult());

Type newOrigInput, newSubstInput;
auto origFunction = cast<AnyFunctionType>(outputOrigType.getType());
if (auto origInput = dyn_cast<TupleType>(origFunction.getInput())) {
auto substInput = cast<TupleType>(outputSubstType.getInput());
assert(origInput->getNumElements() == substInput->getNumElements() &&
"invalid premise");
newOrigInput = TupleType::get(origInput->getElements().slice(1),
gen.getASTContext());
newSubstInput = TupleType::get(substInput->getElements().slice(1),
gen.getASTContext());
} else {
// In this case `self` was the only parameter, which leaves us with an
// application of `Void`
assert(!dyn_cast<TupleType>(outputSubstType.getInput()) &&
"invalid premise");
newOrigInput = newSubstInput = TupleType::getEmpty(gen.getASTContext());
}

outputOrigType = AbstractionPattern(CanFunctionType::get(
newOrigInput->getCanonicalType(), origFunction.getResult(),
origFunction->getExtInfo()));
outputSubstType = CanFunctionType::get(newSubstInput->getCanonicalType(),
outputSubstType.getResult(),
outputSubstType->getExtInfo());
}

// Translate the argument values. Function parameters are
// contravariant: we want to switch the direction of transformation
Expand Down Expand Up @@ -2419,14 +2467,12 @@ CanSILFunctionType SILGenFunction::buildThunkType(
}

/// Create a reabstraction thunk.
static ManagedValue createThunk(SILGenFunction &gen,
SILLocation loc,
ManagedValue fn,
AbstractionPattern inputOrigType,
CanAnyFunctionType inputSubstType,
AbstractionPattern outputOrigType,
CanAnyFunctionType outputSubstType,
const TypeLowering &expectedTL) {
static ManagedValue createThunk(
SILGenFunction &gen, SILLocation loc, ManagedValue fn,
AbstractionPattern inputOrigType, CanAnyFunctionType inputSubstType,
AbstractionPattern outputOrigType, CanAnyFunctionType outputSubstType,
const TypeLowering &expectedTL, bool isFlattening) {

auto expectedType = expectedTL.getLoweredType().castTo<SILFunctionType>();

// We can't do bridging here.
Expand All @@ -2452,7 +2498,7 @@ static ManagedValue createThunk(SILGenFunction &gen,
thunk->setContextGenericParams(gen.F.getContextGenericParams());
SILGenFunction thunkSGF(gen.SGM, *thunk);
auto loc = RegularLocation::getAutoGeneratedLocation();
buildThunkBody(thunkSGF, loc,
buildThunkBody(thunkSGF, loc, isFlattening,
inputOrigType, inputSubstType,
outputOrigType, outputSubstType);
}
Expand Down Expand Up @@ -2494,7 +2540,7 @@ ManagedValue Transform::transformFunction(ManagedValue fn,
return createThunk(SGF, Loc, fn,
inputOrigType, inputSubstType,
outputOrigType, outputSubstType,
expectedTL);
expectedTL, Flattening);
Copy link
Contributor

@rjmccall rjmccall Jul 29, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isFlattening? It looks like maybe you're passing 'true' here always on accident.

Oh, no, I see, you're passing the member variable. Nevermind.

}

// We do not, conversion is trivial.
Expand Down Expand Up @@ -2535,7 +2581,7 @@ SILGenFunction::emitOrigToSubstValue(SILLocation loc, ManagedValue v,
return emitTransformedValue(loc, v,
origType, substType,
AbstractionPattern(substType), substType,
ctxt);
/*isFlattening=*/false, ctxt);
}

/// Given a value with the abstraction patterns of the original formal
Expand All @@ -2561,7 +2607,7 @@ SILGenFunction::emitSubstToOrigValue(SILLocation loc, ManagedValue v,
return emitTransformedValue(loc, v,
AbstractionPattern(substType), substType,
origType, substType,
ctxt);
/*isFlattening=*/false, ctxt);
}

/// Given a value with the abstraction patterns of the substituted
Expand Down Expand Up @@ -2594,10 +2640,12 @@ ManagedValue
SILGenFunction::emitTransformedValue(SILLocation loc, ManagedValue v,
CanType inputType,
CanType outputType,
bool isFlattening,
SGFContext ctxt) {
return emitTransformedValue(loc, v,
AbstractionPattern(inputType), inputType,
AbstractionPattern(outputType), outputType);
AbstractionPattern(outputType), outputType,
isFlattening, ctxt);
}

ManagedValue
Expand All @@ -2606,8 +2654,9 @@ SILGenFunction::emitTransformedValue(SILLocation loc, ManagedValue v,
CanType inputSubstType,
AbstractionPattern outputOrigType,
CanType outputSubstType,
bool isFlattening,
SGFContext ctxt) {
return Transform(*this, loc).transform(v,
return Transform(*this, loc, isFlattening).transform(v,
inputOrigType,
inputSubstType,
outputOrigType,
Expand Down
14 changes: 12 additions & 2 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,16 @@ namespace {
Expr *result = new (context) DotSyntaxBaseIgnoredExpr(base, dotLoc,
ref);
closeExistential(result, /*force=*/openedExistential);
return result;

if (!isa<FuncDecl>(member))
return result;

auto newTy = result->getType()
->castTo<AnyFunctionType>()
->getUncurriedFunction();
auto conversion = new FunctionConversionExpr(result, newTy);
conversion->setFlattening();
return conversion;
} else {
assert((!baseIsInstance || member->isInstanceMember()) &&
"can't call a static method on an instance");
Expand Down Expand Up @@ -4637,6 +4646,8 @@ static bool isReferenceToMetatypeMember(Expr *expr) {
return dotIgnored->getLHS()->getType()->is<AnyMetatypeType>();
if (auto dotSyntax = dyn_cast<DotSyntaxCallExpr>(expr))
return dotSyntax->getBase()->getType()->is<AnyMetatypeType>();
if (auto conversion = dyn_cast<FunctionConversionExpr>(expr))
return isReferenceToMetatypeMember(conversion->getSubExpr());
return false;
}

Expand Down Expand Up @@ -6952,4 +6963,3 @@ Expr *Solution::convertOptionalToBool(Expr *expr,
isSomeExpr->setType(tc.lookupBoolType(cs.DC));
return isSomeExpr;
}

Loading