Commit 7254b138 authored by ='s avatar =

Import of ConstantTransformerPhase for mixed rewriting

parent 092865e2
package daisy
package transform
import scala.collection.immutable.Seq
import lang.Trees._
import lang.Identifiers._
import lang.Types.RealType
import lang.Extractors._
import lang.TreeOps.replace
/**
Pulls out constants. e.g.
from
(((((((-3 * u) * u) * u) + ((3 * u) * u)) + (3 * u)) + 1) / 6.0)
to
val _const0: Double = -3.0
val _const1: Double = 3.0
val _const2: Double = 3.0
val _const3: Double = 1.0
val _const4: Double = 6.0
(((((((_const0 * u) * u) * u) + ((_const1 * u) * u)) + (_const2 * u)) + _const3) / _const4)
Prerequisites:
None
*/
object ConstantTransformerPhase extends DaisyPhase {
override val name = "Constant transformer"
override val description = "Pulls out constants"
override val definedOptions: Set[CmdLineOptionDef[Any]] = Set()
implicit val debugSection = DebugSectionTransform
var reporter: Reporter = null
override def run(ctx: Context, prg: Program): (Context, Program) = {
reporter = ctx.reporter
reporter.info(s"\nStarting $name phase")
val timer = ctx.timers.constTrans.start
// need to replace function bodies, so create a copy of the whole program
val newDefs = prg.defs.map(fnc =>
if (!fnc.body.isEmpty) {
val newBody = pullOutConstants(fnc.body.get)
fnc.copy(body = Some(newBody))
} else {
fnc
})
reporter.debug("Original program:")
reporter.debug(lang.PrettyPrinter.withIDs(prg))
reporter.debug("\nNew program:")
reporter.debug(lang.PrettyPrinter.withIDs(Program(prg.id, newDefs)))
timer.stop
reporter.info(s"Finished $name phase")
// return modified program
(ctx, Program(prg.id, newDefs))
}
def pullOutConstants(expr: Expr): Expr = {
// find all constants
var constants = Seq[(Identifier, RealLiteral)]()
var counter = 0
def mapConstants(e: Expr): Expr = (e: @unchecked) match {
case v: Variable => v
case x @ RealLiteral(r) =>
val fresh = FreshIdentifier("_const" + counter, RealType)
counter = counter + 1
constants = constants :+ (fresh, x)
Variable(fresh)
case UMinus(t) => UMinus(mapConstants(t))
case Sqrt(t) => Sqrt(mapConstants(t))
case Plus(l, r) => Plus(mapConstants(l), mapConstants(r))
case Minus(l, r) => Minus(mapConstants(l), mapConstants(r))
case Times(l, r) => Times(mapConstants(l), mapConstants(r))
case Division(l, r) => Division(mapConstants(l), mapConstants(r))
case Let(id, r @ RealLiteral(_), b) => Let(id, r, mapConstants(b))
case Let(id, v, b) => Let(id, mapConstants(v), mapConstants(b))
}
val expr2 = mapConstants(expr)
// generate vals for all constants
def makeIntoLets(consts: Seq[(Identifier, RealLiteral)]): Expr = {
if (consts.length > 0) {
val (fresh, const) = consts.head
Let(fresh, const, makeIntoLets(consts.tail))
} else {
expr2
}
}
makeIntoLets(constants)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment