import {CompilerError, SourceLocation} from '..';
import {
CallExpression,
Effect,
Environment,
FinishMemoize,
FunctionExpression,
HIRFunction,
IdentifierId,
Instruction,
InstructionId,
InstructionValue,
LoadGlobal,
LoadLocal,
ManualMemoDependency,
MethodCall,
Place,
PropertyLoad,
SpreadPattern,
StartMemoize,
TInstruction,
getHookKindForType,
makeInstructionId,
} from '../HIR';
import {createTemporaryPlace, markInstructionIds} from '../HIR/HIRBuilder';
type ManualMemoCallee = {
kind: 'useMemo' | 'useCallback';
loadInstr: TInstruction<LoadGlobal> | TInstruction<PropertyLoad>;
};
type IdentifierSidemap = {
functions: Map<IdentifierId, TInstruction<FunctionExpression>>;
manualMemos: Map<IdentifierId, ManualMemoCallee>;
react: Set<IdentifierId>;
maybeDepsLists: Map<IdentifierId, Array<Place>>;
maybeDeps: Map<IdentifierId, ManualMemoDependency>;
optionals: Set<IdentifierId>;
};
export function collectMaybeMemoDependencies(
value: InstructionValue,
maybeDeps: Map<IdentifierId, ManualMemoDependency>,
optional: boolean,
): ManualMemoDependency | null {
switch (value.kind) {
case 'LoadGlobal': {
return {
root: {
kind: 'Global',
identifierName: value.binding.name,
},
path: [],
};
}
case 'PropertyLoad': {
const object = maybeDeps.get(value.object.identifier.id);
if (object != null) {
return {
root: object.root,
path: [...object.path, {property: value.property, optional}],
};
}
break;
}
case 'LoadLocal':
case 'LoadContext': {
const source = maybeDeps.get(value.place.identifier.id);
if (source != null) {
return source;
} else if (
value.place.identifier.name != null &&
value.place.identifier.name.kind === 'named'
) {
return {
root: {
kind: 'NamedLocal',
value: {...value.place},
},
path: [],
};
}
break;
}
case 'StoreLocal': {
const lvalue = value.lvalue.place.identifier;
const rvalue = value.value.identifier.id;
const aliased = maybeDeps.get(rvalue);
if (aliased != null && lvalue.name?.kind !== 'named') {
maybeDeps.set(lvalue.id, aliased);
return aliased;
}
break;
}
}
return null;
}
function collectTemporaries(
instr: Instruction,
env: Environment,
sidemap: IdentifierSidemap,
): void {
const {value, lvalue} = instr;
switch (value.kind) {
case 'FunctionExpression': {
sidemap.functions.set(
instr.lvalue.identifier.id,
instr as TInstruction<FunctionExpression>,
);
break;
}
case 'LoadGlobal': {
const global = env.getGlobalDeclaration(value.binding, value.loc);
const hookKind = global !== null ? getHookKindForType(env, global) : null;
const lvalId = instr.lvalue.identifier.id;
if (hookKind === 'useMemo' || hookKind === 'useCallback') {
sidemap.manualMemos.set(lvalId, {
kind: hookKind,
loadInstr: instr as TInstruction<LoadGlobal>,
});
} else if (value.binding.name === 'React') {
sidemap.react.add(lvalId);
}
break;
}
case 'PropertyLoad': {
if (sidemap.react.has(value.object.identifier.id)) {
if (value.property === 'useMemo' || value.property === 'useCallback') {
sidemap.manualMemos.set(instr.lvalue.identifier.id, {
kind: value.property,
loadInstr: instr as TInstruction<PropertyLoad>,
});
}
}
break;
}
case 'ArrayExpression': {
if (value.elements.every(e => e.kind === 'Identifier')) {
sidemap.maybeDepsLists.set(
instr.lvalue.identifier.id,
value.elements as Array<Place>,
);
}
break;
}
}
const maybeDep = collectMaybeMemoDependencies(
value,
sidemap.maybeDeps,
sidemap.optionals.has(lvalue.identifier.id),
);
if (maybeDep != null) {
sidemap.maybeDeps.set(lvalue.identifier.id, maybeDep);
}
}
function makeManualMemoizationMarkers(
fnExpr: Place,
env: Environment,
depsList: Array<ManualMemoDependency> | null,
memoDecl: Place,
manualMemoId: number,
): [TInstruction<StartMemoize>, TInstruction<FinishMemoize>] {
return [
{
id: makeInstructionId(0),
lvalue: createTemporaryPlace(env, fnExpr.loc),
value: {
kind: 'StartMemoize',
manualMemoId,
deps: depsList,
loc: fnExpr.loc,
},
loc: fnExpr.loc,
},
{
id: makeInstructionId(0),
lvalue: createTemporaryPlace(env, fnExpr.loc),
value: {
kind: 'FinishMemoize',
manualMemoId,
decl: {...memoDecl},
loc: fnExpr.loc,
},
loc: fnExpr.loc,
},
];
}
function getManualMemoizationReplacement(
fn: Place,
loc: SourceLocation,
kind: 'useMemo' | 'useCallback',
): LoadLocal | CallExpression {
if (kind === 'useMemo') {
return {
kind: 'CallExpression',
callee: fn,
args: [],
loc,
};
} else {
return {
kind: 'LoadLocal',
place: {
kind: 'Identifier',
identifier: fn.identifier,
effect: Effect.Unknown,
reactive: false,
loc,
},
loc,
};
}
}
function extractManualMemoizationArgs(
instr: TInstruction<CallExpression> | TInstruction<MethodCall>,
kind: 'useCallback' | 'useMemo',
sidemap: IdentifierSidemap,
): {
fnPlace: Place;
depsList: Array<ManualMemoDependency> | null;
} {
const [fnPlace, depsListPlace] = instr.value.args as Array<
Place | SpreadPattern | undefined
>;
if (fnPlace == null) {
CompilerError.throwInvalidReact({
reason: `Expected a callback function to be passed to ${kind}`,
loc: instr.value.loc,
suggestions: null,
});
}
if (fnPlace.kind === 'Spread' || depsListPlace?.kind === 'Spread') {
CompilerError.throwInvalidReact({
reason: `Unexpected spread argument to ${kind}`,
loc: instr.value.loc,
suggestions: null,
});
}
let depsList: Array<ManualMemoDependency> | null = null;
if (depsListPlace != null) {
const maybeDepsList = sidemap.maybeDepsLists.get(
depsListPlace.identifier.id,
);
if (maybeDepsList == null) {
CompilerError.throwInvalidReact({
reason: `Expected the dependency list for ${kind} to be an array literal`,
suggestions: null,
loc: depsListPlace.loc,
});
}
depsList = maybeDepsList.map(dep => {
const maybeDep = sidemap.maybeDeps.get(dep.identifier.id);
if (maybeDep == null) {
CompilerError.throwInvalidReact({
reason: `Expected the dependency list to be an array of simple expressions (e.g. \`x\`, \`x.y.z\`, \`x?.y?.z\`)`,
suggestions: null,
loc: dep.loc,
});
}
return maybeDep;
});
}
return {
fnPlace,
depsList,
};
}
export function dropManualMemoization(func: HIRFunction): void {
const isValidationEnabled =
func.env.config.validatePreserveExistingMemoizationGuarantees ||
func.env.config.validateNoSetStateInRender ||
func.env.config.enablePreserveExistingMemoizationGuarantees;
const optionals = findOptionalPlaces(func);
const sidemap: IdentifierSidemap = {
functions: new Map(),
manualMemos: new Map(),
react: new Set(),
maybeDeps: new Map(),
maybeDepsLists: new Map(),
optionals,
};
let nextManualMemoId = 0;
const queuedInserts: Map<
InstructionId,
TInstruction<StartMemoize> | TInstruction<FinishMemoize>
> = new Map();
for (const [_, block] of func.body.blocks) {
for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i]!;
if (
instr.value.kind === 'CallExpression' ||
instr.value.kind === 'MethodCall'
) {
const id =
instr.value.kind === 'CallExpression'
? instr.value.callee.identifier.id
: instr.value.property.identifier.id;
const manualMemo = sidemap.manualMemos.get(id);
if (manualMemo != null) {
const {fnPlace, depsList} = extractManualMemoizationArgs(
instr as TInstruction<CallExpression> | TInstruction<MethodCall>,
manualMemo.kind,
sidemap,
);
instr.value = getManualMemoizationReplacement(
fnPlace,
instr.value.loc,
manualMemo.kind,
);
if (isValidationEnabled) {
if (!sidemap.functions.has(fnPlace.identifier.id)) {
CompilerError.throwInvalidReact({
reason: `Expected the first argument to be an inline function expression`,
suggestions: [],
loc: fnPlace.loc,
});
}
const memoDecl: Place =
manualMemo.kind === 'useMemo'
? instr.lvalue
: {
kind: 'Identifier',
identifier: fnPlace.identifier,
effect: Effect.Unknown,
reactive: false,
loc: fnPlace.loc,
};
const [startMarker, finishMarker] = makeManualMemoizationMarkers(
fnPlace,
func.env,
depsList,
memoDecl,
nextManualMemoId++,
);
queuedInserts.set(manualMemo.loadInstr.id, startMarker);
queuedInserts.set(instr.id, finishMarker);
}
}
} else {
collectTemporaries(instr, func.env, sidemap);
}
}
}
if (queuedInserts.size > 0) {
let hasChanges = false;
for (const [_, block] of func.body.blocks) {
let nextInstructions: Array<Instruction> | null = null;
for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i];
const insertInstr = queuedInserts.get(instr.id);
if (insertInstr != null) {
nextInstructions = nextInstructions ?? block.instructions.slice(0, i);
nextInstructions.push(instr);
nextInstructions.push(insertInstr);
} else if (nextInstructions != null) {
nextInstructions.push(instr);
}
}
if (nextInstructions !== null) {
block.instructions = nextInstructions;
hasChanges = true;
}
}
if (hasChanges) {
markInstructionIds(func.body);
}
}
}
function findOptionalPlaces(fn: HIRFunction): Set<IdentifierId> {
const optionals = new Set<IdentifierId>();
for (const [, block] of fn.body.blocks) {
if (block.terminal.kind === 'optional' && block.terminal.optional) {
const optionalTerminal = block.terminal;
let testBlock = fn.body.blocks.get(block.terminal.test)!;
loop: while (true) {
const terminal = testBlock.terminal;
switch (terminal.kind) {
case 'branch': {
if (terminal.fallthrough === optionalTerminal.fallthrough) {
const consequent = fn.body.blocks.get(terminal.consequent)!;
const last = consequent.instructions.at(-1);
if (last !== undefined && last.value.kind === 'StoreLocal') {
optionals.add(last.value.value.identifier.id);
}
break loop;
} else {
testBlock = fn.body.blocks.get(terminal.fallthrough)!;
}
break;
}
case 'optional':
case 'logical':
case 'sequence':
case 'ternary': {
testBlock = fn.body.blocks.get(terminal.fallthrough)!;
break;
}
default: {
CompilerError.invariant(false, {
reason: `Unexpected terminal in optional`,
loc: terminal.loc,
});
}
}
}
}
}
return optionals;
}