diff --git a/proj3/.gitignore b/proj3/.gitignore new file mode 100644 index 0000000..183ab06 --- /dev/null +++ b/proj3/.gitignore @@ -0,0 +1,10 @@ +.vscode +.DS_Store +.bloop +.metals +**/target +**/.bloop +**/.metals +gen/* +!gen/bootstrap.c +metals.* \ No newline at end of file diff --git a/proj3/src/main/scala/project3/Compiler.scala b/proj3/src/main/scala/project3/Compiler.scala index 4ae709b..d497e13 100644 --- a/proj3/src/main/scala/project3/Compiler.scala +++ b/proj3/src/main/scala/project3/Compiler.scala @@ -52,7 +52,9 @@ abstract class X86Compiler extends BugReporter with Codegen { */ val primitives = Map( "putchar" -> Func("putchar"), - "getchar" -> Func("getchar")) + "getchar" -> Func("getchar"), + "toInt" -> Func("toInt"), + "toChar" -> Func("toChar")) private class Env { def undef(name: String) = BUG(s"Undefined identifier $name (should have been found during the semantic analysis)") @@ -124,6 +126,28 @@ abstract class X86Compiler extends BugReporter with Codegen { emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") emitln(s"sete %al") emitln(s"movzbq %al, ${loc(sp)}") + case "!=" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"setne %al") + emitln(s"movzbq %al, ${loc(sp)}") + case ">=" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"setge %al") + emitln(s"movzbq %al, ${loc(sp)}") + case ">" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"setg %al") + emitln(s"movzbq %al, ${loc(sp)}") + case "<=" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"setle %al") + emitln(s"movzbq %al, ${loc(sp)}") + case "<" => + emitln(s"cmp ${loc(sp1)}, ${loc(sp)}") + emitln(s"setl %al") + emitln(s"movzbq %al, ${loc(sp)}") + case "block-get" => + emitln(s"movq (${loc(sp)}, ${loc(sp1)}, 8), ${loc(sp)}") case _ => BUG(s"Binary operator $op undefined") } @@ -136,6 +160,9 @@ abstract class X86Compiler extends BugReporter with Codegen { * Valid operators: block-set */ def transTer(op: String)(sp: Loc, sp1: Loc, sp2: Loc) = op match { + case "block-set" => + emitln(s"movq ${loc(sp2)}, (${loc(sp)}, ${loc(sp1)}, 8)") + emitln(s"movq ${loc(sp2)}, ${loc(sp)}") // returns rhs case _ => BUG(s"ternary operator $op undefined") } @@ -173,7 +200,9 @@ abstract class X86Compiler extends BugReporter with Codegen { * if the location 'sp' contains the value 'true' */ def transJumpIfTrue(sp: Loc)(label: Label) = { - ??? + emitln(s"movq $$1, ${loc(sp + 1)}") + emitln(s"cmp ${loc(sp)}, ${loc(sp + 1)}") + emitln(s"je $label") } /* @@ -188,7 +217,13 @@ abstract class X86Compiler extends BugReporter with Codegen { def trans(exp: Exp, sp: Loc)(env: LocationEnv): Unit = exp match { case Lit(x: Int) => emitln(s"movq $$$x, ${loc(sp)}") - case Lit(b: Boolean) => () // TODO + case Lit(x: Char) => + emitln(s"movq $$${x.toInt}, ${loc(sp)}") + case Lit(b: Boolean) => // TODO + if(b) + emitln(s"movq $$1, ${loc(sp)}") + else + emitln(s"movq $$0, ${loc(sp)}") case Lit(x: Unit) => () // TODO case Prim(op, args) => val idxs = List.tabulate(args.length)(i => sp + i) @@ -205,7 +240,7 @@ abstract class X86Compiler extends BugReporter with Codegen { case Ref(x) => env(x) match { case Reg(sp1) => emitln(s"movq ${loc(sp1)}, ${loc(sp)}") - case Func(name) => ??? // Extra credit + case Func(name) => emitln(s"leaq ${funcName(name)}(%rip), ${loc(sp)}") // Extra credit } case If(cond, tBranch, eBranch) => val lab = freshLabel("if") @@ -233,12 +268,24 @@ abstract class X86Compiler extends BugReporter with Codegen { transJumpIfTrue(sp)(s"${lab}_body") trans(body, sp)(env) case LetRec(funs, body) => + emitln(s"${funcName("toInt")}:", 0) + emitln(s"movzbq %dil, %rax") + emitln(s"ret\n") + + emitln(s"${funcName("toChar")}:", 0) + emitln(s"movzbq %dil, %rax") + emitln("ret\n") + emitln("################# FUNCTIONS #####################", 0) // We do not save the location of the function into register because we can use their // name as a label. val funsLoc = funs map { case FunDef(name, _, _, _) => (name, Func(name)) } // TODO complete the code + funs map { + case func@FunDef(_, _, _, _) => + trans(func, Reg(0))(env.withVals(funsLoc)) + } emitln("#################################################\n\n", 0) emitln("###################### MAIN #####################", 0) @@ -257,7 +304,8 @@ abstract class X86Compiler extends BugReporter with Codegen { // emit the main function (body of LetRec) here // TODO you may need to change that code. - trans(body, Reg(0))(LocationEnv()) + trans(body, Reg(0))(LocationEnv().withVals(funsLoc)) + emitln(s"movq ${loc(0)}, %rax") //////////// DO NOT CHANGE//////////////// @@ -280,6 +328,10 @@ abstract class X86Compiler extends BugReporter with Codegen { ////////////////////////////////////////// // TODO + val ids = (0 until args.length).toList + val arg_regs = (args zip ids) map {case (Arg(fname, _, _), idx) => (fname, sp + idx)} + trans(fbody, sp + args.length)(env.withVals(arg_regs)) + emitln(s"movq ${loc(sp + args.length)}, %rax") //////////// DO NOT CHANGE//////////////// emitln("movq %rbp, %rsp\t# reset frame") @@ -292,27 +344,51 @@ abstract class X86Compiler extends BugReporter with Codegen { // code and work on multiple arguments. // Evaluate the arguments // TODO - + val ids = (0 until args.length).toList + (args zip ids) map {case (arg, idx) => trans(arg, sp + idx)(env)} // Compute the physical location of the function to be called val fLoc: String = fun match { case Ref(fname) => env(fname) match { - case Reg(sp) => ??? // Extra credit - case Func(name) => "" // TODO + case Reg(sp) => s"*${loc(sp)}" // Extra credit + case Func(name) => funcName(name) // TODO } - case _ => ??? // Extra credit + case _ => // return from another invocation + trans(fun, sp)(env); + s"*${loc(sp)}" } // Implement the calling conventions after that point // and generate the function call // TODO - () + val spID : Int = sp match { + case Reg(i) => i + } + + for(i <- 0 to spID){ + emitln(s"pushq ${loc(i)}"); + } + + for(i <- 0 to spID) { + emitln(s"movq ${loc(sp + i)}, ${loc(i)}") + } + emitln(s"call ${fLoc}") + + for(i <- spID to 0 by -1) { + emitln(s"popq ${loc(i)}"); + } + emitln(s"movq %rax, ${loc(sp)}") + case ArrayDec(size, _) => // This node needs to allocate an area of eval(size) * 8 bytes in the heap // the assembly variable "heap" contains a pointer to the first valid byte // in the heap. Make sure to update its value accordingly. // TODO - () + emitln(s"movq (%rip), ${loc(sp)}") + trans(size, sp + 1)(env) + emitln(s"movq (%rip), ${loc(sp + 2)}") + emitln(s"movq (${loc(sp + 2)}, ${loc(sp + 1)}, 8), ${loc(sp + 2)}") + emitln(s"movq ${loc(sp + 2)}, %rip") case _ => BUG(s"don't know how to implement $exp") } } diff --git a/proj3/src/main/scala/project3/Interpreter.scala b/proj3/src/main/scala/project3/Interpreter.scala index 62a948a..9fad5fe 100644 --- a/proj3/src/main/scala/project3/Interpreter.scala +++ b/proj3/src/main/scala/project3/Interpreter.scala @@ -19,7 +19,9 @@ class ValueInterpreter extends Interpreter with BugReporter { */ val primitives = Map[String, BoxedVal]( "putchar" -> BoxedVal(Primitive("putchar")), - "getchar" -> BoxedVal(Primitive("getchar")) + "getchar" -> BoxedVal(Primitive("getchar")), + "toInt" -> BoxedVal(Primitive("toInt")), + "toChar" -> BoxedVal(Primitive("toChar")) ) /* @@ -204,6 +206,12 @@ class ValueInterpreter extends Interpreter with BugReporter { Console.out.write(c) Console.out.flush Cst(()) + case Primitive("toInt") => + val List(Cst(c: Char)) = eargs + Cst(c.toInt) + case Primitive("toChar") => + val List(Cst(c: Int)) = eargs + Cst(c.toChar) } case ArrayDec(size, _) => val Cst(s: Int) = eval(size)(env) @@ -225,7 +233,9 @@ object StackVal extends BugReporter { */ val primitives = Map[String, Loc]( "putchar" -> 0, - "getchar" -> 1 + "getchar" -> 1, + "toInt" -> 2, + "toChar" -> 3 ) /** @@ -303,6 +313,8 @@ class StackInterpreter extends Interpreter with BugReporter { val memory = new Array[Val](1000) memory(0) = Primitive("putchar") memory(1) = Primitive("getchar") + memory(2) = Primitive("toInt") + memory(3) = Primitive("toChar") var flag: Boolean = true /* @@ -332,12 +344,15 @@ class StackInterpreter extends Interpreter with BugReporter { case (">=", Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x >= y) case ("<" , Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x < y) case (">" , Cst(x: Int), Cst(y: Int)) => memory(sp) = Cst(x > y) - case ("block-get", Cst(arr: Array[Any]), Cst(i: Int)) => ??? + case ("block-get", Cst(arr: Array[Any]), Cst(i: Int)) => + if (arr(i) == null) + BUG(s"uninitialized memory") + memory(sp) = Cst(arr(i)) case _ => BUG(s"Binary operator $op undefined") } def evalTer(op: String)(sp: Loc, sp1: Loc, sp2: Loc) = (op, memory(sp), memory(sp1), memory(sp2)) match { - case ("block-set", Cst(arr: Array[Any]), Cst(i: Int), Cst(x)) => ??? + case ("block-set", Cst(arr: Array[Any]), Cst(i: Int), Cst(x)) => memory(sp) = Cst(arr(i) = x) case _ => BUG(s"ternary operator $op undefined") } @@ -353,9 +368,9 @@ class StackInterpreter extends Interpreter with BugReporter { * an empty environment and return the value. */ def run(exp: Exp): Val = { - // Start at 2, putchar and getchar are store at 0 and 1!! - eval(exp, 2)(LocationEnv()) - memory(2) + // Start at 4, putchar, getchar, toChar, toInt are store at 0 and 1, 2, 3!! + eval(exp, 4)(LocationEnv()) + memory(4) } /* @@ -405,11 +420,47 @@ class StackInterpreter extends Interpreter with BugReporter { eval(cond, sp)(env) } eval(body, sp)(env) - case FunDef(_, args, _, fbody) => ??? + case FunDef(_, args, _, fbody) => + memory(sp) = Func(args map { arg => (arg.name) }, fbody, env) case LetRec(funs, body) => // TODO modify that code - eval(body, sp)(env) - case App(fun, args) => ??? - case ArrayDec(size, _) => ??? + val ids = (sp until sp + funs.length).toList + // Evaluate all functions + val funcs = (funs zip ids) map { case (fun@FunDef(name, _, _, _), idx) => (name, idx) } + + // Add all functions to the functions environment (recursion) + (funs zip ids) foreach { case (fun@FunDef(_, _, _, _), idx:Int) => eval(fun, idx)(env.withVals(funcs)) } + + eval(body, sp + funs.length)(env.withVals(funcs)) + memory(sp) = memory(sp + funs.length) + + case App(fun, args) => + val ids = (sp until sp + args.length).toList + + (args zip ids) map { case (arg, idx) => eval(arg, idx)(env) } + + eval(fun, sp + args.length)(env) + memory(sp + args.length) match { + case Func(fargs, fbody, fenv) => + eval(fbody, sp + args.length + 1)(fenv.withVals(fargs zip ids)) + memory(sp) = memory(sp + args.length + 1) + case Primitive("getchar") => + memory(sp) = Cst(Console.in.read) + case Primitive("putchar") => + val List(Cst(c: Int)) = memory(sp) + Console.out.write(c) + Console.out.flush + Cst(()) + case Primitive("toInt") => + val List(Cst(c: Char)) = memory(sp) + memory(sp) = Cst(c.toInt) + case Primitive("toChar") => + val List(Cst(c: Int)) = memory(sp) + memory(sp) = Cst(c.toChar) + } + case ArrayDec(size, _) => + eval(size, sp)(env) + val Cst(idx: Int) = memory(sp) + memory(sp) = Cst(new Array[Any](idx)) } } diff --git a/proj3/src/main/scala/project3/Main.scala b/proj3/src/main/scala/project3/Main.scala index 4a40ae6..ada15f3 100644 --- a/proj3/src/main/scala/project3/Main.scala +++ b/proj3/src/main/scala/project3/Main.scala @@ -41,7 +41,6 @@ OPTION: intStack""") } else { args(0) } - println("============ SRC CODE ============") println(src) println("==================================\n") @@ -51,7 +50,7 @@ OPTION: intStack""") // Parser to test! // TODO: Change this as you finish parsers - val parser = new BaseParser(scanner) + val parser = new ArrayParser(scanner) val ast = try { parser.parseCode } catch { diff --git a/proj3/src/main/scala/project3/Parser.scala b/proj3/src/main/scala/project3/Parser.scala index 417eb2d..d36c614 100644 --- a/proj3/src/main/scala/project3/Parser.scala +++ b/proj3/src/main/scala/project3/Parser.scala @@ -99,11 +99,11 @@ class Scanner(in: Reader[Char]) extends Reader[Tokens.Token] with Reporter { // List of delimiters // TODO: Update this as delimiters are added to our language - val isDelim = Set('(',')','=',';','{','}',':') + val isDelim = Set('(',')','=',';','{','}',':', ',', '[', ']') // List of keywords // TODO: Update this as keywords are added to our language - val isKeyword = Set("if", "else", "val", "var", "while") + val isKeyword = Set("if", "else", "val", "var", "while", "def", "=>", "new") val isBoolean = Set("true", "false") @@ -119,7 +119,7 @@ class Scanner(in: Reader[Char]) extends Reader[Tokens.Token] with Reporter { } val s = buf.toString if (isKeyword(s)) Keyword(s) - else if (isBoolean(s)) Boolean(s) + else if (isBoolean(s)) Literal(s.toBoolean) else Ident(s) } @@ -169,6 +169,12 @@ class Scanner(in: Reader[Char]) extends Reader[Tokens.Token] with Reporter { getOperator() } else if (in.hasNext(isDigit)) { getNum() + } else if (in.hasNext(Set('\''))) { + in.next() + val s = in.next() + if (in.hasNext(_ == '\'')) in.next() + else expected(s"\'") + Literal(s) } else if (in.hasNext(isDelim)) { Delim(in.next()) } else if (!in.hasNext) { @@ -372,6 +378,7 @@ object Language { case class ArrayType(tp: Type) extends Type val IntType = BaseType("Int") + val CharType = BaseType("Char") val UnitType = BaseType("Unit") val BooleanType = BaseType("Boolean") @@ -448,6 +455,18 @@ class BaseParser(in: Scanner) extends Parser(in) { * TODO: Implement this function */ def parseType: Type = in.peek match { + case Ident("Int") => + in.next() + IntType + case Ident("Char") => + in.next() + CharType + case Ident("Boolean") => + in.next() + BooleanType + case Ident("Unit") => + in.next() + UnitType case _ => expected("type") } @@ -460,6 +479,9 @@ class BaseParser(in: Scanner) extends Parser(in) { * TODO: Implement this function */ def parseOptionalType: Type = in.peek match { + case Delim(':') => + in.next() + parseType case _ => UnknownType } @@ -626,14 +648,37 @@ class SyntacticSugarParser(in: Scanner) extends BaseParser(in) { suf + "$" + next } - override def parseSimpleExpression = in.peek match { + override def parseSimpleExpression = in.peek match { // !!! + case Keyword("if") => + val pos = in.next().pos + accept('(') + val cond = parseSimpleExpression + accept(')') + val tBranch = parseSimpleExpression + val res = if (in.peek == Keyword("else")) { + in.next() + If(cond, tBranch, parseSimpleExpression).withPos(pos) + } else { + If(cond, tBranch, Lit(Unit)).withPos(pos) + } + res case _ => super.parseSimpleExpression } override def parseExpression = { // NOTE: parse expression terminates when it parse a simples expression. - // syntax sugar allows to have an other expression after it. - val res = super.parseExpression + // syntax sugar allows to have an other expression after it. !!! + var res = in.peek match { + case Keyword("val") | Keyword("var") | Keyword("while") => + super.parseExpression + case _ => + var simp = parseSimpleExpression + if(isNewLine(in.peek)) { + accept(';') + simp = Let(freshName(), UnknownType, simp, parseExpression).withPos(simp.pos) + } + simp + } res } @@ -719,10 +764,30 @@ class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { /* * This function parse types. * - * TODO + * TODO !!! */ override def parseType = in.peek match { - case _ => super.parseType + case Delim('(') => + accept('(') + val list = parseList[Type](parseType, ',', tok => tok match { + case Delim(')') => false; + case _ => true + }) + val argList = list map { tp => ("", tp)} + accept("=>") + val rtp = parseType + FunType(argList, rtp) + case _ => + val typ = super.parseType + if(in.peek == Delim('=') && in.peek1 == Ident(">")) { + accept("=>") + val rtp = super.parseType + val typList: List[(String, Type)] = List(("", typ)) + FunType(typList, rtp) + } + else { + typ + } } /* @@ -742,7 +807,9 @@ class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { * TODO: complete the function */ def parseArg: Arg = { - ??? + val (arg, pos) = getName() + val ty = parseOptionalType + Arg(arg, ty, pos) } /* @@ -752,7 +819,16 @@ class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { * TODO: complete the function */ def parseFunction: Exp = { - ??? + accept("def") + val (fname, pos) = getName() + accept('(') + val args = parseList[Arg](parseArg, ',', _ match { + case Ident(_) => true; case _ => false + }) + accept(')') + val retTy = parseOptionalType + accept('=') + FunDef(fname, args, retTy, parseSimpleExpression).withPos(pos) } /* @@ -765,6 +841,12 @@ class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { * TODO: complete the function */ def parseProgram = in.peek match { + case Keyword("def") => + val fns = parseList[Exp](parseFunction, ';', _ match { + case Keyword("def") => true; case _ => false + }) + accept(';') + LetRec(fns, parseExpression) case _ => LetRec(Nil, parseExpression) } @@ -798,7 +880,15 @@ class FunctionParser(in: Scanner) extends SyntacticSugarParser(in) { res case _ => var res = parseAtom - // TODO: complete + while(in.peek == Delim('(')) { + val pos = in.next().pos + val params = parseList[Exp](parseSimpleExpression, ',', _ match { + case Delim(')') => false; case _ => true + }) + accept(')') + res = App(res, params).withPos(pos) + } + // Done: completed res } @@ -865,7 +955,12 @@ class ArrayParser(in: Scanner) extends FunctionParser(in) { import Tokens._ override def parseType = in.peek match { - case Ident("Array") => ??? + case Ident("Array") => + in.next() + accept('[') + val ty = parseType + accept(']') + ArrayType(ty) case _ => super.parseType } @@ -874,12 +969,37 @@ class ArrayParser(in: Scanner) extends FunctionParser(in) { * * TODO */ - override def parseTight = ??? + override def parseTight : Exp = { + val lhs = super.parseTight + if (in.peek == Delim('=')) { + accept('=') + lhs match { + case App(arr, List(idx)) => + Prim("block-set", List(arr, idx, parseSimpleExpression)).withPos(lhs.pos) + case _ => + error("Unsupported assignment") + lhs + } + } + else + lhs + } /* * Parse array declaration * * TODO */ - override def parseSimpleExpression = ??? + override def parseSimpleExpression = in.peek match { + case Keyword("new") => + val pos = in.next().pos + val ty = parseType + accept('(') + val expr = parseSimpleExpression + accept(')') + ArrayDec(expr, ty).withPos(pos) + case _ => + super.parseSimpleExpression + + } } diff --git a/proj3/src/main/scala/project3/SemanticAnalyzer.scala b/proj3/src/main/scala/project3/SemanticAnalyzer.scala index 3e4901f..ae70955 100644 --- a/proj3/src/main/scala/project3/SemanticAnalyzer.scala +++ b/proj3/src/main/scala/project3/SemanticAnalyzer.scala @@ -7,6 +7,8 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { * Primitive functions that do not need to be defined or declared. */ val primitives = Map[String,(Boolean,Type)]( + "toInt" -> (false, FunType(List(("", CharType)), IntType)), + "toChar" -> (false, FunType(List(("", IntType)), CharType)), "getchar" -> (false, FunType(List(), IntType)), "putchar" -> (false, FunType(List(("", IntType)), UnitType)) ) @@ -163,7 +165,10 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { * TODO: implement the remaining binary operators for typeBinOperator */ def typeBinOperator(op: String)(pos: Position) = op match { - case "+" => FunType(List(("", IntType), ("", IntType)), IntType) + case "+" | "-" | "*" | "/" => + FunType(List(("", IntType), ("", IntType)), IntType) + case ">=" | "<=" | ">" | "<" | "==" | "!=" => + FunType(List(("", IntType), ("", IntType)), BooleanType) case _ => error("undefined binary operator", pos) UnknownType @@ -177,6 +182,8 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { * TODO: implement typeUnOperator */ def typeUnOperator(op: String)(pos: Position) = op match { + case "+" | "-" => FunType(List(("", IntType)), IntType) +// case "!" => FunType(List(("", BooleanType)), BooleanType) case _ => error(s"undefined unary operator", pos) UnknownType @@ -188,6 +195,8 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { * operators: block-set */ def typeTerOperator(op: String)(pos: Position) = op match { + case "block-set" => + FunType(List(("", ArrayType(IntType)), ("", IntType)), IntType) case _ => error(s"undefined ternary operator", pos) UnknownType @@ -215,7 +224,8 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { case (_, UnknownType) => typeWellFormed(tp)(env, pos) // tp <: Any case (UnknownType, _) => typeWellFormed(pt)(env, pos) // for function arguments case (FunType(args1, rtp1), FunType(args2, rtp2)) if args1.length == args2.length => - ??? // TODO: Function type conformity + FunType(typeConform(args1, args2)(env, pos), typeConforms(rtp1, rtp2)(env, pos)) + // Done: Function type conformity case (ArrayType(tp), ArrayType(pt)) => ArrayType(typeConforms(tp, pt)(env, pos)) case _ => error(s"type mismatch;\nfound : $tp\nexpected: $pt", pos); pt } @@ -281,12 +291,22 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { def typeInfer(exp: Exp, pt: Type)(env: TypeEnv): Exp = exp match { case Lit(_: Int) => exp.withType(IntType) - case Lit(_: Boolean) => ??? - case Lit(_: Unit) => ??? - case Prim("block-set", args) => ??? + case Lit(_: Boolean) => exp.withType(BooleanType) + case Lit(_: Char) => exp.withType(CharType) + case Lit(_: Unit) => exp.withType(UnitType) + case Prim("block-set", args) => + Prim("block-set", List( + typeCheck(args(0), ArrayType(pt))(env), + typeCheck(args(1), IntType)(env), + typeCheck(args(2), pt)(env) + )).withType(UnitType) case Prim(op, args) => typeOperator(op, args.length)(exp.pos) match { - case FunType(atps, rtp) => ??? + case FunType(atps, rtp) => + val args_checked = (args zip atps) map { + case (arg, (_, param)) => typeCheck(arg, param)(env) + } + Prim(op, args_checked).withType(rtp) case UnknownType => exp.withType(UnknownType) case _ => BUG("operator's type needs to be FunType") } @@ -298,18 +318,28 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { Let(x, nrhs.tp, nrhs, nbody).withType(nbody.tp) case Ref(x) => env(x) match { - case Some(tp) => ??? // Remember to check that the type taken from the environment is welformed + case Some(tp) => + exp.withType(typeWellFormed(tp)(env, exp.pos)) + // Remember to check that the type taken from the environment is wellformed case _ => error("undefined identifier", exp.pos) - ??? + exp.withType(UnknownType) } case If(cond, tBranch, eBranch) => // Hint: type check the else branch before the then branch. - ??? + val condChecked = typeCheck(cond, BooleanType)(env) + val elseChecked = typeCheck(eBranch, pt)(env) + val thenChecked = typeCheck(tBranch, elseChecked.tp)(env) + If(condChecked, thenChecked, elseChecked).withType(elseChecked.tp) + case VarDec(x, tp, rhs, body) => if (env.isDefined(x)) warn("reuse of variable name", exp.pos) - ??? + val rhsChecked = typeCheck(rhs, tp)(env) + val _env = env.withVar(x, rhsChecked.tp) + val bodyChecked = typeCheck(body, pt)(_env) + VarDec(x, rhsChecked.tp, rhsChecked, bodyChecked).withType(bodyChecked.tp) + case VarAssign(x, rhs) => val xtp = if (!env.isDefined(x)) { error("undefined identifier", exp.pos) @@ -320,7 +350,7 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { env(x).get } - ??? + val rhsChecked = typeCheck(rhs, xtp)(env) /* Because of syntactic sugar, a variable assignment * statement can be accepted as an expression @@ -338,18 +368,35 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { * Without changing the semantics! */ pt match { - case UnitType => ??? - case _ => ??? + case UnitType | IntType | BooleanType | CharType => VarAssign(x, rhsChecked).withType(rhsChecked.tp) + case _ => exp.withType(UnknownType) + } + case While(cond, lbody, body) => + val condChecked = typeCheck(cond, BooleanType)(env) + val lbodyChecked = typeCheck(lbody, UnitType)(env) + val bodyChecked = typeCheck(body, pt)(env) + While(condChecked, lbodyChecked, bodyChecked).withType(bodyChecked.tp) + + case FunDef(fname, args, rtp, fbody) => + checkDuplicateNames(args) + val argsWType = args map { + a => (a.name, typeWellFormed(a.tp)(env, a.pos)) } - case While(cond, lbody, body) => ??? - case FunDef(fname, args, rtp, fbody) => ??? + val bodyChecked = typeCheck(fbody, rtp)(env.withVals(argsWType)) // check rtp w/ pt? + FunDef(fname, args, bodyChecked.tp, bodyChecked).withType( + FunType(argsWType, bodyChecked.tp) + ) + case LetRec(funs, body) => // TODO modify to handle general case - val nbody = typeCheck(body, pt)(env) - LetRec(Nil, nbody).withType(nbody.tp) + val _env = env.withVals(funs.map {case FunDef(name, _, rtp, _) => (name, rtp)}) + val funsChecked = funs map {case f@FunDef(_, _, rtp, _) => typeCheck(f, rtp)(_env) } + val bodyChecked = typeCheck(body, pt)( env.withVals(funsChecked map {case FunDef(name, _, rtp, _) => (name, rtp)})) + LetRec(funsChecked, bodyChecked).withType(bodyChecked.tp) + case App(fun, args) => // TODO Check fun type - val nFun: Exp = ??? + val nFun: Exp = typeCheck(fun, fun.tp)(env) // Handling some errors val ftp = nFun.tp match { @@ -369,7 +416,9 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { } // TODO: Check arguments type - val nargs: List[Exp] = ??? + val nargs: List[Exp] = (args zip ftp.args) map { + case (arg, (_, param)) => typeCheck(arg, param)(env) + } // Transform some function applications into primitives on arrays. nFun.tp match { @@ -380,7 +429,7 @@ class SemanticAnalyzer(parser: Parser) extends Reporter with BugReporter { case ArrayDec(size: Exp, etp: Type) => // TODO: Check array declaration // Note that etp is the type of elements - ??? + ArrayDec(typeCheck(size, IntType)(env), etp).withType(etp) case _ => BUG(s"malformed expresstion $exp") } }