Commit 2c8545d0 authored by Eva Darulova's avatar Eva Darulova
Browse files

Merge branch 'fp-merge' into 'master'

Fp merge



See merge request !40
parents 1045a1af d7785468
......@@ -6,3 +6,4 @@ rawdata/*
.ensime*
/daisy
last.log
output/*.scala
......@@ -25,6 +25,7 @@ object Main {
//ParamOptionDef("timeout", "Timeout in ms.", "1000"),
ListOptionDef("debug", "For which sections to print debug info.",
List("analysis","solver")),
FlagOptionDef("codegen", "Generate code (as opposed to just doing analysis."),
optionFunctions,
optionPrintToughSMTCalls
)
......@@ -161,17 +162,31 @@ object Main {
private def computePipeline(ctx: Context): Pipeline[Program, Program] = {
// this is not ideal, using 'magic' strings
val dynamicF = ctx.hasFlag("dynamic")
val fixedPointArith = ctx.findOption(analysis.RangeErrorPhase.optionPrecision) match {
case Some("Fixed16") => true
case Some("Fixed32") => true
case _ => false
}
if (dynamicF) {
// this is not ideal, using 'magic' strings
if (ctx.hasFlag("dynamic")) {
analysis.SpecsProcessingPhase andThen
analysis.DynamicPhase
} else {
} else if (ctx.hasFlag("codegen") && fixedPointArith) {
analysis.SpecsProcessingPhase andThen
transform.SSATransformerPhase andThen
analysis.RangeErrorPhase andThen
InfoPhase andThen
backend.CodeGenerationPhase
} else if (ctx.hasFlag("codegen")) {
analysis.SpecsProcessingPhase andThen
analysis.RangeErrorPhase andThen
InfoPhase andThen
backend.CodeGenerationPhase
} else {
analysis.SpecsProcessingPhase andThen
analysis.RangeErrorPhase andThen
InfoPhase
}
}
......
......@@ -20,13 +20,17 @@ object RangeErrorPhase extends DaisyPhase with ErrorFunctions {
override val name = "range-error phase"
override val description = "Computes ranges and absolute errors"
val optionPrecision = ChoiceOptionDef("precision", "Type of precision to use",
Set("Float32", "Float64", "DoubleDouble", "QuadDouble",
"Fixed16", "Fixed32"), "Float64")
override val definedOptions: Set[CmdLineOptionDef[Any]] = Set(
ChoiceOptionDef("rangeMethod", "Method to use for range analysis",
Set("affine", "interval", "smt", "subdiv"), "interval"),
FlagOptionDef("noInitialErrors", "do not track initial errors specified by user"),
FlagOptionDef("noRoundoff", "do not track roundoff errors"),
ChoiceOptionDef("precision", "Type of precision to use",
Set("Float32", "Float64", "DoubleDouble", "QuadDouble"), "Float64")
optionPrecision
)
implicit val debugSection = DebugSectionAnalysis
......@@ -50,7 +54,7 @@ object RangeErrorPhase extends DaisyPhase with ErrorFunctions {
// setting trait variables
trackRoundoffErrs = true
uniformPrecision = Float64
var uniformPrecision: Precision = Float64
// process relevant options
for (opt <- ctx.options) opt match {
......@@ -76,8 +80,12 @@ object RangeErrorPhase extends DaisyPhase with ErrorFunctions {
case "QuadDouble" =>
uniformPrecision = QuadDouble
reporter.info(s"using $s")
case "Fixed16" =>
uniformPrecision = Fixed(16)
case "Fixed32" =>
uniformPrecision = Fixed(32)
case _ =>
reporter.warning(s"Unknown precision specified: $s, choosing default (Float64)!")
reporter.warning(s"Unknown precision specified: $s, choosing default ($uniformPrecision)!")
}
case _ =>
}
......@@ -87,7 +95,7 @@ object RangeErrorPhase extends DaisyPhase with ErrorFunctions {
for (fnc <- prg.defs)
if (!fnc.precondition.isEmpty && !fnc.body.isEmpty && fncsToConsider.contains(fnc.id.toString)){
reporter.debug("analyzing fnc: " + fnc.id)
reporter.info("analyzing fnc: " + fnc.id)
val inputValMap: Map[Identifier, Interval] = ctx.inputRanges(fnc.id)
// If we track both input and roundoff errors, then we pre-compute
......@@ -125,13 +133,13 @@ object RangeErrorPhase extends DaisyPhase with ErrorFunctions {
(rangeMethod, errorMethod) match {
case ("interval", "affine") =>
errorIntervalAffine(fnc.body.get, inputValMap, inputErrorMap)
errorIntervalAffine(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision)
case ("affine", "affine") =>
errorAffineAffine(fnc.body.get, inputValMap, inputErrorMap)
errorAffineAffine(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision)
case ("smt", "affine") =>
errorSMTAffine(fnc.body.get, inputValMap, inputErrorMap)
errorSMTAffine(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision)
// default is to use the method that attaches the info to trees.
case ("subdiv", _) =>
......
......@@ -3,46 +3,235 @@ package daisy
package backend
import daisy.lang.{ScalaPrinter, PrettyPrinter}
import lang.Trees.Program
import lang.Trees._
import utils.FinitePrecision._
import lang.Types._
import lang.Extractors.ArithOperator
import utils.Rational
import lang.NumAnnotation
object CodeGenerationPhase extends DaisyPhase {
override val name = "Codegen"
override val description = "Generates code."
override val name = "codegen"
override val description = "Generates (executable) code."
// TODO: have this generate C code
override val definedOptions: Set[CmdLineOptionDef[Any]] = Set(
FlagOptionDef("printToFile", "print generated code to file")
)
override val definedOptions: Set[CmdLineOptionDef[Any]] = Set()
implicit val debugSection = DebugSectionBackend
var reporter: Reporter = null
def run(ctx: Context, prg: Program): (Context, Program) = {
reporter = ctx.reporter
reporter.info(s"\nStarting $name")
val timer = ctx.timers.rangeError.start
var uniformPrecision: Precision = Float64
var printToFile = false
/* Process relevant options */
for (opt <- ctx.options) opt match {
case FlagOption("printToFile") => printToFile = true
// this option is defined in RangeErrorPhase. This is not ideal,
// but since for now, Codegen will only be called after RangeErrorPhase
// it should be OK for now.
case ChoiceOption("precision", s) => s match {
case "Float32" => uniformPrecision = Float32
case "Float64" => uniformPrecision = Float64
case "DoubleDouble" => uniformPrecision = DoubleDouble
case "QuadDouble" => uniformPrecision = QuadDouble
case "Fixed16" => uniformPrecision = Fixed(16)
case "Fixed32" => uniformPrecision = Fixed(32)
case _ =>
ctx.reporter.warning(s"Unknown precision specified: $s, choosing default ($uniformPrecision)!")
}
case _ => ;
}
def writeScalaFile(filename: String) {
import java.io.FileWriter
import java.io.BufferedWriter
val fstream = new FileWriter(filename)
val out = new BufferedWriter(fstream)
out.write(ScalaPrinter.apply(prg))
out.close
val newProgram = uniformPrecision match {
case Fixed(b) =>
// if we have fixed-point code, we need to generate it first
val newDefs = prg.defs.map(fnc => if (!fnc.body.isEmpty && !fnc.precondition.isEmpty) {
val newBody = toFixedPointCode(fnc.body.get, Fixed(b))
val valDefType = if (b == 16) Int32Type else Int64Type
fnc.copy(
params = fnc.params.map(vd => ValDef(vd.id.changeType(valDefType))),
body = Some(newBody),
returnType = valDefType)
} else {
fnc
})
val newPrg = Program(prg.id, newDefs)
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("generating code in " + fileLocation)
writeScalaFile(fileLocation, newPrg)
newPrg
case _ =>
// if we have floating-point code, we need to just change the types
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("generating code in " + fileLocation)
val typedPrg = assignFloatType(prg, FinitePrecisionType(uniformPrecision))
writeScalaFile(fileLocation, typedPrg)
typedPrg
}
timer.stop
ctx.reporter.info(s"Finished $name")
(ctx, newProgram)
}
private def writeScalaFile(filename: String, prg: Program) {
import java.io.FileWriter
import java.io.BufferedWriter
val fstream = new FileWriter(filename)
val out = new BufferedWriter(fstream)
out.write(ScalaPrinter.apply(prg))
out.close
}
private def assignFloatType(prg: Program, tpe: FinitePrecisionType): Program = {
def changeType(e: Expr): Expr = e match {
case Variable(id) => Variable(id.changeType(tpe))
case x @ RealLiteral(r) =>
val tmp = FinitePrecisionLiteral(r, tpe.prec)
tmp.stringValue = x.stringValue
tmp
case ArithOperator(args, recons) =>
recons(args.map(changeType))
case Let(id, value, body) =>
Let(id.changeType(tpe), changeType(value), changeType(body))
}
//println(ScalaPrinter.apply(prg))
val newDefs = prg.defs.map({
case FunDef(id, returnType, params, pre, body, post, isField) =>
FunDef(id, tpe,
params.map(vd => ValDef(vd.id.changeType(tpe))),
// this should really be changed too
pre,
body.map(changeType(_)),
post,
isField
)
})
Program(prg.id, newDefs)
if (printToFile){
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("Generated code: " + fileLocation)
writeScalaFile(fileLocation)
}
/*
* Expects code to be already in SSA form.
* @param fixed the (uniform) fixed-point precision to use
// TODO: we also need to adjust the types, no?
*/
private def toFixedPointCode(expr: Expr, fixed: Fixed): Expr = expr match {
case x @ Variable(id) => fixed match {
case Fixed(16) => Variable(id.changeType(Int32Type))
case Fixed(32) => Variable(id.changeType(Int64Type))
}
(ctx, prg)
case RealLiteral(r) => // TODO: translate constant
val f = fixed.fractionalBits(r)
fixed match {
case Fixed(16) => Int32Literal((r * Rational.fromDouble(math.pow(2, f))).roundToInt)
case Fixed(32) => Int64Literal((r * Rational.fromDouble(math.pow(2, f))).roundToLong)
}
case UMinus(t) => UMinus(toFixedPointCode(t, fixed))
case Sqrt(t) =>
throw new Exception("Sqrt is not supported for fixed-points!")
null
case x @ Plus(lhs, rhs) =>
// fractional bits from lhs
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// determine how much to shift left or right
val fAligned = math.max(fLhs, fRhs)
val newLhs =
if (fLhs < fAligned) LeftShift(toFixedPointCode(lhs, fixed), (fAligned - fLhs))
else toFixedPointCode(lhs, fixed)
val newRhs =
if (fRhs < fAligned) LeftShift(toFixedPointCode(rhs, fixed), (fAligned - fRhs))
else toFixedPointCode(rhs, fixed)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fAligned == fRes) {
Plus(newLhs, newRhs)
} else if(fRes < fAligned) {
RightShift(Plus(newLhs, newRhs), (fAligned - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(Plus(newLhs, newRhs), (fRes - fAligned))
}
case x @ Minus(lhs, rhs) =>
// fractional bits from lhs
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// determine how much to shift left or right
val fAligned = math.max(fLhs, fRhs)
val newLhs =
if (fLhs < fAligned) LeftShift(toFixedPointCode(lhs, fixed), (fAligned - fLhs))
else toFixedPointCode(lhs, fixed)
val newRhs =
if (fRhs < fAligned) LeftShift(toFixedPointCode(rhs, fixed), (fAligned - fRhs))
else toFixedPointCode(rhs, fixed)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fAligned == fRes) {
Minus(newLhs, newRhs)
} else if(fRes < fAligned) {
RightShift(Minus(newLhs, newRhs), (fAligned - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(Minus(newLhs, newRhs), (fRes - fAligned))
}
case x @ Times(lhs, rhs) =>
val mult = Times(toFixedPointCode(lhs, fixed), toFixedPointCode(rhs, fixed))
val fMult = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval) +
fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fMult == fRes) {
mult
} else if(fRes < fMult) {
RightShift(mult, (fMult - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(mult, (fRes - fMult))
}
case x @ Division(lhs, rhs) =>
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
val fRes = fixed.fractionalBits(x.interval)
val shift = fRes + fRhs - fLhs
Division(LeftShift(toFixedPointCode(lhs, fixed), shift), toFixedPointCode(rhs, fixed))
case Let(id, value, body) =>
Let(id, toFixedPointCode(value, fixed), toFixedPointCode(body, fixed))
}
}
......@@ -65,6 +65,10 @@ object Identifiers {
// keeps identifier intact
def deepCopy: Identifier = this
def changeType(newTpe: TypeTree): Identifier = {
new Identifier(this.name, this.globalId, this.id, newTpe, this.alwaysShowUniqueID)
}
}
class UniqueCounter[K] {
......
......@@ -14,6 +14,7 @@ import scala.collection.immutable.Seq
import Trees._
import Types._
import Identifiers._
import utils.FinitePrecision._
object PrettyPrinter {
......@@ -136,10 +137,15 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case Log(expr) => ppUnary(expr, "log(", ")")
case Equals(l,r) => ppBinary(l, r, " == ")
case Int32Literal(v) => sb.append(v)
case Int64Literal(v) => sb.append(v)
case IntegerLiteral(v) => sb.append(v)
case BooleanLiteral(v) => sb.append(v)
case UnitLiteral() => sb.append("()")
case RealLiteral(r) => sb.append(r.toString)
case x @ FinitePrecisionLiteral(r, Float32) =>
sb.append(x.stringValue + "f")
case x @ FinitePrecisionLiteral(r, _) =>
sb.append(x.stringValue)
case FunctionInvocation(fdId, _, args, _) =>
pp(fdId, p)
......@@ -162,6 +168,9 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case LessEquals(l,r) => ppBinary(l, r, " \u2264 ") // \leq
case GreaterEquals(l,r) => ppBinary(l, r, " \u2265 ") // \geq
case RightShift(t, by) => ppUnary(t, "(", " >> " + by + ")")
case LeftShift(t, by) => ppUnary(t, "(", " << " + by + ")")
case IfExpr(c, t, e) =>
sb.append("if (")
pp(c, p)
......@@ -180,8 +189,13 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case Untyped => sb.append("<untyped>")
case UnitType => sb.append("Unit")
case Int32Type => sb.append("Int")
case Int64Type => sb.append("Long")
case BooleanType => sb.append("Boolean")
case RealType => sb.append("Real")
case FinitePrecisionType(Float32) => sb.append("Float")
case FinitePrecisionType(Float64) => sb.append("Double")
case FinitePrecisionType(DoubleDouble) => sb.append("DoubleDouble")
case FinitePrecisionType(QuadDouble) => sb.append("QuadDouble")
case FunctionType(fts, tt) =>
if (fts.size > 1) {
ppNary(fts, "(", ", ", ")")
......
......@@ -36,7 +36,12 @@ class ScalaPrinter extends PrettyPrinter {
case Not(Equals(l, r)) => ppBinary(l, r, " != ") // \neq
case Not(expr) => ppUnary(expr, "!(", ")") // \neg
case RealType => sb.append("Double")
// this should never be called by this printer, i.e. all RealTypes
// should have been transformed before
case RealType =>
throw new Exception("RealType found in ScalaPrinter")
case Program(id, defs) =>
assert(lvl == 0)
......
......@@ -16,6 +16,7 @@ import Types._
import utils.Positioned
import Identifiers._
import utils.Rational
import utils.FinitePrecision.Precision
object Trees {
......@@ -245,6 +246,12 @@ object Trees {
def deepCopy = Int32Literal(value)
}
// Long
case class Int64Literal(value: Long) extends Literal[Long] {
val getType = Int64Type
def deepCopy = Int64Literal(value)
}
/** $encodingof an infinite precision integer literal */
case class IntegerLiteral(value: BigInt) extends Literal[BigInt] {
val getType = IntegerType
......@@ -284,6 +291,25 @@ object Trees {
}
}
case class FinitePrecisionLiteral(value: Rational, prec: Precision) extends Literal[Rational] with NumAnnotation {
val getType = FinitePrecisionType(prec)
private var _stringValue: String = null
def stringValue_=(s: String): Unit = {
if (_stringValue == null) {
_stringValue = s
} else {
throw new Exception("'stringValue' is a write-only-once field, but you tried twice!")
}
}
def stringValue = _stringValue
def deepCopy = {
val tmp = FinitePrecisionLiteral(value, prec)
tmp.stringValue = this.stringValue
tmp
}
}
/* Propositional logic */
......@@ -368,61 +394,35 @@ object Trees {
/** $encodingof `... + ...` */
case class Plus(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
// this does not hold for function calls
//assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = {
if (lhs.getType == RealType) RealType
else if (lhs.getType == IntegerType) IntegerType
else Untyped
}
assert(lhs.getType == rhs.getType, "lhs: " + lhs.getType + ", rhs: " + rhs.getType)
val getType = lhs.getType
def deepCopy = Plus(lhs.deepCopy, rhs.deepCopy)
}
/** $encodingof `... - ...` */
case class Minus(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = {
if (lhs.getType == RealType) RealType
else if (lhs.getType == IntegerType) IntegerType
else Untyped
}
val getType = lhs.getType
def deepCopy = Minus(lhs.deepCopy, rhs.deepCopy)
}
/** $encodingof `- ... for BigInts`*/
case class UMinus(expr: Expr) extends Expr with NumAnnotation {
assert(expr.isInstanceOf[NumAnnotation])
val getType = {
if (expr.getType == RealType) RealType
else if (expr.getType == IntegerType) IntegerType
else Untyped
}
val getType = expr.getType
def deepCopy = UMinus(expr.deepCopy)
}
/** $encodingof `... * ...` */
case class Times(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = {
if (lhs.getType == RealType) RealType
else if (lhs.getType == IntegerType) IntegerType
else Untyped
}
val getType = lhs.getType
def deepCopy = Times(lhs.deepCopy, rhs.deepCopy)
}
/** $encodingof `... / ...` */
case class Division(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = {
if (lhs.getType == RealType) RealType
else if (lhs.getType == IntegerType) IntegerType
else Untyped
}
assert(lhs.getType == rhs.getType, "lhs: " + lhs.getType + ", " + rhs.getType)
val getType = lhs.getType
def deepCopy = Division(lhs.deepCopy, rhs.deepCopy)
}
......@@ -439,7 +439,8 @@ object Trees {
case class Sqrt(t: Expr) extends Expr with NumAnnotation {
require(t.getType == RealType)
assert(t.isInstanceOf[NumAnnotation])
val getType = RealType
// TODO: this operation may not be available for Float32
val getType = t.getType
def deepCopy = Sqrt(t.deepCopy)
}
......@@ -496,6 +497,17 @@ object Trees {
def deepCopy = GreaterEquals(lhs.deepCopy, rhs.deepCopy)
}
/* Shifts */
case class RightShift(t: Expr, by: Int) extends Expr {
val getType = t.getType
def deepCopy = RightShift(t.deepCopy, by)
}
case class LeftShift(t: Expr, by: Int) extends Expr {
val getType = t.getType
def deepCopy = LeftShift(t.deepCopy, by)
}
/* Specs */