import { CompilerError } from "..";
import { BlockId, GotoVariant, HIRFunction } from "./HIR";

export function pruneUnusedLabelsHIR(fn: HIRFunction): void {
  const merged: Array<{
    label: BlockId;
    next: BlockId;
    fallthrough: BlockId;
  }> = [];
  const rewrites: Map<BlockId, BlockId> = new Map();
  for (const [blockId, block] of fn.body.blocks) {
    const terminal = block.terminal;
    if (terminal.kind === "label") {
      const { block: nextId, fallthrough: fallthroughId } = terminal;
      const next = fn.body.blocks.get(nextId)!;
      const fallthrough = fn.body.blocks.get(fallthroughId)!;
      if (
        next.terminal.kind === "goto" &&
        next.terminal.variant === GotoVariant.Break &&
        next.terminal.block === fallthroughId
      ) {
        if (next.kind === "block" && fallthrough.kind === "block") {
          // Only merge normal block types
          merged.push({
            label: blockId,
            next: nextId,
            fallthrough: fallthroughId,
          });
        }
      }
    }
  }

  for (const {
    label: originalLabelId,
    next: nextId,
    fallthrough: fallthroughId,
  } of merged) {
    const labelId = rewrites.get(originalLabelId) ?? originalLabelId;
    const label = fn.body.blocks.get(labelId)!;
    const next = fn.body.blocks.get(nextId)!;
    const fallthrough = fn.body.blocks.get(fallthroughId)!;

    // Merge block and fallthrough
    CompilerError.invariant(
      next.phis.size === 0 && fallthrough.phis.size === 0,
      {
        reason: "Unexpected phis when merging label blocks",
        loc: label.terminal.loc,
      }
    );

    CompilerError.invariant(
      next.preds.size === 1 &&
        fallthrough.preds.size === 1 &&
        next.preds.has(originalLabelId) &&
        fallthrough.preds.has(nextId),
      {
        reason: "Unexpected block predecessors when merging label blocks",
        loc: label.terminal.loc,
      }
    );

    label.instructions.push(...next.instructions, ...fallthrough.instructions);
    label.terminal = fallthrough.terminal;
    fn.body.blocks.delete(nextId);
    fn.body.blocks.delete(fallthroughId);
    rewrites.set(fallthroughId, labelId);
  }
}