import {BlockId, computePostDominatorTree, HIRFunction, Place} from '../HIR';
import {PostDominator} from '../HIR/Dominator';
export type ControlDominators = (id: BlockId) => boolean;
export function createControlDominators(
fn: HIRFunction,
isControlVariable: (place: Place) => boolean,
): ControlDominators {
const postDominators = computePostDominatorTree(fn, {
includeThrowsAsExitNode: false,
});
const postDominatorFrontierCache = new Map<BlockId, Set<BlockId>>();
function isControlledBlock(id: BlockId): boolean {
let controlBlocks = postDominatorFrontierCache.get(id);
if (controlBlocks === undefined) {
controlBlocks = postDominatorFrontier(fn, postDominators, id);
postDominatorFrontierCache.set(id, controlBlocks);
}
for (const blockId of controlBlocks) {
const controlBlock = fn.body.blocks.get(blockId)!;
switch (controlBlock.terminal.kind) {
case 'if':
case 'branch': {
if (isControlVariable(controlBlock.terminal.test)) {
return true;
}
break;
}
case 'switch': {
if (isControlVariable(controlBlock.terminal.test)) {
return true;
}
for (const case_ of controlBlock.terminal.cases) {
if (case_.test !== null && isControlVariable(case_.test)) {
return true;
}
}
break;
}
}
}
return false;
}
return isControlledBlock;
}
function postDominatorFrontier(
fn: HIRFunction,
postDominators: PostDominator<BlockId>,
targetId: BlockId,
): Set<BlockId> {
const visited = new Set<BlockId>();
const frontier = new Set<BlockId>();
const targetPostDominators = postDominatorsOf(fn, postDominators, targetId);
for (const blockId of [...targetPostDominators, targetId]) {
if (visited.has(blockId)) {
continue;
}
visited.add(blockId);
const block = fn.body.blocks.get(blockId)!;
for (const pred of block.preds) {
if (!targetPostDominators.has(pred)) {
frontier.add(pred);
}
}
}
return frontier;
}
function postDominatorsOf(
fn: HIRFunction,
postDominators: PostDominator<BlockId>,
targetId: BlockId,
): Set<BlockId> {
const result = new Set<BlockId>();
const visited = new Set<BlockId>();
const queue = [targetId];
while (queue.length) {
const currentId = queue.shift()!;
if (visited.has(currentId)) {
continue;
}
visited.add(currentId);
const current = fn.body.blocks.get(currentId)!;
for (const pred of current.preds) {
const predPostDominator = postDominators.get(pred) ?? pred;
if (predPostDominator === targetId || result.has(predPostDominator)) {
result.add(pred);
}
queue.push(pred);
}
}
return result;
}