/* NSC -- new scala compiler
 * Copyright 2005-2009 LAMP/EPFL
 * @author  Martin Odersky
 */
// $Id: Variances.scala 18387 2009-07-24 15:28:37Z odersky $

package scala.tools.nsc
package typechecker

import symtab.Flags._

/** Variances form a lattice, 0 <= COVARIANT <= Variances, 0 <= CONTRAVARIANT <= VARIANCES
 */
trait Variances {

  val global: Global
  import global._

  /** Convert variance to string */
  private def varianceString(variance: Int): String =
    if (variance == COVARIANT) "covariant"
    else if (variance == CONTRAVARIANT) "contravariant"
    else "invariant"
  
  /** Flip between covariant and contravariant */
  private def flip(v: Int): Int = {
    if (v == COVARIANT) CONTRAVARIANT
    else if (v == CONTRAVARIANT) COVARIANT
    else v
  }

  private def compose(v1: Int, v2: Int) =
    if (v1 == 0) 0
    else if (v1 == CONTRAVARIANT) flip(v2)
    else v2;

  /** Map everything below VARIANCES to 0 */
  private def cut(v: Int): Int =
    if (v == VARIANCES) v else 0

  /** Compute variance of type parameter `tparam' in types of all symbols `sym'. */
  def varianceInSyms(syms: List[Symbol])(tparam: Symbol): Int =
    (VARIANCES /: syms) ((v, sym) => v & varianceInSym(sym)(tparam))

  /** Compute variance of type parameter `tparam' in type of symbol `sym'. */
  def varianceInSym(sym: Symbol)(tparam: Symbol): Int = 
    if (sym.isAliasType) cut(varianceInType(sym.info)(tparam)) 
    else varianceInType(sym.info)(tparam)

  /** Compute variance of type parameter `tparam' in all types `tps'. */
  def varianceInTypes(tps: List[Type])(tparam: Symbol): Int =
    (VARIANCES /: tps) ((v, tp) => v & varianceInType(tp)(tparam))

  /** Compute variance of type parameter `tparam' in all type arguments
   *  <code>tps</code> which correspond to formal type parameters `tparams1'.
   */
  def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol): Int = {
    var v: Int = VARIANCES;
    for ((tp, tparam1) <- tps zip tparams1) {
      val v1 = varianceInType(tp)(tparam)
      v = v & (if (tparam1.isCovariant) v1
	       else if (tparam1.isContravariant) flip(v1)
	       else cut(v1))
    }
    v
  }

  /** Compute variance of type parameter `tparam' in all type annotations `annots'. */
  def varianceInAttribs(annots: List[AnnotationInfo])(tparam: Symbol): Int = {
    (VARIANCES /: annots) ((v, annot) => v & varianceInAttrib(annot)(tparam))
  }

  /** Compute variance of type parameter `tparam' in type annotation `annot'. */
  def varianceInAttrib(annot: AnnotationInfo)(tparam: Symbol): Int = {
    varianceInType(annot.atp)(tparam)
  }

  /** Compute variance of type parameter <code>tparam</code> in type <code>tp</code>. */
  def varianceInType(tp: Type)(tparam: Symbol): Int = tp match {
    case ErrorType | WildcardType | NoType | NoPrefix | ThisType(_) | ConstantType(_) =>
      VARIANCES
    case SingleType(pre, sym) =>
      varianceInType(pre)(tparam)
    case TypeRef(pre, sym, args) =>
      if (sym == tparam) COVARIANT
      else varianceInType(pre)(tparam) & varianceInArgs(args, sym.typeParams)(tparam)
    case TypeBounds(lo, hi) =>
      flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam)
    case RefinedType(parents, defs) =>
      varianceInTypes(parents)(tparam) & varianceInSyms(defs.toList)(tparam)
    case MethodType(params, restpe) =>
      flip(varianceInSyms(params)(tparam)) & varianceInType(restpe)(tparam)
    case PolyType(tparams, restpe) =>
      flip(varianceInSyms(tparams)(tparam)) & varianceInType(restpe)(tparam)
    case ExistentialType(tparams, restpe) =>
      varianceInSyms(tparams)(tparam) & varianceInType(restpe)(tparam)
    case AnnotatedType(annots, tp, _) =>
      varianceInAttribs(annots)(tparam) & varianceInType(tp)(tparam)
  }
}