import {CompilerError} from '../CompilerError';
import {inRange} from '../ReactiveScopes/InferReactiveScopeVariables';
import {printDependency} from '../ReactiveScopes/PrintReactiveFunction';
import {
Set_equal,
Set_filter,
Set_intersect,
Set_union,
getOrInsertDefault,
} from '../Utils/utils';
import {
BasicBlock,
BlockId,
DependencyPathEntry,
FunctionExpression,
GeneratedSource,
getHookKind,
HIRFunction,
Identifier,
IdentifierId,
InstructionId,
InstructionValue,
LoweredFunction,
PropertyLiteral,
ReactiveScopeDependency,
ScopeId,
TInstruction,
} from './HIR';
const DEBUG_PRINT = false;
export function collectHoistablePropertyLoads(
fn: HIRFunction,
temporaries: ReadonlyMap<IdentifierId, ReactiveScopeDependency>,
hoistableFromOptionals: ReadonlyMap<BlockId, ReactiveScopeDependency>,
): ReadonlyMap<BlockId, BlockInfo> {
const registry = new PropertyPathRegistry();
const knownImmutableIdentifiers = new Set<IdentifierId>();
if (fn.fnType === 'Component' || fn.fnType === 'Hook') {
for (const p of fn.params) {
if (p.kind === 'Identifier') {
knownImmutableIdentifiers.add(p.identifier.id);
}
}
}
return collectHoistablePropertyLoadsImpl(fn, {
temporaries,
knownImmutableIdentifiers,
hoistableFromOptionals,
registry,
nestedFnImmutableContext: null,
assumedInvokedFns: fn.env.config.enableTreatFunctionDepsAsConditional
? new Set()
: getAssumedInvokedFunctions(fn),
});
}
export function collectHoistablePropertyLoadsInInnerFn(
fnInstr: TInstruction<FunctionExpression>,
temporaries: ReadonlyMap<IdentifierId, ReactiveScopeDependency>,
hoistableFromOptionals: ReadonlyMap<BlockId, ReactiveScopeDependency>,
): ReadonlyMap<BlockId, BlockInfo> {
const fn = fnInstr.value.loweredFunc.func;
const initialContext: CollectHoistablePropertyLoadsContext = {
temporaries,
knownImmutableIdentifiers: new Set(),
hoistableFromOptionals,
registry: new PropertyPathRegistry(),
nestedFnImmutableContext: null,
assumedInvokedFns: fn.env.config.enableTreatFunctionDepsAsConditional
? new Set()
: getAssumedInvokedFunctions(fn),
};
const nestedFnImmutableContext = new Set(
fn.context
.filter(place =>
isImmutableAtInstr(place.identifier, fnInstr.id, initialContext),
)
.map(place => place.identifier.id),
);
initialContext.nestedFnImmutableContext = nestedFnImmutableContext;
return collectHoistablePropertyLoadsImpl(fn, initialContext);
}
type CollectHoistablePropertyLoadsContext = {
temporaries: ReadonlyMap<IdentifierId, ReactiveScopeDependency>;
knownImmutableIdentifiers: ReadonlySet<IdentifierId>;
hoistableFromOptionals: ReadonlyMap<BlockId, ReactiveScopeDependency>;
registry: PropertyPathRegistry;
nestedFnImmutableContext: ReadonlySet<IdentifierId> | null;
assumedInvokedFns: ReadonlySet<LoweredFunction>;
};
function collectHoistablePropertyLoadsImpl(
fn: HIRFunction,
context: CollectHoistablePropertyLoadsContext,
): ReadonlyMap<BlockId, BlockInfo> {
const nodes = collectNonNullsInBlocks(fn, context);
propagateNonNull(fn, nodes, context.registry);
if (DEBUG_PRINT) {
console.log('(printing hoistable nodes in blocks)');
for (const [blockId, node] of nodes) {
console.log(
`bb${blockId}: ${[...node.assumedNonNullObjects].map(n => printDependency(n.fullPath)).join(' ')}`,
);
}
}
return nodes;
}
export function keyByScopeId<T>(
fn: HIRFunction,
source: ReadonlyMap<BlockId, T>,
): ReadonlyMap<ScopeId, T> {
const keyedByScopeId = new Map<ScopeId, T>();
for (const [_, block] of fn.body.blocks) {
if (block.terminal.kind === 'scope') {
keyedByScopeId.set(
block.terminal.scope.id,
source.get(block.terminal.block)!,
);
}
}
return keyedByScopeId;
}
export type BlockInfo = {
block: BasicBlock;
assumedNonNullObjects: ReadonlySet<PropertyPathNode>;
};
type RootNode = {
properties: Map<PropertyLiteral, PropertyPathNode>;
optionalProperties: Map<PropertyLiteral, PropertyPathNode>;
parent: null;
fullPath: ReactiveScopeDependency;
hasOptional: boolean;
root: IdentifierId;
};
type PropertyPathNode =
| {
properties: Map<PropertyLiteral, PropertyPathNode>;
optionalProperties: Map<PropertyLiteral, PropertyPathNode>;
parent: PropertyPathNode;
fullPath: ReactiveScopeDependency;
hasOptional: boolean;
}
| RootNode;
class PropertyPathRegistry {
roots: Map<IdentifierId, RootNode> = new Map();
getOrCreateIdentifier(
identifier: Identifier,
reactive: boolean,
): PropertyPathNode {
let rootNode = this.roots.get(identifier.id);
if (rootNode === undefined) {
rootNode = {
root: identifier.id,
properties: new Map(),
optionalProperties: new Map(),
fullPath: {
identifier,
reactive,
path: [],
},
hasOptional: false,
parent: null,
};
this.roots.set(identifier.id, rootNode);
} else {
CompilerError.invariant(reactive === rootNode.fullPath.reactive, {
reason:
'[HoistablePropertyLoads] Found inconsistencies in `reactive` flag when deduping identifier reads within the same scope',
loc: identifier.loc,
});
}
return rootNode;
}
static getOrCreatePropertyEntry(
parent: PropertyPathNode,
entry: DependencyPathEntry,
): PropertyPathNode {
const map = entry.optional ? parent.optionalProperties : parent.properties;
let child = map.get(entry.property);
if (child == null) {
child = {
properties: new Map(),
optionalProperties: new Map(),
parent: parent,
fullPath: {
identifier: parent.fullPath.identifier,
reactive: parent.fullPath.reactive,
path: parent.fullPath.path.concat(entry),
},
hasOptional: parent.hasOptional || entry.optional,
};
map.set(entry.property, child);
}
return child;
}
getOrCreateProperty(n: ReactiveScopeDependency): PropertyPathNode {
let currNode = this.getOrCreateIdentifier(n.identifier, n.reactive);
if (n.path.length === 0) {
return currNode;
}
for (let i = 0; i < n.path.length - 1; i++) {
currNode = PropertyPathRegistry.getOrCreatePropertyEntry(
currNode,
n.path[i],
);
}
return PropertyPathRegistry.getOrCreatePropertyEntry(
currNode,
n.path.at(-1)!,
);
}
}
function getMaybeNonNullInInstruction(
instr: InstructionValue,
context: CollectHoistablePropertyLoadsContext,
): PropertyPathNode | null {
let path: ReactiveScopeDependency | null = null;
if (instr.kind === 'PropertyLoad') {
path = context.temporaries.get(instr.object.identifier.id) ?? {
identifier: instr.object.identifier,
reactive: instr.object.reactive,
path: [],
};
} else if (instr.kind === 'Destructure') {
path = context.temporaries.get(instr.value.identifier.id) ?? null;
} else if (instr.kind === 'ComputedLoad') {
path = context.temporaries.get(instr.object.identifier.id) ?? null;
}
return path != null ? context.registry.getOrCreateProperty(path) : null;
}
function isImmutableAtInstr(
identifier: Identifier,
instr: InstructionId,
context: CollectHoistablePropertyLoadsContext,
): boolean {
if (context.nestedFnImmutableContext != null) {
return context.nestedFnImmutableContext.has(identifier.id);
} else {
const mutableAtInstr =
identifier.mutableRange.end > identifier.mutableRange.start + 1 &&
identifier.scope != null &&
inRange(
{
id: instr,
},
identifier.scope.range,
);
return (
!mutableAtInstr || context.knownImmutableIdentifiers.has(identifier.id)
);
}
}
function collectNonNullsInBlocks(
fn: HIRFunction,
context: CollectHoistablePropertyLoadsContext,
): ReadonlyMap<BlockId, BlockInfo> {
const knownNonNullIdentifiers = new Set<PropertyPathNode>();
if (
fn.fnType === 'Component' &&
fn.params.length > 0 &&
fn.params[0].kind === 'Identifier'
) {
const identifier = fn.params[0].identifier;
knownNonNullIdentifiers.add(
context.registry.getOrCreateIdentifier(identifier, true),
);
}
const nodes = new Map<
BlockId,
{
block: BasicBlock;
assumedNonNullObjects: Set<PropertyPathNode>;
}
>();
for (const [_, block] of fn.body.blocks) {
const assumedNonNullObjects = new Set<PropertyPathNode>(
knownNonNullIdentifiers,
);
const maybeOptionalChain = context.hoistableFromOptionals.get(block.id);
if (maybeOptionalChain != null) {
assumedNonNullObjects.add(
context.registry.getOrCreateProperty(maybeOptionalChain),
);
}
for (const instr of block.instructions) {
const maybeNonNull = getMaybeNonNullInInstruction(instr.value, context);
if (
maybeNonNull != null &&
isImmutableAtInstr(maybeNonNull.fullPath.identifier, instr.id, context)
) {
assumedNonNullObjects.add(maybeNonNull);
}
if (instr.value.kind === 'FunctionExpression') {
const innerFn = instr.value.loweredFunc;
if (context.assumedInvokedFns.has(innerFn)) {
const innerHoistableMap = collectHoistablePropertyLoadsImpl(
innerFn.func,
{
...context,
nestedFnImmutableContext:
context.nestedFnImmutableContext ??
new Set(
innerFn.func.context
.filter(place =>
isImmutableAtInstr(place.identifier, instr.id, context),
)
.map(place => place.identifier.id),
),
},
);
const innerHoistables = assertNonNull(
innerHoistableMap.get(innerFn.func.body.entry),
);
for (const entry of innerHoistables.assumedNonNullObjects) {
assumedNonNullObjects.add(entry);
}
}
} else if (
fn.env.config.enablePreserveExistingMemoizationGuarantees &&
instr.value.kind === 'StartMemoize' &&
instr.value.deps != null
) {
for (const dep of instr.value.deps) {
if (dep.root.kind === 'NamedLocal') {
if (
!isImmutableAtInstr(dep.root.value.identifier, instr.id, context)
) {
continue;
}
for (let i = 0; i < dep.path.length; i++) {
const pathEntry = dep.path[i]!;
if (pathEntry.optional) {
break;
}
const depNode = context.registry.getOrCreateProperty({
identifier: dep.root.value.identifier,
path: dep.path.slice(0, i),
reactive: dep.root.value.reactive,
});
assumedNonNullObjects.add(depNode);
}
}
}
}
}
nodes.set(block.id, {
block,
assumedNonNullObjects,
});
}
return nodes;
}
function propagateNonNull(
fn: HIRFunction,
nodes: ReadonlyMap<BlockId, BlockInfo>,
registry: PropertyPathRegistry,
): void {
const blockSuccessors = new Map<BlockId, Set<BlockId>>();
const terminalPreds = new Set<BlockId>();
for (const [blockId, block] of fn.body.blocks) {
for (const pred of block.preds) {
getOrInsertDefault(blockSuccessors, pred, new Set()).add(blockId);
}
if (block.terminal.kind === 'throw' || block.terminal.kind === 'return') {
terminalPreds.add(blockId);
}
}
function recursivelyPropagateNonNull(
nodeId: BlockId,
direction: 'forward' | 'backward',
traversalState: Map<BlockId, 'active' | 'done'>,
): boolean {
if (traversalState.has(nodeId)) {
return false;
}
traversalState.set(nodeId, 'active');
const node = nodes.get(nodeId);
if (node == null) {
CompilerError.invariant(false, {
reason: `Bad node ${nodeId}, kind: ${direction}`,
loc: GeneratedSource,
});
}
const neighbors = Array.from(
direction === 'backward'
? (blockSuccessors.get(nodeId) ?? [])
: node.block.preds,
);
let changed = false;
for (const pred of neighbors) {
if (!traversalState.has(pred)) {
const neighborChanged = recursivelyPropagateNonNull(
pred,
direction,
traversalState,
);
changed ||= neighborChanged;
}
}
const neighborAccesses = Set_intersect(
Array.from(neighbors)
.filter(n => traversalState.get(n) === 'done')
.map(n => assertNonNull(nodes.get(n)).assumedNonNullObjects),
);
const prevObjects = assertNonNull(nodes.get(nodeId)).assumedNonNullObjects;
const mergedObjects = Set_union(prevObjects, neighborAccesses);
reduceMaybeOptionalChains(mergedObjects, registry);
assertNonNull(nodes.get(nodeId)).assumedNonNullObjects = mergedObjects;
traversalState.set(nodeId, 'done');
changed ||= !Set_equal(prevObjects, mergedObjects);
return changed;
}
const traversalState = new Map<BlockId, 'done' | 'active'>();
const reversedBlocks = [...fn.body.blocks];
reversedBlocks.reverse();
let changed;
let i = 0;
do {
CompilerError.invariant(i++ < 100, {
reason:
'[CollectHoistablePropertyLoads] fixed point iteration did not terminate after 100 loops',
loc: GeneratedSource,
});
changed = false;
for (const [blockId] of fn.body.blocks) {
const forwardChanged = recursivelyPropagateNonNull(
blockId,
'forward',
traversalState,
);
changed ||= forwardChanged;
}
traversalState.clear();
for (const [blockId] of reversedBlocks) {
const backwardChanged = recursivelyPropagateNonNull(
blockId,
'backward',
traversalState,
);
changed ||= backwardChanged;
}
traversalState.clear();
} while (changed);
}
export function assertNonNull<T extends NonNullable<U>, U>(
value: T | null | undefined,
source?: string,
): T {
CompilerError.invariant(value != null, {
reason: 'Unexpected null',
description: source != null ? `(from ${source})` : null,
loc: GeneratedSource,
});
return value;
}
function reduceMaybeOptionalChains(
nodes: Set<PropertyPathNode>,
registry: PropertyPathRegistry,
): void {
let optionalChainNodes = Set_filter(nodes, n => n.hasOptional);
if (optionalChainNodes.size === 0) {
return;
}
let changed: boolean;
do {
changed = false;
for (const original of optionalChainNodes) {
let {identifier, path: origPath, reactive} = original.fullPath;
let currNode: PropertyPathNode = registry.getOrCreateIdentifier(
identifier,
reactive,
);
for (let i = 0; i < origPath.length; i++) {
const entry = origPath[i];
const nextEntry: DependencyPathEntry =
entry.optional && nodes.has(currNode)
? {property: entry.property, optional: false}
: entry;
currNode = PropertyPathRegistry.getOrCreatePropertyEntry(
currNode,
nextEntry,
);
}
if (currNode !== original) {
changed = true;
optionalChainNodes.delete(original);
optionalChainNodes.add(currNode);
nodes.delete(original);
nodes.add(currNode);
}
}
} while (changed);
}
function getAssumedInvokedFunctions(
fn: HIRFunction,
temporaries: Map<
IdentifierId,
{fn: LoweredFunction; mayInvoke: Set<LoweredFunction>}
> = new Map(),
): ReadonlySet<LoweredFunction> {
const hoistableFunctions = new Set<LoweredFunction>();
for (const block of fn.body.blocks.values()) {
for (const {lvalue, value} of block.instructions) {
if (value.kind === 'FunctionExpression') {
temporaries.set(lvalue.identifier.id, {
fn: value.loweredFunc,
mayInvoke: new Set(),
});
} else if (value.kind === 'StoreLocal') {
const lvalue = value.lvalue.place.identifier;
const maybeLoweredFunc = temporaries.get(value.value.identifier.id);
if (maybeLoweredFunc != null) {
temporaries.set(lvalue.id, maybeLoweredFunc);
}
} else if (value.kind === 'LoadLocal') {
const maybeLoweredFunc = temporaries.get(value.place.identifier.id);
if (maybeLoweredFunc != null) {
temporaries.set(lvalue.identifier.id, maybeLoweredFunc);
}
}
}
}
for (const block of fn.body.blocks.values()) {
for (const {lvalue, value} of block.instructions) {
if (value.kind === 'CallExpression') {
const callee = value.callee;
const maybeHook = getHookKind(fn.env, callee.identifier);
const maybeLoweredFunc = temporaries.get(callee.identifier.id);
if (maybeLoweredFunc != null) {
hoistableFunctions.add(maybeLoweredFunc.fn);
} else if (maybeHook != null) {
for (const arg of value.args) {
if (arg.kind === 'Identifier') {
const maybeLoweredFunc = temporaries.get(arg.identifier.id);
if (maybeLoweredFunc != null) {
hoistableFunctions.add(maybeLoweredFunc.fn);
}
}
}
}
} else if (value.kind === 'JsxExpression') {
for (const attr of value.props) {
if (attr.kind === 'JsxSpreadAttribute') {
continue;
}
const maybeLoweredFunc = temporaries.get(attr.place.identifier.id);
if (maybeLoweredFunc != null) {
hoistableFunctions.add(maybeLoweredFunc.fn);
}
}
for (const child of value.children ?? []) {
const maybeLoweredFunc = temporaries.get(child.identifier.id);
if (maybeLoweredFunc != null) {
hoistableFunctions.add(maybeLoweredFunc.fn);
}
}
} else if (value.kind === 'FunctionExpression') {
const loweredFunc = value.loweredFunc.func;
const lambdasCalled = getAssumedInvokedFunctions(
loweredFunc,
temporaries,
);
const maybeLoweredFunc = temporaries.get(lvalue.identifier.id);
if (maybeLoweredFunc != null) {
for (const called of lambdasCalled) {
maybeLoweredFunc.mayInvoke.add(called);
}
}
}
}
if (block.terminal.kind === 'return') {
const maybeLoweredFunc = temporaries.get(
block.terminal.value.identifier.id,
);
if (maybeLoweredFunc != null) {
hoistableFunctions.add(maybeLoweredFunc.fn);
}
}
}
for (const [_, {fn, mayInvoke}] of temporaries) {
if (hoistableFunctions.has(fn)) {
for (const called of mayInvoke) {
hoistableFunctions.add(called);
}
}
}
return hoistableFunctions;
}