Commit d7785468 authored by Eva Darulova's avatar Eva Darulova
Browse files

generation of bit shifts for fixed-point code generation

parent 4ec587b4
......@@ -7,7 +7,8 @@ import lang.Trees._
import utils.FinitePrecision._
import lang.Types._
import lang.Extractors.ArithOperator
import utils.Rational
import lang.NumAnnotation
object CodeGenerationPhase extends DaisyPhase {
......@@ -19,7 +20,12 @@ object CodeGenerationPhase extends DaisyPhase {
implicit val debugSection = DebugSectionBackend
var reporter: Reporter = null
def run(ctx: Context, prg: Program): (Context, Program) = {
reporter = ctx.reporter
reporter.info(s"\nStarting $name")
val timer = ctx.timers.rangeError.start
var uniformPrecision: Precision = Float64
......@@ -41,26 +47,39 @@ object CodeGenerationPhase extends DaisyPhase {
case _ => ;
}
//println(ScalaPrinter.apply(prg))
val newProgram = uniformPrecision match {
uniformPrecision match {
// if we have fixed-point code, we need to generate it first
case Fixed(b) =>
throw new Exception("unsupported yet")
// if we have floating-point code, we need to just change the types
// if we have fixed-point code, we need to generate it first
val newDefs = prg.defs.map(fnc => if (!fnc.body.isEmpty && !fnc.precondition.isEmpty) {
val newBody = toFixedPointCode(fnc.body.get, Fixed(b))
val valDefType = if (b == 16) Int32Type else Int64Type
fnc.copy(
params = fnc.params.map(vd => ValDef(vd.id.changeType(valDefType))),
body = Some(newBody),
returnType = valDefType)
} else {
fnc
})
val newPrg = Program(prg.id, newDefs)
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("generating code in " + fileLocation)
writeScalaFile(fileLocation, newPrg)
newPrg
case _ =>
// if we have floating-point code, we need to just change the types
val fileLocation = "./output/" + prg.id + ".scala"
ctx.reporter.info("generating code in " + fileLocation)
val typedPrg = assignFloatType(prg, FinitePrecisionType(uniformPrecision))
writeScalaFile(fileLocation, typedPrg)
typedPrg
}
(ctx, prg)
timer.stop
ctx.reporter.info(s"Finished $name")
(ctx, newProgram)
}
private def writeScalaFile(filename: String, prg: Program) {
......@@ -104,4 +123,115 @@ object CodeGenerationPhase extends DaisyPhase {
Program(prg.id, newDefs)
}
/*
* Expects code to be already in SSA form.
* @param fixed the (uniform) fixed-point precision to use
// TODO: we also need to adjust the types, no?
*/
private def toFixedPointCode(expr: Expr, fixed: Fixed): Expr = expr match {
case x @ Variable(id) => fixed match {
case Fixed(16) => Variable(id.changeType(Int32Type))
case Fixed(32) => Variable(id.changeType(Int64Type))
}
case RealLiteral(r) => // TODO: translate constant
val f = fixed.fractionalBits(r)
fixed match {
case Fixed(16) => Int32Literal((r * Rational.fromDouble(math.pow(2, f))).roundToInt)
case Fixed(32) => Int64Literal((r * Rational.fromDouble(math.pow(2, f))).roundToLong)
}
case UMinus(t) => UMinus(toFixedPointCode(t, fixed))
case Sqrt(t) =>
throw new Exception("Sqrt is not supported for fixed-points!")
null
case x @ Plus(lhs, rhs) =>
// fractional bits from lhs
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// determine how much to shift left or right
val fAligned = math.max(fLhs, fRhs)
val newLhs =
if (fLhs < fAligned) LeftShift(toFixedPointCode(lhs, fixed), (fAligned - fLhs))
else toFixedPointCode(lhs, fixed)
val newRhs =
if (fRhs < fAligned) LeftShift(toFixedPointCode(rhs, fixed), (fAligned - fRhs))
else toFixedPointCode(rhs, fixed)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fAligned == fRes) {
Plus(newLhs, newRhs)
} else if(fRes < fAligned) {
RightShift(Plus(newLhs, newRhs), (fAligned - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(Plus(newLhs, newRhs), (fRes - fAligned))
}
case x @ Minus(lhs, rhs) =>
// fractional bits from lhs
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// determine how much to shift left or right
val fAligned = math.max(fLhs, fRhs)
val newLhs =
if (fLhs < fAligned) LeftShift(toFixedPointCode(lhs, fixed), (fAligned - fLhs))
else toFixedPointCode(lhs, fixed)
val newRhs =
if (fRhs < fAligned) LeftShift(toFixedPointCode(rhs, fixed), (fAligned - fRhs))
else toFixedPointCode(rhs, fixed)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fAligned == fRes) {
Minus(newLhs, newRhs)
} else if(fRes < fAligned) {
RightShift(Minus(newLhs, newRhs), (fAligned - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(Minus(newLhs, newRhs), (fRes - fAligned))
}
case x @ Times(lhs, rhs) =>
val mult = Times(toFixedPointCode(lhs, fixed), toFixedPointCode(rhs, fixed))
val fMult = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval) +
fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
// fractional bits result
val fRes = fixed.fractionalBits(x.interval)
// shift result
if (fMult == fRes) {
mult
} else if(fRes < fMult) {
RightShift(mult, (fMult - fRes))
} else { //(fAligned < fRes) {
// TODO: this sounds funny. does this ever happen?
reporter.warning("funny shifting condition is happening")
LeftShift(mult, (fRes - fMult))
}
case x @ Division(lhs, rhs) =>
val fLhs = fixed.fractionalBits(lhs.asInstanceOf[NumAnnotation].interval)
val fRhs = fixed.fractionalBits(rhs.asInstanceOf[NumAnnotation].interval)
val fRes = fixed.fractionalBits(x.interval)
val shift = fRes + fRhs - fLhs
Division(LeftShift(toFixedPointCode(lhs, fixed), shift), toFixedPointCode(rhs, fixed))
case Let(id, value, body) =>
Let(id, toFixedPointCode(value, fixed), toFixedPointCode(body, fixed))
}
}
......@@ -137,6 +137,7 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case Log(expr) => ppUnary(expr, "log(", ")")
case Equals(l,r) => ppBinary(l, r, " == ")
case Int32Literal(v) => sb.append(v)
case Int64Literal(v) => sb.append(v)
case IntegerLiteral(v) => sb.append(v)
case BooleanLiteral(v) => sb.append(v)
case UnitLiteral() => sb.append("()")
......@@ -167,6 +168,9 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case LessEquals(l,r) => ppBinary(l, r, " \u2264 ") // \leq
case GreaterEquals(l,r) => ppBinary(l, r, " \u2265 ") // \geq
case RightShift(t, by) => ppUnary(t, "(", " >> " + by + ")")
case LeftShift(t, by) => ppUnary(t, "(", " << " + by + ")")
case IfExpr(c, t, e) =>
sb.append("if (")
pp(c, p)
......@@ -185,6 +189,7 @@ class PrettyPrinter(val sb: StringBuffer = new StringBuffer, printUniqueIds: Boo
case Untyped => sb.append("<untyped>")
case UnitType => sb.append("Unit")
case Int32Type => sb.append("Int")
case Int64Type => sb.append("Long")
case BooleanType => sb.append("Boolean")
case RealType => sb.append("Real")
case FinitePrecisionType(Float32) => sb.append("Float")
......
......@@ -246,6 +246,12 @@ object Trees {
def deepCopy = Int32Literal(value)
}
// Long
case class Int64Literal(value: Long) extends Literal[Long] {
val getType = Int64Type
def deepCopy = Int64Literal(value)
}
/** $encodingof an infinite precision integer literal */
case class IntegerLiteral(value: BigInt) extends Literal[BigInt] {
val getType = IntegerType
......@@ -388,9 +394,7 @@ object Trees {
/** $encodingof `... + ...` */
case class Plus(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
// this does not hold for function calls
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
assert(lhs.getType == rhs.getType, "lhs: " + lhs.getType + ", rhs: " + rhs.getType)
val getType = lhs.getType
def deepCopy = Plus(lhs.deepCopy, rhs.deepCopy)
}
......@@ -398,14 +402,12 @@ object Trees {
/** $encodingof `... - ...` */
case class Minus(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = lhs.getType
def deepCopy = Minus(lhs.deepCopy, rhs.deepCopy)
}
/** $encodingof `- ... for BigInts`*/
case class UMinus(expr: Expr) extends Expr with NumAnnotation {
assert(expr.isInstanceOf[NumAnnotation])
val getType = expr.getType
def deepCopy = UMinus(expr.deepCopy)
}
......@@ -413,7 +415,6 @@ object Trees {
/** $encodingof `... * ...` */
case class Times(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = lhs.getType
def deepCopy = Times(lhs.deepCopy, rhs.deepCopy)
}
......@@ -421,7 +422,6 @@ object Trees {
/** $encodingof `... / ...` */
case class Division(lhs: Expr, rhs: Expr) extends Expr with NumAnnotation {
assert(lhs.getType == rhs.getType, "lhs: " + lhs.getType + ", " + rhs.getType)
assert(lhs.isInstanceOf[NumAnnotation] && rhs.isInstanceOf[NumAnnotation])
val getType = lhs.getType
def deepCopy = Division(lhs.deepCopy, rhs.deepCopy)
}
......@@ -497,6 +497,17 @@ object Trees {
def deepCopy = GreaterEquals(lhs.deepCopy, rhs.deepCopy)
}
/* Shifts */
case class RightShift(t: Expr, by: Int) extends Expr {
val getType = t.getType
def deepCopy = RightShift(t.deepCopy, by)
}
case class LeftShift(t: Expr, by: Int) extends Expr {
val getType = t.getType
def deepCopy = LeftShift(t.deepCopy, by)
}
/* Specs */
case class AbsError(lhs: Expr, rhs: Expr) extends Expr {
......
......@@ -52,6 +52,7 @@ object Types {
case object UnitType extends TypeTree
case object IntegerType extends TypeTree
case object Int32Type extends TypeTree
case object Int64Type extends TypeTree
case object RealType extends TypeTree
case class FinitePrecisionType(prec: Precision) extends TypeTree
......
......@@ -51,11 +51,11 @@ object SSATransformerPhase extends DaisyPhase {
fnc
})
println("Original program:")
println(prg)
// println("Original program:")
// println(prg)
println("\nNew program:")
println(lang.PrettyPrinter.withIDs(Program(prg.id, newDefs)))
// println("\nNew program:")
// println(lang.PrettyPrinter.withIDs(Program(prg.id, newDefs)))
......
......@@ -117,12 +117,19 @@ object FinitePrecision {
//val minNormal: Rational = ???
def absRoundoff(r: Rational): Rational = {
val intBits = bitsNeeded(math.abs(r.integerPart))
val fracBits = bitlength - intBits
val fracBits = fractionalBits(r)
Rational(1, math.pow(2, fracBits).toLong)
}
def fractionalBits(i: Interval): Int = {
fractionalBits(max(abs(i.xlo), abs(i.xhi)))
}
def fractionalBits(r: Rational): Int = {
val intBits = bitsNeeded(math.abs(r.integerPart))
bitlength - intBits
}
/**
Returns the number of bits needed to represent the given integer.
@param 32-bit integer
......@@ -131,7 +138,7 @@ object FinitePrecision {
assert(value >= 0)
// TODO: don't we have to also subtract 1 for the sign?
32 - Integer.numberOfLeadingZeros(value)
}
}
}
}
\ No newline at end of file
......@@ -8,7 +8,7 @@ object Arithmetic {
def fnc1(x: Real, y: Real): Real = {
require(-3 <= x && x <= 6 && 2 <= y && y <= 8)
(x + y) / y
(x + y) / y + 3.4
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment