import {
HIRFunction,
IdentifierId,
ReactiveScope,
makeInstructionId,
} from '../HIR';
import DisjointSet from '../Utils/DisjointSet';
export function alignMethodCallScopes(fn: HIRFunction): void {
const scopeMapping = new Map<IdentifierId, ReactiveScope | null>();
const mergedScopes = new DisjointSet<ReactiveScope>();
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
const {lvalue, value} = instr;
if (value.kind === 'MethodCall') {
const lvalueScope = lvalue.identifier.scope;
const propertyScope = value.property.identifier.scope;
if (lvalueScope !== null) {
if (propertyScope !== null) {
mergedScopes.union([lvalueScope, propertyScope]);
} else {
scopeMapping.set(value.property.identifier.id, lvalueScope);
}
} else if (propertyScope !== null) {
scopeMapping.set(value.property.identifier.id, null);
}
} else if (
value.kind === 'FunctionExpression' ||
value.kind === 'ObjectMethod'
) {
alignMethodCallScopes(value.loweredFunc.func);
}
}
}
mergedScopes.forEach((scope, root) => {
if (scope === root) {
return;
}
root.range.start = makeInstructionId(
Math.min(scope.range.start, root.range.start),
);
root.range.end = makeInstructionId(
Math.max(scope.range.end, root.range.end),
);
});
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
const mappedScope = scopeMapping.get(instr.lvalue.identifier.id);
if (mappedScope !== undefined) {
instr.lvalue.identifier.scope = mappedScope;
} else if (instr.lvalue.identifier.scope !== null) {
const mergedScope = mergedScopes.find(instr.lvalue.identifier.scope);
if (mergedScope != null) {
instr.lvalue.identifier.scope = mergedScope;
}
}
}
}
}