@@ -20,6 +20,10 @@ object ContextFunctionResults:
20
20
*/
21
21
def annotateContextResults (mdef : DefDef )(using Context ): Unit =
22
22
def contextResultCount (rhs : Tree , tp : Type ): Int = tp match
23
+ case defn.DependentFunctionRefinementOf (_, mt) if mt.isContextualMethod =>
24
+ rhs match
25
+ case closureDef(meth) => 1 + contextResultCount(meth.rhs, mt.resType)
26
+ case _ => 0
23
27
case defn.ContextFunctionOf (_, resTpe) =>
24
28
rhs match
25
29
case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe)
@@ -58,6 +62,8 @@ object ContextFunctionResults:
58
62
*/
59
63
def contextResultsAreErased (sym : Symbol )(using Context ): Boolean =
60
64
def allErased (tp : Type ): Boolean = tp.dealias match
65
+ case ft @ defn.DependentFunctionRefinementOf (_, mt) if mt.isContextualMethod =>
66
+ ! defn.erasedFunctionParams(ft).contains(false ) && allErased(mt.resType)
61
67
case ft @ defn.ContextFunctionOf (_, resTpe) =>
62
68
! defn.erasedFunctionParams(ft).contains(false ) && allErased(resTpe)
63
69
case _ => true
@@ -73,6 +79,8 @@ object ContextFunctionResults:
73
79
integrateContextResults(rt, crCount)
74
80
case tp : MethodOrPoly =>
75
81
tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount))
82
+ case defn.DependentFunctionRefinementOf (base, mt) if mt.isContextualMethod =>
83
+ integrateContextResults(base, crCount)
76
84
case defn.ContextFunctionOf (argTypes, resType) =>
77
85
MethodType (argTypes, integrateContextResults(resType, crCount - 1 ))
78
86
@@ -120,6 +128,8 @@ object ContextFunctionResults:
120
128
case Select (qual, name) =>
121
129
if name == nme.apply then
122
130
qual.tpe match
131
+ case defn.DependentFunctionRefinementOf (_, mt) if mt.isContextualMethod =>
132
+ integrateSelect(qual, n + 1 )
123
133
case defn.ContextFunctionOf (_, _) =>
124
134
integrateSelect(qual, n + 1 )
125
135
case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs
0 commit comments