add raw pointers

This commit is contained in:
nora 2023-08-05 14:20:24 +02:00
parent 5bac67b84c
commit 64d81b5608
8 changed files with 101 additions and 29 deletions

View file

@ -403,6 +403,10 @@ export type TypeKind<P extends Phase> =
kind: "tuple"; kind: "tuple";
elems: Type<P>[]; elems: Type<P>[];
} }
| {
kind: "rawptr";
inner: Type<P>;
}
| { kind: "never" }; | { kind: "never" };
export type Type<P extends Phase> = TypeKind<P> & { export type Type<P extends Phase> = TypeKind<P> & {
@ -514,6 +518,11 @@ export type TyStruct = {
fields: [string, Ty][]; fields: [string, Ty][];
}; };
export type TyRawPtr = {
kind: "rawptr";
inner: TyStruct;
};
export type TyNever = { export type TyNever = {
kind: "never"; kind: "never";
}; };
@ -528,6 +537,7 @@ export type Ty =
| TyFn | TyFn
| TyVar | TyVar
| TyStruct | TyStruct
| TyRawPtr
| TyNever; | TyNever;
export function tyIsUnit(ty: Ty): ty is TyUnit { export function tyIsUnit(ty: Ty): ty is TyUnit {
@ -838,6 +848,12 @@ export function superFoldType<From extends Phase, To extends Phase>(
span, span,
}; };
} }
case "rawptr": {
return {
...type,
inner: folder.type(type.inner),
};
}
case "never": { case "never": {
return { ...type, kind: "never" }; return { ...type, kind: "never" };
} }

View file

@ -19,6 +19,7 @@ import {
superFoldExpr, superFoldExpr,
superFoldItem, superFoldItem,
varUnreachable, varUnreachable,
TyRawPtr,
} from "./ast"; } from "./ast";
import { GlobalContext } from "./context"; import { GlobalContext } from "./context";
import { unreachable } from "./error"; import { unreachable } from "./error";
@ -443,8 +444,10 @@ function tryLowerLValue(
instrs: wasm.Instr[], instrs: wasm.Instr[],
expr: Expr<Typecked>, expr: Expr<Typecked>,
): LValue | undefined { ): LValue | undefined {
const fieldParts = (ty: TyStruct, fieldIdx: number): MemoryLayoutPart[] => const fieldParts = (
unwrap(layoutOfStruct(ty).fields[fieldIdx]).types; ty: TyStruct | TyRawPtr,
fieldIdx: number,
): MemoryLayoutPart[] => unwrap(layoutOfStruct(ty).fields[fieldIdx]).types;
switch (expr.kind) { switch (expr.kind) {
case "ident": case "ident":
@ -489,7 +492,8 @@ function tryLowerLValue(
// _potentially_ a localTupleField // _potentially_ a localTupleField
todo("tuple access"); todo("tuple access");
} }
case "struct": { case "struct":
case "rawptr": {
// Codegen the base, this leaves us with the base pointer on the stack. // Codegen the base, this leaves us with the base pointer on the stack.
// Do not increment the refcount for an lvalue base. // Do not increment the refcount for an lvalue base.
lowerExpr(fcx, instrs, expr.lhs, true); lowerExpr(fcx, instrs, expr.lhs, true);
@ -908,7 +912,8 @@ function lowerExpr(
break; break;
} }
case "struct": { case "struct":
case "rawptr": {
const ty = expr.lhs.ty; const ty = expr.lhs.ty;
const layout = layoutOfStruct(ty); const layout = layoutOfStruct(ty);
const field = layout.fields[expr.field.fieldIdx!]; const field = layout.fields[expr.field.fieldIdx!];
@ -1027,7 +1032,7 @@ function lowerExpr(
} }
case "structLiteral": { case "structLiteral": {
if (expr.ty.kind !== "struct") { if (expr.ty.kind !== "struct") {
throw new Error("struct literal must have struct type"); unreachable("struct literal must have struct type");
} }
const layout = layoutOfStruct(expr.ty); const layout = layoutOfStruct(expr.ty);
@ -1236,6 +1241,7 @@ function argRetAbi(param: Ty): ArgRetAbi {
case "tuple": case "tuple":
return param.elems.flatMap(argRetAbi); return param.elems.flatMap(argRetAbi);
case "struct": case "struct":
case "rawptr":
return ["i32"]; return ["i32"];
case "never": case "never":
return []; return [];
@ -1290,6 +1296,7 @@ function wasmTypeForBody(ty: Ty): wasm.ValType[] {
case "fn": case "fn":
todo("fn types"); todo("fn types");
case "struct": case "struct":
case "rawptr":
return ["i32"]; return ["i32"];
case "never": case "never":
return []; return [];
@ -1313,7 +1320,8 @@ function sizeOfValtype(type: wasm.ValType): number {
} }
} }
export function layoutOfStruct(ty: TyStruct): StructLayout { export function layoutOfStruct(ty_: TyStruct | TyRawPtr): StructLayout {
const ty = ty_.kind === "struct" ? ty_ : ty_.inner;
const fieldWasmTys = ty.fields.map(([, field]) => wasmTypeForBody(field)); const fieldWasmTys = ty.fields.map(([, field]) => wasmTypeForBody(field));
// TODO: Use the max alignment instead. // TODO: Use the max alignment instead.

View file

@ -14,28 +14,16 @@ import { loadCrate } from "./loader";
const INPUT = ` const INPUT = `
type A = struct { a: Int }; type A = struct { a: Int };
type Complex = struct {
a: Int,
b: (Int, A),
};
function main() = ( function main() = (
let a = A { a: 0 }; let a = A { a: 0 };
// let b = 0; rawr(___transmute(a));
let c = Complex { a: 0, b: (0, a) }; std.printInt(a.a);
write(a);
std.printlnInt(a.a);
// ptr = c + offset(b) + offset(1)
// a = load(ptr)
// store(a + offset(a), 1)
// c.b.1.a = 1;
); );
function write(a: A) = a.a = 1; function rawr(a: *A) = (
a.a = 1;
function ret(): A = A { a: 0 }; );
`; `;
function main() { function main() {

View file

@ -650,6 +650,12 @@ function parseType(t: State): [State, Type<Parsed>] {
return [t, { kind: "tuple", elems: [head, ...tail], span }]; return [t, { kind: "tuple", elems: [head, ...tail], span }];
} }
case "*": {
let inner;
[t, inner] = parseType(t);
return [t, { kind: "rawptr", inner, span }];
}
default: { default: {
throw new CompilerError( throw new CompilerError(
`unexpected token: \`${tok.kind}\`, expected type`, `unexpected token: \`${tok.kind}\`, expected type`,

View file

@ -214,6 +214,8 @@ function printType(type: Type<AnyPhase>): string {
return `[${printType(type.elem)}]`; return `[${printType(type.elem)}]`;
case "tuple": case "tuple":
return `(${type.elems.map(printType).join(", ")})`; return `(${type.elems.map(printType).join(", ")})`;
case "rawptr":
return `*${printType(type.inner)}`;
case "never": case "never":
return "!"; return "!";
} }
@ -266,6 +268,9 @@ export function printTy(ty: Ty): string {
case "struct": { case "struct": {
return ty._name; return ty._name;
} }
case "rawptr": {
return `*${printTy(ty.inner)}`;
}
case "never": { case "never": {
return "!"; return "!";
} }

View file

@ -131,6 +131,17 @@ function lowerAstTy(cx: TypeckCtx, type: Type<Resolved>): Ty {
elems: type.elems.map((type) => lowerAstTy(cx, type)), elems: type.elems.map((type) => lowerAstTy(cx, type)),
}; };
} }
case "rawptr": {
const inner = lowerAstTy(cx, type.inner);
if (inner.kind !== "struct") {
throw new CompilerError(
"raw pointers must point to structs",
type.span,
);
}
return { kind: "rawptr", inner };
}
case "never": { case "never": {
return TY_NEVER; return TY_NEVER;
} }
@ -778,11 +789,12 @@ export function checkBody(
} }
break; break;
} }
case "struct": { case "struct":
case "rawptr": {
const fields =
lhs.ty.kind === "struct" ? lhs.ty.fields : lhs.ty.inner.fields;
if (typeof field.value === "string") { if (typeof field.value === "string") {
const idx = lhs.ty.fields.findIndex( const idx = fields.findIndex(([name]) => name === field.value);
([name]) => name === field.value,
);
if (idx === -1) { if (idx === -1) {
throw new CompilerError( throw new CompilerError(
`field \`${field.value}\` does not exist on ${printTy( `field \`${field.value}\` does not exist on ${printTy(
@ -792,7 +804,7 @@ export function checkBody(
); );
} }
ty = lhs.ty.fields[idx][1]; ty = fields[idx][1];
fieldIdx = idx; fieldIdx = idx;
} else { } else {
throw new CompilerError( throw new CompilerError(
@ -962,7 +974,10 @@ function checkLValue(expr: Expr<Typecked>) {
checkLValue(expr.lhs); checkLValue(expr.lhs);
break; break;
default: default:
throw new CompilerError("invalid left-hand side of assignment", expr.span); throw new CompilerError(
"invalid left-hand side of assignment",
expr.span,
);
} }
} }

13
ui-tests/raw_ptr.nil Normal file
View file

@ -0,0 +1,13 @@
//@check-pass
type A = struct { a: Int };
function main() = (
let a = A { a: 0 };
rawr(___transmute(a));
std.printInt(a.a);
);
function rawr(a: *A) = (
a.a = 1;
);

View file

@ -0,0 +1,21 @@
//@check-pass
type A = struct { a: Int };
type Complex = struct {
a: Int,
b: (Int, A),
};
function main() = (
let a = A { a: 0 };
// let b = 0;
// let c = Complex { a: 0, b: (0, a) };
write(a);
// c.b.1.a = 1; TODO
);
function write(a: A) = a.a = 1;
function ret(): A = A { a: 0 };