import {CompilerError} from '../CompilerError';
import {
DeclarationId,
Environment,
Identifier,
InstructionId,
Pattern,
Place,
ReactiveFunction,
ReactiveInstruction,
ReactiveScopeBlock,
ReactiveStatement,
ReactiveTerminal,
ReactiveTerminalStatement,
ReactiveValue,
ScopeId,
getHookKind,
isMutableEffect,
} from '../HIR';
import {getFunctionCallSignature} from '../Inference/InferReferenceEffects';
import {assertExhaustive, getOrInsertDefault} from '../Utils/utils';
import {getPlaceScope, ReactiveScope} from '../HIR/HIR';
import {
ReactiveFunctionTransform,
ReactiveFunctionVisitor,
Transformed,
eachReactiveValueOperand,
visitReactiveFunction,
} from './visitors';
import {printPlace} from '../HIR/PrintHIR';
export function pruneNonEscapingScopes(fn: ReactiveFunction): void {
const state = new State(fn.env);
for (const param of fn.params) {
if (param.kind === 'Identifier') {
state.declare(param.identifier.declarationId);
} else {
state.declare(param.place.identifier.declarationId);
}
}
visitReactiveFunction(fn, new CollectDependenciesVisitor(fn.env, state), []);
const memoized = computeMemoizedIdentifiers(state);
visitReactiveFunction(fn, new PruneScopesTransform(), memoized);
}
export type MemoizationOptions = {
memoizeJsxElements: boolean;
forceMemoizePrimitives: boolean;
};
enum MemoizationLevel {
Memoized = 'Memoized',
Conditional = 'Conditional',
Unmemoized = 'Unmemoized',
Never = 'Never',
}
function joinAliases(
kind1: MemoizationLevel,
kind2: MemoizationLevel,
): MemoizationLevel {
if (
kind1 === MemoizationLevel.Memoized ||
kind2 === MemoizationLevel.Memoized
) {
return MemoizationLevel.Memoized;
} else if (
kind1 === MemoizationLevel.Conditional ||
kind2 === MemoizationLevel.Conditional
) {
return MemoizationLevel.Conditional;
} else if (
kind1 === MemoizationLevel.Unmemoized ||
kind2 === MemoizationLevel.Unmemoized
) {
return MemoizationLevel.Unmemoized;
} else {
return MemoizationLevel.Never;
}
}
type IdentifierNode = {
level: MemoizationLevel;
memoized: boolean;
dependencies: Set<DeclarationId>;
scopes: Set<ScopeId>;
seen: boolean;
};
type ScopeNode = {
dependencies: Array<DeclarationId>;
seen: boolean;
};
class State {
env: Environment;
definitions: Map<DeclarationId, DeclarationId> = new Map();
identifiers: Map<DeclarationId, IdentifierNode> = new Map();
scopes: Map<ScopeId, ScopeNode> = new Map();
escapingValues: Set<DeclarationId> = new Set();
constructor(env: Environment) {
this.env = env;
}
declare(id: DeclarationId): void {
this.identifiers.set(id, {
level: MemoizationLevel.Never,
memoized: false,
dependencies: new Set(),
scopes: new Set(),
seen: false,
});
}
visitOperand(
id: InstructionId,
place: Place,
identifier: DeclarationId,
): void {
const scope = getPlaceScope(id, place);
if (scope !== null) {
let node = this.scopes.get(scope.id);
if (node === undefined) {
node = {
dependencies: [...scope.dependencies].map(
dep => dep.identifier.declarationId,
),
seen: false,
};
this.scopes.set(scope.id, node);
}
const identifierNode = this.identifiers.get(identifier);
CompilerError.invariant(identifierNode !== undefined, {
reason: 'Expected identifier to be initialized',
description: `[${id}] operand=${printPlace(place)} for identifier declaration ${identifier}`,
loc: place.loc,
suggestions: null,
});
identifierNode.scopes.add(scope.id);
}
}
}
function computeMemoizedIdentifiers(state: State): Set<DeclarationId> {
const memoized = new Set<DeclarationId>();
function visit(id: DeclarationId, forceMemoize: boolean = false): boolean {
const node = state.identifiers.get(id);
CompilerError.invariant(node !== undefined, {
reason: `Expected a node for all identifiers, none found for \`${id}\``,
description: null,
loc: null,
suggestions: null,
});
if (node.seen) {
return node.memoized;
}
node.seen = true;
node.memoized = false;
let hasMemoizedDependency = false;
for (const dep of node.dependencies) {
const isDepMemoized = visit(dep);
hasMemoizedDependency ||= isDepMemoized;
}
if (
node.level === MemoizationLevel.Memoized ||
(node.level === MemoizationLevel.Conditional &&
(hasMemoizedDependency || forceMemoize)) ||
(node.level === MemoizationLevel.Unmemoized && forceMemoize)
) {
node.memoized = true;
memoized.add(id);
for (const scope of node.scopes) {
forceMemoizeScopeDependencies(scope);
}
}
return node.memoized;
}
function forceMemoizeScopeDependencies(id: ScopeId): void {
const node = state.scopes.get(id);
CompilerError.invariant(node !== undefined, {
reason: 'Expected a node for all scopes',
description: null,
loc: null,
suggestions: null,
});
if (node.seen) {
return;
}
node.seen = true;
for (const dep of node.dependencies) {
visit(dep, true);
}
return;
}
for (const value of state.escapingValues) {
visit(value);
}
return memoized;
}
type LValueMemoization = {
place: Place;
level: MemoizationLevel;
};
function computePatternLValues(pattern: Pattern): Array<LValueMemoization> {
const lvalues: Array<LValueMemoization> = [];
switch (pattern.kind) {
case 'ArrayPattern': {
for (const item of pattern.items) {
if (item.kind === 'Identifier') {
lvalues.push({place: item, level: MemoizationLevel.Conditional});
} else if (item.kind === 'Spread') {
lvalues.push({place: item.place, level: MemoizationLevel.Memoized});
}
}
break;
}
case 'ObjectPattern': {
for (const property of pattern.properties) {
if (property.kind === 'ObjectProperty') {
lvalues.push({
place: property.place,
level: MemoizationLevel.Conditional,
});
} else {
lvalues.push({
place: property.place,
level: MemoizationLevel.Memoized,
});
}
}
break;
}
default: {
assertExhaustive(
pattern,
`Unexpected pattern kind \`${(pattern as any).kind}\``,
);
}
}
return lvalues;
}
class CollectDependenciesVisitor extends ReactiveFunctionVisitor<
Array<ReactiveScope>
> {
env: Environment;
state: State;
options: MemoizationOptions;
constructor(env: Environment, state: State) {
super();
this.env = env;
this.state = state;
this.options = {
memoizeJsxElements: !this.env.config.enableForest,
forceMemoizePrimitives: this.env.config.enableForest,
};
}
computeMemoizationInputs(
value: ReactiveValue,
lvalue: Place | null,
): {
lvalues: Array<LValueMemoization>;
rvalues: Array<Place>;
} {
const env = this.env;
const options = this.options;
switch (value.kind) {
case 'ConditionalExpression': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [
...this.computeMemoizationInputs(value.consequent, null).rvalues,
...this.computeMemoizationInputs(value.alternate, null).rvalues,
],
};
}
case 'LogicalExpression': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [
...this.computeMemoizationInputs(value.left, null).rvalues,
...this.computeMemoizationInputs(value.right, null).rvalues,
],
};
}
case 'SequenceExpression': {
for (const instr of value.instructions) {
this.visitValueForMemoization(instr.id, instr.value, instr.lvalue);
}
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: this.computeMemoizationInputs(value.value, null).rvalues,
};
}
case 'JsxExpression': {
const operands: Array<Place> = [];
if (value.tag.kind === 'Identifier') {
operands.push(value.tag);
}
for (const prop of value.props) {
if (prop.kind === 'JsxAttribute') {
operands.push(prop.place);
} else {
operands.push(prop.argument);
}
}
if (value.children !== null) {
for (const child of value.children) {
operands.push(child);
}
}
const level = options.memoizeJsxElements
? MemoizationLevel.Memoized
: MemoizationLevel.Unmemoized;
return {
lvalues: lvalue !== null ? [{place: lvalue, level}] : [],
rvalues: operands,
};
}
case 'JsxFragment': {
const level = options.memoizeJsxElements
? MemoizationLevel.Memoized
: MemoizationLevel.Unmemoized;
return {
lvalues: lvalue !== null ? [{place: lvalue, level}] : [],
rvalues: value.children,
};
}
case 'NextPropertyOf':
case 'StartMemoize':
case 'FinishMemoize':
case 'Debugger':
case 'ComputedDelete':
case 'PropertyDelete':
case 'LoadGlobal':
case 'MetaProperty':
case 'TemplateLiteral':
case 'Primitive':
case 'JSXText':
case 'BinaryExpression':
case 'UnaryExpression': {
const level = options.forceMemoizePrimitives
? MemoizationLevel.Memoized
: MemoizationLevel.Never;
return {
lvalues: lvalue !== null ? [{place: lvalue, level}] : [],
rvalues: [],
};
}
case 'Await':
case 'TypeCastExpression': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [value.value],
};
}
case 'IteratorNext': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [value.iterator, value.collection],
};
}
case 'GetIterator': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [value.collection],
};
}
case 'LoadLocal': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [value.place],
};
}
case 'LoadContext': {
return {
lvalues:
lvalue !== null
? [{place: lvalue, level: MemoizationLevel.Conditional}]
: [],
rvalues: [value.place],
};
}
case 'DeclareContext': {
const lvalues = [
{place: value.lvalue.place, level: MemoizationLevel.Memoized},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Unmemoized});
}
return {
lvalues,
rvalues: [],
};
}
case 'DeclareLocal': {
const lvalues = [
{place: value.lvalue.place, level: MemoizationLevel.Unmemoized},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Unmemoized});
}
return {
lvalues,
rvalues: [],
};
}
case 'PrefixUpdate':
case 'PostfixUpdate': {
const lvalues = [
{place: value.lvalue, level: MemoizationLevel.Conditional},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
return {
lvalues,
rvalues: [value.value],
};
}
case 'StoreLocal': {
const lvalues = [
{place: value.lvalue.place, level: MemoizationLevel.Conditional},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
return {
lvalues,
rvalues: [value.value],
};
}
case 'StoreContext': {
const lvalues = [
{place: value.lvalue.place, level: MemoizationLevel.Memoized},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
return {
lvalues,
rvalues: [value.value],
};
}
case 'StoreGlobal': {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Unmemoized});
}
return {
lvalues,
rvalues: [value.value],
};
}
case 'Destructure': {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
lvalues.push(...computePatternLValues(value.lvalue.pattern));
return {
lvalues: lvalues,
rvalues: [value.value],
};
}
case 'ComputedLoad':
case 'PropertyLoad': {
const level = options.forceMemoizePrimitives
? MemoizationLevel.Memoized
: MemoizationLevel.Conditional;
return {
lvalues: lvalue !== null ? [{place: lvalue, level}] : [],
rvalues: [value.object],
};
}
case 'ComputedStore': {
const lvalues = [
{place: value.object, level: MemoizationLevel.Conditional},
];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
return {
lvalues,
rvalues: [value.value],
};
}
case 'OptionalExpression': {
const lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Conditional});
}
return {
lvalues: lvalues,
rvalues: [
...this.computeMemoizationInputs(value.value, null).rvalues,
],
};
}
case 'TaggedTemplateExpression': {
const signature = getFunctionCallSignature(
env,
value.tag.identifier.type,
);
let lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Memoized});
}
if (signature?.noAlias === true) {
return {
lvalues,
rvalues: [],
};
}
const operands = [...eachReactiveValueOperand(value)];
lvalues.push(
...operands
.filter(operand => isMutableEffect(operand.effect, operand.loc))
.map(place => ({place, level: MemoizationLevel.Memoized})),
);
return {
lvalues,
rvalues: operands,
};
}
case 'CallExpression': {
const signature = getFunctionCallSignature(
env,
value.callee.identifier.type,
);
let lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Memoized});
}
if (signature?.noAlias === true) {
return {
lvalues,
rvalues: [],
};
}
const operands = [...eachReactiveValueOperand(value)];
lvalues.push(
...operands
.filter(operand => isMutableEffect(operand.effect, operand.loc))
.map(place => ({place, level: MemoizationLevel.Memoized})),
);
return {
lvalues,
rvalues: operands,
};
}
case 'MethodCall': {
const signature = getFunctionCallSignature(
env,
value.property.identifier.type,
);
let lvalues = [];
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Memoized});
}
if (signature?.noAlias === true) {
return {
lvalues,
rvalues: [],
};
}
const operands = [...eachReactiveValueOperand(value)];
lvalues.push(
...operands
.filter(operand => isMutableEffect(operand.effect, operand.loc))
.map(place => ({place, level: MemoizationLevel.Memoized})),
);
return {
lvalues,
rvalues: operands,
};
}
case 'RegExpLiteral':
case 'ObjectMethod':
case 'FunctionExpression':
case 'ArrayExpression':
case 'NewExpression':
case 'ObjectExpression':
case 'PropertyStore': {
const operands = [...eachReactiveValueOperand(value)];
const lvalues = operands
.filter(operand => isMutableEffect(operand.effect, operand.loc))
.map(place => ({place, level: MemoizationLevel.Memoized}));
if (lvalue !== null) {
lvalues.push({place: lvalue, level: MemoizationLevel.Memoized});
}
return {
lvalues,
rvalues: operands,
};
}
case 'UnsupportedNode': {
CompilerError.invariant(false, {
reason: `Unexpected unsupported node`,
description: null,
loc: value.loc,
suggestions: null,
});
}
default: {
assertExhaustive(
value,
`Unexpected value kind \`${(value as any).kind}\``,
);
}
}
}
visitValueForMemoization(
id: InstructionId,
value: ReactiveValue,
lvalue: Place | null,
): void {
const state = this.state;
const aliasing = this.computeMemoizationInputs(value, lvalue);
for (const operand of aliasing.rvalues) {
const operandId =
state.definitions.get(operand.identifier.declarationId) ??
operand.identifier.declarationId;
state.visitOperand(id, operand, operandId);
}
for (const {place: lvalue, level} of aliasing.lvalues) {
const lvalueId =
state.definitions.get(lvalue.identifier.declarationId) ??
lvalue.identifier.declarationId;
let node = state.identifiers.get(lvalueId);
if (node === undefined) {
node = {
level: MemoizationLevel.Never,
memoized: false,
dependencies: new Set(),
scopes: new Set(),
seen: false,
};
state.identifiers.set(lvalueId, node);
}
node.level = joinAliases(node.level, level);
for (const operand of aliasing.rvalues) {
const operandId =
state.definitions.get(operand.identifier.declarationId) ??
operand.identifier.declarationId;
if (operandId === lvalueId) {
continue;
}
node.dependencies.add(operandId);
}
state.visitOperand(id, lvalue, lvalueId);
}
if (value.kind === 'LoadLocal' && lvalue !== null) {
state.definitions.set(
lvalue.identifier.declarationId,
value.place.identifier.declarationId,
);
} else if (value.kind === 'CallExpression' || value.kind === 'MethodCall') {
let callee =
value.kind === 'CallExpression' ? value.callee : value.property;
if (getHookKind(state.env, callee.identifier) != null) {
const signature = getFunctionCallSignature(
this.env,
callee.identifier.type,
);
if (signature && signature.noAlias === true) {
return;
}
for (const operand of value.args) {
const place = operand.kind === 'Spread' ? operand.place : operand;
state.escapingValues.add(place.identifier.declarationId);
}
}
}
}
override visitInstruction(
instruction: ReactiveInstruction,
_scopes: Array<ReactiveScope>,
): void {
this.visitValueForMemoization(
instruction.id,
instruction.value,
instruction.lvalue,
);
}
override visitTerminal(
stmt: ReactiveTerminalStatement<ReactiveTerminal>,
scopes: Array<ReactiveScope>,
): void {
this.traverseTerminal(stmt, scopes);
if (stmt.terminal.kind === 'return') {
this.state.escapingValues.add(
stmt.terminal.value.identifier.declarationId,
);
const identifierNode = this.state.identifiers.get(
stmt.terminal.value.identifier.declarationId,
);
CompilerError.invariant(identifierNode !== undefined, {
reason: 'Expected identifier to be initialized',
description: null,
loc: stmt.terminal.loc,
suggestions: null,
});
for (const scope of scopes) {
identifierNode.scopes.add(scope.id);
}
}
}
override visitScope(
scope: ReactiveScopeBlock,
scopes: Array<ReactiveScope>,
): void {
for (const reassignment of scope.scope.reassignments) {
const identifierNode = this.state.identifiers.get(
reassignment.declarationId,
);
CompilerError.invariant(identifierNode !== undefined, {
reason: 'Expected identifier to be initialized',
description: null,
loc: reassignment.loc,
suggestions: null,
});
for (const scope of scopes) {
identifierNode.scopes.add(scope.id);
}
identifierNode.scopes.add(scope.scope.id);
}
this.traverseScope(scope, [...scopes, scope.scope]);
}
}
class PruneScopesTransform extends ReactiveFunctionTransform<
Set<DeclarationId>
> {
prunedScopes: Set<ScopeId> = new Set();
reassignments: Map<DeclarationId, Set<Identifier>> = new Map();
override transformScope(
scopeBlock: ReactiveScopeBlock,
state: Set<DeclarationId>,
): Transformed<ReactiveStatement> {
this.visitScope(scopeBlock, state);
if (
(scopeBlock.scope.declarations.size === 0 &&
scopeBlock.scope.reassignments.size === 0) ||
scopeBlock.scope.earlyReturnValue !== null
) {
return {kind: 'keep'};
}
const hasMemoizedOutput =
Array.from(scopeBlock.scope.declarations.values()).some(decl =>
state.has(decl.identifier.declarationId),
) ||
Array.from(scopeBlock.scope.reassignments).some(identifier =>
state.has(identifier.declarationId),
);
if (hasMemoizedOutput) {
return {kind: 'keep'};
} else {
this.prunedScopes.add(scopeBlock.scope.id);
return {
kind: 'replace-many',
value: scopeBlock.instructions,
};
}
}
override transformInstruction(
instruction: ReactiveInstruction,
state: Set<DeclarationId>,
): Transformed<ReactiveStatement> {
this.traverseInstruction(instruction, state);
const value = instruction.value;
if (value.kind === 'StoreLocal' && value.lvalue.kind === 'Reassign') {
const ids = getOrInsertDefault(
this.reassignments,
value.lvalue.place.identifier.declarationId,
new Set(),
);
ids.add(value.value.identifier);
} else if (value.kind === 'FinishMemoize') {
let decls;
if (value.decl.identifier.scope == null) {
decls = this.reassignments.get(value.decl.identifier.declarationId) ?? [
value.decl.identifier,
];
} else {
decls = [value.decl.identifier];
}
if (
[...decls].every(
decl => decl.scope == null || this.prunedScopes.has(decl.scope.id),
)
) {
value.pruned = true;
}
}
return {kind: 'keep'};
}
}