import {CompilerError} from '..';
import {
BlockId,
Effect,
HIRFunction,
Identifier,
IdentifierId,
Place,
computePostDominatorTree,
getHookKind,
isStableType,
isUseOperator,
} from '../HIR';
import {PostDominator} from '../HIR/Dominator';
import {
eachInstructionLValue,
eachInstructionValueOperand,
eachTerminalOperand,
} from '../HIR/visitors';
import {
findDisjointMutableValues,
isMutable,
} from '../ReactiveScopes/InferReactiveScopeVariables';
import DisjointSet from '../Utils/DisjointSet';
import {assertExhaustive} from '../Utils/utils';
export function inferReactivePlaces(fn: HIRFunction): void {
const reactiveIdentifiers = new ReactivityMap(findDisjointMutableValues(fn));
for (const param of fn.params) {
const place = param.kind === 'Identifier' ? param : param.place;
reactiveIdentifiers.markReactive(place);
}
const postDominators = computePostDominatorTree(fn, {
includeThrowsAsExitNode: false,
});
const postDominatorFrontierCache = new Map<BlockId, Set<BlockId>>();
function isReactiveControlledBlock(id: BlockId): boolean {
let controlBlocks = postDominatorFrontierCache.get(id);
if (controlBlocks === undefined) {
controlBlocks = postDominatorFrontier(fn, postDominators, id);
postDominatorFrontierCache.set(id, controlBlocks);
}
for (const blockId of controlBlocks) {
const controlBlock = fn.body.blocks.get(blockId)!;
switch (controlBlock.terminal.kind) {
case 'if':
case 'branch': {
if (reactiveIdentifiers.isReactive(controlBlock.terminal.test)) {
return true;
}
break;
}
case 'switch': {
if (reactiveIdentifiers.isReactive(controlBlock.terminal.test)) {
return true;
}
for (const case_ of controlBlock.terminal.cases) {
if (
case_.test !== null &&
reactiveIdentifiers.isReactive(case_.test)
) {
return true;
}
}
break;
}
}
}
return false;
}
do {
const identifierMapping = new Map<Identifier, Identifier>();
for (const [, block] of fn.body.blocks) {
let hasReactiveControl = isReactiveControlledBlock(block.id);
for (const phi of block.phis) {
if (reactiveIdentifiers.isReactiveIdentifier(phi.id)) {
continue;
}
let isPhiReactive = false;
for (const [, operand] of phi.operands) {
if (reactiveIdentifiers.isReactiveIdentifier(operand)) {
isPhiReactive = true;
break;
}
}
if (isPhiReactive) {
reactiveIdentifiers.markReactiveIdentifier(phi.id);
} else {
for (const [pred] of phi.operands) {
if (isReactiveControlledBlock(pred)) {
reactiveIdentifiers.markReactiveIdentifier(phi.id);
break;
}
}
}
}
for (const instruction of block.instructions) {
const {value} = instruction;
let hasReactiveInput = false;
for (const operand of eachInstructionValueOperand(value)) {
const reactive = reactiveIdentifiers.isReactive(operand);
hasReactiveInput ||= reactive;
}
if (
value.kind === 'CallExpression' &&
(getHookKind(fn.env, value.callee.identifier) != null ||
isUseOperator(value.callee.identifier))
) {
hasReactiveInput = true;
} else if (
value.kind === 'MethodCall' &&
(getHookKind(fn.env, value.property.identifier) != null ||
isUseOperator(value.property.identifier))
) {
hasReactiveInput = true;
}
if (hasReactiveInput) {
for (const lvalue of eachInstructionLValue(instruction)) {
if (isStableType(lvalue.identifier)) {
continue;
}
reactiveIdentifiers.markReactive(lvalue);
}
}
if (hasReactiveInput || hasReactiveControl) {
for (const operand of eachInstructionValueOperand(value)) {
switch (operand.effect) {
case Effect.Capture:
case Effect.Store:
case Effect.ConditionallyMutate:
case Effect.Mutate: {
if (isMutable(instruction, operand)) {
const resolvedId = identifierMapping.get(operand.identifier);
if (resolvedId !== undefined) {
reactiveIdentifiers.markReactiveIdentifier(resolvedId);
}
reactiveIdentifiers.markReactive(operand);
}
break;
}
case Effect.Freeze:
case Effect.Read: {
break;
}
case Effect.Unknown: {
CompilerError.invariant(false, {
reason: 'Unexpected unknown effect',
description: null,
loc: operand.loc,
suggestions: null,
});
}
default: {
assertExhaustive(
operand.effect,
`Unexpected effect kind \`${operand.effect}\``,
);
}
}
}
}
switch (value.kind) {
case 'LoadLocal': {
identifierMapping.set(
instruction.lvalue.identifier,
value.place.identifier,
);
break;
}
case 'PropertyLoad':
case 'ComputedLoad': {
const resolvedId =
identifierMapping.get(value.object.identifier) ??
value.object.identifier;
identifierMapping.set(instruction.lvalue.identifier, resolvedId);
break;
}
case 'LoadContext': {
identifierMapping.set(
instruction.lvalue.identifier,
value.place.identifier,
);
break;
}
}
}
for (const operand of eachTerminalOperand(block.terminal)) {
reactiveIdentifiers.isReactive(operand);
}
}
} while (reactiveIdentifiers.snapshot());
}
function postDominatorFrontier(
fn: HIRFunction,
postDominators: PostDominator<BlockId>,
targetId: BlockId,
): Set<BlockId> {
const visited = new Set<BlockId>();
const frontier = new Set<BlockId>();
const targetPostDominators = postDominatorsOf(fn, postDominators, targetId);
for (const blockId of [...targetPostDominators, targetId]) {
if (visited.has(blockId)) {
continue;
}
visited.add(blockId);
const block = fn.body.blocks.get(blockId)!;
for (const pred of block.preds) {
if (!targetPostDominators.has(pred)) {
frontier.add(pred);
}
}
}
return frontier;
}
function postDominatorsOf(
fn: HIRFunction,
postDominators: PostDominator<BlockId>,
targetId: BlockId,
): Set<BlockId> {
const result = new Set<BlockId>();
const visited = new Set<BlockId>();
const queue = [targetId];
while (queue.length) {
const currentId = queue.shift()!;
if (visited.has(currentId)) {
continue;
}
visited.add(currentId);
const current = fn.body.blocks.get(currentId)!;
for (const pred of current.preds) {
const predPostDominator = postDominators.get(pred) ?? pred;
if (predPostDominator === targetId || result.has(predPostDominator)) {
result.add(pred);
}
queue.push(pred);
}
}
return result;
}
class ReactivityMap {
hasChanges: boolean = false;
reactive: Set<IdentifierId> = new Set();
aliasedIdentifiers: DisjointSet<Identifier>;
constructor(aliasedIdentifiers: DisjointSet<Identifier>) {
this.aliasedIdentifiers = aliasedIdentifiers;
}
isReactive(place: Place): boolean {
const reactive = this.isReactiveIdentifier(place.identifier);
if (reactive) {
place.reactive = true;
}
return reactive;
}
isReactiveIdentifier(inputIdentifier: Identifier): boolean {
const identifier =
this.aliasedIdentifiers.find(inputIdentifier) ?? inputIdentifier;
return this.reactive.has(identifier.id);
}
markReactive(place: Place): void {
place.reactive = true;
this.markReactiveIdentifier(place.identifier);
}
markReactiveIdentifier(inputIdentifier: Identifier): void {
const identifier =
this.aliasedIdentifiers.find(inputIdentifier) ?? inputIdentifier;
if (!this.reactive.has(identifier.id)) {
this.hasChanges = true;
this.reactive.add(identifier.id);
}
}
snapshot(): boolean {
const hasChanges = this.hasChanges;
this.hasChanges = false;
return hasChanges;
}
}