reference counting

This commit is contained in:
nora 2023-08-02 23:19:10 +02:00
parent d9ab81bed1
commit 9ece18a48a
18 changed files with 477 additions and 159 deletions

View file

@ -37,6 +37,7 @@ const WASM_PAGE = 65536;
const DUMMY_IDX = 9999999;
const ALLOCATE_ITEM: string[] = ["std", "rt", "allocateItem"];
const DEALLOCATE_ITEM: string[] = ["std", "rt", "deallocateItem"];
type RelocationKind =
| {
@ -111,10 +112,10 @@ function appendData(cx: Context, newData: Uint8Array): number {
}
}
const KNOWN_DEF_PATHS = [ALLOCATE_ITEM];
const KNOWN_DEF_PATHS = [ALLOCATE_ITEM, DEALLOCATE_ITEM];
function getKnownDefPaths(
crates: Crate<Typecked>[]
crates: Crate<Typecked>[],
): ComplexMap<string[], ItemId> {
const knows = new ComplexMap<string[], ItemId>();
@ -141,7 +142,7 @@ function getKnownDefPaths(
};
crates.forEach((crate) =>
crate.rootItems.forEach((item) => folder.item(item))
crate.rootItems.forEach((item) => folder.item(item)),
);
return knows;
@ -231,7 +232,7 @@ export function lower(gcx: GlobalContext): wasm.Module {
const idx = cx.funcIndices.get(rel.res);
if (idx === undefined) {
throw new Error(
`no function found for relocation '${JSON.stringify(rel.res)}'`
`no function found for relocation '${JSON.stringify(rel.res)}'`,
);
}
rel.instr.func = idx.kind === "func" ? offset + idx.idx : idx.idx;
@ -241,7 +242,7 @@ export function lower(gcx: GlobalContext): wasm.Module {
const idx = cx.globalIndices.get(rel.res);
if (idx === undefined) {
throw new Error(
`no global found for relocation '${JSON.stringify(rel.res)}'`
`no global found for relocation '${JSON.stringify(rel.res)}'`,
);
}
rel.instr.imm = idx;
@ -260,10 +261,10 @@ export function lower(gcx: GlobalContext): wasm.Module {
function lowerImport(
cx: Context,
item: Item<Typecked>,
def: ImportDef<Typecked>
def: ImportDef<Typecked>,
) {
const existing = cx.mod.imports.findIndex(
(imp) => imp.module === def.module.value && imp.name === def.func.value
(imp) => imp.module === def.module.value && imp.name === def.func.value,
);
let idx;
@ -271,7 +272,7 @@ function lowerImport(
idx = existing;
} else {
const abi = computeAbi(def.ty!);
const { type: wasmType } = wasmTypeForAbi(abi);
const { type: wasmType } = wasmTypeForAbi(abi, def.ty!);
const type = internFuncType(cx, wasmType);
idx = cx.mod.imports.length;
@ -291,7 +292,7 @@ function lowerImport(
function lowerGlobal(
cx: Context,
item: Item<Typecked>,
def: GlobalItem<Typecked>
def: GlobalItem<Typecked>,
) {
const globalIdx = cx.mod.globals.length;
@ -334,13 +335,14 @@ type FuncContext = {
varLocations: VarLocation[];
loopDepths: Map<LoopId, number>;
currentBlockDepth: number;
scratchLocals: Map<wasm.ValType, wasm.LocalIdx[]>;
};
type FnAbi = { params: ArgRetAbi[]; ret: ArgRetAbi };
type ArgRetAbi = wasm.ValType[];
type VarLocation = { localIdx: number; types: wasm.ValType[] };
type VarLocation = { localIdx: number; types: wasm.ValType[]; ty: Ty };
type StructFieldLayout = {
types: { offset: number; type: wasm.ValType }[];
@ -356,10 +358,10 @@ type StructLayout = {
function lowerFunc(
cx: Context,
item: Item<Typecked>,
func: FunctionDef<Typecked>
func: FunctionDef<Typecked>,
) {
const abi = computeAbi(func.ty!);
const { type: wasmType, paramLocations } = wasmTypeForAbi(abi);
const { type: wasmType, paramLocations } = wasmTypeForAbi(abi, func.ty!);
const type = internFuncType(cx, wasmType);
const wasmFunc: wasm.Func = {
@ -378,16 +380,26 @@ function lowerFunc(
varLocations: paramLocations,
loopDepths: new Map(),
currentBlockDepth: 0,
scratchLocals: new Map(),
};
lowerExpr(fcx, wasmFunc.body, fcx.func.body);
paramLocations.forEach((local) => {
const refcount = needsRefcount(local.ty);
if (refcount !== undefined) {
// TODO: correctly deal with tuples
loadVariable(wasmFunc.body, local);
subRefcount(fcx, wasmFunc.body, refcount);
}
});
const idx = fcx.cx.mod.funcs.length;
fcx.cx.mod.funcs.push(wasmFunc);
fcx.cx.funcIndices.set(
{ kind: "item", id: fcx.item.id },
{ kind: "func", idx }
{ kind: "func", idx },
);
}
@ -399,7 +411,7 @@ Expression lowering.
function lowerExpr(
fcx: FuncContext,
instrs: wasm.Instr[],
expr: Expr<Typecked>
expr: Expr<Typecked>,
) {
const ty = expr.ty;
@ -420,7 +432,7 @@ function lowerExpr(
instrs.push({ kind: "local.set", imm: local + i });
});
fcx.varLocations.push({ localIdx: local, types });
fcx.varLocations.push({ localIdx: local, types, ty: expr.rhs.ty });
break;
}
@ -475,7 +487,7 @@ function lowerExpr(
} else {
const instr: wasm.Instr = {
kind: "block",
instrs: lowerExprBlockBody(fcx, expr),
instrs: lowerExprBlockBody(fcx, expr, prevVarLocationLengths),
type: blockTypeForBody(fcx.cx, expr.ty),
};
@ -518,6 +530,17 @@ function lowerExpr(
const location =
fcx.varLocations[fcx.varLocations.length - 1 - res.index];
loadVariable(instrs, location);
const refcount = needsRefcount(expr.ty);
if (refcount !== undefined) {
addRefcount(
fcx,
instrs,
refcount === "string" ? "string" : "struct",
);
}
break;
}
case "item": {
@ -694,7 +717,7 @@ function lowerExpr(
case "__i32_load": {
assertArgs(1);
lowerExpr(fcx, instrs, expr.args[0]);
instrs.push({ kind: "i64.load", imm: {} });
instrs.push({ kind: "i32.load", imm: {} });
break exprKind;
}
case "__i64_load": {
@ -752,6 +775,11 @@ function lowerExpr(
instrs.push({ kind: "i64.extend_i32_u" });
break exprKind;
}
case "___transmute": {
expr.args.map((arg) => lowerExpr(fcx, instrs, arg));
// don't do anything
break exprKind;
}
}
}
@ -791,7 +819,7 @@ function lowerExpr(
const resultSize = resultAbi.length;
const wasmIdx = wasmTypeIdxForTupleField(
expr.lhs.ty,
expr.field.fieldIdx!
expr.field.fieldIdx!,
);
// lhsSize=5, resultSize=2, wasmIdx=2
@ -808,13 +836,13 @@ function lowerExpr(
if (expr.field.fieldIdx! > 0) {
// Keep the result in scratch space.
storeVariable(instrs, { localIdx, types: resultAbi });
storeVariable(instrs, { localIdx, types: resultAbi, ty: expr.ty });
Array(wasmIdx)
.fill(0)
.forEach(() => instrs.push({ kind: "drop" }));
loadVariable(instrs, { localIdx, types: resultAbi });
loadVariable(instrs, { localIdx, types: resultAbi, ty: expr.ty });
}
break;
@ -824,9 +852,7 @@ function lowerExpr(
const layout = layoutOfStruct(ty);
const field = layout.fields[expr.field.fieldIdx!];
// TODO: SCRATCH LOCALS
const ptrLocal = fcx.wasmType.params.length + fcx.wasm.locals.length;
fcx.wasm.locals.push("i32");
const ptrLocal = getScratchLocals(fcx, "i32", 1)[0];
// We save the local for getting it later for all the field parts.
instrs.push({
@ -860,7 +886,7 @@ function lowerExpr(
break;
default: {
throw new Error(
`unsupported struct content type: ${fieldPart.type}`
`unsupported struct content type: ${fieldPart.type}`,
);
}
}
@ -958,13 +984,12 @@ function lowerExpr(
res: { kind: "item", id: allocateItemId },
});
instrs.push(allocate);
// TODO: scratch locals...
const ptrLocal = fcx.wasmType.params.length + fcx.wasm.locals.length;
fcx.wasm.locals.push("i32");
const ptrLocal = getScratchLocals(fcx, "i32", 1)[0];
instrs.push({ kind: "local.tee", imm: ptrLocal });
// Store the refcount
instrs.push({ kind: "i32.const", imm: 0n });
instrs.push({ kind: "i32.const", imm: 1n });
instrs.push({ kind: "i32.store", imm: { align: 4 } });
// Now, set all fields.
@ -996,7 +1021,7 @@ function lowerExpr(
break;
default: {
throw new Error(
`unsupported struct content type: ${fieldPart.type}`
`unsupported struct content type: ${fieldPart.type}`,
);
}
}
@ -1012,8 +1037,6 @@ function lowerExpr(
expr.fields.forEach((field) => lowerExpr(fcx, instrs, field));
break;
}
case "refcount":
todo("refcount");
default: {
const _: never = expr;
}
@ -1027,29 +1050,89 @@ function lowerExpr(
function lowerExprBlockBody(
fcx: FuncContext,
expr: ExprBlock<Typecked> & Expr<Typecked>
expr: ExprBlock<Typecked> & Expr<Typecked>,
prevVarLocationLength: number,
): wasm.Instr[] {
fcx.currentBlockDepth++;
const innerInstrs: wasm.Instr[] = [];
const instrs: wasm.Instr[] = [];
const headExprs = expr.exprs.slice(0, -1);
const tailExpr = expr.exprs[expr.exprs.length - 1];
for (const inner of headExprs) {
lowerExpr(fcx, innerInstrs, inner);
lowerExpr(fcx, instrs, inner);
if (inner.ty.kind === "never") {
// The rest of the block is unreachable, so we don't bother codegening it.
break;
}
const types = wasmTypeForBody(inner.ty);
types.forEach(() => innerInstrs.push({ kind: "drop" }));
const refcount = needsRefcount(inner.ty);
if (refcount !== undefined) {
subRefcount(fcx, instrs, refcount);
} else {
// TODO: correctly deal with tuples
types.forEach(() => instrs.push({ kind: "drop" }));
}
}
lowerExpr(fcx, innerInstrs, tailExpr);
lowerExpr(fcx, instrs, tailExpr);
const thisBlockLocals = fcx.varLocations.slice(prevVarLocationLength);
thisBlockLocals.forEach((local) => {
const refcount = needsRefcount(local.ty);
if (refcount !== undefined) {
// TODO: correctly deal with tuples
loadVariable(instrs, local);
subRefcount(fcx, instrs, refcount);
}
});
fcx.currentBlockDepth--;
return innerInstrs;
return instrs;
}
function getScratchLocals(
fcx: FuncContext,
type: wasm.ValType,
amount: number,
): wasm.LocalIdx[] {
function addLocals(fcx: FuncContext, type: wasm.ValType[]): wasm.LocalIdx[] {
const local = fcx.wasm.locals.length + fcx.wasmType.params.length;
fcx.wasm.locals.push(...type);
return type.map((_, i) => local + i);
}
const existing = fcx.scratchLocals.get(type);
if (!existing) {
const locals = addLocals(
fcx,
Array(amount)
.fill(0)
.map(() => type),
);
fcx.scratchLocals.set(type, locals);
return locals;
} else {
const toAdd = amount - existing.length;
if (toAdd > 0) {
const locals = addLocals(
fcx,
Array(toAdd)
.fill(0)
.map(() => type),
);
existing.push(...locals);
return existing;
}
return existing.slice(0, amount);
}
}
function loadVariable(instrs: wasm.Instr[], loc: VarLocation) {
@ -1105,17 +1188,21 @@ function computeAbi(ty: TyFn): FnAbi {
return { params, ret };
}
function wasmTypeForAbi(abi: FnAbi): {
function wasmTypeForAbi(
abi: FnAbi,
ty: TyFn,
): {
type: wasm.FuncType;
paramLocations: VarLocation[];
} {
const params: wasm.ValType[] = [];
const paramLocations: VarLocation[] = [];
abi.params.forEach((arg) => {
abi.params.forEach((arg, i) => {
paramLocations.push({
localIdx: params.length,
types: arg,
ty: ty.params[i],
});
params.push(...arg);
});
@ -1168,7 +1255,7 @@ export function layoutOfStruct(ty: TyStruct): StructLayout {
// TODO: Use the max alignment instead.
const align = fieldWasmTys.some((field) =>
field.some((type) => type === "i64")
field.some((type) => type === "i64"),
)
? 8
: 4;
@ -1185,7 +1272,8 @@ export function layoutOfStruct(ty: TyStruct): StructLayout {
const types = field.map((type) => {
const size = sizeOfValtype(type);
if (size === 8 && offset % 8 !== 0) {
// we don't want padding for the first field as the allocator takes care of that.
if (offset !== 4 && size === 8 && offset % 8 !== 0) {
// padding.
offset += 4;
}
@ -1203,6 +1291,9 @@ export function layoutOfStruct(ty: TyStruct): StructLayout {
return value;
});
// we ignore the refcount for struct size.
offset -= 4;
if (align === 8 && offset % 8 !== 0) {
offset += 4;
}
@ -1230,6 +1321,169 @@ function wasmTypeIdxForTupleField(ty: TyTuple, idx: number): number {
return head.reduce((a, b) => a + b.length, 0);
}
// Refcounts:
/*
* Injects `refcount` expressions into the code to make sure
* that no memory is leaked and no memory is freed too early.
*
* When do we need to adjust the refcount?
*
* When looking at reference counts, we need to distiguish between moves
* and copies of a struct. When a struct is moved, no reference count has
* to be changed. When it is copied, we need to increment the reference count.
*
* ```
* let a = S {};
* foo(a); // COPY
* ```
* ```
* let a = identity(S {}); // MOVE
* ```
*
* Due to the way the language is structured, this analysis is fairly simple:
* Most expressions are considered moves, but identifiers like `a` are considered
* copies. This is sound because the only way to refer to a value twice is to bind
* it to a variable. So whenever we load a variable of type struct, we need to bump
* the refcount.
*
* Then we just need to decrement all the locals (and params!) refcounts when they go
* out of scope.
*
* This leaves us with the following rules:
* - when loading an identifier, add an increment
* - when the end of a block is reached, decrement all locals
* - when the end of a function is reached, decrement all params
* - when an expression value is ignored, decrement
*/
function needsRefcount(ty: Ty): StructLayout | "string" | undefined {
switch (ty.kind) {
case "string":
// TODO: deal with strings
return undefined;
case "struct":
return layoutOfStruct(ty);
case "list":
todo("no lists yet");
case "var":
varUnreachable();
default:
return undefined;
}
}
function addRefcount(
fcx: FuncContext,
instrs: wasm.Instr[],
kind: "struct" | "string",
) {
const layout: wasm.ValType[] = kind === "string" ? ["i32", "i32"] : ["i32"];
const [ptr, len] = getScratchLocals(fcx, "i32", layout.length);
// stack: PTR, {LEN}
const innerInstrs: wasm.Instr[] = [];
if (kind === "string") {
innerInstrs.push({ kind: "local.set", imm: len }); // stack: PTR
}
innerInstrs.push({ kind: "local.tee", imm: ptr }); // stack: PTR
innerInstrs.push({ kind: "local.get", imm: ptr }); // stack: PTR, PTR
innerInstrs.push({ kind: "local.get", imm: ptr }); // stack: PTR, PTR, PTR
innerInstrs.push({ kind: "i32.load", imm: { align: 4 } }); // stack: PTR, PTR, cnt
innerInstrs.push({ kind: "i32.const", imm: 1n }); // stack: PTR, PTR, cnt, 1
innerInstrs.push({ kind: "i32.add" }); // stack: PTR, PTR, cnt
innerInstrs.push({ kind: "i32.store", imm: { align: 4 } }); // stack: PTR
if (kind === "string") {
innerInstrs.push({ kind: "local.get", imm: len }); // stack: PTR, {LEN}
}
// stack: PTR, {LEN}
instrs.push({
kind: "block",
instrs: innerInstrs,
type: {
kind: "typeidx",
idx: internFuncType(fcx.cx, { params: layout, returns: layout }),
},
});
}
function subRefcount(
fcx: FuncContext,
instrs: wasm.Instr[],
kind: StructLayout | "string",
) {
const deallocateItemId = fcx.cx.knownDefPaths.get(DEALLOCATE_ITEM);
if (!deallocateItemId) {
throw new Error("std.rt.deallocateItem not found");
}
const layout: wasm.ValType[] = kind === "string" ? ["i32", "i32"] : ["i32"];
const [ptr, len] = getScratchLocals(fcx, "i32", layout.length);
const count = ptr;
const innerInstrs: wasm.Instr[] = [];
// stack: PTR, {LEN}
if (kind === "string") {
innerInstrs.push({ kind: "local.set", imm: len }); // stack: PTR
}
innerInstrs.push({ kind: "local.tee", imm: ptr }); // stack: PTR
innerInstrs.push({ kind: "local.get", imm: ptr }); // stack: PTR, PTR
innerInstrs.push({ kind: "i32.load", imm: { align: 4 } }); // stack: PTR, cnt
innerInstrs.push({ kind: "i32.const", imm: 1n }); // stack: PTR, cnt, 1
innerInstrs.push({ kind: "i32.sub" }); // stack: PTR, cnt
innerInstrs.push({ kind: "local.tee", imm: count }); // stack: PTR, cnt
innerInstrs.push({
kind: "if",
then: [
// stack: PTR
{ kind: "local.get", imm: count }, // stack: PTR, cnt
{ kind: "i32.store", imm: { align: 4 } }, // stack:
],
else: (() => {
// stack: PTR
const instrs: wasm.Instr[] = [];
if (kind === "string") {
instrs.push({ kind: "local.get", imm: len }); // stack: PTR, len
} else {
instrs.push({ kind: "i32.const", imm: BigInt(kind.size) }); // stack: PTR, len
}
const deallocateCall: wasm.Instr = { kind: "call", func: DUMMY_IDX };
fcx.cx.relocations.push({
kind: "funccall",
instr: deallocateCall,
res: { kind: "item", id: deallocateItemId },
});
instrs.push(deallocateCall); // stack:
return instrs;
})(),
type: {
kind: "typeidx",
idx: internFuncType(fcx.cx, { params: ["i32"], returns: [] }),
},
});
instrs.push({
kind: "block",
instrs: innerInstrs,
type: {
kind: "typeidx",
idx: internFuncType(fcx.cx, { params: layout, returns: [] }),
},
});
// stack:
}
function todo(msg: string): never {
throw new Error(`TODO: ${msg}`);
}
@ -1304,7 +1558,7 @@ function addRt(cx: Context, crates: Crate<Typecked>[]) {
cx.funcIndices.set(
{ kind: "builtin", name: "print" },
{ kind: "func", idx: printIdx }
{ kind: "func", idx: printIdx },
);
mod.exports.push({