import { CompilerError, SourceLocation } from "..";
import {
CallExpression,
Effect,
Environment,
FinishMemoize,
FunctionExpression,
HIRFunction,
IdentifierId,
Instruction,
InstructionId,
InstructionValue,
LoadGlobal,
LoadLocal,
ManualMemoDependency,
MethodCall,
Place,
PropertyLoad,
SpreadPattern,
StartMemoize,
TInstruction,
getHookKindForType,
makeInstructionId,
} from "../HIR";
import { createTemporaryPlace, markInstructionIds } from "../HIR/HIRBuilder";
type ManualMemoCallee = {
kind: "useMemo" | "useCallback";
loadInstr: TInstruction<LoadGlobal> | TInstruction<PropertyLoad>;
};
type IdentifierSidemap = {
functions: Map<IdentifierId, TInstruction<FunctionExpression>>;
manualMemos: Map<IdentifierId, ManualMemoCallee>;
react: Set<IdentifierId>;
maybeDepsLists: Map<IdentifierId, Array<Place>>;
maybeDeps: Map<IdentifierId, ManualMemoDependency>;
};
export function collectMaybeMemoDependencies(
value: InstructionValue,
maybeDeps: Map<IdentifierId, ManualMemoDependency>
): ManualMemoDependency | null {
switch (value.kind) {
case "LoadGlobal": {
return {
root: {
kind: "Global",
identifierName: value.binding.name,
},
path: [],
};
}
case "PropertyLoad": {
const object = maybeDeps.get(value.object.identifier.id);
if (object != null) {
return {
root: object.root,
path: [...object.path, value.property],
};
}
break;
}
case "LoadLocal":
case "LoadContext": {
const source = maybeDeps.get(value.place.identifier.id);
if (source != null) {
return source;
} else if (
value.place.identifier.name != null &&
value.place.identifier.name.kind === "named"
) {
return {
root: {
kind: "NamedLocal",
value: { ...value.place },
},
path: [],
};
}
break;
}
case "StoreLocal": {
const lvalue = value.lvalue.place.identifier;
const rvalue = value.value.identifier.id;
const aliased = maybeDeps.get(rvalue);
if (aliased != null && lvalue.name?.kind !== "named") {
maybeDeps.set(lvalue.id, aliased);
return aliased;
}
break;
}
}
return null;
}
function collectTemporaries(
instr: Instruction,
env: Environment,
sidemap: IdentifierSidemap
): void {
const { value, lvalue } = instr;
switch (value.kind) {
case "FunctionExpression": {
sidemap.functions.set(
instr.lvalue.identifier.id,
instr as TInstruction<FunctionExpression>
);
break;
}
case "LoadGlobal": {
const global = env.getGlobalDeclaration(value.binding);
const hookKind = global !== null ? getHookKindForType(env, global) : null;
const lvalId = instr.lvalue.identifier.id;
if (hookKind === "useMemo" || hookKind === "useCallback") {
sidemap.manualMemos.set(lvalId, {
kind: hookKind,
loadInstr: instr as TInstruction<LoadGlobal>,
});
} else if (value.binding.name === "React") {
sidemap.react.add(lvalId);
}
break;
}
case "PropertyLoad": {
if (sidemap.react.has(value.object.identifier.id)) {
if (value.property === "useMemo" || value.property === "useCallback") {
sidemap.manualMemos.set(instr.lvalue.identifier.id, {
kind: value.property,
loadInstr: instr as TInstruction<PropertyLoad>,
});
}
}
break;
}
case "ArrayExpression": {
if (value.elements.every((e) => e.kind === "Identifier")) {
sidemap.maybeDepsLists.set(
instr.lvalue.identifier.id,
value.elements as Array<Place>
);
}
break;
}
}
const maybeDep = collectMaybeMemoDependencies(value, sidemap.maybeDeps);
if (maybeDep != null) {
sidemap.maybeDeps.set(lvalue.identifier.id, maybeDep);
}
}
function makeManualMemoizationMarkers(
fnExpr: Place,
env: Environment,
depsList: Array<ManualMemoDependency> | null,
memoDecl: Place,
manualMemoId: number
): [TInstruction<StartMemoize>, TInstruction<FinishMemoize>] {
return [
{
id: makeInstructionId(0),
lvalue: createTemporaryPlace(env, fnExpr.loc),
value: {
kind: "StartMemoize",
manualMemoId,
deps: depsList,
loc: fnExpr.loc,
},
loc: fnExpr.loc,
},
{
id: makeInstructionId(0),
lvalue: createTemporaryPlace(env, fnExpr.loc),
value: {
kind: "FinishMemoize",
manualMemoId,
decl: { ...memoDecl },
loc: fnExpr.loc,
},
loc: fnExpr.loc,
},
];
}
function getManualMemoizationReplacement(
fn: Place,
loc: SourceLocation,
kind: "useMemo" | "useCallback"
): LoadLocal | CallExpression {
if (kind === "useMemo") {
return {
kind: "CallExpression",
callee: fn,
args: [],
loc,
};
} else {
return {
kind: "LoadLocal",
place: {
kind: "Identifier",
identifier: fn.identifier,
effect: Effect.Unknown,
reactive: false,
loc,
},
loc,
};
}
}
function extractManualMemoizationArgs(
instr: TInstruction<CallExpression> | TInstruction<MethodCall>,
kind: "useCallback" | "useMemo",
sidemap: IdentifierSidemap
): {
fnPlace: Place;
depsList: Array<ManualMemoDependency> | null;
} {
const [fnPlace, depsListPlace] = instr.value.args as Array<
Place | SpreadPattern | undefined
>;
if (fnPlace == null) {
CompilerError.throwInvalidReact({
reason: `Expected a callback function to be passed to ${kind}`,
loc: instr.value.loc,
suggestions: null,
});
}
if (fnPlace.kind === "Spread" || depsListPlace?.kind === "Spread") {
CompilerError.throwInvalidReact({
reason: `Unexpected spread argument to ${kind}`,
loc: instr.value.loc,
suggestions: null,
});
}
let depsList: Array<ManualMemoDependency> | null = null;
if (depsListPlace != null) {
const maybeDepsList = sidemap.maybeDepsLists.get(
depsListPlace.identifier.id
);
if (maybeDepsList == null) {
CompilerError.throwInvalidReact({
reason: `Expected the dependency list for ${kind} to be an array literal`,
suggestions: null,
loc: depsListPlace.loc,
});
}
depsList = maybeDepsList.map((dep) => {
const maybeDep = sidemap.maybeDeps.get(dep.identifier.id);
if (maybeDep == null) {
CompilerError.throwInvalidReact({
reason: `Expected the dependency list to be an array of simple expressions (e.g. \`x\`, \`x.y.z\`, \`x?.y?.z\`)`,
suggestions: null,
loc: dep.loc,
});
}
return maybeDep;
});
}
return {
fnPlace,
depsList,
};
}
export function dropManualMemoization(func: HIRFunction): void {
const isValidationEnabled =
func.env.config.validatePreserveExistingMemoizationGuarantees ||
func.env.config.enablePreserveExistingMemoizationGuarantees;
const sidemap: IdentifierSidemap = {
functions: new Map(),
manualMemos: new Map(),
react: new Set(),
maybeDeps: new Map(),
maybeDepsLists: new Map(),
};
let nextManualMemoId = 0;
const queuedInserts: Map<
InstructionId,
TInstruction<StartMemoize> | TInstruction<FinishMemoize>
> = new Map();
for (const [_, block] of func.body.blocks) {
for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i]!;
if (
instr.value.kind === "CallExpression" ||
instr.value.kind === "MethodCall"
) {
const id =
instr.value.kind === "CallExpression"
? instr.value.callee.identifier.id
: instr.value.property.identifier.id;
const manualMemo = sidemap.manualMemos.get(id);
if (manualMemo != null) {
const { fnPlace, depsList } = extractManualMemoizationArgs(
instr as TInstruction<CallExpression> | TInstruction<MethodCall>,
manualMemo.kind,
sidemap
);
instr.value = getManualMemoizationReplacement(
fnPlace,
instr.value.loc,
manualMemo.kind
);
if (isValidationEnabled) {
if (!sidemap.functions.has(fnPlace.identifier.id)) {
CompilerError.throwInvalidReact({
reason: `Expected the first argument to be an inline function expression`,
suggestions: [],
loc: fnPlace.loc,
});
}
const memoDecl: Place =
manualMemo.kind === "useMemo"
? instr.lvalue
: {
kind: "Identifier",
identifier: fnPlace.identifier,
effect: Effect.Unknown,
reactive: false,
loc: fnPlace.loc,
};
const [startMarker, finishMarker] = makeManualMemoizationMarkers(
fnPlace,
func.env,
depsList,
memoDecl,
nextManualMemoId++
);
queuedInserts.set(manualMemo.loadInstr.id, startMarker);
queuedInserts.set(instr.id, finishMarker);
}
}
} else {
collectTemporaries(instr, func.env, sidemap);
}
}
}
if (queuedInserts.size > 0) {
let hasChanges = false;
for (const [_, block] of func.body.blocks) {
let nextInstructions: Array<Instruction> | null = null;
for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i];
const insertInstr = queuedInserts.get(instr.id);
if (insertInstr != null) {
nextInstructions = nextInstructions ?? block.instructions.slice(0, i);
nextInstructions.push(instr);
nextInstructions.push(insertInstr);
} else if (nextInstructions != null) {
nextInstructions.push(instr);
}
}
if (nextInstructions !== null) {
block.instructions = nextInstructions;
hasChanges = true;
}
}
if (hasChanges) {
markInstructionIds(func.body);
}
}
}