/* NSC -- new Scala compiler * Copyright 2005-2009 LAMP/EPFL * @author Martin Odersky */ // $Id: ConstantFolder.scala 18387 2009-07-24 15:28:37Z odersky $ package scala.tools.nsc package typechecker import java.lang.ArithmeticException /** This class ... * * @author Martin Odersky * @version 1.0 */ abstract class ConstantFolder { val global: Global import global._ import definitions._ /** If tree is a constant operation, replace with result. */ def apply(tree: Tree): Tree = fold(tree, tree match { case Apply(Select(Literal(x), op), List(Literal(y))) => foldBinop(op, x, y) case Select(Literal(x), op) => foldUnop(op, x) case _ => null }) /** If tree is a constant value that can be converted to type `pt', perform * the conversion. * * @param tree ... * @param pt ... */ def apply(tree: Tree, pt: Type): Tree = fold(tree, tree.tpe match { case ConstantType(x) => x convertTo pt case _ => null }) private def fold(tree: Tree, compX: => Constant): Tree = try { val x = compX if ((x ne null) && x.tag != UnitTag) tree setType mkConstantType(x) else tree } catch { case _: ArithmeticException => tree // the code will crash at runtime, // but that is better than the // compiler itself crashing } private def foldUnop(op: Name, x: Constant): Constant = (op, x.tag) match { case (nme.UNARY_!, BooleanTag) => Constant(!x.booleanValue) case (nme.UNARY_~ , IntTag ) => Constant(~x.intValue) case (nme.UNARY_~ , LongTag ) => Constant(~x.longValue) case (nme.UNARY_+ , IntTag ) => Constant(+x.intValue) case (nme.UNARY_+ , LongTag ) => Constant(+x.longValue) case (nme.UNARY_+ , FloatTag ) => Constant(+x.floatValue) case (nme.UNARY_+ , DoubleTag ) => Constant(+x.doubleValue) case (nme.UNARY_- , IntTag ) => Constant(-x.intValue) case (nme.UNARY_- , LongTag ) => Constant(-x.longValue) case (nme.UNARY_- , FloatTag ) => Constant(-x.floatValue) case (nme.UNARY_- , DoubleTag ) => Constant(-x.doubleValue) case _ => null } private def foldBinop(op: Name, x: Constant, y: Constant): Constant = try { val optag = if (x.tag == y.tag) x.tag else if (isNumeric(x.tag) && isNumeric(y.tag)) if (x.tag > y.tag) x.tag else y.tag else NoTag optag match { case BooleanTag => op match { case nme.ZOR => Constant(x.booleanValue | y.booleanValue) case nme.OR => Constant(x.booleanValue | y.booleanValue) case nme.XOR => Constant(x.booleanValue ^ y.booleanValue) case nme.ZAND => Constant(x.booleanValue & y.booleanValue) case nme.AND => Constant(x.booleanValue & y.booleanValue) case nme.EQ => Constant(x.booleanValue == y.booleanValue) case nme.NE => Constant(x.booleanValue != y.booleanValue) case _ => null } case ByteTag | ShortTag | CharTag | IntTag => op match { case nme.OR => Constant(x.intValue | y.intValue) case nme.XOR => Constant(x.intValue ^ y.intValue) case nme.AND => Constant(x.intValue & y.intValue) case nme.LSL => Constant(x.intValue << y.intValue) case nme.LSR => Constant(x.intValue >>> y.intValue) case nme.ASR => Constant(x.intValue >> y.intValue) case nme.EQ => Constant(x.intValue == y.intValue) case nme.NE => Constant(x.intValue != y.intValue) case nme.LT => Constant(x.intValue < y.intValue) case nme.GT => Constant(x.intValue > y.intValue) case nme.LE => Constant(x.intValue <= y.intValue) case nme.GE => Constant(x.intValue >= y.intValue) case nme.ADD => Constant(x.intValue + y.intValue) case nme.SUB => Constant(x.intValue - y.intValue) case nme.MUL => Constant(x.intValue * y.intValue) case nme.DIV => Constant(x.intValue / y.intValue) case nme.MOD => Constant(x.intValue % y.intValue) case _ => null } case LongTag => op match { case nme.OR => Constant(x.longValue | y.longValue) case nme.XOR => Constant(x.longValue ^ y.longValue) case nme.AND => Constant(x.longValue & y.longValue) case nme.LSL => Constant(x.longValue << y.longValue) case nme.LSR => Constant(x.longValue >>> y.longValue) case nme.ASR => Constant(x.longValue >> y.longValue) case nme.EQ => Constant(x.longValue == y.longValue) case nme.NE => Constant(x.longValue != y.longValue) case nme.LT => Constant(x.longValue < y.longValue) case nme.GT => Constant(x.longValue > y.longValue) case nme.LE => Constant(x.longValue <= y.longValue) case nme.GE => Constant(x.longValue >= y.longValue) case nme.ADD => Constant(x.longValue + y.longValue) case nme.SUB => Constant(x.longValue - y.longValue) case nme.MUL => Constant(x.longValue * y.longValue) case nme.DIV => Constant(x.longValue / y.longValue) case nme.MOD => Constant(x.longValue % y.longValue) case _ => null } case FloatTag => op match { case nme.EQ => Constant(x.floatValue == y.floatValue) case nme.NE => Constant(x.floatValue != y.floatValue) case nme.LT => Constant(x.floatValue < y.floatValue) case nme.GT => Constant(x.floatValue > y.floatValue) case nme.LE => Constant(x.floatValue <= y.floatValue) case nme.GE => Constant(x.floatValue >= y.floatValue) case nme.ADD => Constant(x.floatValue + y.floatValue) case nme.SUB => Constant(x.floatValue - y.floatValue) case nme.MUL => Constant(x.floatValue * y.floatValue) case nme.DIV => Constant(x.floatValue / y.floatValue) case nme.MOD => Constant(x.floatValue % y.floatValue) case _ => null } case DoubleTag => op match { case nme.EQ => Constant(x.doubleValue == y.doubleValue) case nme.NE => Constant(x.doubleValue != y.doubleValue) case nme.LT => Constant(x.doubleValue < y.doubleValue) case nme.GT => Constant(x.doubleValue > y.doubleValue) case nme.LE => Constant(x.doubleValue <= y.doubleValue) case nme.GE => Constant(x.doubleValue >= y.doubleValue) case nme.ADD => Constant(x.doubleValue + y.doubleValue) case nme.SUB => Constant(x.doubleValue - y.doubleValue) case nme.MUL => Constant(x.doubleValue * y.doubleValue) case nme.DIV => Constant(x.doubleValue / y.doubleValue) case nme.MOD => Constant(x.doubleValue % y.doubleValue) case _ => null } case StringTag => op match { case nme.ADD => Constant(x.stringValue + y.stringValue) case _ => null } case _ => null } } catch { case ex: ArithmeticException => null } }