struct field writes

This commit is contained in:
nora 2023-08-05 14:05:41 +02:00
parent a6fea036d0
commit 5bac67b84c
7 changed files with 279 additions and 112 deletions

View file

@ -21,6 +21,7 @@ import {
varUnreachable,
} from "./ast";
import { GlobalContext } from "./context";
import { unreachable } from "./error";
import { printTy } from "./printer";
import { ComplexMap, encodeUtf8, unwrap } from "./utils";
import * as wasm from "./wasm/defs";
@ -336,10 +337,15 @@ type ArgRetAbi = wasm.ValType[];
type VarLocation = { localIdx: number; types: wasm.ValType[]; ty: Ty };
type StructFieldLayout = {
types: { offset: number; type: wasm.ValType }[];
types: MemoryLayoutPart[];
ty: Ty;
};
type MemoryLayoutPart = {
offset: number;
type: wasm.ValType;
};
type StructLayout = {
size: number;
align: number;
@ -394,10 +400,120 @@ Expression lowering.
- the result of an expression evaluation is stored on the top of the stack
*/
type LValue =
| {
// An assignment to a local will be simple local.set.
// Example: `a = 0`, `a = (0, 1)`
// Nothing will be on the stack, the loc is here.
kind: "fullLocal";
loc: VarLocation;
}
| {
// An assignment to a global will be global.set.
// Example: `AAA = 0`
// Nothing will be on the stack, the id is here.
kind: "global";
res: Resolution;
}
| {
// Writes to a tuple fields of locals will always be local.set.
// Example: `a.0 = 1`, `a.1.2.3 = 123`.
// Nothing will be on the stack, the loc and offset+size are here.
kind: "localTupleField";
loc: VarLocation;
offset: number;
size: number;
}
| {
// Writes to struct or tuple fields in memory will always be memory stores.
// Example: `a.a = 0`, `a.0.b = 5`, `a.b.c.0.1 = 14`.
// A pointer to the base will be on the stack, the offset is here.
kind: "memoryField";
parts: MemoryLayoutPart[];
};
/**
* Tries to lower an expression as an lvalue. If it succeeds, the lvalue value
* is left on the stack while the lvalue kind is returned.
*
* If the expression is not an lvalue, don't do anything and return undefined.
*/
function tryLowerLValue(
fcx: FuncContext,
instrs: wasm.Instr[],
expr: Expr<Typecked>,
): LValue | undefined {
const fieldParts = (ty: TyStruct, fieldIdx: number): MemoryLayoutPart[] =>
unwrap(layoutOfStruct(ty).fields[fieldIdx]).types;
switch (expr.kind) {
case "ident":
case "path": {
const { res } = expr.value;
switch (res.kind) {
case "local": {
return {
kind: "fullLocal",
loc: fcx.varLocations[fcx.varLocations.length - 1 - res.index],
};
}
case "item": {
const item = fcx.cx.gcx.findItem(res.id);
if (item.kind !== "global") {
throw new Error("cannot store to non-global item");
}
return {
kind: "global",
res,
};
}
case "builtin": {
throw new Error("cannot store to builtin");
}
}
}
case "fieldAccess": {
// Field access lvalues (or rather, lvalues in general) are made of two important parts:
// the _final place_, and the base.
// `a.0.b` -> `.b` is the final place, `a.0` is the base.
// `a.0.1` -> `a.0.1` is the final place with no base.
// We go from the outside in. Tuple fields are a projection and therefore part of the
// final place. A struct field means that everything left of the field access is the base.
// The base can be codegened like a normal expression (except without reference count changes).
// Only the final place needs special handling.
switch (expr.lhs.ty.kind) {
case "tuple": {
// _potentially_ a localTupleField
todo("tuple access");
}
case "struct": {
// Codegen the base, this leaves us with the base pointer on the stack.
// Do not increment the refcount for an lvalue base.
lowerExpr(fcx, instrs, expr.lhs, true);
return {
kind: "memoryField",
parts: fieldParts(expr.lhs.ty, expr.field.fieldIdx!),
};
}
default:
unreachable("invalid field access lhs ty");
}
}
default: {
return undefined;
}
}
}
function lowerExpr(
fcx: FuncContext,
instrs: wasm.Instr[],
expr: Expr<Typecked>,
// Note: This is only forwarded through field accesses.
noRefcountIfLvalue = false,
) {
const ty = expr.ty;
@ -424,41 +540,29 @@ function lowerExpr(
}
case "assign": {
lowerExpr(fcx, instrs, expr.rhs);
const { lhs } = expr;
switch (lhs.kind) {
case "ident":
case "path": {
const { res } = lhs.value;
switch (res.kind) {
case "local": {
const location =
fcx.varLocations[fcx.varLocations.length - 1 - res.index];
storeVariable(instrs, location);
break;
}
case "item": {
const item = fcx.cx.gcx.findItem(res.id);
if (item.kind !== "global") {
throw new Error("cannot store to non-global item");
}
const lvalue = unwrap(
tryLowerLValue(fcx, instrs, expr.lhs),
"invalid assign lhs",
);
const instr: wasm.Instr = { kind: "global.set", imm: DUMMY_IDX };
const rel: Relocation = { kind: "globalref", instr, res };
fcx.cx.relocations.push(rel);
instrs.push(instr);
break;
}
case "builtin": {
throw new Error("cannot store to builtin");
}
}
switch (lvalue.kind) {
case "fullLocal": {
storeVariable(instrs, lvalue.loc);
break;
}
default: {
throw new Error("invalid lhs side of assignment");
case "global": {
const instr: wasm.Instr = { kind: "global.set", imm: DUMMY_IDX };
const rel: Relocation = { kind: "globalref", instr, res: lvalue.res };
fcx.cx.relocations.push(rel);
instrs.push(instr);
break;
}
case "localTupleField":
todo("local tuple fields");
case "memoryField": {
storeMemory(fcx, instrs, lvalue.parts);
}
}
@ -520,7 +624,7 @@ function lowerExpr(
const refcount = needsRefcount(expr.ty);
if (refcount !== undefined) {
if (!noRefcountIfLvalue && refcount !== undefined) {
addRefcount(
fcx,
instrs,
@ -766,18 +870,6 @@ function lowerExpr(
break;
}
case "fieldAccess": {
// We could just naively always evaluate the LHS normally, but that's kinda
// stupid as it would cause way too much code for `let a = (0, 0, 0); a.0`
// as that operation would first load the entire tuple onto the stack!
// Therefore, we are a little clever be peeking into the LHS and doing
// something smarter if it's another field access or ident (in the future,
// we should be able to generalize this to all "places"/"lvalues").
// TODO: Actually do this instead of being naive.
const _isPlace = (expr: Expr<Typecked>) =>
expr.kind === "ident" || expr.kind === "fieldAccess";
lowerExpr(fcx, instrs, expr.lhs);
switch (expr.lhs.ty.kind) {
@ -963,38 +1055,9 @@ function lowerExpr(
// Now, set all fields.
expr.fields.forEach((field, i) => {
instrs.push({ kind: "local.get", imm: ptrLocal });
lowerExpr(fcx, instrs, field.expr);
const fieldLayout = [...layout.fields[i].types];
fieldLayout.reverse();
fieldLayout.forEach((fieldPart) => {
switch (fieldPart.type) {
case "i32":
instrs.push({
kind: "i32.store",
imm: {
align: sizeOfValtype(fieldPart.type),
offset: fieldPart.offset,
},
});
break;
case "i64":
instrs.push({
kind: "i64.store",
imm: {
align: sizeOfValtype(fieldPart.type),
offset: fieldPart.offset,
},
});
break;
default: {
throw new Error(
`unsupported struct content type: ${fieldPart.type}`,
);
}
}
});
instrs.push({ kind: "local.get", imm: ptrLocal });
storeMemory(fcx, instrs, layout.fields[i].types);
});
// Last, load the pointer and pass that on.
@ -1125,6 +1188,37 @@ function storeVariable(instrs: wasm.Instr[], loc: VarLocation) {
});
}
/**
* Generates TYPE.store instructions for memory parts.
* STACK IN: [...types_, i32 (base pointer)]
* STACK OUT: []
*/
function storeMemory(
fcx: FuncContext,
instrs: wasm.Instr[],
types_: MemoryLayoutPart[],
) {
const ptr = getScratchLocals(fcx, "i32", 1)[0];
instrs.push({ kind: "local.set", imm: ptr });
const types = [...types_];
types.reverse();
types.forEach(({ type, offset }) => {
if (type === "externref" || type === "funcref") {
unreachable("non-i32/i64 stores");
}
const val = getScratchLocals(fcx, type, 1)[0];
instrs.push({ kind: "local.set", imm: val });
instrs.push({ kind: "local.get", imm: ptr });
instrs.push({ kind: "local.get", imm: val });
const kind: wasm.SimpleStoreKind = `${type}.store`;
instrs.push({
kind,
imm: { offset, align: sizeOfValtype(type) },
});
});
}
function argRetAbi(param: Ty): ArgRetAbi {
switch (param.kind) {
case "string":

View file

@ -14,35 +14,28 @@ import { loadCrate } from "./loader";
const INPUT = `
type A = struct { a: Int };
type Uwu = (Int, Int);
function main() = (
uwu();
);
function uwu() = (
let a = A { a: 100 };
eat(a /*+1*/);
A { a: 100 };
/*-1*/
);
type B = struct {
a: (Int, Int, Int, Int, Int),
type Complex = struct {
a: Int,
b: (Int, A),
};
function test(b: B) = (
b.a;
function main() = (
let a = A { a: 0 };
// let b = 0;
let c = Complex { a: 0, b: (0, a) };
write(a);
std.printlnInt(a.a);
// ptr = c + offset(b) + offset(1)
// a = load(ptr)
// store(a + offset(a), 1)
// c.b.1.a = 1;
);
mod aa (
global UWU: Int = 0;
);
function write(a: A) = a.a = 1;
function eat(a: A) =;
function ret(): A = A { a: 0 };
`;
function main() {

View file

@ -676,6 +676,10 @@ export function checkBody(
}
break;
}
case "fieldAccess": {
checkLValue(lhs);
break;
}
default: {
throw new CompilerError(
"invalid left-hand side of assignment",
@ -949,6 +953,19 @@ export function checkBody(
return resolved;
}
function checkLValue(expr: Expr<Typecked>) {
switch (expr.kind) {
case "ident":
case "path":
break;
case "fieldAccess":
checkLValue(expr.lhs);
break;
default:
throw new CompilerError("invalid left-hand side of assignment", expr.span);
}
}
function checkBinary(
fcx: FuncCtx,
expr: Expr<Resolved> & ExprBinary<Resolved>,

View file

@ -10,9 +10,9 @@ export class Ids {
}
}
export function unwrap<T>(value: T | undefined): T {
export function unwrap<T>(value: T | undefined, msg?: string): T {
if (value === undefined) {
throw new Error("tried to unwrap undefined value");
throw new Error(msg ?? "tried to unwrap undefined value");
}
return value;
}

View file

@ -194,9 +194,13 @@ export type MemArg = {
align?: u32;
};
export type SimpleStoreKind = `${`${"i" | "f"}${BitWidth}` | "v128"}.${
| "load"
| "store"}`;
export type MemoryInstr =
| {
kind: `${`${"i" | "f"}${BitWidth}` | "v128"}.${"load" | "store"}`;
kind: SimpleStoreKind;
imm: MemArg;
}
| {

View file

@ -67,7 +67,7 @@ function deallocateItem(ptr: I32, objSize: I32) = (
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
global HEAP_START: I32 = 2048_I32;
global HEAP_START: HeapPtr = 2048_I32;
// heap size = start+end+bin_t*BIN_COUNT
// 4+ 4+ 4* 9 = 8+36=42 (round to 64)
global HEAP_REGION_START: I32 = 2112_I32;
@ -78,6 +78,10 @@ global HEAP_REGION_START: I32 = 2112_I32;
// struct node_t* next;
// struct node_t* prev;
// } node_t;
global NODE_HOLE: I32 = 0_I32;
global NODE_SIZE: I32 = 4_I32;
global NODE_NEXT: I32 = 8_I32;
global NODE_PREV: I32 = 12_I32;
// typedef struct {
// node_t *header;
@ -90,23 +94,77 @@ global HEAP_REGION_START: I32 = 2112_I32;
global SIZEOF_NODE: I32 = 16_I32;
global SIZEOF_FOOTER: I32 = 4_I32;
type HeapPtr = I32;
type NodePtr = I32;
type FootPtr = I32;
type BinPtr = I32;
function initHeap() = (
let heap_init_size = 65536_I32 - HEAP_REGION_START;
__i32_store(HEAP_REGION_START, 1_I32); // START.hole =
__i32_store(HEAP_REGION_START + 4_I32, heap_init_size - SIZEOF_NODE - SIZEOF_FOOTER); // START.size =
__i32_store(HEAP_REGION_START + NODE_HOLE, 1_I32); // START.hole =
__i32_store(HEAP_REGION_START + NODE_SIZE, heap_init_size - SIZEOF_NODE - SIZEOF_FOOTER); // START.size =
createFoot(HEAP_REGION_START);
);
function createFoot(head_node: I32) = (
function createFoot(head_node: NodePtr) = (
let foot = getFoot(head_node);
__i32_store(foot, head_node); // foot.header = head_node
);
function getFoot(node: I32): I32 = (
let node_size = __i32_load(node + 4_I32); // node.size
function getFoot(node: NodePtr): FootPtr = (
let node_size = __i32_load(node + NODE_SIZE);
node + SIZEOF_NODE + node_size
);
function addNode(bin: I32, node: I32) = ;
function getWilderness() =;
// llist.c
function addNode(bin: BinPtr, node: NodePtr) = (
__i32_store(node + NODE_NEXT, 0_I32); // node.next =
__i32_store(node + NODE_PREV, 0_I32); // node.prev =
let bin_head: NodePtr = __i32_load(bin); // bin.head
if (bin_head == 0_I32) then
__i32_store(bin, node) // bin.head =
else (
let current: NodePtr = bin_head;
let previous: NodePtr = 0_I32;
loop (
if (current != 0_I32)
& (__i32_load(current + NODE_SIZE) // current.size
<= __i32_load(node + NODE_SIZE)) // node.size
then break;
previous = current;
current = __i32_load(current + NODE_NEXT); // current.next
);
if (current == 0_I32) then (
__i32_store(previous + NODE_NEXT, node); // previous.next
__i32_store(node + NODE_PREV, previous); // node.prev
) else (
if (previous != 0_I32) then (
__i32_store(node + NODE_NEXT, current); // node.next
__i32_store(previous + NODE_NEXT, node); // previous.next
__i32_store(node + NODE_PREV, previous); // node.prev
__i32_store(current + NODE_PREV, node); // current.prev
) else (
__i32_store(node + NODE_NEXT, __i32_load(bin)); // node.next = bin.head
__i32_store(__i32_load(bin) + NODE_PREV, node); // bin.head.prev = node;
__i32_store(bin, node); // bin.head
);
);
)
);
function removeNode(bin: BinPtr, node: NodePtr) = (
);
function test() =;

View file

@ -1,5 +1,6 @@
extern mod std;
type A = struct { a: Int };
function main() = (
std.printInt(10);
let a = A { a: 0 };
a.a = 1;
);