import prettyFormat from 'pretty-format';
import {CompilerError} from '../CompilerError';
import {BlockId, HIRFunction} from './HIR';
import {eachTerminalSuccessor} from './visitors';
export function computeDominatorTree(fn: HIRFunction): Dominator<BlockId> {
const graph = buildGraph(fn);
const nodes = computeImmediateDominators(graph);
return new Dominator(graph.entry, nodes);
}
export function computePostDominatorTree(
fn: HIRFunction,
options: {includeThrowsAsExitNode: boolean},
): PostDominator<BlockId> {
const graph = buildReverseGraph(fn, options.includeThrowsAsExitNode);
const nodes = computeImmediateDominators(graph);
if (!options.includeThrowsAsExitNode) {
for (const [id] of fn.body.blocks) {
if (!nodes.has(id)) {
nodes.set(id, id);
}
}
}
return new PostDominator(graph.entry, nodes);
}
type Node<T> = {
id: T;
index: number;
preds: Set<T>;
succs: Set<T>;
};
type Graph<T> = {
entry: T;
nodes: Map<T, Node<T>>;
};
export class Dominator<T> {
#entry: T;
#nodes: Map<T, T>;
constructor(entry: T, nodes: Map<T, T>) {
this.#entry = entry;
this.#nodes = nodes;
}
get entry(): T {
return this.#entry;
}
get(id: T): T | null {
const dominator = this.#nodes.get(id);
CompilerError.invariant(dominator !== undefined, {
reason: 'Unknown node',
description: null,
loc: null,
suggestions: null,
});
return dominator === id ? null : dominator;
}
debug(): string {
const dominators = new Map();
for (const [key, value] of this.#nodes) {
dominators.set(`bb${key}`, `bb${value}`);
}
return prettyFormat({
entry: `bb${this.#entry}`,
dominators,
});
}
}
export class PostDominator<T> {
#exit: T;
#nodes: Map<T, T>;
constructor(exit: T, nodes: Map<T, T>) {
this.#exit = exit;
this.#nodes = nodes;
}
get exit(): T {
return this.#exit;
}
get(id: T): T | null {
const dominator = this.#nodes.get(id);
CompilerError.invariant(dominator !== undefined, {
reason: 'Unknown node',
description: null,
loc: null,
suggestions: null,
});
return dominator === id ? null : dominator;
}
debug(): string {
const postDominators = new Map();
for (const [key, value] of this.#nodes) {
postDominators.set(`bb${key}`, `bb${value}`);
}
return prettyFormat({
exit: `bb${this.exit}`,
postDominators,
});
}
}
function computeImmediateDominators<T>(graph: Graph<T>): Map<T, T> {
const nodes: Map<T, T> = new Map();
nodes.set(graph.entry, graph.entry);
let changed = true;
while (changed) {
changed = false;
for (const [id, node] of graph.nodes) {
if (node.id === graph.entry) {
continue;
}
let newIdom: T | null = null;
for (const pred of node.preds) {
if (nodes.has(pred)) {
newIdom = pred;
break;
}
}
CompilerError.invariant(newIdom !== null, {
reason: `At least one predecessor must have been visited for block ${id}`,
description: null,
loc: null,
suggestions: null,
});
for (const pred of node.preds) {
if (pred === newIdom) {
continue;
}
const predDom = nodes.get(pred);
if (predDom !== undefined) {
newIdom = intersect(pred, newIdom, graph, nodes);
}
}
if (nodes.get(id) !== newIdom) {
nodes.set(id, newIdom);
changed = true;
}
}
}
return nodes;
}
function intersect<T>(a: T, b: T, graph: Graph<T>, nodes: Map<T, T>): T {
let block1: Node<T> = graph.nodes.get(a)!;
let block2: Node<T> = graph.nodes.get(b)!;
while (block1 !== block2) {
while (block1.index > block2.index) {
const dom = nodes.get(block1.id)!;
block1 = graph.nodes.get(dom)!;
}
while (block2.index > block1.index) {
const dom = nodes.get(block2.id)!;
block2 = graph.nodes.get(dom)!;
}
}
return block1.id;
}
function buildGraph(fn: HIRFunction): Graph<BlockId> {
const graph: Graph<BlockId> = {entry: fn.body.entry, nodes: new Map()};
let index = 0;
for (const [id, block] of fn.body.blocks) {
graph.nodes.set(id, {
id,
index: index++,
preds: block.preds,
succs: new Set(eachTerminalSuccessor(block.terminal)),
});
}
return graph;
}
function buildReverseGraph(
fn: HIRFunction,
includeThrowsAsExitNode: boolean,
): Graph<BlockId> {
const nodes: Map<BlockId, Node<BlockId>> = new Map();
const exitId = fn.env.nextBlockId;
const exit: Node<BlockId> = {
id: exitId,
index: 0,
preds: new Set(),
succs: new Set(),
};
nodes.set(exitId, exit);
for (const [id, block] of fn.body.blocks) {
const node: Node<BlockId> = {
id,
index: 0,
preds: new Set(eachTerminalSuccessor(block.terminal)),
succs: new Set(block.preds),
};
if (block.terminal.kind === 'return') {
node.preds.add(exitId);
exit.succs.add(id);
} else if (block.terminal.kind === 'throw' && includeThrowsAsExitNode) {
node.preds.add(exitId);
exit.succs.add(id);
}
nodes.set(id, node);
}
const visited = new Set<BlockId>();
const postorder: Array<BlockId> = [];
function visit(id: BlockId): void {
if (visited.has(id)) {
return;
}
visited.add(id);
const node = nodes.get(id)!;
for (const successor of node.succs) {
visit(successor);
}
postorder.push(id);
}
visit(exitId);
const rpo: Graph<BlockId> = {entry: exitId, nodes: new Map()};
let index = 0;
for (const id of postorder.reverse()) {
const node = nodes.get(id)!;
node.index = index++;
rpo.nodes.set(id, node);
}
return rpo;
}