add raw pointers

This commit is contained in:
nora 2023-08-05 14:20:24 +02:00
parent 5bac67b84c
commit 64d81b5608
8 changed files with 101 additions and 29 deletions

View file

@ -403,6 +403,10 @@ export type TypeKind<P extends Phase> =
kind: "tuple";
elems: Type<P>[];
}
| {
kind: "rawptr";
inner: Type<P>;
}
| { kind: "never" };
export type Type<P extends Phase> = TypeKind<P> & {
@ -514,6 +518,11 @@ export type TyStruct = {
fields: [string, Ty][];
};
export type TyRawPtr = {
kind: "rawptr";
inner: TyStruct;
};
export type TyNever = {
kind: "never";
};
@ -528,6 +537,7 @@ export type Ty =
| TyFn
| TyVar
| TyStruct
| TyRawPtr
| TyNever;
export function tyIsUnit(ty: Ty): ty is TyUnit {
@ -838,6 +848,12 @@ export function superFoldType<From extends Phase, To extends Phase>(
span,
};
}
case "rawptr": {
return {
...type,
inner: folder.type(type.inner),
};
}
case "never": {
return { ...type, kind: "never" };
}

View file

@ -19,6 +19,7 @@ import {
superFoldExpr,
superFoldItem,
varUnreachable,
TyRawPtr,
} from "./ast";
import { GlobalContext } from "./context";
import { unreachable } from "./error";
@ -443,8 +444,10 @@ function tryLowerLValue(
instrs: wasm.Instr[],
expr: Expr<Typecked>,
): LValue | undefined {
const fieldParts = (ty: TyStruct, fieldIdx: number): MemoryLayoutPart[] =>
unwrap(layoutOfStruct(ty).fields[fieldIdx]).types;
const fieldParts = (
ty: TyStruct | TyRawPtr,
fieldIdx: number,
): MemoryLayoutPart[] => unwrap(layoutOfStruct(ty).fields[fieldIdx]).types;
switch (expr.kind) {
case "ident":
@ -489,7 +492,8 @@ function tryLowerLValue(
// _potentially_ a localTupleField
todo("tuple access");
}
case "struct": {
case "struct":
case "rawptr": {
// Codegen the base, this leaves us with the base pointer on the stack.
// Do not increment the refcount for an lvalue base.
lowerExpr(fcx, instrs, expr.lhs, true);
@ -908,7 +912,8 @@ function lowerExpr(
break;
}
case "struct": {
case "struct":
case "rawptr": {
const ty = expr.lhs.ty;
const layout = layoutOfStruct(ty);
const field = layout.fields[expr.field.fieldIdx!];
@ -1027,7 +1032,7 @@ function lowerExpr(
}
case "structLiteral": {
if (expr.ty.kind !== "struct") {
throw new Error("struct literal must have struct type");
unreachable("struct literal must have struct type");
}
const layout = layoutOfStruct(expr.ty);
@ -1236,6 +1241,7 @@ function argRetAbi(param: Ty): ArgRetAbi {
case "tuple":
return param.elems.flatMap(argRetAbi);
case "struct":
case "rawptr":
return ["i32"];
case "never":
return [];
@ -1290,6 +1296,7 @@ function wasmTypeForBody(ty: Ty): wasm.ValType[] {
case "fn":
todo("fn types");
case "struct":
case "rawptr":
return ["i32"];
case "never":
return [];
@ -1313,7 +1320,8 @@ function sizeOfValtype(type: wasm.ValType): number {
}
}
export function layoutOfStruct(ty: TyStruct): StructLayout {
export function layoutOfStruct(ty_: TyStruct | TyRawPtr): StructLayout {
const ty = ty_.kind === "struct" ? ty_ : ty_.inner;
const fieldWasmTys = ty.fields.map(([, field]) => wasmTypeForBody(field));
// TODO: Use the max alignment instead.

View file

@ -14,28 +14,16 @@ import { loadCrate } from "./loader";
const INPUT = `
type A = struct { a: Int };
type Complex = struct {
a: Int,
b: (Int, A),
};
function main() = (
let a = A { a: 0 };
// let b = 0;
let c = Complex { a: 0, b: (0, a) };
write(a);
std.printlnInt(a.a);
// ptr = c + offset(b) + offset(1)
// a = load(ptr)
// store(a + offset(a), 1)
// c.b.1.a = 1;
rawr(___transmute(a));
std.printInt(a.a);
);
function write(a: A) = a.a = 1;
function ret(): A = A { a: 0 };
function rawr(a: *A) = (
a.a = 1;
);
`;
function main() {

View file

@ -650,6 +650,12 @@ function parseType(t: State): [State, Type<Parsed>] {
return [t, { kind: "tuple", elems: [head, ...tail], span }];
}
case "*": {
let inner;
[t, inner] = parseType(t);
return [t, { kind: "rawptr", inner, span }];
}
default: {
throw new CompilerError(
`unexpected token: \`${tok.kind}\`, expected type`,

View file

@ -214,6 +214,8 @@ function printType(type: Type<AnyPhase>): string {
return `[${printType(type.elem)}]`;
case "tuple":
return `(${type.elems.map(printType).join(", ")})`;
case "rawptr":
return `*${printType(type.inner)}`;
case "never":
return "!";
}
@ -266,6 +268,9 @@ export function printTy(ty: Ty): string {
case "struct": {
return ty._name;
}
case "rawptr": {
return `*${printTy(ty.inner)}`;
}
case "never": {
return "!";
}

View file

@ -131,6 +131,17 @@ function lowerAstTy(cx: TypeckCtx, type: Type<Resolved>): Ty {
elems: type.elems.map((type) => lowerAstTy(cx, type)),
};
}
case "rawptr": {
const inner = lowerAstTy(cx, type.inner);
if (inner.kind !== "struct") {
throw new CompilerError(
"raw pointers must point to structs",
type.span,
);
}
return { kind: "rawptr", inner };
}
case "never": {
return TY_NEVER;
}
@ -778,11 +789,12 @@ export function checkBody(
}
break;
}
case "struct": {
case "struct":
case "rawptr": {
const fields =
lhs.ty.kind === "struct" ? lhs.ty.fields : lhs.ty.inner.fields;
if (typeof field.value === "string") {
const idx = lhs.ty.fields.findIndex(
([name]) => name === field.value,
);
const idx = fields.findIndex(([name]) => name === field.value);
if (idx === -1) {
throw new CompilerError(
`field \`${field.value}\` does not exist on ${printTy(
@ -792,7 +804,7 @@ export function checkBody(
);
}
ty = lhs.ty.fields[idx][1];
ty = fields[idx][1];
fieldIdx = idx;
} else {
throw new CompilerError(
@ -962,7 +974,10 @@ function checkLValue(expr: Expr<Typecked>) {
checkLValue(expr.lhs);
break;
default:
throw new CompilerError("invalid left-hand side of assignment", expr.span);
throw new CompilerError(
"invalid left-hand side of assignment",
expr.span,
);
}
}

13
ui-tests/raw_ptr.nil Normal file
View file

@ -0,0 +1,13 @@
//@check-pass
type A = struct { a: Int };
function main() = (
let a = A { a: 0 };
rawr(___transmute(a));
std.printInt(a.a);
);
function rawr(a: *A) = (
a.a = 1;
);

View file

@ -0,0 +1,21 @@
//@check-pass
type A = struct { a: Int };
type Complex = struct {
a: Int,
b: (Int, A),
};
function main() = (
let a = A { a: 0 };
// let b = 0;
// let c = Complex { a: 0, b: (0, a) };
write(a);
// c.b.1.a = 1; TODO
);
function write(a: A) = a.a = 1;
function ret(): A = A { a: 0 };