Commit 94aa329a authored by ='s avatar =

Daisy can now generate mixed precision certificates for Coq. I also

implemented a new analysis phase to assign a valid but random
precision assignment for each variable.
parent 7254b138
......@@ -7,6 +7,8 @@
*/
package daisy
import scala.collection.immutable.Seq
import lang.Trees.Program
/**
......
......@@ -14,7 +14,7 @@ import scala.reflect.ClassTag
import lang.Identifiers._
import lang.Trees._
import utils.{Interval, PartialInterval, Rational}
import utils.{Interval, PartialInterval, Rational, FinitePrecision}
case class Context(
reporter: Reporter,
......@@ -42,6 +42,8 @@ case class Context(
// encoding the analysis result
intermediateAbsoluteErrors: Map[Identifier, Map[Expr, Rational]]= Map(),
intermediateRanges: Map[Identifier, Map[Expr, Interval]] = Map(),
precision: Map[Identifier, Map[Identifier, FinitePrecision.Precision]] = Map(),
constantPrecision: FinitePrecision.Precision = FinitePrecision.Float64,
// the analysed/computed roundoff errors for each function
resultAbsoluteErrors: Map[Identifier, Rational] = Map(),
......
......@@ -28,6 +28,7 @@ object Main {
ListOptionDef("debug", "For which sections to print debug info.",
List("analysis","solver")),
FlagOptionDef("codegen", "Generate code (as opposed to just doing analysis)."),
FlagOptionDef("randomMP", "random mixed precision assignment"),
optionFunctions,
optionPrintToughSMTCalls,
optionValidators
......@@ -191,6 +192,7 @@ object Main {
} else if (ctx.findOption(optionValidators) != None) {
analysis.SpecsProcessingPhase andThen
transform.SSATransformerPhase andThen
analysis.ChoosePrecisionPhase andThen
analysis.RangeErrorPhase andThen
InfoPhase andThen
backend.CertificatePhase andThen
......
......@@ -8,6 +8,7 @@ import utils._
import FinitePrecision._
import Rational._
import lang.Identifiers._
import lang.TreeOps.allVariablesOf
/**
??? Description goes here
......@@ -53,6 +54,7 @@ object RangeErrorPhase extends DaisyPhase with RoundoffEvaluators with IntervalS
var trackInitialErrs = true
var trackRoundoffErrs = true
var randomMP = false
var uniformPrecision: Precision = Float64
......@@ -67,6 +69,7 @@ object RangeErrorPhase extends DaisyPhase with RoundoffEvaluators with IntervalS
}
case FlagOption("noInitialErrors") => trackInitialErrs = false
case FlagOption("noRoundoff") => trackRoundoffErrs = false
case FlagOption("randomMP") => randomMP = true
case ChoiceOption("precision", s) => s match {
case "Float32" =>
uniformPrecision = Float32
......@@ -95,75 +98,83 @@ object RangeErrorPhase extends DaisyPhase with RoundoffEvaluators with IntervalS
var roundoffErrorMap: Map[Identifier, Map[Expr, Rational]] = Map()
var rangeResMap: Map[Identifier, Map[Expr, Interval]] = Map()
// first identifier is for function only, second one for
// variables/parameters used inside a given function
var varPrecisionMap: Map[Identifier, Map[Identifier, Precision]] = ctx.precision
val res: Map[Identifier, (Rational, Interval)] = prg.defs.filter(fnc =>
!fnc.precondition.isEmpty &&
!fnc.body.isEmpty &&
fncsToConsider.contains(fnc.id.toString)).map(fnc => {
reporter.info("analyzing fnc: " + fnc.id)
val inputValMap: Map[Identifier, Interval] = ctx.specInputRanges(fnc.id)
// If we track both input and roundoff errors, then we pre-compute
// the roundoff errors for those variables that do not have a user-defined
// error, in order to keep correlations.
val inputErrorMap: Map[Identifier, Rational] =
if (trackInitialErrs && trackRoundoffErrs){
val inputErrs = ctx.specInputErrors(fnc.id)
val allIDs = fnc.params.map(_.id).toSet
val missingIDs = allIDs -- inputErrs.keySet
inputErrs ++ missingIDs.map( id => (id -> uniformPrecision.absRoundoff(inputValMap(id))))
} else if(trackInitialErrs) {
val inputErrs = ctx.specInputErrors(fnc.id)
val allIDs = fnc.params.map(_.id).toSet
val missingIDs = allIDs -- inputErrs.keySet
inputErrs ++ missingIDs.map( id => (id -> zero))
} else if (trackRoundoffErrs) {
val allIDs = fnc.params.map(_.id)
allIDs.map( id => (id -> uniformPrecision.absRoundoff(inputValMap(id)))).toMap
} else {
val allIDs = fnc.params.map(_.id)
allIDs.map( id => (id -> zero)).toMap
}
val (resError: Rational, resRange: Interval, rangeMap:Map[Expr, Interval], errorMap:Map[Expr, Interval]) = (rangeMethod, errorMethod) match {
case ("interval", "interval") =>
uniformRoundoff_IA_IA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
case ("interval", "affine") =>
uniformRoundoff_IA_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
case ("affine", "affine") =>
uniformRoundoff_AA_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
case ("smt", "affine") =>
uniformRoundoff_SMT_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
// default is to use the method that attaches the info to trees.
case ("subdiv", _) =>
val tmp = doIntervalSubdivision( //evaluateSubdiv(
reporter.info("analyzing fnc: " + fnc.id)
val inputValMap: Map[Identifier, Interval] = ctx.specInputRanges(fnc.id)
val precisionMap = varPrecisionMap(fnc.id)
// If we track both input and roundoff errors, then we pre-compute
// the roundoff errors for those variables that do not have a user-defined
// error, in order to keep correlations.
val inputErrorMap: Map[Identifier, Rational] =
if (trackInitialErrs && trackRoundoffErrs){
val inputErrs = ctx.specInputErrors(fnc.id)
val allIDs = fnc.params.map(_.id).toSet
val missingIDs = allIDs -- inputErrs.keySet
inputErrs ++ missingIDs.map( id => (id -> precisionMap(id).absRoundoff(inputValMap(id))))
} else if(trackInitialErrs) {
val inputErrs = ctx.specInputErrors(fnc.id)
val allIDs = fnc.params.map(_.id).toSet
val missingIDs = allIDs -- inputErrs.keySet
inputErrs ++ missingIDs.map( id => (id -> zero))
} else if (trackRoundoffErrs) {
val allIDs = fnc.params.map(_.id)
allIDs.map( id => (id -> precisionMap(id).absRoundoff(inputValMap(id)))).toMap
} else {
val allIDs = fnc.params.map(_.id)
allIDs.map( id => (id -> zero)).toMap
}
val (resError: Rational, resRange: Interval, rangeMap:Map[Expr, Interval], errorMap:Map[Expr, Interval]) = (rangeMethod, errorMethod) match {
case ("interval", "interval") =>
if(randomMP) {
mixedPrecRoundoff_IA_IA(fnc.body.get, inputValMap, inputErrorMap, precisionMap, uniformPrecision, trackRoundoffErrs)
} else {
uniformRoundoff_IA_IA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
}
case ("interval", "affine") =>
uniformRoundoff_IA_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
case ("affine", "affine") =>
uniformRoundoff_AA_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
case ("smt", "affine") =>
uniformRoundoff_SMT_AA(fnc.body.get, inputValMap, inputErrorMap, uniformPrecision, trackRoundoffErrs)
// default is to use the method that attaches the info to trees.
case ("subdiv", _) =>
val tmp = doIntervalSubdivision( //evaluateSubdiv(
fnc.body.get, lang.TreeOps.freeVariablesOf(fnc.body.get),
inputValMap,
inputErrorMap,
trackRoundoffErrs,
uniformPrecision)
(tmp._2, tmp._1)
(tmp._2, tmp._1)
case _ =>
reporter.fatalError(s"Your combination of $rangeMethod and $errorMethod" +
"for computing ranges and errors is not supported.")
null
}
case _ =>
reporter.fatalError(s"Your combination of $rangeMethod and $errorMethod" +
"for computing ranges and errors is not supported.")
null
}
roundoffErrorMap += (fnc.id -> (errorMap.map(x => x._1 -> x._2.xhi)))
rangeResMap += (fnc.id -> rangeMap)
(fnc.id -> (resError, resRange))
}).toMap
(fnc.id -> (resError, resRange))
}).toMap
timer.stop
ctx.reporter.info(s"Finished $name")
......
......@@ -3,11 +3,13 @@ package backend
import daisy.lang.{ScalaPrinter, PrettyPrinter}
import lang.Trees._
import lang.Types._
import lang.TreeOps
import lang.Identifiers._
import lang.NumAnnotation
import utils.Interval
import utils.Rational
import utils.FinitePrecision._
import analysis.SpecsProcessingPhase
object CertificatePhase extends DaisyPhase {
......@@ -24,8 +26,12 @@ object CertificatePhase extends DaisyPhase {
val prover = ctx.findOption(Main.optionValidators)
val errorMap = ctx.resultAbsoluteErrors
val rangeMap = ctx.resultRealRanges
val precision = ctx.precision
val constPrecision = ctx.constantPrecision
def writeToFile (fileContent:String, prover:String){
reporter.info(s"\nStarting $name")
def writeToFile (fileContent:String, prover:String) : String = {
import java.io.FileWriter
import java.io.BufferedWriter
val fileLocation =
......@@ -40,6 +46,7 @@ object CertificatePhase extends DaisyPhase {
val out = new BufferedWriter(fstream)
out.write(fileContent)
out.close
fileLocation
}
var fullCertificate = "";
......@@ -67,13 +74,13 @@ object CertificatePhase extends DaisyPhase {
thePrecondition match {
case Some (pre) =>
//the definitions for the whole expression
val (theDefinitions, lastGenName) = getCmd(theBody.get,reporter,prv)
val (theDefinitions, lastGenName) = getCmd(theBody.get,precision(fnc.id),constPrecision, reporter,prv)
//generate the precondition
val (thePreconditionFunction, functionName) =
getPrecondFunction(pre, fnc.id.toString, reporter, prv)
//the analysis result function
val (analysisResultText, analysisResultName) =
getAbsEnvDef(theBody.get, errorMap, rangeMap, fnc.id.toString, prv, reporter)
getAbsEnvDef(theBody.get, errorMap, rangeMap, precision(fnc.id), fnc.id.toString, prv, reporter)
//generate the final evaluation statement
val functionCall = getComputeExpr(lastGenName,analysisResultName,functionName,fnc.id.toString,prv)
//val functionCall = getAllComputeExps(theBody.get, analysisResultName, functionName, prv)
......@@ -98,7 +105,8 @@ object CertificatePhase extends DaisyPhase {
reporter.info(s"Number of operations: $size")
//end iteration, write certificate
writeToFile(fullCertificate,prover.get)
val filename = writeToFile(fullCertificate,prover.get)
reporter.info(s"Wrote certificate to $filename")
(ctx, prg)
}
......@@ -114,11 +122,23 @@ object CertificatePhase extends DaisyPhase {
else
"needs \"CertificateChecker.hl\";;\nneeds \"Infra/convs.hl\";;"
private def coqVariable(vname:Identifier, id:Int, reporter:Reporter) :(String, String) =
private def coqPrecision(p:Precision, reporter:Reporter) : String =
{
p match {
case Float32 => "M32"
case Float64 => "M64"
case DoubleDouble => "M128"
case QuadDouble => "M256"
case _ => reporter.fatalError ("In coqPrecision, unknwon precision.")
}
}
private def coqVariable(vname:Identifier, id:Int, precision:Precision, reporter:Reporter) :(String, String) =
{
//FIXME: Ugly Hack to get disjoint names for multiple function encodings with same variable names:
val freshId = nextFreshVariable()
val theExpr = s"Definition ExpVar$vname$freshId :exp Q := Var Q $id.\n"
val prec = coqPrecision(precision, reporter)
val theExpr = s"Definition ExpVar$vname$freshId :exp Q := Var Q $prec $id.\n"
(theExpr,s"ExpVar$vname$freshId")
}
......@@ -136,14 +156,15 @@ object CertificatePhase extends DaisyPhase {
(theExpr, s"ExpVar$vname")
}
private def coqConstant(r:RealLiteral, id:Int, reporter:Reporter) :(String, String) =
private def coqConstant(r:RealLiteral, id:Int, precision:Precision, reporter:Reporter) :(String, String) =
r match {
case RealLiteral(v) =>
//FIXME: Ugly Hack to get disjoint names for multiple function encodings with same variable names:
val freshId = nextConstantId()
val prec = coqPrecision(precision, reporter)
val rationalStr = v.toFractionString
val coqRational = rationalStr.replace('/','#')
val theExpr = s"Definition ExpCst$id$freshId :exp Q := Const ($coqRational).\n"
val theExpr = s"Definition ExpCst$id$freshId :exp Q := Const $prec ($coqRational).\n"
(theExpr, s"ExpCst$id$freshId")
}
......@@ -167,6 +188,7 @@ object CertificatePhase extends DaisyPhase {
(theExpr, s"ExpCst$id")
}
private def coqBinOp (e: Expr, nameLHS:String, nameRHS:String, reporter:Reporter) :(String, String) =
e match {
case x @ Plus(lhs, rhs) =>
......@@ -221,6 +243,9 @@ object CertificatePhase extends DaisyPhase {
reporter.fatalError("Unsupported value")
}
private def coqDowncast(nameOp:String, prec:Precision, reporter:Reporter) :(String, String) =
(s"Definition Downcast$nameOp :exp Q := Downcast ${coqPrecision(prec, reporter)} $nameOp.\n", s"Downcast$nameOp")
private def coqUMin (e:Expr, nameOp:String, reporter:Reporter) :(String, String) =
(s"Definition UMin${nameOp} :exp Q := Unop Neg $nameOp.\n", s"UMin${nameOp}")
......@@ -228,7 +253,7 @@ object CertificatePhase extends DaisyPhase {
(s"val UMin${nameOp} = Define `UMin${nameOp}:(real exp) = Unop Neg ${nameOp}`;\n",
s"UMin${nameOp}")
private def getValues(e: Expr,reporter:Reporter,prv:String): (String, String) = {
private def getValues(e: Expr, precision:Map[Identifier, Precision], constPrecision:Precision, reporter:Reporter,prv:String): (String, String) = {
if (expressionNames.contains(e)){
......@@ -244,7 +269,7 @@ object CertificatePhase extends DaisyPhase {
identifierNums += (id -> varId)
val (definition, name) =
if (prv == "coq"){
coqVariable (id,varId,reporter)
coqVariable (id, varId, precision(id), reporter)
}else if (prv == "hol4"){
hol4Variable (id, varId, reporter)
}else {
......@@ -257,7 +282,7 @@ object CertificatePhase extends DaisyPhase {
case x @ RealLiteral(r) =>
val (definition, name) =
if (prv == "coq")
coqConstant (x,nextConstantId(),reporter)
coqConstant (x,nextConstantId(),constPrecision,reporter)
else if (prv == "hol4")
hol4Constant (x, nextConstantId(), reporter)
else
......@@ -266,8 +291,8 @@ object CertificatePhase extends DaisyPhase {
(definition,name)
case x @ Plus(lhs, rhs) =>
val (lhsText, lhsName) = getValues(lhs,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,reporter,prv)
val (lhsText, lhsName) = getValues(lhs,precision,constPrecision,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,precision,constPrecision,reporter,prv)
val (definition, name) =
if (prv == "coq"){
val (binOpDef, binOpName) = coqBinOp (e, lhsName, rhsName,reporter)
......@@ -283,8 +308,8 @@ object CertificatePhase extends DaisyPhase {
(definition,name)
case x @ Minus(lhs, rhs) =>
val (lhsText, lhsName) = getValues(lhs,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,reporter,prv)
val (lhsText, lhsName) = getValues(lhs,precision,constPrecision,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,precision,constPrecision,reporter,prv)
val (definition, name) =
if (prv == "coq"){
val (binOpDef, binOpName) = coqBinOp (e, lhsName, rhsName,reporter)
......@@ -300,8 +325,8 @@ object CertificatePhase extends DaisyPhase {
(definition,name)
case x @ Times(lhs, rhs) =>
val (lhsText, lhsName) = getValues(lhs,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,reporter,prv)
val (lhsText, lhsName) = getValues(lhs,precision,constPrecision,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,precision,constPrecision,reporter,prv)
val (definition, name) =
if (prv == "coq"){
val (binOpDef, binOpName) = coqBinOp (e, lhsName, rhsName,reporter)
......@@ -317,8 +342,8 @@ object CertificatePhase extends DaisyPhase {
(definition,name)
case x @ Division(lhs, rhs) =>
val (lhsText, lhsName) = getValues(lhs,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,reporter,prv)
val (lhsText, lhsName) = getValues(lhs,precision,constPrecision,reporter,prv)
val (rhsText, rhsName) = getValues(rhs,precision,constPrecision,reporter,prv)
val (definition, name) =
if (prv == "coq"){
val (binOpDef, binOpName) = coqBinOp (e, lhsName, rhsName,reporter)
......@@ -334,7 +359,7 @@ object CertificatePhase extends DaisyPhase {
(definition,name)
case x @ UMinus (exp) =>
val (opDef, opName) = getValues (exp, reporter, prv)
val (opDef, opName) = getValues (exp, precision, constPrecision, reporter, prv)
val (defintion, name) =
if (prv == "coq") {
val (unopDef, unopName) = coqUMin (e, opName, reporter)
......@@ -346,33 +371,47 @@ object CertificatePhase extends DaisyPhase {
expressionNames += (e -> name)
(defintion, name)
case x @ Downcast(exp, tpe) =>
val (expDef, expName) = getValues(exp, precision, constPrecision, reporter, prv)
val tpe_prec = tpe match { case FinitePrecisionType(t) => t }
val (definition, name) =
if (prv == "coq") {
val (dDef, dName) = coqDowncast(expName, tpe_prec, reporter)
(expDef + dDef, dName)
} else {
reporter.fatalError("Downcast only implemented in coq")
}
expressionNames += (e -> name)
(definition, name)
case x @ _ =>
reporter.fatalError(s"Unsupported operation $e")
}
}
}
private def getCmd (e: Expr,reporter:Reporter,prv:String): (String, String) = {
private def getCmd (e: Expr, precision:Map[Identifier, Precision], constPrecision:Precision, reporter:Reporter,prv:String): (String, String) = {
e match {
case e @ Let(x,exp,g) =>
//first compute expression AST
val (expDef, expName) = getValues (exp,reporter,prv)
val (expDef, expName) = getValues (exp, precision, constPrecision, reporter,prv)
//now allocate a new variable
val varId = nextFreshVariable()
identifierNums += (x -> varId)
val (varDef, varName) =
if (prv == "coq"){
coqVariable (x,varId,reporter)
coqVariable (x, varId, precision(x), reporter)
}else{
hol4Variable (x, varId, reporter)
}
expressionNames += (Variable(x) -> varName)
//now recursively compute the command
val (cmdDef, cmdName) = getCmd (g, reporter,prv)
val (cmdDef, cmdName) = getCmd (g, precision, constPrecision, reporter,prv)
val prec_x = coqPrecision(precision(x), reporter)
val letName = "Let"+varName+expName+cmdName
val letDef =
if (prv == "coq"){
s"Definition $letName := Let $varId $expName $cmdName.\n"
s"Definition $letName := Let $prec_x $varId $expName $cmdName.\n"
}else {
s"val ${letName}_def = Define `$letName = Let $varId $expName $cmdName`;\n"
}
......@@ -380,7 +419,7 @@ object CertificatePhase extends DaisyPhase {
case e @ _ =>
//return statement necessary
val (expDef, expName) = getValues (e, reporter, prv)
val (expDef, expName) = getValues (e, precision, constPrecision, reporter, prv)
val retName = s"Ret$expName"
if (prv == "coq"){
(expDef + s"Definition $retName := Ret $expName.\n", retName)
......@@ -469,18 +508,19 @@ object CertificatePhase extends DaisyPhase {
"then " + thenC + "\n" +
"else " + elseC
private def coqAbsEnv (e:Expr, errorMap:Map[Expr, Rational], rangeMap:Map[Expr, Interval], cont:String, reporter:Reporter) :String =
private def coqAbsEnv (e:Expr, errorMap:Map[Expr, Rational], rangeMap:Map[Expr, Interval], precision:Map[Identifier, Precision], cont:String, reporter:Reporter) :String =
{
// Let bindings do not have names, so these are a special case here:
e match {
case x @ Let (y,exp, g) =>
val expFun = coqAbsEnv (exp, errorMap, rangeMap, cont, reporter)
val gFun = coqAbsEnv (g, errorMap, rangeMap, expFun, reporter)
val prec = coqPrecision(precision(y), reporter)
val expFun = coqAbsEnv (exp, errorMap, rangeMap, precision, cont, reporter)
val gFun = coqAbsEnv (g, errorMap, rangeMap, precision, expFun, reporter)
val intvY = coqInterval((rangeMap(Variable(y)).xlo, rangeMap(Variable(y)).xhi))
val errY = errorMap(Variable (y)).toFractionString.replace("/","#")
val nameY = expressionNames(Variable(y))
conditional (
s"( expEqBool e (Var Q ${identifierNums(y)}) )",
s"( expEqBool e (Var Q $prec ${identifierNums(y)}) )",
"(" + intvY + ", " + errY + ")",
gFun)
......@@ -499,23 +539,26 @@ object CertificatePhase extends DaisyPhase {
case x @ RealLiteral(r) => cont
case x @ Plus(lhs, rhs) =>
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, lFun, reporter)
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, precision, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, precision, lFun, reporter)
case x @ Minus(lhs, rhs) =>
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, lFun, reporter)
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, precision, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, precision, lFun, reporter)
case x @ Times(lhs, rhs) =>
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, lFun, reporter)
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, precision, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, precision, lFun, reporter)
case x @ Division(lhs, rhs) =>
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, lFun, reporter)
val lFun = coqAbsEnv (lhs, errorMap, rangeMap, precision, cont, reporter)
coqAbsEnv (rhs, errorMap, rangeMap, precision, lFun, reporter)
case x @ UMinus(e) =>
coqAbsEnv (e, errorMap, rangeMap, cont, reporter)
coqAbsEnv (e, errorMap, rangeMap, precision, cont, reporter)
case x @ Downcast(e, t) =>
coqAbsEnv (e, errorMap, rangeMap, precision, cont, reporter)
case x @ _ =>
reporter.fatalError(s"Unsupported operation $e")
......@@ -644,10 +687,10 @@ object CertificatePhase extends DaisyPhase {
}
}
private def getAbsEnvDef (e:Expr, errorMap:Map[Expr, Rational], rangeMap:Map[Expr, Interval], fName:String, prv:String, reporter:Reporter) :(String,String)=
private def getAbsEnvDef (e:Expr, errorMap:Map[Expr, Rational], rangeMap:Map[Expr, Interval], precision:Map[Identifier, Precision], fName:String, prv:String, reporter:Reporter) :(String,String)=
if (prv == "coq")
(s"Definition absenv_${fName} :analysisResult := \nfun (e:exp Q) =>\n" +
coqAbsEnv(e, errorMap, rangeMap, "((0#1,0#1),0#1)", reporter) + ".",
coqAbsEnv(e, errorMap, rangeMap, precision, "((0#1,0#1),0#1)", reporter) + ".",
s"absenv_${fName}")
else if (prv == "hol4")
(s"val absenv_${fName}_def = Define `\n absenv_${fName}:analysisResult = \n\\e. \n" +
......
......@@ -2,10 +2,13 @@
package daisy
package backend
import scala.collection.immutable.Seq
import daisy.lang.{ScalaPrinter, PrettyPrinter}
import lang.Trees._
import utils.FinitePrecision._
import lang.Types._
import lang.Identifiers._
import lang.Extractors.ArithOperator
import utils.Rational
import lang.NumAnnotation
......@@ -72,7 +75,7 @@ object CodeGenerationPhase extends DaisyPhase {
// if we have floating-point code, we need to just change the types
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("generating code in " + fileLocation)
val typedPrg = assignFloatType(prg, FinitePrecisionType(uniformPrecision))
val typedPrg = assignMixedPrecision(prg, ctx.precision, ctx.constantPrecision)
writeScalaFile(fileLocation, typedPrg)
typedPrg
}
......@@ -124,6 +127,54 @@ object CodeGenerationPhase extends DaisyPhase {
}
private def assignMixedPrecision(prg: Program, typeMap:Map[Identifier, Map[Identifier, Precision]], constantType:Precision): Program = {
def changeType(e: Expr, t:Map[Identifier, Precision], ct:Precision): (Expr, Precision) = e match {
case Variable(id) => (Variable(id.changeType(FinitePrecisionType(t(id)))), t(id))
case x @ RealLiteral(r) =>
val tmp = FinitePrecisionLiteral(r, ct)
tmp.stringValue = x.stringValue
(tmp, ct)
case ArithOperator(Seq(l, r), recons) =>
val (newl, tl) = changeType(l, t, ct)
val (newr, tr) = changeType(r, t, ct)
val prec = getUpperBound(tl, tr)
(recons(Seq(newl, newr)), prec)