/* NSC -- new scala compiler * Copyright 2005-2009 LAMP/EPFL * @author Iulian Dragos */ // $Id: TailCalls.scala 18579 2009-08-26 15:17:51Z extempore $ package scala.tools.nsc package transform import scala.tools.nsc.symtab.Flags /** Perform tail recursive call elimination. * * @author Iulian Dragos * @version 1.0 */ abstract class TailCalls extends Transform /* with JavaLogging() */ { // inherits abstract value `global' and class `Phase' from Transform import global._ // the global environment import definitions._ // standard classes and methods import typer.{typed, atOwner} // methods to type trees val phaseName: String = "tailcalls" def newTransformer(unit: CompilationUnit): Transformer = new TailCallElimination(unit) /** Create a new phase which applies transformer */ override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev) /** The phase defined by this transform */ class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) { def apply(unit: global.CompilationUnit) { if (!(settings.debuginfo.value == "notailcalls")) { newTransformer(unit).transformUnit(unit); } } } /** The @tailrec annotation indicates TCO is mandatory */ private def tailrecRequired(defdef: DefDef) = defdef.symbol hasAnnotation TailrecClass /** * A Tail Call Transformer * * @author Erik Stenman, Iulian Dragos * @version 1.1 * * What it does: * <p> * Finds method calls in tail-position and replaces them with jumps. * A call is in a tail-position if it is the last instruction to be * executed in the body of a method. This is done by recursing over * the trees that may contain calls in tail-position (trees that can't * contain such calls are not transformed). However, they are not that * many. * </p> * <p> * Self-recursive calls in tail-position are replaced by jumps to a * label at the beginning of the method. As the JVM provides no way to * jump from a method to another one, non-recursive calls in * tail-position are not optimized. * </p> * <p> * A method call is self-recursive if it calls the current method and * the method is final (otherwise, it could * be a call to an overridden method in a subclass). Furthermore, If * the method has type parameters, the call must contain these * parameters as type arguments. Recursive calls on a different instance * are optimized. Since 'this' is not a local variable, a dummy local val * is added and used as a label parameter. The backend knows to load * the corresponding argument in the 'this' (local at index 0). This dummy local * is never used and should be cleand up by dead code elmination (when enabled). * </p> * <p> * This phase has been moved before pattern matching to catch more * of the common cases of tail recursive functions. This means that * more cases should be taken into account (like nested function, and * pattern cases). * </p> * <p> * If a method contains self-recursive calls, a label is added to at * the beginning of its body and the calls are replaced by jumps to * that label. * </p> * <p> * Assumes: <code>Uncurry</code> has been run already, and no multiple * parameter lists exit. * </p> */ class TailCallElimination(unit: CompilationUnit) extends Transformer { class Context { /** The current method */ var currentMethod: Symbol = NoSymbol /** The current tail-call label */ var label: Symbol = NoSymbol /** The expected type arguments of self-recursive calls */ var tparams: List[Symbol] = Nil /** Tells whether we are in a (possible) tail position */ var tailPos = false /** Is the label accessed? */ var accessed = false def this(that: Context) = { this() this.currentMethod = that.currentMethod this.label = that.label this.tparams = that.tparams this.tailPos = that.tailPos this.accessed = that.accessed } /** Create a new method symbol for the current method and store it in * the label field. */ def makeLabel(): Unit = { label = currentMethod.newLabel(currentMethod.pos, "_" + currentMethod.name) accessed = false } override def toString(): String = ( "" + currentMethod.name + " tparams: " + tparams + " tailPos: " + tailPos + " accessed: " + accessed + "\nLabel: " + label + "\nLabel type: " + label.info ) } private def mkContext(that: Context) = new Context(that) private def mkContext(that: Context, tp: Boolean): Context = { val t = mkContext(that) t.tailPos = tp t } private var ctx: Context = new Context() /** Rewrite this tree to contain no tail recursive calls */ def transform(tree: Tree, nctx: Context): Tree = { val oldCtx = ctx ctx = nctx val t = transform(tree) this.ctx = oldCtx t } override def transform(tree: Tree): Tree = { tree match { case dd @ DefDef(mods, name, tparams, vparams, tpt, rhs) => log("Entering DefDef: " + name) var isTransformed = false val newCtx = mkContext(ctx) newCtx.currentMethod = tree.symbol newCtx.makeLabel() val currentClassParam = tree.symbol.newSyntheticValueParam(currentClass.tpe) newCtx.label.setInfo(MethodType(currentClassParam :: tree.symbol.tpe.params, tree.symbol.tpe.finalResultType)) newCtx.tailPos = true val isEligible = newCtx.currentMethod.isFinal || (newCtx.currentMethod.enclClass hasFlag Flags.MODULE) if (isEligible) { newCtx.tparams = Nil log(" Considering " + name + " for tailcalls") tree.symbol.tpe match { case PolyType(tpes, restpe) => newCtx.tparams = tparams map (_.symbol) newCtx.label.setInfo( newCtx.label.tpe.substSym(tpes, tparams map (_.symbol))) case _ => } } val t1 = treeCopy.DefDef(tree, mods, name, tparams, vparams, tpt, transform(rhs, newCtx) match { case newRHS if isEligible && newCtx.accessed => log("Rewrote def " + newCtx.currentMethod) isTransformed = true val newThis = newCtx.currentMethod . newValue (tree.pos, nme.THIS) . setInfo (currentClass.tpe) . setFlag (Flags.SYNTHETIC) typed(atPos(tree.pos)(Block( List(ValDef(newThis, This(currentClass))), LabelDef(newCtx.label, newThis :: (vparams.flatten map (_.symbol)), newRHS) ))) case rhs => rhs } ) if (!isTransformed && tailrecRequired(dd)) unit.error(dd.pos, "could not optimize @tailrec annotated method") log("Leaving DefDef: " + name) t1 case EmptyTree => tree case PackageDef(_, _) => super.transform(tree) case ClassDef(_, name, _, _) => log("Entering class " + name) val res = super.transform(tree) log("Leaving class " + name) res case ValDef(mods, name, tpt, rhs) => super.transform(tree) case LabelDef(name, params, rhs) => super.transform(tree) case Template(parents, self, body) => super.transform(tree) case Block(stats, expr) => treeCopy.Block(tree, transformTrees(stats, mkContext(ctx, false)), transform(expr)) case CaseDef(pat, guard, body) => treeCopy.CaseDef(tree, pat, guard, transform(body)) case Sequence(_) | Alternative(_) | Star(_) | Bind(_, _) => throw new RuntimeException("We should've never gotten inside a pattern") case Function(vparams, body) => tree //throw new RuntimeException("Anonymous function should not exist at this point. at: " + unit.position(tree.pos)); case Assign(lhs, rhs) => super.transform(tree) case If(cond, thenp, elsep) => treeCopy.If(tree, cond, transform(thenp), transform(elsep)) case Match(selector, cases) => //super.transform(tree); treeCopy.Match(tree, transform(selector, mkContext(ctx, false)), transformTrees(cases).asInstanceOf[List[CaseDef]]) case Return(expr) => super.transform(tree) case Try(block, catches, finalizer) => // no calls inside a try are in tail position, but keep recursing for nested functions treeCopy.Try(tree, transform(block, mkContext(ctx, false)), transformTrees(catches, mkContext(ctx, false)).asInstanceOf[List[CaseDef]], transform(finalizer, mkContext(ctx, false))) case Throw(expr) => super.transform(tree) case New(tpt) => super.transform(tree) case Typed(expr, tpt) => super.transform(tree) case Apply(tapply @ TypeApply(fun, targs), vargs) => lazy val defaultTree = treeCopy.Apply(tree, tapply, transformTrees(vargs, mkContext(ctx, false))) if ( ctx.currentMethod.isFinal && ctx.tailPos && isSameTypes(ctx.tparams, targs map (_.tpe.typeSymbol)) && isRecursiveCall(fun)) { fun match { case Select(receiver, _) => val recTpe = receiver.tpe.widen val enclTpe = ctx.currentMethod.enclClass.typeOfThis // make sure the type of 'this' doesn't change through this polymorphic recursive call if (!forMSIL && (receiver.tpe.typeParams.isEmpty || (receiver.tpe.widen == ctx.currentMethod.enclClass.typeOfThis))) rewriteTailCall(fun, receiver :: transformTrees(vargs, mkContext(ctx, false))) else defaultTree case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(vargs, mkContext(ctx, false))) } } else defaultTree case TypeApply(fun, args) => super.transform(tree) case Apply(fun, args) if (fun.symbol == definitions.Boolean_or || fun.symbol == definitions.Boolean_and) => treeCopy.Apply(tree, fun, transformTrees(args)) case Apply(fun, args) => lazy val defaultTree = treeCopy.Apply(tree, fun, transformTrees(args, mkContext(ctx, false))) if (ctx.currentMethod.isFinal && ctx.tailPos && isRecursiveCall(fun)) { fun match { case Select(receiver, _) => if (!forMSIL) rewriteTailCall(fun, receiver :: transformTrees(args, mkContext(ctx, false))) else defaultTree case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(args, mkContext(ctx, false))) } } else defaultTree case Super(qual, mix) => tree case This(qual) => tree case Select(qualifier, selector) => tree case Ident(name) => tree case Literal(value) => tree case TypeTree() => tree case _ => tree } } def transformTrees(trees: List[Tree], nctx: Context): List[Tree] = trees map ((tree) => transform(tree, nctx)) private def rewriteTailCall(fun: Tree, args: List[Tree]): Tree = { log("Rewriting tail recursive method call at: " + (fun.pos)) ctx.accessed = true //println("fun: " + fun + " args: " + args) val t = atPos(fun.pos)(Apply(Ident(ctx.label), args)) // println("TAIL: "+t) typed(t) } private def isSameTypes(ts1: List[Symbol], ts2: List[Symbol]): Boolean = { def isSameType(t1: Symbol, t2: Symbol) = { t1 == t2 } List.forall2(ts1, ts2)(isSameType) } /** Returns <code>true</code> if the fun tree refers to the same method as * the one saved in <code>ctx</code>. * * @param fun the expression that is applied * @return <code>true</code> if the tree symbol refers to the innermost * enclosing method */ private def isRecursiveCall(fun: Tree): Boolean = (fun.symbol eq ctx.currentMethod) } }