inline assembly :3

This commit is contained in:
nora 2023-12-15 22:31:13 +01:00
parent 3ab116d7f0
commit ba5d41674c
17 changed files with 448 additions and 45 deletions

View file

@ -1,6 +1,7 @@
import { ErrorEmitted, LoadedFile, Span, unreachable } from "./error";
import { LitIntType } from "./lexer";
import { ComplexMap } from "./utils";
import { Instr, ValType } from "./wasm/defs";
export type Phase = {
res: unknown;
@ -308,6 +309,12 @@ export type ExprTupleLiteral<P extends Phase> = {
fields: Expr<P>[];
};
export type ExprInlineAsm = {
kind: "asm";
locals: ValType[];
instructions: Instr[];
};
export type ExprError = {
kind: "error";
err: ErrorEmitted;
@ -330,6 +337,7 @@ export type ExprKind<P extends Phase> =
| ExprBreak
| ExprStructLiteral<P>
| ExprTupleLiteral<P>
| ExprInlineAsm
| ExprError;
export type Expr<P extends Phase> = ExprKind<P> & {
@ -485,6 +493,8 @@ export const BUILTINS = [
"__memory_grow",
"__i32_extend_to_i64_u",
"___transmute",
"___asm",
"__locals",
] as const;
export type BuiltinName = (typeof BUILTINS)[number];
@ -920,6 +930,9 @@ export function superFoldExpr<From extends Phase, To extends Phase>(
fields: expr.fields.map(folder.expr.bind(folder)),
};
}
case "asm": {
return { ...expr };
}
case "error": {
return { ...expr };
}

View file

@ -118,9 +118,7 @@ function appendData(cx: Context, newData: Uint8Array): number {
const KNOWN_DEF_PATHS = [ALLOCATE_ITEM, DEALLOCATE_ITEM];
function getKnownDefPaths(
pkgs: Pkg<Typecked>[],
): ComplexMap<string[], ItemId> {
function getKnownDefPaths(pkgs: Pkg<Typecked>[]): ComplexMap<string[], ItemId> {
const knows = new ComplexMap<string[], ItemId>();
const folder: Folder<Typecked, Typecked> = {
@ -145,9 +143,7 @@ function getKnownDefPaths(
},
};
pkgs.forEach((pkg) =>
pkg.rootItems.forEach((item) => folder.item(item)),
);
pkgs.forEach((pkg) => pkg.rootItems.forEach((item) => folder.item(item)));
return knows;
}
@ -380,16 +376,22 @@ function lowerFunc(cx: Context, func: ItemFunction<Typecked>) {
scratchLocals: new Map(),
};
lowerExpr(fcx, wasmFunc.body, fcx.func.body);
const body = fcx.func.body;
if (body.kind === "asm") {
fcx.wasm.locals = body.locals;
fcx.wasm.body = body.instructions;
} else {
lowerExpr(fcx, wasmFunc.body, body);
paramLocations.forEach((local) => {
const refcount = needsRefcount(local.ty);
if (refcount !== undefined) {
// TODO: correctly deal with tuples
loadVariable(wasmFunc.body, local);
subRefcount(fcx, wasmFunc.body, refcount);
}
});
paramLocations.forEach((local) => {
const refcount = needsRefcount(local.ty);
if (refcount !== undefined) {
// TODO: correctly deal with tuples
loadVariable(wasmFunc.body, local);
subRefcount(fcx, wasmFunc.body, refcount);
}
});
}
const idx = fcx.cx.mod.funcs.length;
fcx.cx.mod.funcs.push(wasmFunc);
@ -1084,6 +1086,9 @@ function lowerExpr(
expr.fields.forEach((field) => lowerExpr(fcx, instrs, field));
break;
}
case "asm": {
unreachable("asm");
}
case "error":
unreachable("codegen should never see errors");
default: {

View file

@ -211,6 +211,12 @@ function printExpr(expr: Expr<AnyPhase>, indent: number): string {
.map((expr) => printExpr(expr, indent))
.join(", ")})`;
}
case "asm": {
return `___asm(___locals(${expr.locals
.map((valty) => `"${valty}"`)
// object object
.join(", ")}), ${expr.instructions.join(", ")})`;
}
case "error":
return "<ERROR>";
}

View file

@ -28,12 +28,10 @@ import {
} from "../ast";
import { CompilerError, ErrorEmitted, Span, unreachable } from "../error";
import { printTy } from "../printer";
import { INSTRS, Instr, VALTYPES, ValType } from "../wasm/defs";
import { TypeckCtx, emitError, mkTyFn, tyError, tyErrorFrom } from "./base";
import { InferContext } from "./infer";
import {
lowerAstTy,
typeOfItem,
} from "./item";
import { lowerAstTy, typeOfItem } from "./item";
export function exprError(err: ErrorEmitted, span: Span): Expr<Typecked> {
return {
@ -129,6 +127,15 @@ export function checkBody(
checkExpr: () => unreachable(),
};
if (
body.kind === "call" &&
body.lhs.kind === "ident" &&
body.lhs.value.res.kind === "builtin" &&
body.lhs.value.res.name === "___asm"
) {
return checkInlineAsm(cx, body, fnTy.returnTy);
}
const checker: Folder<Resolved, Typecked> = {
...mkDefaultFolder(),
expr(expr): Expr<Typecked> {
@ -503,6 +510,9 @@ export function checkBody(
return { ...expr, fields, ty };
}
case "asm": {
unreachable("asm expression doesn't exist before type checking");
}
case "error": {
return { ...expr, ty: tyErrorFrom(expr) };
}
@ -530,6 +540,90 @@ export function checkBody(
return resolved;
}
function checkInlineAsm(
cx: TypeckCtx,
body: Expr<Resolved> & ExprCall<Resolved>,
retTy: Ty,
): Expr<Typecked> {
const err = (msg: string, span: Span): Expr<Typecked> =>
exprError(emitError(cx, new CompilerError(msg, span)), span);
const args = body.args;
if (
args.length < 1 ||
args[0].kind !== "call" ||
args[0].lhs.kind !== "ident" ||
args[0].lhs.value.res.kind !== "builtin" ||
args[0].lhs.value.res.name !== "__locals"
) {
return err(
"inline assembly must have __locals() as first argument",
body.span,
);
}
const locals: ValType[] = [];
for (const local of args[0].args) {
const isValtype = (s: string): s is ValType =>
VALTYPES.includes(s as ValType);
if (
local.kind !== "literal" ||
local.value.kind !== "str" ||
!isValtype(local.value.value)
) {
return err(
"inline assembly local must be string literal of value type",
local.span,
);
}
locals.push(local.value.value);
}
const instructions: Instr[] = [];
for (const expr of args.slice(1)) {
if (expr.kind !== "literal" || expr.value.kind !== "str") {
return err(
"inline assembly instruction must be string literal with instruction",
expr.span,
);
}
const text = expr.value.value;
const parts = text.split(" ");
const imms = parts.slice(1);
const wasmInstr = INSTRS.find((instrVal) => instrVal.name === parts[0]);
if (!wasmInstr) {
return err(`unknown instruction: ${parts[0]}`, expr.span);
}
if (wasmInstr.immediates === "select") {
throw new Error("todo: select");
} else if (wasmInstr.immediates === "memarg") {
throw new Error("todo: memarg");
} else {
if (imms.length !== wasmInstr.immediates.length) {
return err(
`mismatched immediate lengths, expected ${wasmInstr.immediates.length}, got ${imms.length}`,
expr.span,
);
}
if (wasmInstr.immediates.length > 1) {
throw new Error("todo: immediates");
}
if (wasmInstr.immediates.length === 0) {
instructions.push({ kind: wasmInstr.name } as Instr);
} else {
instructions.push({
kind: wasmInstr.name,
imm: Number(imms[0]),
} as Instr);
}
}
}
return { kind: "asm", locals, ty: retTy, instructions, span: body.span };
}
function checkLValue(cx: TypeckCtx, expr: Expr<Typecked>) {
switch (expr.kind) {
case "ident":
@ -641,21 +735,19 @@ function checkCall(
fcx: FuncCtx,
expr: ExprCall<Resolved> & Expr<Resolved>,
): Expr<Typecked> {
if (
expr.lhs.kind === "ident" &&
expr.lhs.value.res.kind === "builtin" &&
expr.lhs.value.res.name === "___transmute"
) {
const ty = fcx.infcx.newVar();
const args = expr.args.map((arg) => fcx.checkExpr(arg));
const ret: Expr<Typecked> = {
...expr,
lhs: { ...expr.lhs, ty: TY_UNIT },
args,
ty,
};
if (expr.lhs.kind === "ident" && expr.lhs.value.res.kind === "builtin") {
if (expr.lhs.value.res.name === "___transmute") {
const ty = fcx.infcx.newVar();
const args = expr.args.map((arg) => fcx.checkExpr(arg));
const ret: Expr<Typecked> = {
...expr,
lhs: { ...expr.lhs, ty: TY_UNIT },
args,
ty,
};
return ret;
return ret;
}
}
const lhs = fcx.checkExpr(expr.lhs);

View file

@ -19,6 +19,15 @@ export type Vectype = "v128";
export type Reftype = "funcref" | "externref";
export type ValType = Numtype | Vectype | Reftype;
export const VALTYPES: ValType[] = [
"i32",
"i64",
"f32",
"f64",
"v128",
"funcref",
"externref",
];
export type ResultType = ValType[];
@ -66,11 +75,25 @@ export type Externtype =
// instructions
// Value representations of the types for the assembler
export type InstrValue = {
name: Instr["kind"];
immediates: ImmediateValue[] | "select" | "memarg";
};
type ImmediateValue = "i32" | "i64" | "f32" | "f64" | "refkind";
const ins = (name: Instr["kind"], immediates: InstrValue["immediates"]) => ({
name,
immediates,
});
// . numeric
export type BitWidth = "32" | "64";
const BIT_WIDTHS: BitWidth[] = ["32", "64"];
export type Sign = "u" | "s";
const SIGNS: Sign[] = ["u", "s"];
export type NumericInstr =
| { kind: "i32.const"; imm: bigint }
@ -86,7 +109,7 @@ export type NumericInstr =
| { kind: `f${BitWidth}.${FRelOp}` }
| { kind: `i${BitWidth}.extend8_s` }
| { kind: `i${BitWidth}.extend16_s` }
| { kind: `i64.extend32_s` }
| { kind: "i64.extend32_s" }
| { kind: "i32.wrap_i64" }
| { kind: `i64.extend_i32_${Sign}` }
| { kind: `i${BitWidth}.trunc_f${BitWidth}_${Sign}` }
@ -98,6 +121,7 @@ export type NumericInstr =
| { kind: "f32.reinterpret_i32" | "f64.reinterpret_i64" };
export type IUnOp = "clz" | "ctz" | "popcnt";
const I_UN_OPS: IUnOp[] = ["clz", "ctz", "popcnt"];
export type IBinOp =
| "add"
@ -112,19 +136,47 @@ export type IBinOp =
| `shr_${Sign}`
| "rotl"
| "rotr";
const I_BIN_OPS: IBinOp[] = [
"add",
"sub",
"mul",
"and",
"or",
"xor",
"shl",
"rotl",
"rotr",
...SIGNS.flatMap((sign): IBinOp[] => [
`div_${sign}`,
`rem_${sign}`,
`shr_${sign}`,
]),
];
export type FUnOp =
| "abs"
| "neg"
| "sqrt"
| "ceil"
| "floor"
| "trunc"
| "nearest";
const F_UN_OPS = [
"abs",
"neg",
"sqrt",
"ceil",
"floor",
"trunc",
"nearest",
] as const;
export type FUnOp = (typeof F_UN_OPS)[number];
export type FBinOp = "add" | "sub" | "mul" | "div" | "min" | "max" | "copysign";
const F_BIN_OPS = [
"add",
"sub",
"mul",
"div",
"min",
"max",
"copysign",
] as const;
export type FBinOp = (typeof F_BIN_OPS)[number];
export type ITestOp = "eqz";
const I_TEST_OPS = ["eqz"] as const;
export type ITestOp = (typeof I_TEST_OPS)[number];
export type IRelOp =
| "eq"
@ -133,12 +185,62 @@ export type IRelOp =
| `gt_${Sign}`
| `le_${Sign}`
| `ge_${Sign}`;
const I_REL_OPS: IRelOp[] = [
"eq",
"ne",
...SIGNS.flatMap((sign): IRelOp[] => [
`lt_${sign}`,
`gt_${sign}`,
`le_${sign}`,
`ge_${sign}`,
]),
];
export type FRelOp = "eq" | "ne" | "lt" | "gt" | "le" | "ge";
const F_REL_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] as const;
export type FRelOp = (typeof F_REL_OPS)[number];
const NO_IMM_NUMERIC_INSTRS: NumericInstr["kind"][] = [
...BIT_WIDTHS.flatMap((bitWidth): NumericInstr["kind"][] => [
...I_UN_OPS.map((op): NumericInstr["kind"] => `i${bitWidth}.${op}`),
...F_UN_OPS.map((op): NumericInstr["kind"] => `f${bitWidth}.${op}`),
...I_BIN_OPS.map((op): NumericInstr["kind"] => `i${bitWidth}.${op}`),
...F_BIN_OPS.map((op): NumericInstr["kind"] => `f${bitWidth}.${op}`),
...I_TEST_OPS.map((op): NumericInstr["kind"] => `i${bitWidth}.${op}`),
...I_REL_OPS.map((op): NumericInstr["kind"] => `i${bitWidth}.${op}`),
...F_REL_OPS.map((op): NumericInstr["kind"] => `f${bitWidth}.${op}`),
`i${bitWidth}.extend8_s`,
`i${bitWidth}.extend16_s`,
...BIT_WIDTHS.flatMap((bitWidth2): NumericInstr["kind"][] =>
SIGNS.flatMap((sign): NumericInstr["kind"][] => [
`i${bitWidth}.trunc_f${bitWidth2}_${sign}`,
`i${bitWidth}.trunc_sat_f${bitWidth2}_${sign}`,
`f${bitWidth}.convert_i${bitWidth2}_${sign}`,
]),
),
]),
"i64.extend32_s",
"i32.wrap_i64",
...SIGNS.flatMap((sign): NumericInstr["kind"] => `i64.extend_i32_${sign}`),
"f32.demote_f64",
"f64.promote_f32",
"i32.reinterpret_f32",
"i64.reinterpret_f64",
"f32.reinterpret_i32",
"f64.reinterpret_i64",
];
const IMM_NUMERIC_INSTRS: InstrValue[] = (
["i32", "i64", "f32", "f64"] as const
).map((type) => ins(`${type}.const`, [type]));
const NUMERIC_INSTRS: InstrValue[] = [
...NO_IMM_NUMERIC_INSTRS.map((instr) => ins(instr, [])),
...IMM_NUMERIC_INSTRS,
];
// . vectors
export type VectorInstr = never;
const VECTOR_INSTRS: InstrValue[] = [];
// . reference
@ -146,12 +248,21 @@ export type ReferenceInstr =
| { kind: "ref.null"; imm: Reftype }
| { kind: "ref.is_null" }
| { kind: "ref.func"; imm: FuncIdx };
const REFERENCE_INSTRS: InstrValue[] = [
ins("ref.null", ["refkind"]),
ins("ref.is_null", []),
ins("ref.func", ["i32"]),
];
// . parametric
export type ParametricInstr =
| { kind: "drop" }
| { kind: "select"; type?: ValType[] };
const PARAMETRIC_INSTRS: InstrValue[] = [
ins("drop", []),
ins("select", "select"),
];
// . variable
@ -164,6 +275,12 @@ export type VariableInstr =
kind: `global.${"get" | "set"}`;
imm: LocalIdx;
};
const VARIABLE_INSTR: InstrValue[] = [
...(["get", "set", "tee"] as const).map((kind) =>
ins(`local.${kind}`, ["i32"]),
),
...(["get", "set"] as const).map((kind) => ins(`global.${kind}`, ["i32"])),
];
// . table
@ -186,6 +303,14 @@ export type TableInstr =
kind: "elem.drop";
imm: ElemIdx;
};
const TABLE_INSTRS: InstrValue[] = [
...(["get", "set", "size", "grow", "fill"] as const).map((kind) =>
ins(`table.${kind}`, ["i32"]),
),
ins("table.copy", ["i32", "i32"]),
ins("table.init", ["i32", "i32"]),
ins("elem.drop", ["i32"]),
];
// . memory
@ -197,6 +322,18 @@ export type MemArg = {
export type SimpleStoreKind = `${`${"i" | "f"}${BitWidth}` | "v128"}.${
| "load"
| "store"}`;
const SIMPLE_STORES: SimpleStoreKind[] = [
"i32.load",
"i32.store",
"f32.load",
"f32.store",
"i64.load",
"i64.store",
"f64.load",
"f64.store",
"v128.load",
"v128.store",
];
export type MemoryInstr =
| {
@ -221,6 +358,35 @@ export type MemoryInstr =
imm: DataIdx;
}
| { kind: "data.drop"; imm: DataIdx };
const MEMORY_INSTRS_WITH_MEMARG: MemoryInstr["kind"][] = [
...SIMPLE_STORES,
"i32.load8_u",
"i32.load8_s",
"i32.load16_u",
"i32.load16_s",
"i64.load8_u",
"i64.load8_s",
"i64.load16_u",
"i64.load16_s",
//
"i64.load32_u",
"i64.load32_s",
//
"i32.store8",
"i32.store16",
"i64.store8",
"i64.store16",
"i64.store32",
];
const MEMORY_INSTRS: InstrValue[] = [
...MEMORY_INSTRS_WITH_MEMARG.map((kind) => ins(kind, "memarg")),
ins("memory.size", []),
ins("memory.grow", []),
ins("memory.fill", []),
ins("memory.copy", []),
ins("memory.init", ["i32"]),
ins("data.drop", ["i32"]),
];
// . control
@ -278,6 +444,16 @@ export type Instr =
| MemoryInstr
| ControlInstr;
export const INSTRS: InstrValue[] = [
...NUMERIC_INSTRS,
...VECTOR_INSTRS,
...REFERENCE_INSTRS,
...PARAMETRIC_INSTRS,
...VARIABLE_INSTR,
...TABLE_INSTRS,
...MEMORY_INSTRS,
];
export type Expr = Instr[];
// Modules

10
ui-tests/asm/drop.nil Normal file
View file

@ -0,0 +1,10 @@
//@check-pass
function dropping(a: I32) =
___asm(
__locals(),
"local.get 0",
"drop",
);
function main() = ;

View file

@ -0,0 +1,7 @@
function a(a: I32) =
___asm(
__locals(),
0,
);
function main() = ;

View file

@ -0,0 +1,4 @@
error: inline assembly instruction must be string literal with instruction
--> $DIR/instr_not_string.nil:4
4 | 0,
^

View file

@ -0,0 +1,9 @@
//@check-pass
function dropping(a: I32) =
___asm(
__locals(),
"meow meow",
);
function main() = ;

View file

@ -0,0 +1,4 @@
error: unknown instruction: meow
--> $DIR/invalid_instr.nil:6
6 | "meow meow",
^^^^^^^^^^^

View file

@ -0,0 +1,9 @@
//@check-pass
function dropping(a: I32) =
___asm(
"local.get 0",
"drop",
);
function main() = ;

View file

@ -0,0 +1,4 @@
error: inline assembly must have __locals() as first argument
--> $DIR/missing_locals.nil:4
4 | ___asm(
^

View file

@ -0,0 +1,6 @@
function dropping(a: I32) = (
1;
___asm(__locals(), "drop");
);
function main() = ;

View file

@ -0,0 +1,16 @@
error: `___asm` cannot be used as a value
--> $DIR/not_toplevel.nil:3
3 | ___asm(__locals(), "drop");
^^^^^^
error: `__locals` cannot be used as a value
--> $DIR/not_toplevel.nil:3
3 | ___asm(__locals(), "drop");
^^^^^^^^
error: expression of type <ERROR> is not callable
--> $DIR/not_toplevel.nil:3
3 | ___asm(__locals(), "drop");
^^^^^^^^
error: expression of type <ERROR> is not callable
--> $DIR/not_toplevel.nil:3
3 | ___asm(__locals(), "drop");
^^^^^^

View file

@ -0,0 +1,17 @@
//@check-pass
function a(a: I32) =
___asm(
__locals(),
"local.get 0 0",
"drop",
);
function b(a: I32) =
___asm(
__locals(),
"local.get",
"drop",
);
function main() = ;

View file

@ -0,0 +1,8 @@
error: mismatched immediate lengths, expected 1, got 2
--> $DIR/wrong_imm.nil:6
6 | "local.get 0 0",
^^^^^^^^^^^^^^^
error: mismatched immediate lengths, expected 1, got 0
--> $DIR/wrong_imm.nil:13
13 | "local.get",
^^^^^^^^^^^

View file

@ -0,0 +1,17 @@
/home/nils/projects/riverdelta/target/ast.js:110
throw new Error(`substitution out of range, param index ${ty.idx} of param ${ty.name} out of range for length ${genericArgs.length}`);
^
Error: substitution out of range, param index 0 of param T out of range for length 0
at substituteTy (/home/nils/projects/riverdelta/target/ast.js:110:23)
at subst (/home/nils/projects/riverdelta/target/ast.js:106:27)
at Array.map (<anonymous>)
at substituteTy (/home/nils/projects/riverdelta/target/ast.js:125:45)
at typeOfItem (/home/nils/projects/riverdelta/target/typeck/item.js:193:33)
at Object.itemInner (/home/nils/projects/riverdelta/target/typeck/index.js:63:54)
at Object.item (/home/nils/projects/riverdelta/target/ast.js:146:34)
at /home/nils/projects/riverdelta/target/ast.js:164:55
at Array.map (<anonymous>)
at foldAst (/home/nils/projects/riverdelta/target/ast.js:164:34)
Node.js v18.18.2