Skip to content

Commit 98f33f3

Browse files
committed
Split macro body checks from interpreter
1 parent 56acbc9 commit 98f33f3

File tree

2 files changed

+180
-163
lines changed

2 files changed

+180
-163
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 166 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,82 @@ object Splicer {
6464
*/
6565
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
6666
case Quoted(_) => // ok
67-
case _ => (new CheckValidMacroBody).apply(tree)
67+
case _ =>
68+
def checkValidStat(tree: Tree): Unit = tree match {
69+
case tree: ValDef if tree.symbol.is(Synthetic) =>
70+
// Check val from `foo(j = x, i = y)` which it is expanded to
71+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
72+
checkIfValidArgument(tree.rhs)
73+
case _ =>
74+
ctx.error("Macro should not have statements", tree.sourcePos)
75+
}
76+
def checkIfValidArgument(tree: Tree): Unit = tree match {
77+
case Block(Nil, expr) => checkIfValidArgument(expr)
78+
case Typed(expr, _) => checkIfValidArgument(expr)
79+
80+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
81+
// OK
82+
83+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
84+
// OK
85+
86+
case Literal(Constant(value)) =>
87+
// OK
88+
89+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
90+
// OK
91+
92+
case Call(fn, args)
93+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
94+
fn.symbol.is(Module) || fn.symbol.isStatic ||
95+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
96+
args.foreach(_.foreach(checkIfValidArgument))
97+
98+
case NamedArg(_, arg) =>
99+
checkIfValidArgument(arg)
100+
101+
case SeqLiteral(elems, _) =>
102+
elems.foreach(checkIfValidArgument)
103+
104+
case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
105+
// OK
106+
107+
case _ =>
108+
ctx.error(
109+
"""Malformed macro parameter
110+
|
111+
|Parameters may be:
112+
| * Quoted parameters or fields
113+
| * References to inline parameters
114+
| * Literal values of primitive types
115+
|""".stripMargin, tree.sourcePos)
116+
}
117+
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
118+
case Block(stats, expr) =>
119+
stats.foreach(checkValidStat)
120+
checkIfValidStaticCall(expr)
121+
122+
case Typed(expr, _) => checkIfValidStaticCall(expr)
123+
case Call(fn, args)
124+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
125+
fn.symbol.is(Module) || fn.symbol.isStatic ||
126+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
127+
args.flatten.foreach(checkIfValidArgument)
128+
case _ =>
129+
ctx.error(
130+
"""Malformed macro.
131+
|
132+
|Expected the splice ${...} to contain a single call to a static method.
133+
|""".stripMargin, tree.sourcePos)
134+
}
135+
136+
checkIfValidStaticCall(tree)
68137
}
69138

70139
/** Tree interpreter that evaluates the tree */
71-
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {
140+
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {
72141

73-
def checking: Boolean = false
74-
75-
type Result = Object
142+
type Env = Map[Name, Object]
76143

77144
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
78145
* Return Some of the result or None if some error happen during the interpretation.
@@ -93,22 +160,92 @@ object Splicer {
93160
}
94161
}
95162

96-
protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
163+
def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
164+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
165+
val quoted1 = quoted match {
166+
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
167+
// inline proxy for by-name parameter
168+
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
169+
case Inlined(EmptyTree, _, quoted) => quoted
170+
case _ => quoted
171+
}
172+
interpretQuote(quoted1)
173+
174+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
175+
interpretTypeQuote(quoted)
176+
177+
case Literal(Constant(value)) =>
178+
interpretLiteral(value)
179+
180+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
181+
interpretQuoteContext()
182+
183+
// TODO disallow interpreted method calls as arguments
184+
case Call(fn, args) =>
185+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
186+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
187+
} else if (fn.symbol.is(Module)) {
188+
interpretModuleAccess(fn.symbol)
189+
} else if (fn.symbol.isStatic) {
190+
val module = fn.symbol.owner
191+
interpretStaticMethodCall(module, fn.symbol, args.flatten.map(interpretTree))
192+
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
193+
val module = fn.qualifier.symbol.moduleClass
194+
interpretStaticMethodCall(module, fn.symbol, args.flatten.map(interpretTree))
195+
} else if (env.contains(fn.name)) {
196+
env(fn.name)
197+
} else if (tree.symbol.is(InlineProxy)) {
198+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
199+
} else {
200+
unexpectedTree(tree)
201+
}
202+
203+
// Interpret `foo(j = x, i = y)` which it is expanded to
204+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
205+
case Block(stats, expr) => interpretBlock(stats, expr)
206+
case NamedArg(_, arg) => interpretTree(arg)
207+
208+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
209+
210+
case Typed(expr, _) =>
211+
interpretTree(expr)
212+
213+
case SeqLiteral(elems, _) =>
214+
interpretVarargs(elems.map(e => interpretTree(e)))
215+
216+
case _ =>
217+
unexpectedTree(tree)
218+
}
219+
220+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
221+
var unexpected: Option[Object] = None
222+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
223+
case stat: ValDef =>
224+
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
225+
case stat =>
226+
if (unexpected.isEmpty)
227+
unexpected = Some(unexpectedTree(stat))
228+
accEnv
229+
})
230+
unexpected.getOrElse(interpretTree(expr)(newEnv))
231+
}
232+
233+
private def interpretQuote(tree: Tree)(implicit env: Env): Object =
97234
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))
98235

99-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
236+
private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
100237
new scala.internal.quoted.TreeType(tree)
101238

102-
protected def interpretLiteral(value: Any)(implicit env: Env): Object =
239+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
103240
value.asInstanceOf[Object]
104241

105-
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
242+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
106243
args.toSeq
107244

108-
protected def interpretQuoteContext()(implicit env: Env): Object =
245+
private def interpretQuoteContext()(implicit env: Env): Object =
109246
new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos))
110247

111-
protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
248+
private def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
112249
val (inst, clazz) =
113250
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) {
114251
(null, loadReplLineClass(moduleClass))
@@ -128,16 +265,16 @@ object Splicer {
128265
stopIfRuntimeException(method.invoke(inst, args: _*))
129266
}
130267

131-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
268+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
132269
loadModule(fn.moduleClass)
133270

134-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
271+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
135272
val clazz = loadClass(fn.owner.fullName.toString)
136273
val constr = clazz.getConstructor(paramsSig(fn): _*)
137274
constr.newInstance(args: _*).asInstanceOf[Object]
138275
}
139276

140-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
277+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
141278
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos)
142279

143280
private def loadModule(sym: Symbol): Object = {
@@ -265,146 +402,22 @@ object Splicer {
265402

266403
}
267404

268-
/** Tree interpreter that tests if tree can be interpreted */
269-
private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter {
270-
def checking: Boolean = true
271-
272-
type Result = Unit
273-
274-
def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty)
275-
276-
protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
277-
protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
278-
protected def interpretLiteral(value: Any)(implicit env: Env): Unit = ()
279-
protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = ()
280-
protected def interpretQuoteContext()(implicit env: Env): Unit = ()
281-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
282-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = ()
283-
protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
284-
285-
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = {
286-
// Assuming that top-level splices can only be in inline methods
287-
// and splices are expanded at inline site, references to inline values
288-
// will be known literal constant trees.
289-
if (!tree.symbol.is(Inline))
290-
ctx.error(
291-
"""Malformed macro.
292-
|
293-
|Expected the splice ${...} to contain a single call to a static method.
294-
|
295-
|Where parameters may be:
296-
| * Quoted paramers or fields
297-
| * References to inline parameters
298-
| * Literal values of primitive types
299-
""".stripMargin, tree.sourcePos)
300-
}
301-
}
302-
303-
/** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
304-
private abstract class AbstractInterpreter(implicit ctx: Context) {
305-
306-
def checking: Boolean
307-
308-
type Env = Map[Name, Result]
309-
type Result
310-
311-
protected def interpretQuote(tree: Tree)(implicit env: Env): Result
312-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result
313-
protected def interpretLiteral(value: Any)(implicit env: Env): Result
314-
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
315-
protected def interpretQuoteContext()(implicit env: Env): Result
316-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
317-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
318-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
319-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
320-
321-
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
322-
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
323-
val quoted1 = quoted match {
324-
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
325-
// inline proxy for by-name parameter
326-
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
327-
case Inlined(EmptyTree, _, quoted) => quoted
328-
case _ => quoted
329-
}
330-
interpretQuote(quoted1)
331-
332-
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
333-
interpretTypeQuote(quoted)
334-
335-
case Literal(Constant(value)) =>
336-
interpretLiteral(value)
337-
338-
case _ if tree.symbol == defn.QuoteContext_macroContext =>
339-
interpretQuoteContext()
340-
341-
case Call(fn, args) =>
342-
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
343-
interpretNew(fn.symbol, args.flatten.map(interpretTree))
344-
} else if (fn.symbol.is(Module)) {
345-
interpretModuleAccess(fn.symbol)
346-
} else if (fn.symbol.isStatic) {
347-
val module = fn.symbol.owner
348-
def interpretedArgs = args.flatten.map(interpretTree)
349-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
350-
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
351-
val module = fn.qualifier.symbol.moduleClass
352-
def interpretedArgs = args.flatten.map(interpretTree)
353-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
354-
} else if (env.contains(fn.name)) {
355-
env(fn.name)
356-
} else if (tree.symbol.is(InlineProxy)) {
357-
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
358-
} else {
359-
unexpectedTree(tree)
360-
}
361-
362-
// Interpret `foo(j = x, i = y)` which it is expanded to
363-
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
364-
case Block(stats, expr) => interpretBlock(stats, expr)
365-
case NamedArg(_, arg) => interpretTree(arg)
366-
367-
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
368-
369-
case Typed(expr, _) =>
370-
interpretTree(expr)
371-
372-
case SeqLiteral(elems, _) =>
373-
interpretVarargs(elems.map(e => interpretTree(e)))
374-
375-
case _ =>
376-
unexpectedTree(tree)
377-
}
378-
379-
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
380-
var unexpected: Option[Result] = None
381-
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
382-
case stat: ValDef if stat.symbol.is(Synthetic) || !checking =>
383-
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
384-
case stat =>
385-
if (unexpected.isEmpty)
386-
unexpected = Some(unexpectedTree(stat))
387-
accEnv
388-
})
389-
unexpected.getOrElse(interpretTree(expr)(newEnv))
390-
}
391-
392-
object Call {
393-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
394-
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
395-
396-
private object Call0 {
397-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
398-
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
399-
Some((fn, args))
400-
case fn: RefTree => Some((fn, Nil))
401-
case Apply(f @ Call0(fn, args1), args2) =>
402-
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
403-
else Some((fn, args2 :: args1))
404-
case TypeApply(Call0(fn, args), _) => Some((fn, args))
405-
case _ => None
406-
}
405+
object Call {
406+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
407+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
408+
409+
private object Call0 {
410+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
411+
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
412+
Some((fn, args))
413+
case fn: RefTree => Some((fn, Nil))
414+
case Apply(f @ Call0(fn, args1), args2) =>
415+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
416+
else Some((fn, args2 :: args1))
417+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
418+
case _ => None
407419
}
408420
}
409421
}
422+
410423
}

0 commit comments

Comments
 (0)