commit acf15f81fedfe95ec6c18cda3c178c102fbf1375 Author: bill Date: Sun Sep 17 08:07:19 2023 -0400 Initial commit diff --git a/proj3/build.sbt b/proj3/build.sbt new file mode 100644 index 0000000..dfd3b24 --- /dev/null +++ b/proj3/build.sbt @@ -0,0 +1,12 @@ +scalaVersion := "2.12.10" + +scalacOptions in ThisBuild ++= Seq("-unchecked", "-deprecation", "-feature") + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" % "test" + +excludeFilter in unmanagedSources := HiddenFileFilter || "*sample*" + +logBuffered in Test := false + +parallelExecution in Test := false + diff --git a/proj3/cleanall.sh b/proj3/cleanall.sh new file mode 100644 index 0000000..15dd033 --- /dev/null +++ b/proj3/cleanall.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +find . -name target -type d -prune -exec rm -rf {} \; diff --git a/proj3/examples/invalid_semantic_arithm.scala b/proj3/examples/invalid_semantic_arithm.scala new file mode 100644 index 0000000..2ac1108 --- /dev/null +++ b/proj3/examples/invalid_semantic_arithm.scala @@ -0,0 +1 @@ +1 + 3 -- -1 diff --git a/proj3/examples/invalid_semantic_imm_variables.scala b/proj3/examples/invalid_semantic_imm_variables.scala new file mode 100644 index 0000000..824fdc5 --- /dev/null +++ b/proj3/examples/invalid_semantic_imm_variables.scala @@ -0,0 +1 @@ +val x = y; z diff --git a/proj3/examples/invalid_semantic_mut_variables.scala b/proj3/examples/invalid_semantic_mut_variables.scala new file mode 100644 index 0000000..2d120c2 --- /dev/null +++ b/proj3/examples/invalid_semantic_mut_variables.scala @@ -0,0 +1 @@ +val x = 0; x = 1 diff --git a/proj3/examples/invalid_syntax_arithm.scala b/proj3/examples/invalid_syntax_arithm.scala new file mode 100644 index 0000000..141336f --- /dev/null +++ b/proj3/examples/invalid_syntax_arithm.scala @@ -0,0 +1 @@ +1 + 3 4 diff --git a/proj3/examples/invalid_syntax_branch.scala b/proj3/examples/invalid_syntax_branch.scala new file mode 100644 index 0000000..2de57f6 --- /dev/null +++ b/proj3/examples/invalid_syntax_branch.scala @@ -0,0 +1 @@ +if (0 == 0) 2 diff --git a/proj3/examples/invalid_syntax_imm_variables.scala b/proj3/examples/invalid_syntax_imm_variables.scala new file mode 100644 index 0000000..fae0f96 --- /dev/null +++ b/proj3/examples/invalid_syntax_imm_variables.scala @@ -0,0 +1 @@ +val x == 4; x diff --git a/proj3/examples/invalid_syntax_loop.scala b/proj3/examples/invalid_syntax_loop.scala new file mode 100644 index 0000000..10d50cf --- /dev/null +++ b/proj3/examples/invalid_syntax_loop.scala @@ -0,0 +1,5 @@ +while (2 + 2) { + 1 +}; +2 + diff --git a/proj3/examples/invalid_syntax_mut_variables.scala b/proj3/examples/invalid_syntax_mut_variables.scala new file mode 100644 index 0000000..e28dd47 --- /dev/null +++ b/proj3/examples/invalid_syntax_mut_variables.scala @@ -0,0 +1,2 @@ +var x = 0 +x + 1 diff --git a/proj3/examples/unexpected_character.scala b/proj3/examples/unexpected_character.scala new file mode 100644 index 0000000..aaa27c5 --- /dev/null +++ b/proj3/examples/unexpected_character.scala @@ -0,0 +1,2 @@ +// Test +& diff --git a/proj3/examples/valid_arithm.scala b/proj3/examples/valid_arithm.scala new file mode 100644 index 0000000..8411283 --- /dev/null +++ b/proj3/examples/valid_arithm.scala @@ -0,0 +1 @@ +1* -4 + 4/2 diff --git a/proj3/examples/valid_branch.scala b/proj3/examples/valid_branch.scala new file mode 100644 index 0000000..a306ef2 --- /dev/null +++ b/proj3/examples/valid_branch.scala @@ -0,0 +1 @@ +if (2 > 4) 1 + 4 else 4 diff --git a/proj3/examples/valid_imm_variable.scala b/proj3/examples/valid_imm_variable.scala new file mode 100644 index 0000000..2b6b27d --- /dev/null +++ b/proj3/examples/valid_imm_variable.scala @@ -0,0 +1 @@ +val x = 0; x diff --git a/proj3/examples/valid_loop.scala b/proj3/examples/valid_loop.scala new file mode 100644 index 0000000..56892c8 --- /dev/null +++ b/proj3/examples/valid_loop.scala @@ -0,0 +1,7 @@ +var x = 2; +var y = 0; +while ({ if (y < 3) true else y < 4}) { + x = x * x; + y = y + 1 +}; +x diff --git a/proj3/examples/valid_mut_variable.scala b/proj3/examples/valid_mut_variable.scala new file mode 100644 index 0000000..9b4fa67 --- /dev/null +++ b/proj3/examples/valid_mut_variable.scala @@ -0,0 +1 @@ +var x = 0; x = (x = x + 2) - 1 diff --git a/proj3/gen/bootstrap.c b/proj3/gen/bootstrap.c new file mode 100644 index 0000000..eacd8a6 --- /dev/null +++ b/proj3/gen/bootstrap.c @@ -0,0 +1,12 @@ +#include +#include + +// assembly function that we are compiling +int entry_point(char* heap); + +int main() { + char* heap = malloc(10000 * 8); + printf("Exit Code: %d\n", entry_point(heap)); + free(heap); + return 0; +} diff --git a/proj3/project/build.properties b/proj3/project/build.properties new file mode 100644 index 0000000..c0bab04 --- /dev/null +++ b/proj3/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.2.8 diff --git a/proj3/src/main/scala/project3/Compiler.scala b/proj3/src/main/scala/project3/Compiler.scala new file mode 100644 index 0000000..4ae709b --- /dev/null +++ b/proj3/src/main/scala/project3/Compiler.scala @@ -0,0 +1,318 @@ +package project3 + +abstract class X86Compiler extends BugReporter with Codegen { + import Language._ + + /* + * Abstract class used to store the location + */ + abstract class Loc { + def +(y: Int): Loc + } + + /* + * Register location, the physical location + * can be addressed with the register #sp + */ + case class Reg(sp: Int) extends Loc { + def +(y: Int) = Reg(sp+y) + } + + /* + * Function location, the physical location + * can be addressed directly with the name + */ + case class Func(name: String) extends Loc { + def +(y: Int) = BUG("This Loc should not be used as a stack location.") + } + + // Function to extra physical address from Loc + // CHANGE: instead of using regs(...) directly + // we now use the function loc. + def loc(l: Loc): String = l match { + case Reg(sp) => avRegs(sp) + case Func(name) => name + } + + def loc(sp: Int): String = avRegs(sp) + + // List of available register. + // DO NOT CHANGE THE REGISTERS!! + val avRegs = Seq("%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15") + + /****************************************************************************/ + + def onMac = System.getProperty("os.name").toLowerCase contains "mac" + val entry_point = "entry_point" + def funcName(name: String) = (if (onMac) "_" else "") + name + + /** + * Env of the compiler. Keep track of the location + * in memory of each variable defined. + */ + val primitives = Map( + "putchar" -> Func("putchar"), + "getchar" -> Func("getchar")) + + private class Env { + def undef(name: String) = BUG(s"Undefined identifier $name (should have been found during the semantic analysis)") + def apply(name: String): Loc = undef(name) + } + + private case class LocationEnv( + vars: Map[String, Loc] = primitives, + outer: Env = new Env) extends Env { + + /* + * Return a copy of the current state plus a + * variable 'name' at the location 'loc' + */ + def withVal(name: String, loc: Loc): LocationEnv = { + copy(vars = vars + (name -> loc)) + } + + /* + * Return a copy of the current state plus all + * variables in 'list' + */ + def withVals(list: List[(String,Loc)]): LocationEnv = { + copy(vars = vars ++ list.toMap) + } + + /* + * Return the location of the variable 'name' + */ + override def apply(name: String): Loc = vars.get(name) match { + case Some(loc) => loc + case _ => outer(name) + } + } + + /* + * Generate code that computes the unary operator + * 'op' on the value at memory location 'sp' and that + * stores the result at 'sp'. + */ + def transUn(op: String)(sp: Loc) = op match { + case "+" => () // nothing to do! + case "-" => emitln(s"negq ${loc(sp)}") + case _ => BUG(s"Unary operator $op undefined") + } + + /* + * Generate code that computes the binary operator + * 'op' on the values at memory location 'sp' and + * 'sp1' and that stores the result at 'sp'. + * + * TODO: implement missing operators. + * Here are the valid operators: + * +, -, *, /, ==, !=, <=, <, >=, >, block-get + */ + def transBin(op: String)(sp: Loc, sp1: Loc) = op match { + case "+" => emitln(s"addq ${loc(sp1)}, ${loc(sp)}") + case "-" => emitln(s"subq ${loc(sp1)}, ${loc(sp)}") + case "*" => emitln(s"imul ${loc(sp1)}, ${loc(sp)}") + case "/" => + emitln(s"movq ${loc(sp)}, %rax") + emitln(s"pushq %rdx") // save $rdx for the division + emitln(s"movq ${loc(sp1)}, %rbx") // in case sp1 == %rdx + emitln(s"cqto") + emitln(s"idiv %rbx") + emitln(s"popq %rdx") // put back + emitln(s"movq %rax, ${loc(sp)}") + case "==" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"sete %al") + emitln(s"movzbq %al, ${loc(sp)}") + case _ => BUG(s"Binary operator $op undefined") + } + + /* + * Generate code that computes the ternary operator + * 'op' on the values at memory location 'sp', 'sp1 and' + * 'sp2' and that stores the result at 'sp'. + * + * TODO: implement the missing operator + * Valid operators: block-set + */ + def transTer(op: String)(sp: Loc, sp1: Loc, sp2: Loc) = op match { + case _ => BUG(s"ternary operator $op undefined") + } + + def transPrim(op: String)(idxs: List[Loc]) = idxs match { + case List(sp, sp1, sp2) => transTer(op)(sp, sp1, sp2) + case List(sp, sp1) => transBin(op)(sp, sp1) + case List(sp) => transUn(op)(sp) + case _ => BUG(s"no prim with ${idxs.length} arguments") + } + + type Label = String + + var nLabel = 0 + def freshLabel(pref: String) = { nLabel += 1; s"$pref$nLabel" } + + /* + * Generate code that compute the result of the + * computation represented by the AST 'exp'. + */ + val global = (primitives.keySet + entry_point) map(funcName(_)) + def emitCode(exp: Exp): Unit = { + emitln(".text", 0) + emitln(s".global ${global mkString ", "}\n", 0) + + // Generate code for our AST + trans(exp, Reg(0))(LocationEnv()) + + emitln("#################### DATA #######################", 0) + emitln("\n.data\nheap:\t.quad 0",0) + emitln("#################################################", 0) + } + + /* + * Generate code that jump to the label 'label' + * if the location 'sp' contains the value 'true' + */ + def transJumpIfTrue(sp: Loc)(label: Label) = { + ??? + } + + /* + * Generate code that compute the result og the + * computation represented by the AST 'exp'. The + * value will be placed at memory location 'sp' + * + * TODO: Fill in each TODO with the appropriate code. + * + * The ??? can be filled for extra credit. + */ + def trans(exp: Exp, sp: Loc)(env: LocationEnv): Unit = exp match { + case Lit(x: Int) => + emitln(s"movq $$$x, ${loc(sp)}") + case Lit(b: Boolean) => () // TODO + case Lit(x: Unit) => () // TODO + case Prim(op, args) => + val idxs = List.tabulate(args.length)(i => sp + i) + (args zip idxs) foreach { case (arg, idx) => trans(arg, idx)(env) } + transPrim(op)(idxs) + case Let(x, tp, rhs, body) => + trans(rhs, sp)(env) + if (tp == UnitType) { // simple optimization for Daniel + trans(body, sp)(env) + } else { + trans(body, sp + 1)(env.withVal(x, sp)) + emitln(s"movq ${loc(sp + 1)}, ${loc(sp)}") + } + case Ref(x) => + env(x) match { + case Reg(sp1) => emitln(s"movq ${loc(sp1)}, ${loc(sp)}") + case Func(name) => ??? // Extra credit + } + case If(cond, tBranch, eBranch) => + val lab = freshLabel("if") + trans(cond, sp)(env) + transJumpIfTrue(sp)(s"${lab}_then") + trans(eBranch, sp)(env) + emitln(s"jmp ${lab}_end") + emitln(s"${lab}_then:", 0) + trans(tBranch, sp)(env) + emitln(s"${lab}_end:", 0) + case VarDec(x, tp, rhs, body) => + trans(rhs, sp)(env) + trans(body, sp + 1)(env.withVal(x, sp)) + emitln(s"movq ${loc(sp + 1)}, ${loc(sp)}") + case VarAssign(x, rhs) => + trans(rhs, sp)(env) + emitln(s"movq ${loc(sp)}, ${loc(env(x))}") + case While(cond, lBody, body) => + val lab = freshLabel("loop") + emitln(s"jmp ${lab}_cond") + emitln(s"${lab}_body:", 0) + trans(lBody, sp)(env) + emitln(s"${lab}_cond:", 0) + trans(cond, sp)(env) + transJumpIfTrue(sp)(s"${lab}_body") + trans(body, sp)(env) + case LetRec(funs, body) => + emitln("################# FUNCTIONS #####################", 0) + // We do not save the location of the function into register because we can use their + // name as a label. + val funsLoc = funs map { case FunDef(name, _, _, _) => (name, Func(name)) } + + // TODO complete the code + + emitln("#################################################\n\n", 0) + emitln("###################### MAIN #####################", 0) + //////////// DO NOT CHANGE//////////////// + emitln(s"${funcName(entry_point)}:", 0) + emitln("pushq %rbp\t# save stack frame for calling convention") + emitln("movq %rsp, %rbp") + emitln("movq %rdi, heap(%rip)") + + emitln("pushq %rbx") + emitln("pushq %r12") + emitln("pushq %r13") + emitln("pushq %r14") + emitln("pushq %r15") + ////////////////////////////////////////// + + // emit the main function (body of LetRec) here + // TODO you may need to change that code. + trans(body, Reg(0))(LocationEnv()) + emitln(s"movq ${loc(0)}, %rax") + + //////////// DO NOT CHANGE//////////////// + emitln("popq %r15") + emitln("popq %r14") + emitln("popq %r13") + emitln("popq %r12") + emitln("popq %rbx") + emitln("movq %rbp, %rsp\t# reset frame") + emitln("popq %rbp") + emitln("ret") + emitln("#################################################\n\n", 0) + ////////////////////////////////////////// + + case FunDef(fname, args, _, fbody) => + //////////// DO NOT CHANGE//////////////// + emitln(s"${funcName(fname)}:", 0) + emitln("pushq %rbp\t# save stack frame for calling convention") + emitln("movq %rsp, %rbp") + ////////////////////////////////////////// + + // TODO + + //////////// DO NOT CHANGE//////////////// + emitln("movq %rbp, %rsp\t# reset frame") + emitln("popq %rbp") + emitln("ret\n") + ////////////////////////////////////////// + case App(fun, args) => + // Advice: you may want to start to work on functions with only one argument + // i.e. change args to List(arg). Once it is working you can generalize your + // code and work on multiple arguments. + // Evaluate the arguments + // TODO + + // Compute the physical location of the function to be called + val fLoc: String = fun match { + case Ref(fname) => + env(fname) match { + case Reg(sp) => ??? // Extra credit + case Func(name) => "" // TODO + } + case _ => ??? // Extra credit + } + + // Implement the calling conventions after that point + // and generate the function call + // TODO + () + case ArrayDec(size, _) => + // This node needs to allocate an area of eval(size) * 8 bytes in the heap + // the assembly variable "heap" contains a pointer to the first valid byte + // in the heap. Make sure to update its value accordingly. + // TODO + () + case _ => BUG(s"don't know how to implement $exp") + } +} diff --git a/proj3/src/main/scala/project3/Interpreter.scala b/proj3/src/main/scala/project3/Interpreter.scala new file mode 100644 index 0000000..62a948a --- /dev/null +++ b/proj3/src/main/scala/project3/Interpreter.scala @@ -0,0 +1,415 @@ +package project3 + +abstract class Interpreter { + type MyVal + def run(ast: Language.Exp): MyVal +} + +/** + * This interpreter specifies the semantics of our + * programming language. + * + * The evaluation of each node returns a value. + */ +class ValueInterpreter extends Interpreter with BugReporter { + import Language._ + + /* + * Values for primitives operators. Already stored in the environment. + */ + val primitives = Map[String, BoxedVal]( + "putchar" -> BoxedVal(Primitive("putchar")), + "getchar" -> BoxedVal(Primitive("getchar")) + ) + + /* + * Definition of the values of our language + * + * We can return constant Int, Boolean, Unit or Array. + * We can also return Function. Primitive is a special + * kind of function. + */ + abstract class Val + case class Cst(x: Any) extends Val { + override def toString = if (x != null) x.toString else "null" + } + case class Func(args: List[String], fbody: Exp, var env: ValueEnv) extends Val { + // Add all the value into the existing environment + def withVals(list: List[(String,Val)]) = env = env.withVals(list) + override def toString = s"(${args mkString ","}) => $fbody" + } + case class Primitive(name: String) extends Val + + type MyVal = Val + + /** + * Env of the interpreter. Keeps track of the value + * of each variable defined. + */ + class Env { + def undef(name: String) = + BUG(s"Undefined identifier $name (should have been found during the semantic analysis)") + + def updateVar(name: String, v: Val): Val = undef(name) + def apply(name: String): Val = undef(name) + } + + case class BoxedVal(var v: Val) + case class ValueEnv( + vars: Map[String, BoxedVal] = primitives, + outer: Env = new Env) extends Env { + + /* + * Return a copy of the current state plus an immutable + * variable 'name' of value 'v' + */ + def withVal(name: String, v: Val): ValueEnv = { + copy(vars = vars + (name -> BoxedVal(v))) + } + + /* + * Return a copy of the current state plus all the immutables + * variable in list. + */ + def withVals(list: List[(String,Val)]): ValueEnv = { + copy(vars = vars ++ (list.map {case (n, v) => n -> BoxedVal(v) })) + } + + /* + * Update the variable 'name' in this scope or in the + * outer scope. + * Return the new value of the variable + */ + override def updateVar(name: String, v: Val): Val = { + if (vars.contains(name)) + vars(name).v = v + else + outer.updateVar(name, v) + v + } + + /* + * Return the value of the variable 'name' + */ + override def apply(name: String): Val = { + if (vars.contains(name)) + vars(name).v + else + outer(name) + } + } + + /* + * Compute and return the result of the unary + * operation 'op' on the value 'v' + */ + def evalUn(op: String)(v: Val) = (op, v) match { + case ("-", Cst(v: Int)) => Cst(-v) + case ("+", Cst(_: Int)) => v + case _ => BUG(s"unary operator $op undefined") + } + + /* + * Compute and return the result of the binary + * operation 'op' on the value 'v' and 'w' + * Note: v op w + */ + def evalBin(op: String)(v: Val, w: Val) = (op, v, w) match { + case ("-", Cst(v: Int), Cst(w: Int)) => Cst(v-w) + case ("+", Cst(v: Int), Cst(w: Int)) => Cst(v+w) + case ("*", Cst(v: Int), Cst(w: Int)) => Cst(v*w) + case ("/", Cst(v: Int), Cst(w: Int)) => Cst(v/w) + case ("==", Cst(v: Int), Cst(w: Int)) => Cst(v == w) + case ("!=", Cst(v: Int), Cst(w: Int)) => Cst(v != w) + case ("<=", Cst(v: Int), Cst(w: Int)) => Cst(v <= w) + case (">=", Cst(v: Int), Cst(w: Int)) => Cst(v >= w) + case ("<" , Cst(v: Int), Cst(w: Int)) => Cst(v < w) + case (">" , Cst(v: Int), Cst(w: Int)) => Cst(v > w) + case ("block-get", Cst(arr: Array[Any]), Cst(i: Int)) => + if (arr(i) == null) + BUG(s"uninitialized memory") + Cst(arr(i)) + case _ => BUG(s"binary operator $op undefined") + } + + /* + * Compute and return the result of the ternary + * operations 'op' on the value 'v', 'w' and 'z' + */ + def evalTer(op: String)(v: Val, w: Val, z: Val) = (op, v, w, z) match { + case ("block-set", Cst(arr: Array[Any]), Cst(i: Int), Cst(x)) => Cst(arr(i) = x) + case _ => BUG(s"ternary operator $op undefined") + } + + def evalPrim(op: String)(eargs: List[Val]) = eargs match { + case List(v, w, z) => evalTer(op)(v, w, z) + case List(v, w) => evalBin(op)(v, w) + case List(v) => evalUn(op)(v) + case _ => BUG(s"no prim with ${eargs.length} arguments") + } + + /* + * Evaluate the AST starting with an empty Env + */ + def run(exp: Exp) = eval(exp)(ValueEnv()) + + /* + * Evaluate the AST within the environment 'env' + */ + def eval(exp: Exp)(env: ValueEnv): Val = exp match { + case Lit(x) => Cst(x) + case Prim(op, args) => + val eargs = args map { arg => eval(arg)(env) } + evalPrim(op)(eargs) + case Let(x, tp, a, b) => + eval(b)(env.withVal(x, eval(a)(env))) + case Ref(x) => + env(x) + case If(cond, tBranch, eBranch) => + val Cst(v: Boolean) = eval(cond)(env) + if (v) + eval(tBranch)(env) + else + eval(eBranch)(env) + case VarDec(x, tp, rhs, body) => + eval(body)(env.withVal(x, eval(rhs)(env))) + case VarAssign(x, rhs) => + env.updateVar(x, eval(rhs)(env)) + case While(cond, lBody, body) => + while (eval(cond)(env) == Cst(true)) { + eval(lBody)(env) + } + eval(body)(env) + case FunDef(_, args, _, fbody) => + Func(args map { arg => arg.name }, fbody, env) + case LetRec(funs, body) => + // Evaluate all functions + val funcs = funs map { case fun@FunDef(name, _, _, _) => (name, eval(fun)(env)) } + + // Add all functions to the functions environment (recursion) + funcs foreach { case (_, func@Func(_, _, _)) => func.withVals(funcs) } + + eval(body)(env.withVals(funcs)) + case App(fun, args) => + // Evaluate the arguments + val eargs = args map { arg => eval(arg)(env) } + + // Evaluate the function to be called. + eval(fun)(env) match { + case Func(fargs, fbody, fenv) => + eval(fbody)(fenv.withVals(fargs zip eargs)) + case Primitive("getchar") => Cst(Console.in.read) + case Primitive("putchar") => + val List(Cst(c: Int)) = eargs + Console.out.write(c) + Console.out.flush + Cst(()) + } + case ArrayDec(size, _) => + val Cst(s: Int) = eval(size)(env) + Cst(new Array[Any](s)) + } + +} + +/* + * Defintion of the value produces by the StackInterpreter + */ +object StackVal extends BugReporter { + import Language._ + + type Loc = Int + + /* + * Location of primitives operators. Already stored in the environment. + */ + val primitives = Map[String, Loc]( + "putchar" -> 0, + "getchar" -> 1 + ) + + /** + * Env of the interpreter. Keep track of the location + * in memory of each variable defined. + */ + class Env { + def apply(name: String): Loc = + BUG(s"Undefined identifier $name (should have been found during the semantic analysis)") + } + + case class LocationEnv( + vars: Map[String, Loc] = primitives, + outer: Env = new Env) extends Env { + + /* + * Return a copy of the current state plus a + * variable 'name' at the location 'loc' + */ + def withVal(name: String, loc: Loc): LocationEnv = { + copy(vars = vars + (name -> loc)) + } + + /* + * Return a copy of the current state plus a + * variable 'name' at the location 'loc' + */ + def withVals(list: List[(String,Loc)]): LocationEnv = { + copy(vars = vars ++ list.toMap) + } + + /* + * Return the location of the variable 'name' + */ + override def apply(name: String) = vars.get(name) match { + case Some(loc) => loc + case None => outer(name) + } + } + + abstract class Val + + // Constant values: Int, Boolean, Array + case class Cst(x: Any) extends Val { + override def toString = if (x != null) x.toString else "null" + } + + // Function values + case class Func(args: List[String], fbody: Exp, var env: LocationEnv) extends Val { + def withVals(list: List[(String,Int)]) = env = env.withVals(list) + override def toString = s"(${args mkString ","}) => $fbody" + } + + // Primitives + case class Primitive(name: String) extends Val +} + +/** + * This interpreter is a stack-based interpreter as we have seen + * during the lecture. + * + * Rather than returning the value of a node, it stores it in memory, + * following a well-establish convention. + * + * This interpreter works in a similar manner as a processor. + */ +class StackInterpreter extends Interpreter with BugReporter { + import Language._ + import StackVal._ + + type MyVal = Val + + + // Memory and flag used by the interpreter + val memory = new Array[Val](1000) + memory(0) = Primitive("putchar") + memory(1) = Primitive("getchar") + var flag: Boolean = true + + /* + * Compute the result of the operator 'op' on the + * value stored at 'sp' and store it at 'sp' + */ + def evalUn(op: String)(sp: Loc) = (op, memory(sp)) match { + case ("+", _) => () + case ("-", Cst(x: Int)) => memory(sp) = Cst(-x) + case _ => BUG(s"Unary operator $op undefined") + } + + /* + * Compute the result of the operator 'op' on the + * value stored at 'sp' and 'sp1', and store it at 'sp' + * + * TODO: implement the missing case + */ + def evalBin(op: String)(sp: Loc, sp1: Loc) = (op, memory(sp), memory(sp1)) match { + case ("+", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x + y) + case ("-", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x - y) + case ("*", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x * y) + case ("/", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x / y) + case ("==", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x == y) + case ("!=", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x != y) + case ("<=", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x <= y) + case (">=", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x >= y) + case ("<" , Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x < y) + case (">" , Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x > y) + case ("block-get", Cst(arr: Array[Any]), Cst(i: Int)) => ??? + case _ => BUG(s"Binary operator $op undefined") + } + + def evalTer(op: String)(sp: Loc, sp1: Loc, sp2: Loc) = (op, memory(sp), memory(sp1), memory(sp2)) match { + case ("block-set", Cst(arr: Array[Any]), Cst(i: Int), Cst(x)) => ??? + case _ => BUG(s"ternary operator $op undefined") + } + + def evalPrim(op: String)(idxs: List[Int]) = idxs match { + case List(sp, sp1, sp2) => evalTer(op)(sp, sp1, sp2) + case List(sp, sp1) => evalBin(op)(sp, sp1) + case List(sp) => evalUn(op)(sp) + case _ => BUG(s"no prim with ${idxs.length} arguments") + } + + /* + * Evaluate the value of the AST 'exp' within + * an empty environment and return the value. + */ + def run(exp: Exp): Val = { + // Start at 2, putchar and getchar are store at 0 and 1!! + eval(exp, 2)(LocationEnv()) + memory(2) + } + + /* + * Evaluate the value of the AST 'exp' within + * the environment 'env 'and store the result + * at 'sp'. + * + * NOTE: Cond stores its result in the 'flag' + * variable. + * + * TODO: Remove all ???s and implement the + * appropriate code. The result must be the + * same than the evaluator defined above. + */ + def eval(exp: Exp, sp: Loc)(env: LocationEnv): Unit = exp match { + case Lit(x: Unit) => () + case Lit(x) => + memory(sp) = Cst(x) + case Prim(op, args) => + val idxs = List.tabulate(args.length)(i => sp + i) + (args zip idxs) foreach { case (arg, idx) => eval(arg, idx)(env) } + evalPrim(op)(idxs) + case Let(x, tp, a, b) => + eval(a, sp)(env) + eval(b, sp + 1)(env.withVal(x, sp)) + memory(sp) = memory(sp + 1) + case Ref(x) => + memory(sp) = memory(env(x)) + case If(cond, tBranch, eBranch) => + eval(cond, sp)(env) + val Cst(flag: Boolean) = memory(sp) + if (flag) + eval(tBranch, sp)(env) + else + eval(eBranch, sp)(env) + case VarDec(x, tp, rhs, body) => + eval(rhs, sp)(env) + eval(body, sp + 1)(env.withVal(x, sp)) + memory(sp) = memory(sp + 1) + case VarAssign(x, rhs) => + eval(rhs, sp)(env) + memory(env(x)) = memory(sp) + case While(cond, lbody, body) => + eval(cond, sp)(env) + while (memory(sp) == Cst(true)) { + eval(lbody, sp)(env) + eval(cond, sp)(env) + } + eval(body, sp)(env) + case FunDef(_, args, _, fbody) => ??? + case LetRec(funs, body) => + // TODO modify that code + eval(body, sp)(env) + case App(fun, args) => ??? + case ArrayDec(size, _) => ??? + } +} diff --git a/proj3/src/main/scala/project3/Main.scala b/proj3/src/main/scala/project3/Main.scala new file mode 100644 index 0000000..4a40ae6 --- /dev/null +++ b/proj3/src/main/scala/project3/Main.scala @@ -0,0 +1,105 @@ +package project3 + +import java.io._ +import scala.io._ + +trait CodeGenerator { + // Define the PrintWriter used to emit + // the code. + val out = new ByteArrayOutputStream + val pOut = new PrintWriter(out, true) + def stream = pOut + + def emitCode(ast: Language.Exp): Unit + + // Emit the code and return the String + // representation of the code + def code(ast: Language.Exp) = { + emitCode(ast) + out.toString.stripLineEnd + } +} + +object Runner { + import Language._ + + def printUsage: Unit = { + println("""Usage: run PROG [OPTION] + or: run FILE [OPTION] +OPTION: intStack""") + } + + def main(args: Array[String]): Unit = { + if (args.size == 0) { + printUsage + return + } + + val src = if (new File(args(0)).exists) { + val source = Source.fromFile(args(0)) + try source.getLines mkString "\n" finally source.close() + } else { + args(0) + } + + println("============ SRC CODE ============") + println(src) + println("==================================\n") + + val reader = new BaseReader(src, '\u0000') + val scanner = new Scanner(reader) + + // Parser to test! + // TODO: Change this as you finish parsers + val parser = new BaseParser(scanner) + val ast = try { + parser.parseCode + } catch { + case e: AbortException => return + case e: Throwable => throw e + } + + println("============= AST ================") + println(ast) + println("==================================\n") + + val analyzer = new SemanticAnalyzer(parser) + println("======= Semantic Analyzer ========") + val (nast, numWarning, numError) = analyzer.run(ast) + if (numError > 0) { + println("==================================\n") + return + } + println("=========== Typed AST ============") + print(nast) + println(s": ${nast.tp}") + println("==================================\n") + + + + val interpreter = if (args.contains("intStack")) + new StackInterpreter + else + new ValueInterpreter + println("========== Interpreter ===========") + println(s"Exit Code: ${interpreter.run(nast)}") + println("==================================\n") + + // Generator to test + val generator = new X86Compiler with CodeGenerator + val code = generator.code(nast) + + val runner = new ASMRunner(code) + println("============ OUTPUT ==============") + println(runner.code) + println("==================================\n") + + if (runner.assemble != 0) { + println("Compilation error!") + } else { + println("============ RESULT ==============") + println(s"Exit Code: ${runner.run}") + println("==================================") + } + } +} diff --git a/proj3/src/main/scala/project3/Parser.scala b/proj3/src/main/scala/project3/Parser.scala new file mode 100644 index 0000000..417eb2d --- /dev/null +++ b/proj3/src/main/scala/project3/Parser.scala @@ -0,0 +1,885 @@ +package project3 + +// Class used to carry position information within the source code +case class Position(gapLine: Int, gapCol: Int, startLine: Int, startCol: Int, endLine: Int, endCol: Int) { + override def toString = "pos" +} +class Positioned { + var pos: Position = _ + def withPos(p: Position) = { + pos = p + this + } +} + +object Tokens { + + abstract class Token { + var pos: Position = _ + } + case object EOF extends Token + + // CHANGED: As we added new types, instead of having a Token called Number, + // we have a Token called Literal for all constant values. + case class Literal(x: Any) extends Token + case class Ident(x: String) extends Token + case class Keyword(x: String) extends Token + case class Delim(x: Char) extends Token +} + + +// Scanner +class Scanner(in: Reader[Char]) extends Reader[Tokens.Token] with Reporter { + import Tokens._ + + // Position handling + def pos = in.pos + def input = in.input + + // Current line in the file + var line = 0 + + // lineStarts(i) contains the offset of the i th line within the file + val lineStarts = scala.collection.mutable.ArrayBuffer(0) + + // Current column in the file + def column = pos - lineStarts(line) + + // Extract the i th line of code. + def getLine(i: Int) = { + val start = lineStarts(i) + val end = input.indexOf('\n', start) + + if (end < 0) + input.substring(start) + else + input.substring(start, end) + } + + // Information for the current Position + var gapLine = 0; + var gapCol = 0; + var startLine = 0; + var startCol = 0; + var endLine = 0; + var endCol = 0; + + override def abort(msg: String) = { + abort(msg, showSource(getCurrentPos())) + } + + /* + * Show the line of code and highlight the token at position p + */ + def showSource(p: Position) = { + val width = if (p.endLine == p.startLine) (p.endCol - p.startCol) else 0 + + val header = s"${p.startLine + 1}:${p.startCol + 1}: " + val line1 = getLine(p.startLine) + val line2 = " "*(p.startCol+header.length) + "^"*(width max 1) + header + line1 + '\n' + line2 + } + + def isAlpha(c: Char) = + ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') + + def isDigit(c: Char) = '0' <= c && c <= '9' + + def isAlphaNum(c: Char) = isAlpha(c) || isDigit(c) + + def isCommentStart(c1: Char, c2: Char) = c1 == '/' && c2 == '/' + + val isWhiteSpace = Set(' ','\t','\n','\r') + + // Boolean operators start with one of the following characters + val isBOperator = Set('<', '>', '!', '=') + + // Operators start with one of the following characters + val isOperator = Set('+','-','*','/') ++ isBOperator + + // List of delimiters + // TODO: Update this as delimiters are added to our language + val isDelim = Set('(',')','=',';','{','}',':') + + // List of keywords + // TODO: Update this as keywords are added to our language + val isKeyword = Set("if", "else", "val", "var", "while") + + val isBoolean = Set("true", "false") + + /* + * Extract a name from the stream + * + * TODO: Handle Boolean literals + */ + def getName() = { + val buf = new StringBuilder + while (in.hasNext(isAlphaNum)) { + buf += in.next() + } + val s = buf.toString + if (isKeyword(s)) Keyword(s) + else if (isBoolean(s)) Boolean(s) + else Ident(s) + } + + /* + * Extract an operator from the stream + */ + def getOperator() = { + val buf = new StringBuilder + do { + buf += in.next() + } while (in.hasNext(isOperator)) + val s = buf.toString + // "=" is a delimiter, "=>" is a keyword, "==","=+", etc are operators + if (s == "=") Delim('=') + else if (isKeyword(s)) Keyword(s) + else Ident(s) + } + + /* + * Extract a number from the stream and return it. + * Raise an error if there is overflow. + * + * NOTE: An integer can be between 0 and (2 to the power 31) minus 1 + */ + val MAX_NUM = s"${(1 << 31) - 1}" + def getNum() = { + val num = new StringBuilder + while (in.hasNext(isDigit)) { + num += in.next() + } + + val sNum = num.toString + if (sNum.length < MAX_NUM.length || sNum <= MAX_NUM) + Literal(sNum.toInt) + else + abort(s"integer overflow") + } + + /* + * Extract a raw token from the stream. + * i.e. without position information. + */ + def getRawToken(): Token = { + if (in.hasNext(isAlpha)) { + getName() + } else if (in.hasNext(isOperator)) { + getOperator() + } else if (in.hasNext(isDigit)) { + getNum() + } else if (in.hasNext(isDelim)) { + Delim(in.next()) + } else if (!in.hasNext) { + EOF + } else { + abort(s"unexpected character") + } + } + + /* + * Skip whitespace and comments. Stop at the next token. + */ + def skipWhiteSpace() = { + while (in.hasNext(isWhiteSpace) || in.hasNext2(isCommentStart)) { + + // If it is a comment, consume the full line + if (in.peek == '/') { + in.next() + while (in.peek != '\n') in.next() + + } + + // Update file statistics if new line + if (in.peek == '\n') { + lineStarts += pos + 1 + line += 1 + } + in.next() + } + } + + def getCurrentPos() = { + endLine = line; endCol = column + Position(gapLine,gapCol,startLine,startCol,endLine,endCol) + } + + /* + * Extract a token and set position information + */ + def getToken(): Token = { + gapLine = line; gapCol = column + skipWhiteSpace() + startLine = line; startCol = column + val tok = getRawToken() + tok.pos = getCurrentPos() + + tok + } + + var peek = getToken() + var peek1 = getToken() + def hasNext: Boolean = peek != EOF + def hasNext(f: Token => Boolean) = f(peek) + def hasNext2(f: (Token, Token) => Boolean) = f(peek, peek1) + def next() = { + val res = peek + peek = peek1 + peek1 = getToken() + res + } +} + +class Parser(in: Scanner) extends Reporter { + import Tokens._ + + /* + * Overloaded methods that show the source code + * and highlight the current token when reporting + * an error. + */ + override def expected(msg: String) = { + expected(msg, in.showSource(in.peek.pos)) + } + + override def abort(msg: String) = { + abort(msg, in.showSource(in.peek.pos)) + } + + def error(msg: String, pos: Position): Unit = + error(msg, in.showSource(pos)) + + def warn(msg: String, pos: Position): Unit = + warn(msg, in.showSource(pos)) + + def accept(c: Char) = { + if (in.hasNext(_ == Delim(c))) in.next() + else expected(s"'$c'") + } + + def accept(s: String) = { + if (in.hasNext(_ == Keyword(s))) in.next() + else expected(s"'$s'") + } + + /* + * Auxilaries functions + * Test and extract data + */ + def isName(x: Token) = x match { + case Ident(x) => true + case _ => false + } + + def getName(): (String, Position) = { + if (!in.hasNext(isName)) expected("Name") + val pos = in.peek.pos + val Ident(x) = in.next() + (x, pos) + } + + // CHANGED: It was only Number previsously + def isLiteral(x: Token) = x match { + case Literal(x) => true + case _ => false + } + + def getLiteral(): (Any, Position) = { + if (!in.hasNext(isLiteral)) expected("Literal") + val pos = in.peek.pos + val Literal(x) = in.next() + (x, pos) + } + + def getOperator(): (String, Position) = { + if (!in.hasNext(isName)) expected("Operator") + val pos = in.peek.pos + val Ident(x) = in.next() + (x, pos) + } + + /* + * Test if the following token is an infix + * operator with highest precedence + */ + def isInfixOp(min: Int)(x: Token) = isOperator(x) && (x match { + case Ident(x) => prec(x) >= min + case _ => false + }) + + /* + * Test if the following token is an operator. + */ + def isOperator(x: Token) = x match { + case Ident(x) => in.isOperator(x.charAt(0)) + case _ => false + } + + /* + * Define precedence of operator. + * Negative precedence means that the operator can + * not be used as an infix operator within a simple expression. + * + * CHANGED: boolean operators have precedence of 0 + */ + def prec(a: String) = a match { // higher bind tighter + case "+" | "-" => 1 + case "*" | "/" => 2 + case _ if in.isBOperator(a.charAt(0)) => 0 + case _ => 0 + } + + def assoc(a: String) = a match { + case "+" | "-" | "*" | "/" => 1 + case _ => 1 + } +} + + +/** + * Definition of our target language. + * + * The different nodes of the AST also keep Position information + * for error handling during the semantic analysis. + * + * TODO: Every time you add an AST node, you must also track the position + */ +object Language { + abstract class Exp { + var pos: Position = _ + var tp: Type = UnknownType + + def withPos(p: Position) = { + pos = p + this + } + + def withType(pt: Type) = { + tp = pt + this + } + } + + abstract class Type + case object UnknownType extends Type + case class BaseType(v: String) extends Type { + override def toString = v + } + case class FunType(args: List[(String, Type)], rtp: Type) extends Type { + override def toString = s"(${args mkString ","}) => $rtp" + } + case class ArrayType(tp: Type) extends Type + + val IntType = BaseType("Int") + val UnitType = BaseType("Unit") + val BooleanType = BaseType("Boolean") + + + // Arithmetic + case class Lit(x: Any) extends Exp + // CHANGED: instead of creating a node for different operator arity, + // we use a single node with a list of arguments. + case class Prim(op: String, args: List[Exp]) extends Exp + + // Immutable variables + case class Let(x: String, xtp: Type, a: Exp, b: Exp) extends Exp + case class Ref(x: String) extends Exp + + // Branches + case class If(cond: Exp, tBranch: Exp, eBranch: Exp) extends Exp + + // Mutable variables + case class VarDec(x: String, xtp: Type, rhs: Exp, body: Exp) extends Exp + case class VarAssign(x: String, rhs: Exp) extends Exp + + // While loops + case class While(cond: Exp, lbody: Exp, body: Exp) extends Exp + + // Functions + case class LetRec(funs: List[Exp], body: Exp) extends Exp + case class Arg(name: String, tp: Type, pos: Position) + case class FunDef(name: String, args: List[Arg], rtp: Type, fbody: Exp) extends Exp + case class App(f: Exp, args: List[Exp]) extends Exp + + // Arrays + case class ArrayDec(size: Exp, etp: Type) extends Exp +} + +/* + * The BaseParser class implements all of the functionality implemented in project 2, + * with the addition of type information. + * + * To avoid repeating your effort from project 2, we have implemented all of the + * parsing for you, excluding the parsing of types. As such... + * + * TODO: Implement the two functions that parse types. + * + * ::= + * ::= ['*' | '/' | '+' | '-' | '<' | '>' | '=' | '!']+ + * ::= 'true' | 'false' + * ::= | | '()' + * | '('')' + * | + * | '{''}' + * ::= [] + * ::= []* + * | 'if' '('')' 'else' + * | '=' + * ::= + * | 'val' [:] '=' ';' + * | 'var' [:] '=' ';' + * | 'while' '('')'';' + */ +class BaseParser(in: Scanner) extends Parser(in) { + import Language._ + import Tokens._ + + /******************* Types **********************/ + + /* + * This function extracts the type information from + * the source code. Raise an error if there is no + * type information. + * + * This function will only be used to read in a type + * (i.e. you should not read in a delimiter) + * + * TODO: Implement this function + */ + def parseType: Type = in.peek match { + case _ => expected("type") + } + + + /* + * This function is parsing a type which can be omitted. + * If the type information is not in the source code, + * it returns UnknownType + * + * TODO: Implement this function + */ + def parseOptionalType: Type = in.peek match { + case _ => UnknownType + } + + /******************* Code **********************/ + + /* + * Parse the full code, + * verify that there are no unused tokens, + * and raise an error if there are. + */ + def parseCode = { + val res = parseExpression + if (in.hasNext) + expected(s"EOF") + LetRec(Nil, res) + } + + def parseAtom: Exp = (in.peek, in.peek1) match { + case (Literal(x), _) => + val (_, pos) = getLiteral + Lit(x).withPos(pos) + case (Delim('('), Delim(')')) => + val pos = in.next().pos + in.next + Lit(()).withPos(pos) + case (Delim('('), _) => + in.next() + val res = parseSimpleExpression + accept(')') + res + case (Ident(x), _) => + val (_, pos) = getName + Ref(x).withPos(pos) + case (Delim('{'), _) => + accept('{') + val res = parseExpression + accept('}') + res + case _ => abort(s"Illegal start of simple expression") + } + + def parseUAtom: Exp = if (in.hasNext(isOperator)) { + val (op, pos) = getOperator + Prim(op, List(parseAtom)).withPos(pos) + } else { + parseAtom + } + + def parseSimpleExpression(min: Int): Exp = { + var res = parseUAtom + while (in.hasNext(isInfixOp(min))) { + val (op, pos) = getOperator + val nMin = prec(op) + assoc(op) + val rhs = parseSimpleExpression(nMin) + res = Prim(op, List(res, rhs)).withPos(pos) + } + res + } + + def parseSimpleExpression: Exp = (in.peek, in.peek1) match { + case (Ident(x), Delim('=')) => + val (_, pos) = getName + accept('=') + val rhs = parseSimpleExpression + VarAssign(x, rhs).withPos(pos) + case (Keyword("if"), _) => + val pos = accept("if").pos + accept('(') + val cond = parseSimpleExpression + accept(')') + val tBranch = parseSimpleExpression + accept("else") + val eBranch = parseSimpleExpression + If(cond, tBranch, eBranch).withPos(pos) + case _ => parseSimpleExpression(0) + } + + def parseExpression: Exp = in.peek match { + case Keyword("val") => + accept("val") + val (name, pos) = getName + val tp = parseOptionalType + accept('=') + val rhs = parseSimpleExpression + accept(';') + val body = parseExpression + Let(name, tp, rhs, body).withPos(pos) + case Keyword("var") => + accept("var") + val (name, pos) = getName + val tp = parseOptionalType + accept('=') + val rhs = parseSimpleExpression + accept(';') + val body = parseExpression + VarDec(name, tp, rhs, body).withPos(pos) + case Keyword("while") => + val pos = accept("while").pos + accept('(') + val cond = parseSimpleExpression + accept(')') + val lBody = parseSimpleExpression + accept(';') + val body = parseExpression + While(cond, lBody, body).withPos(pos) + case _ => parseSimpleExpression + } +} + +/* + * We want to make our syntax easier for the programmer to use. + * + * For example, instead of writing: + * + * var x = 0; + * var y = 3; + * let dummy = x = x + 1; + * y = y + 1 + * + * We will write + * + * var x = 0; + * var y = 3; + * x = x + 1; + * y = y + 1 + * + * However the AST generated will be the same. The parser will have to create a dummy + * variable and insert a let binding. + * + * We also have some syntactic sugar for the if statement. If the else branch doesn't exist, + * then the unit literal will be used for that branch. + * + * TODO complete the two functions to handle syntactic sugar. + * + * ::= + * ::= ['*' | '/' | '+' | '-' | '<' | '>' | '=' | '!']+ + * ::= 'true' | 'false' + * ::= | | '()' + * | '('')' + * | + * | '{''}' + * ::= [] + * ::= []* + * | 'if' '('')' ['else' ] + * | '=' + * ::= [;] + * | 'val' [:] '=' ';' + * | 'var' [:] '=' ';' + * | 'while' '('')'';' + */ +class SyntacticSugarParser(in: Scanner) extends BaseParser(in) { + import Language._ + import Tokens._ + + // Can be overriden for ; inference + def isNewLine(x: Token) = x match { + case Delim(';') => true + case _ => false + } + + var next = 0 + def freshName(suf: String = "x") = { + next += 1 + suf + "$" + next + } + + override def parseSimpleExpression = in.peek match { + case _ => super.parseSimpleExpression + } + + override def parseExpression = { + // NOTE: parse expression terminates when it parse a simples expression. + // syntax sugar allows to have an other expression after it. + val res = super.parseExpression + res + } + +} + +/* + * The next parser is going to add the necessary mechanic to parse functions. + * + * With function come function declaration, function definition and function type. + * + * Here are some example of valid syntax: + * + * def f(x: Int, k: Int => Int): Int = h(x); + * + * h(1)(2, 4); + * + * val g: (Int => Int) => Int = 3; g + * + * You need to write the function to parse these expression. The job has been splitted + * in multiple small auxilary functions. Also don't forget that we already have some + * function doing part of the job in the super class. + * + * We also defined the concept of program. All function must be defined first and then + * the following expression is considered the main. + * + * Here is the formalized grammar. Most of it is already handle by the based parser. you + * only need to handle the new constructs. + * + * ::= + * | '=>' + * | '('[[',']*]')' '=>' + * ::= ['*' | '/' | '+' | '-' | '<' | '>' | '=' | '!']+ + * ::= 'true' | 'false' + * ::= | | '()' + * | '('')' + * | + * ::= ['('[[',']*]')']* + * | '{''}' + * ::= [] + * ::= []* + * | 'if' '('')' ['else' ] + * | '=' + * ::= [;] + * | 'val' [':'] '=' ';' + * | 'var' [':'] '=' ';' + * | 'while' '('')'';' + * ::= ':' + * ::= ['def''('[[',']*]')'[':' ] '=' ';']* + */ +class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { + import Language._ + import Tokens._ + + /* + * This function is an auxilary function that is parsing a list of elements of type T which are + * separated by 'sep'. + * + * 'sep' must be a valid delimiter. + * + * 12, 14, 11, 23, 10, 234 + * + * parseList[Exp](parseAtom, ',', tok => tok match { + * case Literal(x: Int) => x < 20; + * case _ => false + * }) + * + * will return the list List(Lit(12), Lit(14), lit(11)) and the next token will be Delim(',') + * + * You don't have to use this function but it may be useful. + */ + def parseList[T](parseElem: => T, sep: Char, cond: Token => Boolean, first: Boolean = true): List[T] = { + if (first && cond(in.peek) || (!first && in.peek == Delim(sep) && cond(in.peek1))) { + if (!first) { + accept(sep) + } + parseElem :: parseList(parseElem, sep, cond, false) + } else { + Nil + } + } + + + /* + * This function parse types. + * + * TODO + */ + override def parseType = in.peek match { + case _ => super.parseType + } + + /* + * Parse the program and verify that there nothing left + * to be parsed. + */ + override def parseCode = { + val prog = parseProgram + if (in.hasNext) + expected(s"EOF") + prog + } + + /* + * Parse one argument () + * + * TODO: complete the function + */ + def parseArg: Arg = { + ??? + } + + /* + * Parse one function. + * We assume that the first token is Keyword("def") + * + * TODO: complete the function + */ + def parseFunction: Exp = { + ??? + } + + /* + * Parse a program. I.e a list of function following + * by an expression. + * + * If there is no functions defined, this function + * still return a LetRec with an empty function list. + * + * TODO: complete the function + */ + def parseProgram = in.peek match { + case _ => LetRec(Nil, parseExpression) + } + + /* + * this function is called uatom to avoid reimplementing + * the previous functions. However it is parsing the + * grammar. + */ + override def parseUAtom = if (in.hasNext(isOperator)) { + val (op, pos) = getOperator + Prim(op, List(parseTight)).withPos(pos) + } else { + parseTight + } + + /* + * Parse grammar. i.e. function applications. + * + * Remember function application is left associative + * and they all have the same precedence. + * + * a(i)(k, j) is parsed to + * + * App(App(Ref("a"), List(Ref("i"))), List(Ref("k"), Ref("j"))) + */ + def parseTight = in.peek match { + case Delim('{') => + val pos = in.next().pos + val res = parseExpression + accept('}') + res + case _ => + var res = parseAtom + // TODO: complete + res + } + +} + +/* + * We are now going to add heap storage. This kind of storage is persistant + * between function calls. + * + * We are going to use the scala syntax of: new Array[Int](4). However + * we are not going to implement object. The array behavior will be closer + * to a C array. + * + * In order to access an element the element in the array we use the syntax: + * + * val arr = new Array[Int](4); + * val x = arr(0); + * + * And for the update: + * + * arr(0) = 3; + * + * The acces is going to be parse as a function application but this is fine. + * For the value update, the parser need to generate a primitive: block-set + * which take three paramter. 1 the arr, 2 the idx and 3 the value to update. + * + * arr(0) = 3; + * + * will be parsed to + * Prim("block-set", List(Ref("arr"), Lit(0), Lit(3))) + * + * One idea to parse it it to follow the following process: + * + * parse a tight, if it returns a function application with only one argument + * and the following token is an '=' then you are in the array update situation. + * + * TODO: Complete the methods + * + * ::= + * | '=>' + * | '('[[',']*]')' '=>' + * | 'Array' '[' ']' + * ::= ['*' | '/' | '+' | '-' | '<' | '>' | '=' | '!']+ + * ::= 'true' | 'false' + * ::= | | '()' + * | '('')' + * | + * ::= ['('[[',']*]')']*['('')' '=' ] + * | '{''}' + * ::= [] + * ::= []* + * | 'if' '('')' ['else' ] + * | '=' + * | 'new' 'Array' '[' ']' '('')' // type not optional '[' is the delimiter. + * ::= [;] + * | 'val' [':'] '=' ';' + * | 'var' [':'] '=' ';' + * | 'while' '('')'';' + * ::= ':' + * ::= ['def''('[[',']*]')'[':' ] '=' ';']* + */ +class ArrayParser(in: Scanner) extends FunctionParser(in) { + import Language._ + import Tokens._ + + override def parseType = in.peek match { + case Ident("Array") => ??? + case _ => super.parseType + } + + /* + * Parse array update + * + * TODO + */ + override def parseTight = ??? + + /* + * Parse array declaration + * + * TODO + */ + override def parseSimpleExpression = ??? +} diff --git a/proj3/src/main/scala/project3/SemanticAnalyzer.scala b/proj3/src/main/scala/project3/SemanticAnalyzer.scala new file mode 100644 index 0000000..3e4901f --- /dev/null +++ b/proj3/src/main/scala/project3/SemanticAnalyzer.scala @@ -0,0 +1,386 @@ +package project3 + +class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { + import Language._ + + /* + * Primitive functions that do not need to be defined or declared. + */ + val primitives = Map[String,(Boolean,Type)]( + "getchar" -> (false, FunType(List(), IntType)), + "putchar" -> (false, FunType(List(("", IntType)), UnitType)) + ) + + /* + * Define an empty state for the Semantic Analyzer. + * + * NOTE: + * val env = new Env + * + * env("hello") is equivalent to env.apply("hello") + */ + class Env { + def apply(name: String): Option[Type] = None + def isVar(name: String) = false + } + + /* + * Env that keeps track of variables defined. + * The map stores true if the variable is mutable, + * false otherwise and its type. + */ + case class TypeEnv( + vars: Map[String,(Boolean, Type)] = primitives, + outer: Env = new Env) extends Env { + + /* + * Return true if the variable is already defined + * in this scope + */ + def isDefined(name: String) = vars.contains(name) + + /* + * Make a copy of this object and add a mutable variable 'name' + */ + def withVar(name: String, tp: Type): TypeEnv = { + copy(vars = vars + (name -> (true, tp))) + } + + /* + * Make a copy of this object and add an immutable variable 'name' + */ + def withVal(name: String, tp: Type): TypeEnv = { + copy(vars = vars + (name -> (false, tp))) + } + + /* + * Make a copy of this object and add in the list of immutable variables. + */ + def withVals(list: List[(String,Type)]): TypeEnv = { + copy(vars = vars ++ (list map { t => (t._1, (false, t._2)) }).toMap) + } + + /* + * Return true if 'name' is a mutable variable defined in this scope + * or in the outer scope. + */ + override def isVar(name: String) = vars.get(name) match { + case None => outer.isVar(name) + case Some((mut, _)) => mut + } + + /* + * Return the Type if the variable 'name' is an option. + * i.e. Some(tp) if the variable exists or None if it doesn't + */ + override def apply(name: String): Option[Type] = vars.get(name) match { + case Some((_, tp)) => Some(tp) + case None => outer(name) + } + } + + // Error reporting + var numError = 0 + def error(msg: String, pos: Position): Unit = { + numError += 1 + parser.error(msg, pos) + } + + // Warning reporting + var numWarning = 0 + def warn(msg: String, pos: Position): Unit = { + numWarning += 1 + parser.warn(msg, pos) + } + + /* + * Return a fresh name if a new variable needs to be defined + */ + var next = 0 + def freshName(pref: String = "x") = { + next += 1 + s"${pref}_$next" + } + + /* + * Auxiliary functions. May be useful. + */ + def getName(arg: Any): String = arg match { + case Arg(name, _, _) => name + case FunDef(name, _, _, _) => name + case _ => BUG(s"Don't know how to extract name from $arg") + } + + def getPos(arg: Any): Position = arg match { + case Arg(_, _, pos) => pos + case fd@FunDef(_, _, _, _) => fd.pos + case _ => BUG(s"Don't know how to extract position from $arg") + } + + def checkDuplicateNames(args: List[Any]): Boolean = args match { + case h::t => + val name = getName(h) + val (dup, other) = t partition { arg => name == getName(arg) } + dup foreach { arg => + error(s"$name is already defined", getPos(arg)) + } + checkDuplicateNames(other) || dup.length > 0 + case Nil => false + } + + def funType(args: List[Arg], rtp: Type): FunType = { + FunType(args map { arg => (arg.name, arg.tp) }, rtp) + } + + def listArgType(size: Int, tp: Type) = List.fill(size)(("", tp)) + + /** + * Run the Semantic Analyzer on the given AST. + * + * Print out the number of warnings and errors found, if any. + * Return the AST with types resolved and the number of warnings + * and errors. + * + * NOTE: we want our main program to return an Int! + */ + def run(exp: Exp) = { + numError = 0 + val nexp = typeCheck(exp, IntType)(TypeEnv()) + if (numWarning > 0) + System.err.println(s"""$numWarning warning${if (numWarning != 1) "s" else ""} found""") + if (numError > 0) + System.err.println(s"""$numError error${if (numError != 1) "s" else ""} found""") + + (nexp, numWarning, numError) + } + + // List of valid infix operators + val isBOperator = Set("==","!=","<=",">=","<",">") + val isIntOperator = Set("+","-","*","/") + + /* + * Returns the type of the binary operator 'op'. See case "+" for an example + * TODO: implement the remaining binary operators for typeBinOperator + */ + def typeBinOperator(op: String)(pos: Position) = op match { + case "+" => FunType(List(("", IntType), ("", IntType)), IntType) + case _ => + error("undefined binary operator", pos) + UnknownType + } + + // List of valid unary operators + val isIntUnOperator = Set("+","-") + + /* + * Returns the type of the unary operator 'op' + * TODO: implement typeUnOperator + */ + def typeUnOperator(op: String)(pos: Position) = op match { + case _ => + error(s"undefined unary operator", pos) + UnknownType + } + + /* + * Returns the type of the ternary operator 'op' + * TODO: implement typeTerOperator + * operators: block-set + */ + def typeTerOperator(op: String)(pos: Position) = op match { + case _ => + error(s"undefined ternary operator", pos) + UnknownType + } + /* + * Return the type of the operator 'op' with arity 'arity' + */ + def typeOperator(op: String, arity: Int)(pos: Position): Type = arity match { + case 3 => typeTerOperator(op)(pos) + case 2 => typeBinOperator(op)(pos) + case 1 => typeUnOperator(op)(pos) + case _ => + error(s"undefined operator", pos) + UnknownType + } + + /* + * Check if 'tp' conforms to 'pt' and return the more precise type. + * The result needs to be well formed. + * + * TODO: implement the case of function type. + */ + def typeConforms(tp: Type, pt: Type)(env: TypeEnv, pos: Position): Type = (tp, pt) match { + case (_, _) if tp == pt => typeWellFormed(tp)(env, pos) + case (_, UnknownType) => typeWellFormed(tp)(env, pos) // tp <: Any + case (UnknownType, _) => typeWellFormed(pt)(env, pos) // for function arguments + case (FunType(args1, rtp1), FunType(args2, rtp2)) if args1.length == args2.length => + ??? // TODO: Function type conformity + case (ArrayType(tp), ArrayType(pt)) => ArrayType(typeConforms(tp, pt)(env, pos)) + case _ => error(s"type mismatch;\nfound : $tp\nexpected: $pt", pos); pt + } + + /* + * Auxiliary function used to check function type argument conformity. + * + * The function is verifying that 'tp' elements number n conforms + * to 'pt' element number n. It returns the list of precise types + * returned by each invocation to typeConforms + */ + def typeConform(tp: List[(String, Type)], pt: List[(String,Type)])(env: TypeEnv, pos: Position): List[(String, Type)] = { + if (tp.length != pt.length) BUG("length of list does not match") + + (tp zip pt) map { case ((arg1, tp1), (arg2, tp2)) => + (if (tp1 != UnknownType) arg1 + else arg2, typeConforms(tp1, tp2)(env, pos)) + } + } + + /* + * Verify that the type 'tp' is well formed. i.e there is no + * UnknownType. + */ + def typeWellFormed(tp: Type)(env: TypeEnv, pos: Position)(implicit forFunction: Boolean=false): Type = tp match { + case FunType(args, rte) => + FunType(args map { case (n, tp) => + (n, typeWellFormed(tp)(env, pos)) + }, typeWellFormed(rte)(env, pos)(true)) + case ArrayType(tp) => ArrayType(typeWellFormed(tp)(env, pos)) + case UnknownType => + if (forFunction) error("malformed type: function return types must be explicit if function is used recursively or in other functions' bodies", pos) + else error("malformed type", pos) + UnknownType + case _ => tp + } + + + /* + * typeCheck takes an expression and an expected type (which may be UnknownType). + * This is done via calling the typeInfer and typeConforms + * functions (details below), and finally returning the original + * expression with all typing information resolved. + * + * typeInfer uses the inference rules seen during the lectures + * to discover the type of an expression. As a reminder, the rules we saw can be + * found in lectures 5 and 6. + * + * TODO: Remove the ??? and add the correct implementation. + * The code must follow the inference rules seen during the lectures. + * + * The errors/warnings check that you had to implement for project 2 + * should be already implemented. However, there are new variables + * introduced that need to be check for duplicate (function name, + * variables names). We defined the rules for function semantic in + * lecture 5. + */ + def typeCheck(exp: Exp, pt: Type)(env: TypeEnv): Exp = { + val nexp = typeInfer(exp, pt)(env) + val rnexpType = typeConforms(nexp.tp, pt)(env, exp.pos) + nexp.withType(rnexpType) + } + + def typeInfer(exp: Exp, pt: Type)(env: TypeEnv): Exp = exp match { + case Lit(_: Int) => exp.withType(IntType) + case Lit(_: Boolean) => ??? + case Lit(_: Unit) => ??? + case Prim("block-set", args) => ??? + case Prim(op, args) => + typeOperator(op, args.length)(exp.pos) match { + case FunType(atps, rtp) => ??? + case UnknownType => exp.withType(UnknownType) + case _ => BUG("operator's type needs to be FunType") + } + case Let(x, tp, rhs, body) => + if (env.isDefined(x)) + warn("reuse of variable name", exp.pos) + val nrhs = typeCheck(rhs, tp)(env) + val nbody = typeCheck(body, pt)(env.withVal(x, nrhs.tp)) + Let(x, nrhs.tp, nrhs, nbody).withType(nbody.tp) + case Ref(x) => + env(x) match { + case Some(tp) => ??? // Remember to check that the type taken from the environment is welformed + case _ => + error("undefined identifier", exp.pos) + ??? + } + case If(cond, tBranch, eBranch) => + // Hint: type check the else branch before the then branch. + ??? + case VarDec(x, tp, rhs, body) => + if (env.isDefined(x)) + warn("reuse of variable name", exp.pos) + ??? + case VarAssign(x, rhs) => + val xtp = if (!env.isDefined(x)) { + error("undefined identifier", exp.pos) + UnknownType + } else { + if (!env.isVar(x)) + error("reassignment to val", exp.pos) + env(x).get + } + + ??? + + /* Because of syntactic sugar, a variable assignment + * statement can be accepted as an expression + * of type Unit. In this case, we will modify + * the AST and store the assignment value into + * a "dummy" variable and return the Unit Literal. + * + * For example, + * + * If(..., VarAssign("x", Lit(1)), Lit(())) + * + * requires the two branches of the If to be of the same + * type, in this case, Unit. Therefore the "then" branch + * will need to be modified to have the correct type. + * Without changing the semantics! + */ + pt match { + case UnitType => ??? + case _ => ??? + } + case While(cond, lbody, body) => ??? + case FunDef(fname, args, rtp, fbody) => ??? + case LetRec(funs, body) => + // TODO modify to handle general case + val nbody = typeCheck(body, pt)(env) + LetRec(Nil, nbody).withType(nbody.tp) + case App(fun, args) => + // TODO Check fun type + val nFun: Exp = ??? + + // Handling some errors + val ftp = nFun.tp match { + case ftp@FunType(fargs, _) if fargs.length == args.length => + ftp + case ftp@FunType(fargs, rtp) if fargs.length < args.length => + error(s"too many arguments for method: ($fargs)$rtp", exp.pos) + FunType(fargs ++ List.fill(args.length - fargs.length)(("", UnknownType)), rtp) + case ftp@FunType(fargs, rtp) => + error(s"not enough arguments for method: ($fargs)$rtp", exp.pos) + ftp + case ArrayType(tp) => + FunType(List(("", IntType)), tp) + case tp => + error(s"$tp does not take paramters", exp.pos) + FunType(List.fill(args.length)(("", UnknownType)), pt) + } + + // TODO: Check arguments type + val nargs: List[Exp] = ??? + + // Transform some function applications into primitives on arrays. + nFun.tp match { + case ArrayType(tp) => + Prim("block-get", List(nFun, nargs.head)).withType(tp) + case _ => App(nFun, nargs).withType(ftp.rtp) + } + case ArrayDec(size: Exp, etp: Type) => + // TODO: Check array declaration + // Note that etp is the type of elements + ??? + case _ => BUG(s"malformed expresstion $exp") + } +} diff --git a/proj3/src/main/scala/project3/Util.scala b/proj3/src/main/scala/project3/Util.scala new file mode 100644 index 0000000..51e56b6 --- /dev/null +++ b/proj3/src/main/scala/project3/Util.scala @@ -0,0 +1,88 @@ +package project3 + +import java.io._ +import scala.sys.process._ + +class AbortException extends Exception("aborted") + +// Error reporting +trait Reporter { + // report a warning + def warn(s: String): Unit = System.err.println(s"Warning: $s.") + def warn(s: String, msg: String): Unit = System.err.println(s"Warning: $s.\n" + msg) + // report an error + def error(s: String): Unit = System.err.println(s"Error: $s.") + def error(s: String, msg: String): Unit = System.err.println(s"Error: $s.\n" + msg) + // report error and halt + def abort(s: String): Nothing = { error(s); throw new AbortException()} + def abort(s: String, msg: String): Nothing = { error(s, msg); throw new AbortException()} + + def expected(s: String): Nothing = abort(s"$s expected") + def expected(s: String, msg: String): Nothing = + abort(s"$s expected", msg) +} + +trait BugReporter { + def BUG(msg: String) = throw new Exception(s"BUG: $msg") +} + +// Utilities to emit code +trait Codegen { + + def stream: PrintWriter + + // output + def emit(s: String): Unit = stream.print('\t' + s) + def emitln(s: String, nTab: Int = 1): Unit = stream.println("\t" * nTab + s) +} + +abstract class Reader[T] { + def pos: Int + def input: String + def peek: T + def peek1: T // second look-ahead character used for comments '//' + def hasNext: Boolean + def hasNext(f: T => Boolean): Boolean + def hasNext2(f: (T,T) => Boolean): Boolean + def next(): T +} + +class BaseReader(str: String, eof: Char) extends Reader[Char] { + var pos = 0 + def input = str + val in = str.iterator + var peek = if (in.hasNext) in.next() else eof + var peek1 = if (in.hasNext) in.next() else eof + def hasNext: Boolean = peek != eof + def hasNext(f: Char => Boolean) = f(peek) + def hasNext2(f: (Char,Char) => Boolean) = f(peek,peek1) + def next() = { + val x = peek; peek = peek1; + peek1 = if (in.hasNext) in.next() else eof + pos += 1 + x + } +} + +// ASM bootstrapping +class ASMRunner(snipet: String) { + + def code = snipet + + def assemble = { + val file = new File("gen/gen.s") + val writer = new PrintWriter(file) + + writer.println(snipet) + writer.flush + writer.close + + Seq("gcc", "-no-pie", "gen/bootstrap.c", "gen/gen.s", "-o", "gen/out").!.toInt + } + + def run = { + val stdout = "gen/out".!! + // output format: Exit Code: \n + stdout.split(" ").last.trim.toInt + } +} diff --git a/proj3/src/test/scala/project3/CompilerTest.scala b/proj3/src/test/scala/project3/CompilerTest.scala new file mode 100644 index 0000000..ad22458 --- /dev/null +++ b/proj3/src/test/scala/project3/CompilerTest.scala @@ -0,0 +1,41 @@ +package project3 + +import org.scalatest._ +import java.io.{ByteArrayOutputStream, PrintWriter} + +// Define the stream method +trait TestOutput { + import Language._ + + val out = new ByteArrayOutputStream + val pOut = new PrintWriter(out, true) + def stream = pOut + def emitCode(ast: Exp): Unit + + def code(ast: Exp) = { + emitCode(ast) + out.toString.stripLineEnd + } +} + +class CompilerTest extends TimedSuite { + import Language._ + + def runner(src: String) = new ASMRunner(src) + + def testCompiler(ast: Exp, res: Int) = { + val interpreter = new X86Compiler with TestOutput + + val code = interpreter.code(ast) + val asm = runner(code) + + assert(asm.assemble == 0, "Code generated couldn't be assembled") + assert(asm.run == res, "Invalid result") + } + + test("arithm") { + testCompiler(LetRec(Nil, Lit(-21)), -21) + testCompiler(LetRec(Nil, Prim("-", List(Lit(10), Lit(2)))), 8) + } + +} diff --git a/proj3/src/test/scala/project3/InterpreterTest.scala b/proj3/src/test/scala/project3/InterpreterTest.scala new file mode 100644 index 0000000..7357585 --- /dev/null +++ b/proj3/src/test/scala/project3/InterpreterTest.scala @@ -0,0 +1,20 @@ +package project3 + +import org.scalatest._ + +class InterpretTest extends TimedSuite { + import Language._ + import StackVal._ + + def testInterpreter(ast: Exp, res: Any) = { + val interpreter = new StackInterpreter + + assert(res == interpreter.run(ast), "Interpreter does not return the correct value") + } + + test("arithm") { + testInterpreter(Lit(-21), Cst(-21)) + testInterpreter(Prim("-", List(Lit(10), Lit(2))), Cst(8)) + } + +} diff --git a/proj3/src/test/scala/project3/ParserTest.scala b/proj3/src/test/scala/project3/ParserTest.scala new file mode 100644 index 0000000..dddd8a7 --- /dev/null +++ b/proj3/src/test/scala/project3/ParserTest.scala @@ -0,0 +1,33 @@ +package project3 + +import java.io._ +import org.scalatest._ + +class ParserTest extends TimedSuite { + import Language._ + + def scanner(src: String) = new Scanner(new BaseReader(src, '\u0000')) + + def testBaseParser(op: String, res: Exp) = { + val gen = new BaseParser(scanner(op)) + val ast = gen.parseCode + + assert(ast == LetRec(Nil, res), "Invalid result") + } + + test("SingleDigit") { + testBaseParser("1", Lit(1)) + } + + test("GenericPrecedence") { + testBaseParser("2-4*3", Prim("-", List(Lit(2), Prim("*", List(Lit(4), Lit(3)))))) + } + + test("ParseType") { + testBaseParser("val x: Int = 1; 2", Let("x", IntType, Lit(1), Lit(2))) + } + + test("ParseOptionalType") { + testBaseParser("val x = 1; 2", Let("x", UnknownType, Lit(1), Lit(2))) + } +} diff --git a/proj3/src/test/scala/project3/SemanticAnalyzerTest.scala b/proj3/src/test/scala/project3/SemanticAnalyzerTest.scala new file mode 100644 index 0000000..5feb474 --- /dev/null +++ b/proj3/src/test/scala/project3/SemanticAnalyzerTest.scala @@ -0,0 +1,51 @@ +package project3 + +import org.scalatest._ + +class SemanticAnalyzerTest extends TimedSuite { + import Language._ + + def astTypeEquals(ast: Exp, tsa: Exp): Boolean = ast == tsa && ast.tp == tsa.tp && { (ast, tsa) match { + case (Prim(_, args), Prim(_, sgra))=> + (args zip sgra) forall { case (arg, gra) => astTypeEquals(arg, gra) } + case (Let(_, _, a, b), Let(_, _, c, d)) => + astTypeEquals(a, c) && astTypeEquals(b, d) + case (If(cond, tBranch, eBranch), If(cond1, tBranch1, eBranch1)) => + astTypeEquals(cond, cond1) && astTypeEquals(tBranch, tBranch1) && astTypeEquals(eBranch, eBranch1) + case (VarDec(_, _, a, b), VarDec(_, _, c, d)) => + astTypeEquals(a, c) && astTypeEquals(b, d) + case (VarAssign(_, rhs), VarAssign(_, shr)) => + astTypeEquals(rhs, shr) + case (While(cond, tBranch, eBranch), While(cond1, tBranch1, eBranch1)) => + astTypeEquals(cond, cond1) && astTypeEquals(tBranch, tBranch1) && astTypeEquals(eBranch, eBranch1) + case (FunDef(_, _, _, fbody), FunDef(_, _, _, fbody1)) => + astTypeEquals(fbody, fbody1) + case (LetRec(funs, body), LetRec(funs1, body1)) => + ((funs zip funs1) forall { case (arg, gra) => astTypeEquals(arg, gra) }) && astTypeEquals(body, body1) + case (App(fun, args), App(fun1, args1)) => + ((args zip args1) forall { case (arg, gra) => astTypeEquals(arg, gra) }) && astTypeEquals(fun, fun1) + case (ArrayDec(size, _), ArrayDec(size1, _)) => + astTypeEquals(size, size1) + case _ => true + }} + + def testSemanticAnalyzer(ast: Exp, tsa: Exp, nWarning: Int, nError: Int) = { + val fakeParser = new Parser(null) { + override def error(msg: String, pos: Position) = {} + override def warn(msg: String, pos: Position) = {} + } + + val analyzer = new SemanticAnalyzer(fakeParser) + + val (tast, w, e) = analyzer.run(ast) + assert(w == nWarning, "Incorrect number of Warnings") + assert(e == nError, "Incorrect number of Errors") + assert(astTypeEquals(tast, tsa), "AST does not have correct type") + } + + test("NoErrorNoWarning") { + testSemanticAnalyzer(Lit(1), Lit(1).withType(IntType), 0, 0) + testSemanticAnalyzer(Prim("+", List(Lit(1), Lit(2))), Prim("+", List(Lit(1).withType(IntType), Lit(2).withType(IntType))).withType(IntType), 0, 0) + } + +} diff --git a/proj3/src/test/scala/project3/TimedSuite.scala b/proj3/src/test/scala/project3/TimedSuite.scala new file mode 100644 index 0000000..02bce3f --- /dev/null +++ b/proj3/src/test/scala/project3/TimedSuite.scala @@ -0,0 +1,10 @@ +package project3 +import org.scalatest._ +import org.scalatest.concurrent.{TimeLimitedTests, Signaler, ThreadSignaler} +import org.scalatest.time.{Span, Millis} + +class TimedSuite extends FunSuite with TimeLimitedTests { + + val timeLimit = Span(1000, Millis) + override val defaultTestSignaler: Signaler = ThreadSignaler +}