import {
HIRFunction,
InstructionId,
Place,
ReactiveScope,
makeInstructionId,
} from '.';
import {getPlaceScope} from '../HIR/HIR';
import {isMutable} from '../ReactiveScopes/InferReactiveScopeVariables';
import DisjointSet from '../Utils/DisjointSet';
import {getOrInsertDefault} from '../Utils/utils';
import {
eachInstructionLValue,
eachInstructionOperand,
eachTerminalOperand,
} from './visitors';
export function mergeOverlappingReactiveScopesHIR(fn: HIRFunction): void {
const scopesInfo = collectScopeInfo(fn);
const joinedScopes = getOverlappingReactiveScopes(fn, scopesInfo);
joinedScopes.forEach((scope, groupScope) => {
if (scope !== groupScope) {
groupScope.range.start = makeInstructionId(
Math.min(groupScope.range.start, scope.range.start),
);
groupScope.range.end = makeInstructionId(
Math.max(groupScope.range.end, scope.range.end),
);
}
});
for (const [place, originalScope] of scopesInfo.placeScopes) {
const nextScope = joinedScopes.find(originalScope);
if (nextScope !== null && nextScope !== originalScope) {
place.identifier.scope = nextScope;
}
}
}
type ScopeInfo = {
scopeStarts: Array<{id: InstructionId; scopes: Set<ReactiveScope>}>;
scopeEnds: Array<{id: InstructionId; scopes: Set<ReactiveScope>}>;
placeScopes: Map<Place, ReactiveScope>;
};
type TraversalState = {
joined: DisjointSet<ReactiveScope>;
activeScopes: Array<ReactiveScope>;
};
function collectScopeInfo(fn: HIRFunction): ScopeInfo {
const scopeStarts: Map<InstructionId, Set<ReactiveScope>> = new Map();
const scopeEnds: Map<InstructionId, Set<ReactiveScope>> = new Map();
const placeScopes: Map<Place, ReactiveScope> = new Map();
function collectPlaceScope(place: Place): void {
const scope = place.identifier.scope;
if (scope != null) {
placeScopes.set(place, scope);
if (scope.range.start !== scope.range.end) {
getOrInsertDefault(scopeStarts, scope.range.start, new Set()).add(
scope,
);
getOrInsertDefault(scopeEnds, scope.range.end, new Set()).add(scope);
}
}
}
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
for (const operand of eachInstructionLValue(instr)) {
collectPlaceScope(operand);
}
for (const operand of eachInstructionOperand(instr)) {
collectPlaceScope(operand);
}
}
for (const operand of eachTerminalOperand(block.terminal)) {
collectPlaceScope(operand);
}
}
return {
scopeStarts: [...scopeStarts.entries()]
.map(([id, scopes]) => ({id, scopes}))
.sort((a, b) => b.id - a.id),
scopeEnds: [...scopeEnds.entries()]
.map(([id, scopes]) => ({id, scopes}))
.sort((a, b) => b.id - a.id),
placeScopes,
};
}
function visitInstructionId(
id: InstructionId,
{scopeEnds, scopeStarts}: ScopeInfo,
{activeScopes, joined}: TraversalState,
): void {
const scopeEndTop = scopeEnds.at(-1);
if (scopeEndTop != null && scopeEndTop.id <= id) {
scopeEnds.pop();
const scopesSortedStartDescending = [...scopeEndTop.scopes].sort(
(a, b) => b.range.start - a.range.start,
);
for (const scope of scopesSortedStartDescending) {
const idx = activeScopes.indexOf(scope);
if (idx !== -1) {
if (idx !== activeScopes.length - 1) {
joined.union([scope, ...activeScopes.slice(idx + 1)]);
}
activeScopes.splice(idx, 1);
}
}
}
const scopeStartTop = scopeStarts.at(-1);
if (scopeStartTop != null && scopeStartTop.id <= id) {
scopeStarts.pop();
const scopesSortedEndDescending = [...scopeStartTop.scopes].sort(
(a, b) => b.range.end - a.range.end,
);
activeScopes.push(...scopesSortedEndDescending);
for (let i = 1; i < scopesSortedEndDescending.length; i++) {
const prev = scopesSortedEndDescending[i - 1];
const curr = scopesSortedEndDescending[i];
if (prev.range.end === curr.range.end) {
joined.union([prev, curr]);
}
}
}
}
function visitPlace(
id: InstructionId,
place: Place,
{activeScopes, joined}: TraversalState,
): void {
const placeScope = getPlaceScope(id, place);
if (placeScope != null && isMutable({id} as any, place)) {
const placeScopeIdx = activeScopes.indexOf(placeScope);
if (placeScopeIdx !== -1 && placeScopeIdx !== activeScopes.length - 1) {
joined.union([placeScope, ...activeScopes.slice(placeScopeIdx + 1)]);
}
}
}
function getOverlappingReactiveScopes(
fn: HIRFunction,
context: ScopeInfo,
): DisjointSet<ReactiveScope> {
const state: TraversalState = {
joined: new DisjointSet<ReactiveScope>(),
activeScopes: [],
};
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
visitInstructionId(instr.id, context, state);
for (const place of eachInstructionOperand(instr)) {
visitPlace(instr.id, place, state);
}
for (const place of eachInstructionLValue(instr)) {
visitPlace(instr.id, place, state);
}
}
visitInstructionId(block.terminal.id, context, state);
for (const place of eachTerminalOperand(block.terminal)) {
visitPlace(block.terminal.id, place, state);
}
}
return state.joined;
}