import { CompilerError } from "../CompilerError";
import {
Environment,
IdentifierId,
InstructionId,
Pattern,
Place,
ReactiveFunction,
ReactiveInstruction,
ReactiveScopeBlock,
ReactiveStatement,
ReactiveTerminal,
ReactiveTerminalStatement,
ReactiveValue,
ScopeId,
getHookKind,
isMutableEffect,
} from "../HIR";
import { getFunctionCallSignature } from "../Inference/InferReferenceEffects";
import { assertExhaustive } from "../Utils/utils";
import { getPlaceScope } from "./BuildReactiveBlocks";
import {
ReactiveFunctionTransform,
ReactiveFunctionVisitor,
Transformed,
eachReactiveValueOperand,
visitReactiveFunction,
} from "./visitors";
export function pruneNonEscapingScopes(fn: ReactiveFunction): void {
const state = new State(fn.env);
for (const param of fn.params) {
if (param.kind === "Identifier") {
state.declare(param.identifier.id);
} else {
state.declare(param.place.identifier.id);
}
}
visitReactiveFunction(fn, new CollectDependenciesVisitor(fn.env), state);
const memoized = computeMemoizedIdentifiers(state);
visitReactiveFunction(fn, new PruneScopesTransform(), memoized);
}
export type MemoizationOptions = {
memoizeJsxElements: boolean;
forceMemoizePrimitives: boolean;
};
enum MemoizationLevel {
Memoized = "Memoized",
Conditional = "Conditional",
Unmemoized = "Unmemoized",
Never = "Never",
}
function joinAliases(
kind1: MemoizationLevel,
kind2: MemoizationLevel
): MemoizationLevel {
if (
kind1 === MemoizationLevel.Memoized ||
kind2 === MemoizationLevel.Memoized
) {
return MemoizationLevel.Memoized;
} else if (
kind1 === MemoizationLevel.Conditional ||
kind2 === MemoizationLevel.Conditional
) {
return MemoizationLevel.Conditional;
} else if (
kind1 === MemoizationLevel.Unmemoized ||
kind2 === MemoizationLevel.Unmemoized
) {
return MemoizationLevel.Unmemoized;
} else {
return MemoizationLevel.Never;
}
}
type IdentifierNode = {
level: MemoizationLevel;
memoized: boolean;
dependencies: Set<IdentifierId>;
scopes: Set<ScopeId>;
seen: boolean;
};
type ScopeNode = {
dependencies: Array<IdentifierId>;
seen: boolean;
};
class State {
env: Environment;
definitions: Map<IdentifierId, IdentifierId> = new Map();
identifiers: Map<IdentifierId, IdentifierNode> = new Map();
scopes: Map<ScopeId, ScopeNode> = new Map();
escapingValues: Set<IdentifierId> = new Set();
constructor(env: Environment) {
this.env = env;
}
declare(id: IdentifierId): void {
this.identifiers.set(id, {
level: MemoizationLevel.Never,
memoized: false,
dependencies: new Set(),
scopes: new Set(),
seen: false,
});
}
visitOperand(
id: InstructionId,
place: Place,
identifier: IdentifierId
): void {
const scope = getPlaceScope(id, place);
if (scope !== null) {
let node = this.scopes.get(scope.id);
if (node === undefined) {
node = {
dependencies: [...scope.dependencies].map((dep) => dep.identifier.id),
seen: false,
};
this.scopes.set(scope.id, node);
}
const identifierNode = this.identifiers.get(identifier);
CompilerError.invariant(identifierNode !== undefined, {
reason: "Expected identifier to be initialized",
description: null,
loc: place.loc,
suggestions: null,
});
identifierNode.scopes.add(scope.id);
}
}
}
function computeMemoizedIdentifiers(state: State): Set<IdentifierId> {
const memoized = new Set<IdentifierId>();
function visit(id: IdentifierId, forceMemoize: boolean = false): boolean {
const node = state.identifiers.get(id);
CompilerError.invariant(node !== undefined, {
reason: `Expected a node for all identifiers, none found for \`${id}\``,
description: null,
loc: null,
suggestions: null,
});
if (node.seen) {
return node.memoized;
}
node.seen = true;
node.memoized = false;
let hasMemoizedDependency = false;
for (const dep of node.dependencies) {
const isDepMemoized = visit(dep);
hasMemoizedDependency ||= isDepMemoized;
}
if (
node.level === MemoizationLevel.Memoized ||
(node.level === MemoizationLevel.Conditional &&
(hasMemoizedDependency || forceMemoize)) ||
(node.level === MemoizationLevel.Unmemoized && forceMemoize)
) {
node.memoized = true;
memoized.add(id);
for (const scope of node.scopes) {
forceMemoizeScopeDependencies(scope);
}
}
return node.memoized;
}
function forceMemoizeScopeDependencies(id: ScopeId): void {
const node = state.scopes.get(id);
CompilerError.invariant(node !== undefined, {
reason: "Expected a node for all scopes",
description: null,
loc: null,
suggestions: null,
});
if (node.seen) {
return;
}
node.seen = true;
for (const dep of node.dependencies) {
visit(dep, true);
}
return;
}
for (const value of state.escapingValues) {
visit(value);
}
return memoized;
}
type LValueMemoization = {
place: Place;
level: MemoizationLevel;
};
function computeMemoizationInputs(
env: Environment,
value: ReactiveValue,
lvalue: Place | null,
options: MemoizationOptions
): {
lvalues: Array<LValueMemoization>;
rvalues: Array<Place>;
} {
switch (value.kind) {
case "ConditionalExpression": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [
...computeMemoizationInputs(env, value.consequent, null, options)
.rvalues,
...computeMemoizationInputs(env, value.alternate, null, options)
.rvalues,
],
};
}
case "LogicalExpression": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [
...computeMemoizationInputs(env, value.left, null, options).rvalues,
...computeMemoizationInputs(env, value.right, null, options).rvalues,
],
};
}
case "SequenceExpression": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: computeMemoizationInputs(env, value.value, null, options)
.rvalues,
};
}
case "JsxExpression": {
const operands: Array<Place> = [];
if (value.tag.kind === "Identifier") {
operands.push(value.tag);
}
for (const prop of value.props) {
if (prop.kind === "JsxAttribute") {
operands.push(prop.place);
} else {
operands.push(prop.argument);
}
}
if (value.children !== null) {
for (const child of value.children) {
operands.push(child);
}
}
const level = options.memoizeJsxElements
? MemoizationLevel.Memoized
: MemoizationLevel.Unmemoized;
return {
lvalues: lvalue !== null ? [{ place: lvalue, level }] : [],
rvalues: operands,
};
}
case "JsxFragment": {
const level = options.memoizeJsxElements
? MemoizationLevel.Memoized
: MemoizationLevel.Unmemoized;
return {
lvalues: lvalue !== null ? [{ place: lvalue, level }] : [],
rvalues: value.children,
};
}
case "NextPropertyOf":
case "StartMemoize":
case "FinishMemoize":
case "Debugger":
case "ComputedDelete":
case "PropertyDelete":
case "LoadGlobal":
case "MetaProperty":
case "TemplateLiteral":
case "Primitive":
case "JSXText":
case "BinaryExpression":
case "UnaryExpression": {
const level = options.forceMemoizePrimitives
? MemoizationLevel.Memoized
: MemoizationLevel.Never;
return {
lvalues: lvalue !== null ? [{ place: lvalue, level }] : [],
rvalues: [],
};
}
case "Await":
case "TypeCastExpression": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [value.value],
};
}
case "IteratorNext": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [value.iterator, value.collection],
};
}
case "GetIterator": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [value.collection],
};
}
case "LoadLocal": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [value.place],
};
}
case "LoadContext": {
return {
lvalues:
lvalue !== null
? [{ place: lvalue, level: MemoizationLevel.Conditional }]
: [],
rvalues: [value.place],
};
}
case "DeclareContext": {
const lvalues = [
{ place: value.lvalue.place, level: MemoizationLevel.Memoized },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Unmemoized });
}
return {
lvalues,
rvalues: [],
};
}
case "DeclareLocal": {
const lvalues = [
{ place: value.lvalue.place, level: MemoizationLevel.Unmemoized },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Unmemoized });
}
return {
lvalues,
rvalues: [],
};
}
case "PrefixUpdate":
case "PostfixUpdate": {
const lvalues = [
{ place: value.lvalue, level: MemoizationLevel.Conditional },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
return {
lvalues,
rvalues: [value.value],
};
}
case "StoreLocal": {
const lvalues = [
{ place: value.lvalue.place, level: MemoizationLevel.Conditional },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
return {
lvalues,
rvalues: [value.value],
};
}
case "StoreContext": {
const lvalues = [
{ place: value.lvalue.place, level: MemoizationLevel.Memoized },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
return {
lvalues,
rvalues: [value.value],
};
}
case "StoreGlobal": {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Unmemoized });
}
return {
lvalues,
rvalues: [value.value],
};
}
case "Destructure": {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
lvalues.push(...computePatternLValues(value.lvalue.pattern));
return {
lvalues: lvalues,
rvalues: [value.value],
};
}
case "ComputedLoad":
case "PropertyLoad": {
const level = options.forceMemoizePrimitives
? MemoizationLevel.Memoized
: MemoizationLevel.Conditional;
return {
lvalues: lvalue !== null ? [{ place: lvalue, level }] : [],
rvalues: [value.object],
};
}
case "ComputedStore": {
const lvalues = [
{ place: value.object, level: MemoizationLevel.Conditional },
];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
return {
lvalues,
rvalues: [value.value],
};
}
case "OptionalExpression": {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Conditional });
}
return {
lvalues: lvalues,
rvalues: [
...computeMemoizationInputs(env, value.value, null, options).rvalues,
],
};
}
case "CallExpression": {
const signature = getFunctionCallSignature(
env,
value.callee.identifier.type
);
const operands = [...eachReactiveValueOperand(value)];
let lvalues = [];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Memoized });
}
if (signature?.noAlias === true) {
return {
lvalues,
rvalues: [],
};
}
lvalues.push(
...operands
.filter((operand) => isMutableEffect(operand.effect, operand.loc))
.map((place) => ({ place, level: MemoizationLevel.Memoized }))
);
return {
lvalues,
rvalues: operands,
};
}
case "MethodCall": {
const signature = getFunctionCallSignature(
env,
value.property.identifier.type
);
const operands = [...eachReactiveValueOperand(value)];
let lvalues = [];
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Memoized });
}
if (signature?.noAlias === true) {
return {
lvalues,
rvalues: [],
};
}
lvalues.push(
...operands
.filter((operand) => isMutableEffect(operand.effect, operand.loc))
.map((place) => ({ place, level: MemoizationLevel.Memoized }))
);
return {
lvalues,
rvalues: operands,
};
}
case "RegExpLiteral":
case "ObjectMethod":
case "FunctionExpression":
case "TaggedTemplateExpression":
case "ArrayExpression":
case "NewExpression":
case "ObjectExpression":
case "PropertyStore": {
const operands = [...eachReactiveValueOperand(value)];
const lvalues = operands
.filter((operand) => isMutableEffect(operand.effect, operand.loc))
.map((place) => ({ place, level: MemoizationLevel.Memoized }));
if (lvalue !== null) {
lvalues.push({ place: lvalue, level: MemoizationLevel.Memoized });
}
return {
lvalues,
rvalues: operands,
};
}
case "ReactiveFunctionValue": {
CompilerError.invariant(false, {
reason: `Unexpected ReactiveFunctionValue node`,
description: null,
loc: value.loc,
suggestions: null,
});
}
case "UnsupportedNode": {
CompilerError.invariant(false, {
reason: `Unexpected unsupported node`,
description: null,
loc: value.loc,
suggestions: null,
});
}
default: {
assertExhaustive(
value,
`Unexpected value kind \`${(value as any).kind}\``
);
}
}
}
function computePatternLValues(pattern: Pattern): Array<LValueMemoization> {
const lvalues: Array<LValueMemoization> = [];
switch (pattern.kind) {
case "ArrayPattern": {
for (const item of pattern.items) {
if (item.kind === "Identifier") {
lvalues.push({ place: item, level: MemoizationLevel.Conditional });
} else if (item.kind === "Spread") {
lvalues.push({ place: item.place, level: MemoizationLevel.Memoized });
}
}
break;
}
case "ObjectPattern": {
for (const property of pattern.properties) {
if (property.kind === "ObjectProperty") {
lvalues.push({
place: property.place,
level: MemoizationLevel.Conditional,
});
} else {
lvalues.push({
place: property.place,
level: MemoizationLevel.Memoized,
});
}
}
break;
}
default: {
assertExhaustive(
pattern,
`Unexpected pattern kind \`${(pattern as any).kind}\``
);
}
}
return lvalues;
}
class CollectDependenciesVisitor extends ReactiveFunctionVisitor<State> {
env: Environment;
options: MemoizationOptions;
constructor(env: Environment) {
super();
this.env = env;
this.options = {
memoizeJsxElements: !this.env.config.enableForest,
forceMemoizePrimitives: this.env.config.enableForest,
};
}
override visitInstruction(
instruction: ReactiveInstruction,
state: State
): void {
this.traverseInstruction(instruction, state);
const aliasing = computeMemoizationInputs(
this.env,
instruction.value,
instruction.lvalue,
this.options
);
for (const operand of aliasing.rvalues) {
const operandId =
state.definitions.get(operand.identifier.id) ?? operand.identifier.id;
state.visitOperand(instruction.id, operand, operandId);
}
for (const { place: lvalue, level } of aliasing.lvalues) {
const lvalueId =
state.definitions.get(lvalue.identifier.id) ?? lvalue.identifier.id;
let node = state.identifiers.get(lvalueId);
if (node === undefined) {
node = {
level: MemoizationLevel.Never,
memoized: false,
dependencies: new Set(),
scopes: new Set(),
seen: false,
};
state.identifiers.set(lvalueId, node);
}
node.level = joinAliases(node.level, level);
for (const operand of aliasing.rvalues) {
const operandId =
state.definitions.get(operand.identifier.id) ?? operand.identifier.id;
if (operandId === lvalueId) {
continue;
}
node.dependencies.add(operandId);
}
state.visitOperand(instruction.id, lvalue, lvalueId);
}
if (instruction.value.kind === "LoadLocal" && instruction.lvalue !== null) {
state.definitions.set(
instruction.lvalue.identifier.id,
instruction.value.place.identifier.id
);
} else if (
instruction.value.kind === "CallExpression" ||
instruction.value.kind === "MethodCall"
) {
let callee =
instruction.value.kind === "CallExpression"
? instruction.value.callee
: instruction.value.property;
if (getHookKind(state.env, callee.identifier) != null) {
const signature = getFunctionCallSignature(
this.env,
callee.identifier.type
);
if (signature && signature.noAlias === true) {
return;
}
for (const operand of instruction.value.args) {
const place = operand.kind === "Spread" ? operand.place : operand;
state.escapingValues.add(place.identifier.id);
}
}
}
}
override visitTerminal(
stmt: ReactiveTerminalStatement<ReactiveTerminal>,
state: State
): void {
this.traverseTerminal(stmt, state);
if (stmt.terminal.kind === "return") {
state.escapingValues.add(stmt.terminal.value.identifier.id);
}
}
}
class PruneScopesTransform extends ReactiveFunctionTransform<
Set<IdentifierId>
> {
prunedScopes: Set<ScopeId> = new Set();
override transformScope(
scopeBlock: ReactiveScopeBlock,
state: Set<IdentifierId>
): Transformed<ReactiveStatement> {
this.visitScope(scopeBlock, state);
if (
scopeBlock.scope.declarations.size === 0 &&
scopeBlock.scope.reassignments.size === 0
) {
return { kind: "keep" };
}
const hasMemoizedOutput =
Array.from(scopeBlock.scope.declarations.keys()).some((id) =>
state.has(id)
) ||
Array.from(scopeBlock.scope.reassignments).some((identifier) =>
state.has(identifier.id)
);
if (hasMemoizedOutput) {
return { kind: "keep" };
} else {
this.prunedScopes.add(scopeBlock.scope.id);
return {
kind: "replace-many",
value: scopeBlock.instructions,
};
}
}
override transformInstruction(
instruction: ReactiveInstruction,
state: Set<IdentifierId>
): Transformed<ReactiveStatement> {
this.traverseInstruction(instruction, state);
if (instruction.value.kind === "FinishMemoize") {
const identifier = instruction.value.decl.identifier;
if (
identifier.scope !== null &&
this.prunedScopes.has(identifier.scope.id)
) {
instruction.value.pruned = true;
}
}
return { kind: "keep" };
}
}