diff --git a/src/ast.ts b/src/ast.ts index 805c323..750f7e0 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -15,6 +15,7 @@ export type ItemKind = { export type Item = ItemKind & { span: Span; + id: number; }; export type FunctionDef = { @@ -94,6 +95,7 @@ export type ExprKind = export type Expr = ExprKind & { span: Span; + ty?: Ty; }; export type Literal = @@ -128,6 +130,7 @@ export const COMPARISON_KINDS: BinaryKind[] = [ ">=", "!=", ]; +export const EQUALITY_KINDS: BinaryKind[] = ["==", "!="]; export const LOGICAL_KINDS: BinaryKind[] = ["&", "|"]; export const ARITH_TERM_KINDS: BinaryKind[] = ["+", "-"]; export const ARITH_FACTOR_KINDS: BinaryKind[] = ["*", "/"]; @@ -174,6 +177,7 @@ export type TypeKind = export type Type = TypeKind & { span: Span; + ty?: Ty; }; export type Resolution = @@ -202,8 +206,53 @@ export type Resolution = } | { kind: "builtin"; + name: string; }; +export type TyString = { + kind: "string"; +}; + +export type TyInt = { + kind: "int"; +}; + +export type TyBool = { + kind: "bool"; +}; + +export type TyList = { + kind: "list"; + elem: Ty; +}; + +export type TyTuple = { + kind: "tuple"; + elems: Ty[]; +}; + +export type TyUnit = { + kind: "tuple"; + elems: []; +}; + +export type TyFn = { + kind: "fn"; + params: Ty[]; + returnTy: Ty; +}; + +export type TyVar = { + kind: "var"; + index: number; +}; + +export type Ty = TyString | TyInt | TyBool | TyList | TyTuple | TyFn | TyVar; + +export function tyIsUnit(ty: Ty): ty is TyUnit { + return ty.kind === "tuple" && ty.elems.length === 0; +} + // folders export type FoldFn = (value: T) => T; @@ -252,6 +301,7 @@ export function super_fold_item(item: Item, folder: Folder): Item { body: folder.expr(item.node.body), returnType: item.node.returnType && folder.type(item.node.returnType), }, + id: item.id, }; } } diff --git a/src/index.ts b/src/index.ts index 9698fe6..10d9796 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,19 +3,15 @@ import { tokenize } from "./lexer"; import { parse } from "./parser"; import { printAst } from "./printer"; import { resolve } from "./resolve"; +import { typeck } from "./typeck"; const input = ` -function main(argv: [String]): () = ( - print(argv); - if 1 then ( - print("AAAAAAAAAAAAAAAAAAAA"); - let a = 0 in - a; - ) else ( - print("AAAAAAAAAAAAAAAAAAAAAA"); - let b = 0 in - b; - ) +function main() = ( + let a = 0 in + let b = a in + let c = b in + let d = c in + d; ); `; @@ -38,6 +34,11 @@ function main() { console.log("-----AST resolved------"); const resolvedPrinted = printAst(resolved); console.log(resolvedPrinted); + + console.log("-----AST typecked------"); + + const typecked = typeck(resolved); + console.dir(typecked, { depth: 10 }); }); } diff --git a/src/parser.ts b/src/parser.ts index f5cfdad..40e4c6b 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -27,7 +27,9 @@ export function parse(t: Token[]): Ast { items.push(item); } - return items; + const withIds = items.map((item, i) => ({ ...item, id: i })); + + return withIds; } function parseItem(t: Token[]): [Token[], Item] { @@ -79,7 +81,16 @@ function parseItem(t: Token[]): [Token[], Item] { body, }; - return [t, { kind: "function", node: def, span: tok.span }]; + return [ + t, + { + kind: "function", + node: def, + span: tok.span, + // Assigned later. + id: 0, + }, + ]; } else { unexpectedToken(tok); } @@ -114,7 +125,7 @@ function parseExpr(t: Token[]): [Token[], Expr] { if (peak.kind === "let") { [t] = expectNext(t, "let"); let name; - [t, name] = expectNext(t, "identifier"); + [t, name] = expectNext(t, "identifier"); let type = undefined; let colon; @@ -132,7 +143,7 @@ function parseExpr(t: Token[]): [Token[], Expr] { return [ t, - { kind: "let", name: name.ident, type, rhs, after, span: t[0].span }, + { kind: "let", name: name.ident, type, rhs, after, span: name.span }, ]; } diff --git a/src/printer.ts b/src/printer.ts index ec330dd..0bb8860 100644 --- a/src/printer.ts +++ b/src/printer.ts @@ -5,7 +5,9 @@ import { Identifier, Item, Resolution, + Ty, Type, + tyIsUnit, } from "./ast"; export function printAst(ast: Ast): string { @@ -139,6 +141,33 @@ function printIdent(ident: Identifier): string { return `${ident.name}${res}`; } +export function printTy(ty: Ty): string { + switch (ty.kind) { + case "string": { + return "String"; + } + case "int": { + return "Int"; + } + case "bool": { + return "Bool"; + } + case "list": { + return `[${printTy(ty.elem)}]`; + } + case "tuple": { + return `(${ty.elems.map(printTy).join(", ")})`; + } + case "fn": { + const ret = tyIsUnit(ty.returnTy) ? "" : `: ${printTy(ty.returnTy)}`; + return `fn(${ty.params.map(printTy).join(", ")})${ret}`; + } + case "var": { + return `?${ty.index}`; + } + } +} + function linebreak(indent: number): string { return `\n${ind(indent)}`; } diff --git a/src/resolve.ts b/src/resolve.ts index a351a62..4a4ec15 100644 --- a/src/resolve.ts +++ b/src/resolve.ts @@ -10,7 +10,14 @@ import { } from "./ast"; import { CompilerError } from "./error"; -const BUILTINS = new Set(["print", "String"]); +const BUILTINS = new Set([ + "print", + "String", + "Int", + "Bool", + "true", + "false", +]); export function resolve(ast: Ast): Ast { const items = new Map(); @@ -60,9 +67,9 @@ export function resolve(ast: Ast): Ast { } if (BUILTINS.has(ident.name)) { - return { kind: "builtin" }; + return { kind: "builtin", name: ident.name }; } - + throw new CompilerError(`cannot find ${ident.name}`, ident.span); }; @@ -92,6 +99,7 @@ export function resolve(ast: Ast): Ast { returnType, body, }, + id: item.id, }; } } diff --git a/src/typeck.ts b/src/typeck.ts new file mode 100644 index 0000000..bce1afd --- /dev/null +++ b/src/typeck.ts @@ -0,0 +1,553 @@ +import { check } from "prettier"; +import { + Ast, + COMPARISON_KINDS, + DEFAULT_FOLDER, + EQUALITY_KINDS, + Expr, + ExprBinary, + ExprCall, + ExprUnary, + Folder, + Identifier, + LOGICAL_KINDS, + Resolution, + Ty, + TyFn, + TyVar, + Type, + binaryExprPrecedenceClass, + fold_ast, + super_fold_expr, +} from "./ast"; +import { CompilerError, Span } from "./error"; +import { printTy } from "./printer"; + +const TY_UNIT: Ty = { kind: "tuple", elems: [] }; +const TY_STRING: Ty = { kind: "string" }; +const TY_BOOL: Ty = { kind: "bool" }; +const TY_INT: Ty = { kind: "int" }; + +function builtinAsTy(name: string, span: Span): Ty { + switch (name) { + case "String": { + return TY_STRING; + } + case "Int": { + return TY_INT; + } + case "Bool": { + return TY_BOOL; + } + default: { + throw new CompilerError(`\`${name}\` is not a type`, span); + } + } +} + +function typeOfBuiltinValue(name: string, span: Span): Ty { + switch (name) { + case "false": + case "true": + return TY_BOOL; + case "print": + return { kind: "fn", params: [TY_STRING], returnTy: TY_UNIT }; + default: { + throw new CompilerError(`\`${name}\` cannot be used as a value`, span); + } + } +} + +function lowerAstTyBase( + type: Type, + lowerIdentTy: (ident: Identifier) => Ty, + typeOfItem: (index: number) => Ty +): Ty { + switch (type.kind) { + case "ident": { + const res = type.value.res!; + switch (res.kind) { + case "local": { + throw new Error("Cannot resolve local here"); + } + case "item": { + return typeOfItem(res.index); + } + case "builtin": { + return builtinAsTy(res.name, type.value.span); + } + } + } + case "list": { + return { + kind: "list", + elem: lowerAstTyBase(type.elem, lowerIdentTy, typeOfItem), + }; + } + case "tuple": { + return { + kind: "tuple", + elems: type.elems.map((type) => + lowerAstTyBase(type, lowerIdentTy, typeOfItem) + ), + }; + } + } +} + +export function typeck(ast: Ast): Ast { + const itemTys = new Map(); + function typeOfItem(index: number): Ty { + const ty = itemTys.get(index); + if (ty) { + return ty; + } + if (ty === null) { + throw Error(`cycle computing type of #G${index}`); + } + itemTys.set(index, null); + const item = ast[index]; + switch (item.kind) { + case "function": { + const args = item.node.args.map((arg) => lowerAstTy(arg.type)); + const returnTy: Ty = item.node.returnType + ? lowerAstTy(item.node.returnType) + : TY_UNIT; + + return { kind: "fn", params: args, returnTy }; + } + } + } + + function lowerAstTy(type: Type): Ty { + return lowerAstTyBase( + type, + (ident) => { + const res = ident.res!; + switch (res.kind) { + case "local": { + throw new Error("Cannot resolve local here"); + } + case "item": { + return typeOfItem(res.index); + } + case "builtin": { + return builtinAsTy(res.name, ident.span); + } + } + }, + typeOfItem + ); + } + + const checker: Folder = { + ...DEFAULT_FOLDER, + item(item) { + switch (item.kind) { + case "function": { + const fnTy = typeOfItem(item.id) as TyFn; + const body = checkBody(item.node.body, fnTy, typeOfItem); + + const returnType = item.node.returnType && { + ...item.node.returnType, + ty: fnTy.returnTy, + }; + + return { + kind: "function", + node: { + name: item.node.name, + args: item.node.args.map((arg, i) => ({ + ...arg, + type: { ...arg.type, ty: fnTy.params[i] }, + })), + body, + returnType, + }, + span: item.span, + id: item.id, + }; + } + } + }, + }; + + const withTypes = fold_ast(ast, checker); + + return withTypes; +} + +type TyVarRes = + | { + kind: "final"; + ty: Ty; + } + | { + kind: "unified"; + index: number; + } + | { + kind: "unknown"; + }; + +export function checkBody( + body: Expr, + fnTy: TyFn, + typeOfItem: (index: number) => Ty +): Expr { + const localTys = [...fnTy.params]; + const tyVars: TyVarRes[] = []; + + function newVar(): Ty { + const index = tyVars.length; + tyVars.push({ kind: "unknown" }); + return { kind: "var", index }; + } + + function typeOf(res: Resolution, span: Span): Ty { + switch (res.kind) { + case "local": { + const idx = localTys.length - 1 - res.index; + return localTys[idx]; + } + case "item": { + return typeOfItem(res.index); + } + case "builtin": + return typeOfBuiltinValue(res.name, span); + } + } + + function lowerAstTy(type: Type): Ty { + return lowerAstTyBase( + type, + (ident) => { + const res = ident.res!; + return typeOf(res, ident.span); + }, + typeOfItem + ); + } + + function tryResolveVar(variable: number): Ty | undefined { + const varRes = tyVars[variable]; + switch (varRes.kind) { + case "final": { + return varRes.ty; + } + case "unified": { + const ty = tryResolveVar(varRes.index); + if (ty) { + tyVars[variable] = { kind: "final", ty }; + return ty; + } else { + return undefined; + } + } + case "unknown": { + return undefined; + } + } + } + + /** + * Try to constrain a type variable to be of a specific type. + * INVARIANT: Both sides must not be of res "final", use resolveIfPossible + * before calling this. + */ + function constrainVar(variable: number, ty: Ty) { + if (ty.kind === "var") { + // Point the lhs to the rhs. + tyVars[variable] = { kind: "unified", index: ty.index }; + } + + let idx = variable; + let nextVar; + while ((nextVar = tyVars[idx]).kind === "unified") { + idx = nextVar.index; + } + + const root = idx; + tyVars[root] = { kind: "final", ty }; + } + + function resolveIfPossible(ty: Ty): Ty { + if (ty.kind === "var") { + return tryResolveVar(ty.index) ?? ty; + } else { + return ty; + } + } + + function assign(lhs_: Ty, rhs_: Ty, span: Span) { + const lhs = resolveIfPossible(lhs_); + const rhs = resolveIfPossible(rhs_); + + if (lhs.kind === "var") { + constrainVar(lhs.index, rhs); + return; + } + if (rhs.kind === "var") { + constrainVar(rhs.index, lhs); + return; + } + // type variable handling here + + switch (lhs.kind) { + case "string": { + if (rhs.kind === "string") return; + break; + } + case "int": { + if (rhs.kind === "int") return; + break; + } + case "bool": { + if (rhs.kind === "bool") return; + break; + } + case "list": { + if (rhs.kind === "list") { + assign(lhs.elem, rhs.elem, span); + return; + } + break; + } + case "tuple": { + if (rhs.kind === "tuple" && lhs.elems.length === rhs.elems.length) { + lhs.elems.forEach((lhs, i) => assign(lhs, rhs.elems[i], span)); + return; + } + break; + } + case "fn": { + if (rhs.kind === "fn" && lhs.params.length === rhs.params.length) { + // swapping because of contravariance in the future maybe + lhs.params.forEach((lhs, i) => assign(rhs.params[i], lhs, span)); + + assign(lhs.returnTy, rhs.returnTy, span); + + return; + } + break; + } + } + + throw new CompilerError( + `cannot assign ${printTy(rhs)} to ${printTy(lhs)}`, + span + ); + } + + const checker: Folder = { + ...DEFAULT_FOLDER, + expr(expr) { + switch (expr.kind) { + case "empty": { + return { ...expr, ty: TY_UNIT }; + } + case "let": { + const loweredBindingTy = expr.type && lowerAstTy(expr.type); + let bindingTy = loweredBindingTy ? loweredBindingTy : newVar(); + + const rhs = this.expr(expr.rhs); + assign(bindingTy, rhs.ty!, expr.span); + + localTys.push(bindingTy); + const after = this.expr(expr.after); + localTys.pop(); + + const type: Type | undefined = loweredBindingTy && { + ...expr.type!, + ty: loweredBindingTy!, + }; + + return { + kind: "let", + name: expr.name, + type, + rhs, + after, + ty: after.ty!, + span: expr.span, + }; + } + case "block": { + const exprs = expr.exprs.map((expr) => this.expr(expr)); + const ty = exprs.length > 0 ? exprs[exprs.length - 1].ty! : TY_UNIT; + + return { + kind: "block", + exprs, + ty, + span: expr.span, + }; + } + case "literal": { + let ty; + switch (expr.value.kind) { + case "str": { + ty = TY_STRING; + break; + } + case "int": { + ty = TY_INT; + break; + } + } + + return { ...expr, ty }; + } + case "ident": { + const ty = typeOf(expr.value.res!, expr.value.span); + + return { ...expr, ty }; + } + case "binary": { + const lhs = this.expr(expr.lhs); + const rhs = this.expr(expr.rhs); + + lhs.ty = resolveIfPossible(lhs.ty!); + rhs.ty = resolveIfPossible(rhs.ty!); + + return checkBinary({ ...expr, lhs, rhs }); + } + case "unary": { + const rhs = this.expr(expr.rhs); + rhs.ty = resolveIfPossible(rhs.ty!); + return checkUnary({ ...expr, rhs }); + } + case "call": { + const lhs = this.expr(expr.lhs); + lhs.ty = resolveIfPossible(lhs.ty!); + const lhsTy = lhs.ty!; + if (lhsTy.kind !== "fn") { + throw new CompilerError( + `expression of type ${printTy(lhsTy)} is not callable`, + lhs.span + ); + } + + const args = expr.args.map((arg) => this.expr(arg)); + + lhsTy.params.forEach((param, i) => { + if (!args[i]) { + throw new CompilerError( + `missing argument of type ${printTy(param)}`, + expr.span + ); + } + const arg = checker.expr(args[i]); + + assign(param, arg.ty!, args[i].span); + }); + + if (args.length > lhsTy.params.length) { + throw new CompilerError( + `too many arguments passed, expected ${lhsTy.params.length}, found ${args.length}`, + expr.span + ); + } + + return { ...expr, lhs, args, ty: lhsTy.returnTy }; + } + case "if": { + const cond = this.expr(expr.cond); + const then = this.expr(expr.then); + const elsePart = expr.else && this.expr(expr.else); + + assign(TY_BOOL, cond.ty!, cond.span); + + let ty; + if (elsePart) { + assign(then.ty!, elsePart.ty!, elsePart.span); + ty = then.ty!; + } else { + assign(TY_UNIT, then.ty!, then.span); + } + + return { ...expr, cond, then, else: elsePart, ty }; + } + } + }, + }; + + const checked = checker.expr(body); + + assign(fnTy.returnTy, checked.ty!, body.span); + + return checked; +} + +function checkBinary(expr: Expr & ExprBinary): Expr { + const checkPrecedence = (inner: Expr, side: string) => { + if (inner.kind === "binary") { + const ourClass = binaryExprPrecedenceClass(expr.binaryKind); + const innerClass = binaryExprPrecedenceClass(inner.binaryKind); + + if (ourClass !== innerClass) { + throw new CompilerError( + `mixing operators without parentheses is not allowed. ${side} is ${inner.binaryKind}, which is different from ${expr.binaryKind}`, + expr.span + ); + } + } + }; + + checkPrecedence(expr.lhs, "left"); + checkPrecedence(expr.rhs, "right"); + + let lhsTy = expr.lhs.ty!; + let rhsTy = expr.rhs.ty!; + + if (lhsTy.kind === "int" && rhsTy.kind === "int") { + return { ...expr, ty: TY_INT }; + } + + if (COMPARISON_KINDS.includes(expr.binaryKind)) { + if (lhsTy.kind === "string" && rhsTy.kind === "string") { + return { ...expr, ty: TY_BOOL }; + } + + if (EQUALITY_KINDS.includes(expr.binaryKind)) { + if (lhsTy.kind === "bool" && rhsTy.kind === "bool") { + return { ...expr, ty: TY_BOOL }; + } + } + } + + if (LOGICAL_KINDS.includes(expr.binaryKind)) { + if (lhsTy.kind === "bool" && rhsTy.kind === "bool") { + return { ...expr, ty: TY_BOOL }; + } + } + + throw new CompilerError( + `invalid types for binary operation: ${printTy(expr.lhs.ty!)} ${ + expr.binaryKind + } ${printTy(expr.rhs.ty!)}`, + expr.span + ); +} + +function checkUnary(expr: Expr & ExprUnary): Expr { + let rhsTy = expr.rhs.ty!; + + if ( + expr.unaryKind === "!" && + (rhsTy.kind === "int" || rhsTy.kind === "bool") + ) { + return { ...expr, ty: rhsTy }; + } + + if (expr.unaryKind === "-" && rhsTy.kind == "int") { + return { ...expr, ty: rhsTy }; + } + + throw new CompilerError( + `invalid types for unary operation: ${expr.unaryKind} ${printTy( + expr.rhs.ty! + )}`, + expr.span + ); +}