/**
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

import {CompilerError} from '../CompilerError';
import {BlockId, HIRFunction, Identifier, Place} from '../HIR/HIR';
import {
  eachInstructionLValue,
  eachInstructionOperand,
  eachTerminalOperand,
} from '../HIR/visitors';

const DEBUG = false;

/*
 * Pass to eliminate redundant phi nodes:
 * - all operands are the same identifier, ie `x2 = phi(x1, x1, x1)`.
 * - all operands are the same identifier *or* the output of the phi, ie `x2 = phi(x1, x2, x1, x2)`.
 *
 * In both these cases, the phi is eliminated and all usages of the phi identifier
 * are replaced with the other operand (ie in both cases above, all usages of `x2` are replaced with `x1` .
 *
 * The algorithm is inspired by that in https://pp.ipd.kit.edu/uploads/publikationen/braun13cc.pdf
 * but modified to reduce passes over the CFG. We visit the blocks in reverse postorder. Each time a redundant
 * phi is encountered we add a mapping (eg x2 -> x1) to a rewrite table. Subsequent instructions, terminals,
 * and phis rewrite all their identifiers based on this table. The algorithm loops over the CFG repeatedly
 * until there are no new rewrites: for a CFG without back-edges it completes in a single pass.
 */
export function eliminateRedundantPhi(
  fn: HIRFunction,
  sharedRewrites?: Map<Identifier, Identifier>,
): void {
  const ir = fn.body;
  const rewrites: Map<Identifier, Identifier> =
    sharedRewrites != null ? sharedRewrites : new Map();

  /*
   * Whether or the CFG has a back-edge (a loop). We determine this dynamically
   * during the first iteration over the CFG by recording which blocks were already
   * visited, and checking if a block has any predecessors that weren't visited yet.
   * Because blocks are in reverse postorder, the only time this can occur is a loop.
   */
  let hasBackEdge = false;
  const visited: Set<BlockId> = new Set();

  /*
   * size tracks the number of rewrites at the beginning of each iteration, so we can
   * compare to see if any new rewrites were added in that iteration.
   */
  let size = rewrites.size;
  do {
    size = rewrites.size;
    for (const [blockId, block] of ir.blocks) {
      /*
       * On the first iteration of the loop check for any back-edges.
       * if there aren't any then there won't be a second iteration
       */
      if (!hasBackEdge) {
        for (const predId of block.preds) {
          if (!visited.has(predId)) {
            hasBackEdge = true;
          }
        }
      }
      visited.add(blockId);

      // Find any redundant phis
      phis: for (const phi of block.phis) {
        // Remap phis in case operands are from eliminated phis
        phi.operands.forEach((place, _) => rewritePlace(place, rewrites));
        // Find if the phi can be eliminated
        let same: Identifier | null = null;
        for (const [_, operand] of phi.operands) {
          if (
            (same !== null && operand.identifier.id === same.id) ||
            operand.identifier.id === phi.place.identifier.id
          ) {
            /*
             * This operand is the same as the phi or is the same as the
             * previous non-phi operands
             */
            continue;
          } else if (same !== null) {
            /*
             * There are multiple operands not equal to the phi itself,
             * this phi can't be eliminated.
             */
            continue phis;
          } else {
            // First non-phi operand
            same = operand.identifier;
          }
        }
        CompilerError.invariant(same !== null, {
          reason: 'Expected phis to be non-empty',
          description: null,
          loc: null,
          suggestions: null,
        });
        rewrites.set(phi.place.identifier, same);
        block.phis.delete(phi);
      }

      // Rewrite all instruction lvalues and operands
      for (const instr of block.instructions) {
        for (const place of eachInstructionLValue(instr)) {
          rewritePlace(place, rewrites);
        }
        for (const place of eachInstructionOperand(instr)) {
          rewritePlace(place, rewrites);
        }

        if (
          instr.value.kind === 'FunctionExpression' ||
          instr.value.kind === 'ObjectMethod'
        ) {
          const {context} = instr.value.loweredFunc.func;
          for (const place of context) {
            rewritePlace(place, rewrites);
          }

          /*
           * recursive call to:
           * - eliminate phi nodes in child node
           * - propagate rewrites, which may have changed between iterations
           */
          eliminateRedundantPhi(instr.value.loweredFunc.func, rewrites);
        }
      }

      // Rewrite all terminal operands
      const {terminal} = block;
      for (const place of eachTerminalOperand(terminal)) {
        rewritePlace(place, rewrites);
      }
    }
    /*
     * We only need to loop if there were newly eliminated phis in this iteration
     * *and* the CFG has loops. If there are no loops, then all eliminated phis
     * have already propagated forwards since we visit in reverse postorder.
     */
  } while (rewrites.size > size && hasBackEdge);

  if (DEBUG) {
    for (const [, block] of ir.blocks) {
      for (const phi of block.phis) {
        CompilerError.invariant(!rewrites.has(phi.place.identifier), {
          reason: '[EliminateRedundantPhis]: rewrite not complete',
          loc: phi.place.loc,
        });
        for (const [, operand] of phi.operands) {
          CompilerError.invariant(!rewrites.has(operand.identifier), {
            reason: '[EliminateRedundantPhis]: rewrite not complete',
            loc: phi.place.loc,
          });
        }
      }
    }
  }
}

function rewritePlace(
  place: Place,
  rewrites: Map<Identifier, Identifier>,
): void {
  const rewrite = rewrites.get(place.identifier);
  if (rewrite != null) {
    place.identifier = rewrite;
  }
}