Skip to content

[Async CC] Support for protocol extension methods. #34200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ AsyncContextLayout irgen::getAsyncContextLayout(
SILType ty =
IGF.IGM.silConv.getSILType(localContextParameter, substitutedType,
IGF.IGM.getMaximalTypeExpansionContext());
auto &ti = IGF.getTypeInfoForLowered(ty.getASTType());
auto argumentLoweringType =
getArgumentLoweringType(ty.getASTType(), localContextParameter,
/*isNoEscape*/ true);

auto &ti = IGF.getTypeInfoForLowered(argumentLoweringType);
valTypes.push_back(ty);
typeInfos.push_back(&ti);
localContextInfo = {ty, localContextParameter.getConvention()};
Expand All @@ -153,7 +157,7 @@ AsyncContextLayout irgen::getAsyncContextLayout(
}

// ArgTypes formalArguments...;
auto bindings = NecessaryBindings::forAsyncFunctionInvocations(
auto bindings = NecessaryBindings::forAsyncFunctionInvocation(
IGF.IGM, originalType, substitutionMap);
if (!bindings.empty()) {
auto bindingsSize = bindings.getBufferSize(IGF.IGM);
Expand Down Expand Up @@ -1970,40 +1974,34 @@ class AsyncCallEmission final : public CallEmission {
llArgs.add(selfValue);
}
auto layout = getAsyncContextLayout();
auto params = fnConv.getParameters();
for (auto index : indices(params)) {
Optional<ElementLayout> fieldLayout;
if (selfValue && index == params.size() - 1) {
fieldLayout = layout.getLocalContextLayout();
} else {
fieldLayout = layout.getArgumentLayout(index);
}
for (unsigned index = 0, count = layout.getArgumentCount(); index < count;
++index) {
auto fieldLayout = layout.getArgumentLayout(index);
Address fieldAddr =
fieldLayout->project(IGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(fieldLayout->getType());
fieldLayout.project(IGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(fieldLayout.getType());
ti.initialize(IGF, llArgs, fieldAddr, isOutlined);
}
unsigned index = 0;
for (auto indirectResult : fnConv.getIndirectSILResultTypes(
IGF.IGM.getMaximalTypeExpansionContext())) {
(void)indirectResult;
for (unsigned index = 0, count = layout.getIndirectReturnCount();
index < count; ++index) {
auto fieldLayout = layout.getIndirectReturnLayout(index);
Address fieldAddr =
fieldLayout.project(IGF, context, /*offsets*/ llvm::None);
cast<LoadableTypeInfo>(fieldLayout.getType())
.initialize(IGF, llArgs, fieldAddr, isOutlined);
++index;
}
if (layout.hasBindings()) {
auto bindingLayout = layout.getBindingsLayout();
auto bindingsAddr = bindingLayout.project(IGF, context, /*offsets*/ None);
layout.getBindings().save(IGF, bindingsAddr);
layout.getBindings().save(IGF, bindingsAddr, llArgs);
}
if (selfValue) {
auto fieldLayout = layout.getLocalContextLayout();
Address fieldAddr =
fieldLayout.project(IGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(fieldLayout.getType());
ti.initialize(IGF, llArgs, fieldAddr, isOutlined);
}
// At this point, llArgs contains the arguments that are being passed along
// via the async context. We can safely drop them on the floor.
(void)llArgs.claimAll();
// TODO: Validation: we should be able to check that the contents of llArgs
// matches what is expected by the layout.
}
void emitCallToUnmappedExplosion(llvm::CallInst *call, Explosion &out) override {
SILFunctionConventions fnConv(getCallee().getSubstFunctionType(),
Expand All @@ -2024,15 +2022,14 @@ class AsyncCallEmission final : public CallEmission {
Explosion nativeExplosion;
auto layout = getAsyncContextLayout();
auto dataAddr = layout.emitCastTo(IGF, context);
int index = layout.getFirstDirectReturnIndex();
for (auto result : fnConv.getDirectSILResults()) {
auto &fieldLayout = layout.getElement(index);
for (unsigned index = 0, count = layout.getDirectReturnCount();
index < count; ++index) {
auto fieldLayout = layout.getDirectReturnLayout(index);
Address fieldAddr =
fieldLayout.project(IGF, dataAddr, /*offsets*/ llvm::None);
auto &fieldTI = fieldLayout.getType();
cast<LoadableTypeInfo>(fieldTI).loadAsTake(IGF, fieldAddr,
nativeExplosion);
++index;
}

out = nativeSchema.mapFromNative(IGF.IGM, IGF, nativeExplosion, resultType);
Expand Down
4 changes: 4 additions & 0 deletions lib/IRGen/GenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ namespace irgen {
}

unsigned getFirstDirectReturnIndex() { return getIndexAfterArguments(); }
unsigned getDirectReturnCount() { return directReturnInfos.size(); }
ElementLayout getDirectReturnLayout(unsigned index) {
return getElement(getFirstDirectReturnIndex() + index);
}

AsyncContextLayout(IRGenModule &IGM, LayoutStrategy strategy,
ArrayRef<SILType> fieldTypes,
Expand Down
57 changes: 39 additions & 18 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2728,24 +2728,45 @@ void NecessaryBindings::restore(IRGenFunction &IGF, Address buffer,
[&](CanType type) { return type;});
}

template <typename Transform>
static void save(const NecessaryBindings &bindings, IRGenFunction &IGF,
Address buffer, Transform transform) {
emitInitOfGenericRequirementsBuffer(
IGF, bindings.getRequirements().getArrayRef(), buffer,
[&](GenericRequirement requirement) -> llvm::Value * {
CanType type = requirement.TypeParameter;
if (auto protocol = requirement.Protocol) {
if (auto archetype = dyn_cast<ArchetypeType>(type)) {
auto wtable =
emitArchetypeWitnessTableRef(IGF, archetype, protocol);
return transform(requirement, wtable);
} else {
auto conformance = bindings.getConformance(requirement);
auto wtable = emitWitnessTableRef(IGF, type, conformance);
return transform(requirement, wtable);
}
} else {
auto metadata = IGF.emitTypeMetadataRef(type);
return transform(requirement, metadata);
}
});
};

void NecessaryBindings::save(IRGenFunction &IGF, Address buffer,
Explosion &source) const {
::save(*this, IGF, buffer,
[&](GenericRequirement requirement,
llvm::Value *expected) -> llvm::Value * {
auto *value = source.claimNext();
assert(value == expected);
return value;
});
}

void NecessaryBindings::save(IRGenFunction &IGF, Address buffer) const {
emitInitOfGenericRequirementsBuffer(IGF, Requirements.getArrayRef(), buffer,
[&](GenericRequirement requirement) -> llvm::Value* {
CanType type = requirement.TypeParameter;
if (auto protocol = requirement.Protocol) {
if (auto archetype = dyn_cast<ArchetypeType>(type)) {
auto wtable = emitArchetypeWitnessTableRef(IGF, archetype, protocol);
return wtable;
} else {
auto conformance = getConformance(requirement);
auto wtable = emitWitnessTableRef(IGF, type, conformance);
return wtable;
}
} else {
auto metadata = IGF.emitTypeMetadataRef(type);
return metadata;
}
});
::save(*this, IGF, buffer,
[](GenericRequirement requirement,
llvm::Value *value) -> llvm::Value * { return value; });
}

void NecessaryBindings::addTypeMetadata(CanType type) {
Expand Down Expand Up @@ -3021,7 +3042,7 @@ void EmitPolymorphicArguments::emit(SubstitutionMap subs,
}
}

NecessaryBindings NecessaryBindings::forAsyncFunctionInvocations(
NecessaryBindings NecessaryBindings::forAsyncFunctionInvocation(
IRGenModule &IGM, CanSILFunctionType origType, SubstitutionMap subs) {
return computeBindings(IGM, origType, subs,
false /*forPartialApplyForwarder*/);
Expand Down
8 changes: 6 additions & 2 deletions lib/IRGen/NecessaryBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "llvm/ADT/SetVector.h"
#include "swift/AST/Types.h"

#include "Explosion.h"

namespace swift {
class CanType;
enum class MetadataState : size_t;
Expand Down Expand Up @@ -55,8 +57,8 @@ class NecessaryBindings {
/// Collect the necessary bindings to invoke a function with the given
/// signature.
static NecessaryBindings
forAsyncFunctionInvocations(IRGenModule &IGM, CanSILFunctionType origType,
SubstitutionMap subs);
forAsyncFunctionInvocation(IRGenModule &IGM, CanSILFunctionType origType,
SubstitutionMap subs);
static NecessaryBindings forPartialApplyForwarder(IRGenModule &IGM,
CanSILFunctionType origType,
SubstitutionMap subs,
Expand Down Expand Up @@ -94,6 +96,8 @@ class NecessaryBindings {
/// Save the necessary bindings to the given buffer.
void save(IRGenFunction &IGF, Address buffer) const;

void save(IRGenFunction &IGF, Address buffer, Explosion &source) const;

/// Restore the necessary bindings from the given buffer.
void restore(IRGenFunction &IGF, Address buffer, MetadataState state) const;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// RUN: %empty-directory(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(PrintShims)) %S/../../Inputs/print-shims.swift -module-name PrintShims -emit-module -emit-module-path %t/PrintShims.swiftmodule
// RUN: %target-codesign %t/%target-library-name(PrintShims)
// RUN: %target-build-swift -Xfrontend -enable-experimental-concurrency -parse-sil %s -emit-ir -I %t -L %t -lPrintShim | %FileCheck %s --check-prefix=CHECK-LL
// RUN: %target-build-swift -Xfrontend -enable-experimental-concurrency -parse-sil %s -module-name main -o %t/main -I %t -L %t -lPrintShims %target-rpath(%t)
// RUN: %target-codesign %t/main
// RUN: %target-run %t/main | %FileCheck %s

// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize_none
// UNSUPPORTED: use_os_stdlib

import Builtin
import Swift
import PrintShims

sil public_external @printGeneric : $@convention(thin) <T> (@in_guaranteed T) -> ()
sil public_external @printInt64 : $@convention(thin) (Int64) -> ()

protocol P {
func printMe() -> Int64
}

extension P {
func callPrintMe() async -> Int64
}

struct I : P {
@_hasStorage let int: Int64 { get }
func printMe() -> Int64
init(int: Int64)
}

// CHECK-LL: define hidden swiftcc void @callPrintMe(%swift.context* {{%[0-9]*}}) {{#[0-9]*}} {
sil hidden @callPrintMe : $@async @convention(method) <Self where Self : P> (@in_guaranteed Self) -> Int64 {
bb0(%self : $*Self):
%P_printMe = witness_method $Self, #P.printMe : <Self where Self : P> (Self) -> () -> Int64 : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> Int64
%result = apply %P_printMe<Self>(%self) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> Int64
return %result : $Int64
}

sil hidden @I_printMe : $@convention(method) (I) -> Int64 {
bb0(%self : $I):
%self_addr = alloc_stack $I
store %self to %self_addr : $*I
%printGeneric = function_ref @printGeneric : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> ()
%printGeneric_result = apply %printGeneric<I>(%self_addr) : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> ()
dealloc_stack %self_addr : $*I
%result = struct_extract %self : $I, #I.int
return %result : $Int64
}

sil private [transparent] [thunk] @I_P_printMe : $@convention(witness_method: P) (@in_guaranteed I) -> Int64 {
bb0(%self_addr : $*I):
%self = load %self_addr : $*I
%I_printMe = function_ref @I_printMe : $@convention(method) (I) -> Int64
%result = apply %I_printMe(%self) : $@convention(method) (I) -> Int64
return %result : $Int64
}

sil @main : $@convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 {
bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>):
%i_type = metatype $@thin I.Type
%i_int_literal = integer_literal $Builtin.Int64, 99
%i_int = struct $Int64 (%i_int_literal : $Builtin.Int64)
%i = struct $I (%i_int : $Int64)
%i_addr = alloc_stack $I
store %i to %i_addr : $*I
%callPrintMe = function_ref @callPrintMe : $@async @convention(method) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> Int64
%result = apply %callPrintMe<I>(%i_addr) : $@async @convention(method) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> Int64 // CHECK: I(int: 99)
dealloc_stack %i_addr : $*I
%printInt64 = function_ref @printInt64 : $@convention(thin) (Int64) -> ()
%printInt64_result = apply %printInt64(%result) : $@convention(thin) (Int64) -> () // CHECK: 99

%out_literal = integer_literal $Builtin.Int32, 0
%out = struct $Int32 (%out_literal : $Builtin.Int32)
return %out : $Int32
}

sil_witness_table hidden I: P module main {
method #P.printMe: <Self where Self : P> (Self) -> () -> Int64 : @I_P_printMe
}