/* NSC -- new Scala compiler
 * Copyright 2005-2009 LAMP/EPFL
 * @author  Martin Odersky
 */
// $Id: RefChecks.scala 13735 2008-01-18 17:18:58Z odersky $

package scala.tools.nsc
package typechecker

import symtab.Flags._
import transform.{InfoTransform, TypingTransformers}
import scala.tools.nsc.util.{Position, NoPosition}
import scala.collection.mutable.ListBuffer

abstract class DeVirtualize extends InfoTransform with TypingTransformers {

  import global._
  import definitions._
  import typer.{typed, typedOperator, atOwner}

  /** the following two members override abstract members in Transform */
  val phaseName: String = "devirtualize"

  /** The phase might set the following new flags: */
  override def phaseNextFlags: Long = notDEFERRED | notOVERRIDE | notFINAL | lateABSTRACT

  def newTransformer(unit: CompilationUnit): DeVirtualizeTransformer =
    new DeVirtualizeTransformer(unit)
 
  /** The class does not change base-classes of existing classes */
  override def changesBaseClasses = false 

  def transformInfo(sym: Symbol, tp: Type): Type = 
    if (sym.isThisSym && sym.owner.isVirtualClass) {
      val clazz = sym.owner
      intersectionType(
        List(
          appliedType(abstractType(clazz).typeConstructor, clazz.typeParams map (_.tpe)),
          clazz.tpe))
    } else devirtualizeMap(tp)

  /* todo:
     handle constructor arguments
     check: overriding classes must have same type params
     virtual classes cannot have self types
   */

  /** Do the following transformations everywhere in a type:
   *
   * 1. Replace a virtual class
   *
   *  attrs mods class VC[Ts] <: Ps { decls }
   *  
   * by the following symbols
   *
   *  attrs mods1 type VC[Ts] <: dvm(Ps) with VC$trait[Ts]
   *  attrs mods2 trait VC$trait[Ts] extends AnyRef with ScalaObject { 
   *    this: VC[Ts] with VC$trait[Ts] => decls1
   *  }
   *  
   * The class symbol VC becomes the symbol of the workertrait.
   * 
   *  dvm    is the devirtalization mapping which converts refs to 
   *         virtual classes to refs to their abstract types (@see devirtualize)
   *  mods1  are the modifiers inherited to abstract types
   *  mods2  are the modifiers inherited to worker traits
   *  decls1 is decls but members that have an override modifier 
   *         lose it and any final modifier as well. 
   *
   * 2. For all virtual member classes VC which
   *    are not abstract and which are or inherit from a virtual class defined in current class
   *    add a factory (@see mkFactory)
   * 3. Convert TypeRef's to VC where VC is a virtual class to TypeRef's to AT, where AT
   *    is the abstract type corresponding to VC.
   *
   *  Note: If a class inherits vc's from two different paths, a vc in the 
   *  inheriting class has to be created beforehand. This is done in phase ??? (NOT YET DONE!)
   *
   *  Note: subclasses of virtual classes are treated as if they are virtual.
   *  isVirtualClass returns true for them also. 
   */
  object devirtualizeMap extends TypeMap {
    def apply(tp: Type): Type = mapOver(tp) match { 
      case tp1 @ ClassInfoType(parents, decls0, clazz) =>
        var decls = decls0
        def enter(sym: Symbol) = // at next phase because names of worker traits change
          atPhase(ownPhase.next) { decls.enter(sym) } 
        if (containsVirtuals(clazz)) {
          decls = newScope
          for (m <- decls0.toList) {
            if (m.isVirtualClass) {
              m.setFlag(notDEFERRED | notFINAL | lateABSTRACT)
              enter(mkAbstractType(m))
            }
            enter(m)
          }
          for (m <- classesInNeedOfFactories(clazz))
            enter(mkFactory(m, clazz))
        }
        if (clazz.isVirtualClass) {
          println("virtual class: "+clazz+clazz.locationString)
          transformOwnerInfo(clazz) 
          decls = newScope
          
          // add virtual fields for all primary constructor parameters
          for (row <- paramTypesAndIndices(clazz.primaryConstructor.tpe, 0))
            for ((pt, i) <- row) 
              enter(mkParamField(clazz, i, devirtualizeMap(pt)))

          // remove OVERRIDE from all workertrait members, except if they override a member in Object
          for (m <- decls0.toList) {
            if (!m.isConstructor) {
              if ((m hasFlag OVERRIDE) && m.overriddenSymbol(ObjectClass) == NoSymbol)
                m setFlag (notOVERRIDE | notFINAL)
              enter(m)
            }
          }
          if (clazz.thisSym == clazz) clazz.typeOfThis = clazz.thisType
            // ... to give a hook on which we can hang selftype transformers
          ClassInfoType(List(ObjectClass.tpe, ScalaObjectClass.tpe), decls, clazz)
        } else {
          ClassInfoType(parents map this, decls, clazz)
        }
      case tp1 @ TypeRef(pre, clazz, args) if clazz.isVirtualClass =>
        TypeRef(pre, abstractType(clazz), args)
      case tp1 =>
        tp1
    }
  }

  /** Transform owner of given clazz symbol */
  protected def transformOwnerInfo(clazz: Symbol) { atPhase(ownPhase.next) { clazz.owner.info } }

  /** Names of derived classes and factories */
  protected def concreteClassName(clazz: Symbol) = 
    atPhase(ownPhase) { newTypeName(clazz.name+"$fix") }
  protected def factoryName(clazz: Symbol) = 
    atPhase(ownPhase) { newTermName("new$"+clazz.name) }

  /** Does `clazz' contaion virtual classes? */
  protected def containsVirtuals(clazz: Symbol) = clazz.info.decls.toList exists (_.isVirtualClass)

  /** The inner classes that need factory methods in `clazz' 
   *  This is intended to catch situations like the following
   *
   *  abstract class C { 
   *    class V <: {...}
   *    class W extends V 
   *  }
   *  class D extends C {
   *    class V <: {...}
   *    // factories needed for V and W!
   *  }
   */
  protected def classesInNeedOfFactories(clazz: Symbol) = atPhase(ownPhase) {
    def isOverriddenVirtual(c: Symbol) = 
      c.isVirtualClass && clazz.info.decl(c.name).isVirtualClass
    val xs = clazz.info.members.toList filter (x => x.isVirtualClass && !x.hasFlag(ABSTRACT))
    for (m <- clazz.info.members.toList; 
         if (m.isVirtualClass && !(m hasFlag ABSTRACT) && 
             (m.info.baseClasses exists isOverriddenVirtual))) yield m
  }

  /** The abstract type corresponding to a virtual class. */  
  protected def abstractType(clazz: Symbol): Symbol =  atPhase(ownPhase.next) {
    val abstpe = clazz.owner.info.decl(atPhase(ownPhase) { clazz.name })
    assert(abstpe.isAbstractType)
    abstpe
  }

  /** The factory corresponding to a virtual class. */  
  protected def factory(clazz: Symbol, owner: Symbol) = atPhase(ownPhase.next) {
    val fsym = owner.info.member(factoryName(clazz))
    assert(fsym.isMethod, clazz)
    fsym
  }

  /** The name of the field representing a constructor parameter of a virtual class */
  protected def paramFieldName(clazz: Symbol, index: Int) = atPhase(ownPhase) {
    clazz.expandedName(newTermName("param$"+index))
  }

  /** The name of the field representing a constructor parameter of a virtual class */
  protected def fixParamName(index: Int) = newTermName("fix$"+index)

  /** The field representing a constructor parameter of a virtual class */
  protected def paramField(clazz: Symbol, index: Int) = atPhase(ownPhase.next) {
    clazz.info.decl(paramFieldName(clazz, index))
  }

  /** The flags that an abstract type can inherit from its virtual class */
  protected val absTypeFlagMask = AccessFlags | DEFERRED

  /** The flags that a factory method can inherit from its virtual class */
  protected val factoryFlagMask = AccessFlags

  /** Create a polytype with given type parameters and given type, or return just the type
   *  if type params is empty. */
  protected def mkPolyType(tparams: List[Symbol], tp: Type) = 
    if (tparams.isEmpty) tp else PolyType(tparams, tp)

  /** A lazy type to complete `sym', which is is generated for virtual class
   *  `clazz'. 
   *  The info of the symbol is computed by method `getInfo'.
   *  It is wrapped in copies of the type parameters of `clazz'.
   */
  abstract class PolyTypeCompleter(sym: Symbol, clazz: Symbol) extends LazyType {
    def getInfo: Type
    override val typeParams = cloneSymbols(clazz.typeParams, sym)
    override def complete(sym: Symbol) {
      sym.setInfo(
        mkPolyType(typeParams, getInfo.substSym(clazz.typeParams, typeParams)))
    }
  }

  protected def wasVirtualClass(sym: Symbol) = { 
    sym.isVirtualClass || {
      sym.info
      sym hasFlag notDEFERRED 
    }
  }

  protected def addOverriddenVirtuals(clazz: Symbol) = {
    (clazz.allOverriddenSymbols filter wasVirtualClass) ::: List(clazz)
  }

  protected def addOverriddenVirtuals(tpe: Type) = tpe match {
    case TypeRef(pre, sym, args) =>
      { for (vc <- sym.allOverriddenSymbols if wasVirtualClass(vc))
        yield typeRef(pre, vc, args) }.reverse ::: List(tpe)
  }

  protected def mkParamField(clazz: Symbol, index: Int, tpe: Type): Symbol = {
    val param = clazz.newMethod(clazz.pos, paramFieldName(clazz, index))
      .setFlag(PROTECTED | LOCAL | DEFERRED | EXPANDEDNAME | SYNTHETIC | STABLE)
    atPhase(ownPhase.next) {
      param.setInfo(PolyType(List(), tpe))
    }
    param
  }

  protected def mkAbstractType(clazz: Symbol): Symbol = {
    val cabstype = clazz.owner.newAbstractType(clazz.pos, clazz.name)
      .setFlag(clazz.flags & absTypeFlagMask | SYNTHETIC)
      .setAnnotations(clazz.annotations)
    atPhase(ownPhase.next) {
      cabstype setInfo new PolyTypeCompleter(cabstype, clazz) {
        def getInfo = {
          val parents1 = clazz.info.parents map {
            p => devirtualizeMap(p.substSym(clazz.typeParams, typeParams))
          }
          val parents2 = addOverriddenVirtuals(clazz) map {
            c => typeRef(clazz.owner.thisType, c, typeParams map (_.tpe))
          }
          mkTypeBounds(NothingClass.tpe, intersectionType(parents1 ::: parents2))
        }
      }
    }
  }

  protected def paramTypesAndIndices(tpe: Type, start: Int): List[List[(Type, Int)]] = tpe match {
    case PolyType(_, restpe) => paramTypesAndIndices(restpe, start)
    case MethodType(params, restpe) =>
      val end = start + params.length
      (tpe.paramTypes zip List.range(start, end)) :: paramTypesAndIndices(restpe, end)
    case _ =>
      List()
  }
    
  /* Add a factory symbol for a virtual class  
   *
   *  attrs mods class VC[Ts] <: Ps { decls }
   *  with base classes BC[Us]'s
   *
   * which corresponds to the following definition :
   *
   *  attrs mods3 def new$VC[Ts](): VC[Ts] = {
   *    class VC$fix extends v2w(BC's[Ts]) with VC$trait[Ts] { ... }
   *    new VC$fix
   *  }
   *
   * where
   *  
   *  mods3  are the modifiers inherited to factories
   *  v2w maps every virtual class to its workertrait and leaves other types alone.
   *
   *  @param  clazz    The virtual class for which factory is added
   *  @param  owner    The owner for which factory is added as a member
   *  @param  scope    The scope into which factory is entered
   */
  def mkFactory(clazz: Symbol, owner: Symbol): Symbol = {
    val pos = if (clazz.owner == owner) clazz.pos else owner.pos
    val factory = owner.newMethod(pos, factoryName(clazz))
      .setFlag(clazz.flags & factoryFlagMask | SYNTHETIC)
      .setAnnotations(clazz.annotations)
    factory setInfo new PolyTypeCompleter(factory, clazz) {
      private def copyType(tpe: Type): Type = tpe match {
        case MethodType(formals, restpe) => MethodType(formals, copyType(restpe))
        case PolyType(List(), restpe) => PolyType(List(), copyType(restpe))
        case PolyType(_, _) => throw new Error("bad case: "+tpe)
        case _ => owner.thisType.memberType(abstractType(clazz))
      }
      def getInfo = copyType(clazz.primaryConstructor.tpe)
    }
    factory
  }

  def removeDuplicates(ts: List[Type]): List[Type] = ts match {
    case List() => List()
    case t :: ts1 => t :: removeDuplicates(ts1 filter (_.typeSymbol != t.typeSymbol))
  }

  /** The concrete class symbol VC$fix in the factory symbol (@see mkFactory)
   *  @param clazz    the virtual class
   *  @param factory  the factory which returns an instance of this class
   */
  protected def mkConcreteClass(clazz: Symbol, factory: Symbol) = {
    val cclazz = factory.newClass(clazz.pos, concreteClassName(clazz))
      .setFlag(FINAL | SYNTHETIC)
      .setAnnotations(clazz.annotations)
      
    cclazz setInfo new LazyType {
      override def complete(sym: Symbol) {
        val parents1 = atPhase(ownPhase) {
          var superclazz = clazz
          do {
            superclazz = superclazz.info.parents.head.typeSymbol
          } while (wasVirtualClass(superclazz)) 
          val bcs = superclazz :: (clazz.info.baseClasses takeWhile (superclazz != )).reverse
          println("MKConcrete1 "+cclazz+factory.locationString+" "+bcs+" from "+clazz+clazz.locationString)
          println("MKConcrete2 "+cclazz+factory.locationString+" "+(bcs map factory.owner.thisType.memberType))
          bcs map factory.owner.thisType.memberType
        }
        atPhase(ownPhase.next) {
          val parents2 = 
            removeDuplicates(parents1.flatMap(addOverriddenVirtuals))
            .map(_.substSym(clazz.typeParams, factory.typeParams))
          sym setInfo ClassInfoType(parents2, newScope, cclazz)
        }
      }
    }

    cclazz
  }

  /** Perform the following tree transformations:
   *  
   *  1. Add trees for abstract types (@see devirtualize),
   *     worker traits (@see devirtualize)
   *     and factories (@see mkFactory)
   *  
   *  2. Replace a new VC().init(...) where VC is a virtual class with new$VC(...)
   *  
   *  3. Replace references to VC.this and VC.super where VC is a virtual class
   *     with VC$trait.this and VC$trait.super
   *
   *  4. Transform type references to virtual classes to type references of corresponding
   *     abstract types.
   */
  class DeVirtualizeTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { 
    // all code is executed at phase ownPhase.next
    
    /** Add trees for abstract types, worker traits, and factories (@see mkFactory)
     *  to template body `stats'
     */  
    override def transformStats(stats: List[Tree], exprOwner: Symbol): List[Tree] = {
      val stats1 = stats flatMap transformStat
      val fclasses = atPhase(ownPhase) {
        if (currentOwner.isClass && containsVirtuals(currentOwner)) classesInNeedOfFactories(currentOwner)  
        else List()
      }
      val newDefs = fclasses map factoryDef
      if (newDefs.isEmpty) stats1 else stats1 ::: newDefs
    }

    def fixClassDef(clazz: Symbol, factory: Symbol): Tree = {
      val cclazz = mkConcreteClass(clazz, factory) 
      val overrideBridges = 
        for (m <- clazz.info.decls.toList if m hasFlag notOVERRIDE) 
        yield overrideBridge(m, cclazz)
      
      val vparamss: List[List[ValDef]] = atPhase(ownPhase) {
        paramTypesAndIndices(clazz.primaryConstructor.tpe, 0) map {
          _ map {
          case (pt, i) =>
            atPos(factory.pos) {
              ValDef(Modifiers(PARAMACCESSOR | PRIVATE | LOCAL), fixParamName(i), 
                     TypeTree(devirtualizeMap(pt)), EmptyTree)
            }
          }
        }
      }
      val pfields: List[DefDef] = atPhase(ownPhase) {
        paramTypesAndIndices(clazz.primaryConstructor.tpe, 0) flatMap {
          _ map {
          case (pt, i) =>
            val pfield = cclazz.newMethod(cclazz.pos, paramFieldName(clazz, i))
              .setFlag(PROTECTED | LOCAL | EXPANDEDNAME | SYNTHETIC | STABLE)
              .setInfo(PolyType(List(), pt))
            cclazz.info.decls enter pfield
            atPos(factory.pos) {
              DefDef(pfield, Ident(fixParamName(i)))
            }
          }
        }
      }
      atPos(clazz.pos) {
        ClassDef(cclazz, Modifiers(0), vparamss, List(List()), pfields ::: overrideBridges, clazz.pos.focus)
      }
    }
      

    /** The factory definition for virtual class `clazz' (@see mkFactory)
     *  For a virtual class  
     *
     *  attrs mods class VC[Ts] <: Ps { decls }
     *  with overridden classes _VC[Us]'s
     *
     * we need the following factory:
     *
     *  attrs mods3 def new$VC[Ts](): VC[Ts] = {
     *    class VC$fix extends _VC$trait's[Ts] with VC$trait[Ts] {
     *      override-bridges
     *    }
     *    new VC$fix.asInstanceOf[VC[Ts]]
     *  }
     *
     * where
     *  
     *  mods3  are the modifiers inherited to factories
     *  override-bridges are definitions that link every symbol in a worker trait
     *                   that was overriding something to the overridden symbol
     *                   //todo: not sure what happens with abstract override?
     */
    def factoryDef(clazz: Symbol): Tree = {
      val factorySym = factory(clazz, currentOwner)
      val cclazzDef = fixClassDef(clazz, factorySym)
      println("Concrete: "+cclazzDef)
      val abstpeSym = abstractType(clazz)
      localTyper.typed {
        atPos(factorySym.pos) {
          DefDef(factorySym,
            Block(
              List(cclazzDef), 
              TypeApply(
                Select(
                  gen.mkForwarder(
                    Select(New(TypeTree(cclazzDef.symbol.tpe)), nme.CONSTRUCTOR),
                    factorySym.paramss),
                  Any_asInstanceOf),
             List(
               TypeTree(
                 currentOwner.thisType.memberType(abstpeSym)
                 .substSym(abstpeSym.typeParams, factorySym.typeParams))))))
        }
      }
    }

    /** Create an override bridge for method `meth' in concrete class `cclazz'.
     *  An override bridge has the form
     *
     *   override f(xs1)...(xsN) = super.f(xs)...(xsN)
     */
    def overrideBridge(meth: Symbol, cclazz: Symbol) = atPos(meth.pos) {
      val bridge = meth.cloneSymbol(cclazz)
        .resetFlag(notOVERRIDE | notFINAL)
      cclazz.info.decls.enter(bridge)
      val superRef: Tree = Select(Super(cclazz, nme.EMPTY.toTypeName), meth)
      DefDef(bridge, gen.mkForwarder(superRef, bridge.paramss))
    }

    /** Replace definitions of virtual classes by definitions of corresponding
     *  abstract type and worker traits.
     *  Eliminate constructors of former virtual classes because these are now traits.
     */
    protected def transformStat(tree: Tree): List[Tree] = {
      val sym = tree.symbol
      tree match {
        case ClassDef(mods, name, tparams, templ) 
        if (wasVirtualClass(sym)) =>
          val clazz = sym
          val absTypeSym = abstractType(clazz)
          val abstypeDef = TypeDef(absTypeSym)
          List(localTyper.typed(abstypeDef), transform(tree))
        case DefDef(_, nme.CONSTRUCTOR, _, _, _, _) 
        if (wasVirtualClass(sym.owner)) =>
          if (atPhase(ownPhase)(sym != sym.owner.primaryConstructor))
            unit.error(tree.pos, "virtual classes cannot have auxiliary constructors")
          List()
        case _ =>
          List(transform(tree))
      }
    }

    override def transform(tree0: Tree): Tree = {
      val tree = super.transform(tree0)
      val sym = tree.symbol
      tree match {
        // Replace a new VC().init() where VC is a virtual class with new$VC
        case Apply(Select(New(tpt), name), args) if (sym.isConstructor && wasVirtualClass(sym.owner)) =>
          val clazz = sym.owner
          val fn = 
            Select(
              gen.mkAttributedQualifier(tpt.tpe.prefix),
              factory(clazz, clazz.owner).name)
          println("fac "+factory(clazz, clazz.owner).tpe)
          val targs = tpt.tpe.typeArgs
          atPos(tree.pos) {
            localTyper.typed {
              val res = 
                Apply(if (targs.isEmpty) fn else TypeApply(fn, targs map TypeTree), args)
                println("typing "+res+" from "+args)
                res
              }
            }
        
        case Template(parents, self, body) if (wasVirtualClass(sym.owner)) =>
          // add param field accessors
          val paramFieldAccessors = new ListBuffer[Tree]
          val paramFields = new ListBuffer[Tree]
          val presupers = new ListBuffer[Tree]
          val others = new ListBuffer[Tree]
          var paramFieldCount = 0
          for (stat <- body) {
            if (stat.symbol != null && (stat.symbol hasFlag PARAMACCESSOR))
              stat match {
                case pacc @ ValDef(mods, name, tpt, rhs) =>
                  pacc.symbol resetFlag PARAMACCESSOR setFlag PRESUPER
                  val pfield = paramField(sym.owner, paramFieldCount)
                  paramFieldCount += 1
                  pfield setPos pacc.pos
                  paramFields += localTyper.typed(DefDef(pfield, EmptyTree))
                  val pfieldRef = localTyper.typed { 
                    atPos(pacc.pos) { 
                      Select(This(sym.owner), pfield) 
                    } 
                  }
                  paramFieldAccessors += treeCopy.ValDef(pacc, mods, name, tpt, pfieldRef)
                case _ =>
                  stat.symbol resetFlag PARAMACCESSOR // ??? can we do this                  
                  others += stat
              }
            else 
              (if (stat.symbol != null && (stat.symbol hasFlag PRESUPER)) presupers else others) += stat
          }
          treeCopy.Template(tree, parents, self, 
                            paramFieldAccessors.toList ::: 
                            presupers.toList ::: 
                            paramFields.toList :::
                            others.toList)
        case _ =>
          tree setType atPhase(ownPhase)(devirtualizeMap(tree.tpe))
      }
    }
    override def transformUnit(unit: CompilationUnit) = atPhase(ownPhase.next) { 
      super.transformUnit(unit)
    }
  }
}

    

/*
    class A {
      trait C[X, Y] <: { 
        var x: X
        def f(y: Y): X = { println("A.T"); x } 
      }
      class D[X](xp: X) extends C[X, Int] { 
        var x: X = xp
        override def f(y: Int) = { println(y); super.f(y) } 
      }
    }
    class B extends A {
      override trait C[X, Y] <: { 
        override def f(y: Y): X = { println("B.T"); super.f(y) } 
        def g: X = x
      }
    }
    object Test extends B {
      val c = new D[String]("OK")
      println(c.g)
      println(c.f(42))
    }

maps to:

  class A {
    type C[X, Y] <: C$trait[X, Y]

    trait C$trait[X, Y] { this: C with C$trait => 
      var x: X
      def f(y: Y): X = { println("A.T"); x }
    } 

    class D[X](xp: X) extends C[X, Int] { 
      var x: X = xp
      override def f(y: Int) = { println(y); super.f(y) } 
    }

      protected[this] val x: Int; val y = x; def f(z:Int) = z + 1 }

    type D <: C with DT

    trait DT extends { self: D => def f(z:Int) = z + 2 }

    trait preDT extends { self: D => val z: Int; val x = f(z) }

    def newC(x: Int): C
    def newD(x: Int): D

    //type C = CT
    //type D = C with DT

    class CC(_x:Int) extends { val x = _x } with CT

    def newC[X, Y](x:Int): C = 
      new CC(x).asInstanceOf[C]

    class DC(_z:Int) extends { val z = _z } with preDT with CT with DT {
      override def f(z:Int) = super.f(z)
    }

    def newD(z:Int):D = new DC(z).asInstanceOf[D]
  }

  class B extends A {
    type C <: CT with CT2

    trait CT2 { self : C => def g = 2 }

    //type C = CT with CT2
    //type D = C with DT

    class CC2(_x:Int) extends { val x = _x } with CT with CT2

    def newC(x:Int): C = new CC2(x).asInstanceOf[C]

    class DC2(_z:Int) extends { val z = _z } with preDT with CT with CT2
      with DT { override def f(z:Int) = super.f(z) }

    def newD(z:Int): D = new DC2(z).asInstanceOf[D]
  }

*/