/**
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

import {CompilerError} from '../CompilerError';
import {
  Environment,
  Identifier,
  IdentifierId,
  InstructionId,
  Place,
  PropertyLiteral,
  ReactiveBlock,
  ReactiveFunction,
  ReactiveInstruction,
  ReactiveScopeBlock,
  ReactiveTerminalStatement,
  getHookKind,
  isUseRefType,
  isUseStateType,
} from '../HIR';
import {eachCallArgument, eachInstructionLValue} from '../HIR/visitors';
import DisjointSet from '../Utils/DisjointSet';
import {assertExhaustive} from '../Utils/utils';
import {ReactiveFunctionVisitor, visitReactiveFunction} from './visitors';

/**
 * This pass is built based on the observation by @jbrown215 that arguments
 * to useState and useRef are only used the first time a component is rendered.
 * Any subsequent times, the arguments will be evaluated but ignored. In this pass,
 * we use this fact to improve the output of the compiler by not recomputing values that
 * are only used as arguments (or inputs to arguments to) useState and useRef.
 *
 * This pass isn't yet stress-tested so it's not enabled by default. It's only enabled
 * to support certain debug modes that detect non-idempotent code, since non-idempotent
 * code can "safely" be used if its only passed to useState and useRef. We plan to rewrite
 * this pass in HIR and enable it as an optimization in the future.
 *
 * Algorithm:
 * We take two passes over the reactive function AST. In the first pass, we gather
 * aliases and build relationships between property accesses--the key thing we need
 * to do here is to find that, e.g., $0.x and $1 refer to the same value if
 * $1 = PropertyLoad $0.x.
 *
 * In the second pass, we traverse the AST in reverse order and track how each place
 * is used. If a place is read from in any Terminal, we mark the place as "Update", meaning
 * it is used whenever the component is updated/re-rendered. If a place is read from in
 * a useState or useRef hook call, we mark it as "Create", since it is only used when the
 * component is created. In other instructions, we propagate the inferred place for the
 * instructions lvalues onto any other instructions that are read.
 *
 * Whenever we finish this reverse pass over a reactive block, we can look at the blocks
 * dependencies and see whether the dependencies are used in an "Update" context or only
 * in a "Create" context. If a dependency is create-only, then we can remove that dependency
 * from the block.
 */

type CreateUpdate = 'Create' | 'Update' | 'Unknown';

type KindMap = Map<IdentifierId, CreateUpdate>;

class Visitor extends ReactiveFunctionVisitor<CreateUpdate> {
  map: KindMap = new Map();
  aliases: DisjointSet<IdentifierId>;
  paths: Map<IdentifierId, Map<PropertyLiteral, IdentifierId>>;
  env: Environment;

  constructor(
    env: Environment,
    aliases: DisjointSet<IdentifierId>,
    paths: Map<IdentifierId, Map<PropertyLiteral, IdentifierId>>,
  ) {
    super();
    this.aliases = aliases;
    this.paths = paths;
    this.env = env;
  }

  join(values: Array<CreateUpdate>): CreateUpdate {
    function join2(l: CreateUpdate, r: CreateUpdate): CreateUpdate {
      if (l === 'Update' || r === 'Update') {
        return 'Update';
      } else if (l === 'Create' || r === 'Create') {
        return 'Create';
      } else if (l === 'Unknown' || r === 'Unknown') {
        return 'Unknown';
      }
      assertExhaustive(r, `Unhandled variable kind ${r}`);
    }
    return values.reduce(join2, 'Unknown');
  }

  isCreateOnlyHook(id: Identifier): boolean {
    return isUseStateType(id) || isUseRefType(id);
  }

  override visitPlace(
    _: InstructionId,
    place: Place,
    state: CreateUpdate,
  ): void {
    this.map.set(
      place.identifier.id,
      this.join([state, this.map.get(place.identifier.id) ?? 'Unknown']),
    );
  }

  override visitBlock(block: ReactiveBlock, state: CreateUpdate): void {
    super.visitBlock([...block].reverse(), state);
  }

  override visitInstruction(instruction: ReactiveInstruction): void {
    const state = this.join(
      [...eachInstructionLValue(instruction)].map(
        operand => this.map.get(operand.identifier.id) ?? 'Unknown',
      ),
    );

    const visitCallOrMethodNonArgs = (): void => {
      switch (instruction.value.kind) {
        case 'CallExpression': {
          this.visitPlace(instruction.id, instruction.value.callee, state);
          break;
        }
        case 'MethodCall': {
          this.visitPlace(instruction.id, instruction.value.property, state);
          this.visitPlace(instruction.id, instruction.value.receiver, state);
          break;
        }
      }
    };

    const isHook = (): boolean => {
      let callee = null;
      switch (instruction.value.kind) {
        case 'CallExpression': {
          callee = instruction.value.callee.identifier;
          break;
        }
        case 'MethodCall': {
          callee = instruction.value.property.identifier;
          break;
        }
      }
      return callee != null && getHookKind(this.env, callee) != null;
    };

    switch (instruction.value.kind) {
      case 'CallExpression':
      case 'MethodCall': {
        if (
          instruction.lvalue &&
          this.isCreateOnlyHook(instruction.lvalue.identifier)
        ) {
          [...eachCallArgument(instruction.value.args)].forEach(operand =>
            this.visitPlace(instruction.id, operand, 'Create'),
          );
          visitCallOrMethodNonArgs();
        } else {
          this.traverseInstruction(instruction, isHook() ? 'Update' : state);
        }
        break;
      }
      default: {
        this.traverseInstruction(instruction, state);
      }
    }
  }

  override visitScope(scope: ReactiveScopeBlock): void {
    const state = this.join(
      [
        ...scope.scope.declarations.keys(),
        ...[...scope.scope.reassignments.values()].map(ident => ident.id),
      ].map(id => this.map.get(id) ?? 'Unknown'),
    );
    super.visitScope(scope, state);
    [...scope.scope.dependencies].forEach(ident => {
      let target: undefined | IdentifierId =
        this.aliases.find(ident.identifier.id) ?? ident.identifier.id;
      ident.path.forEach(token => {
        target &&= this.paths.get(target)?.get(token.property);
      });
      if (target && this.map.get(target) === 'Create') {
        scope.scope.dependencies.delete(ident);
      }
    });
  }

  override visitTerminal(
    stmt: ReactiveTerminalStatement,
    state: CreateUpdate,
  ): void {
    CompilerError.invariant(state !== 'Create', {
      reason: "Visiting a terminal statement with state 'Create'",
      loc: stmt.terminal.loc,
    });
    super.visitTerminal(stmt, state);
  }

  override visitReactiveFunctionValue(
    _id: InstructionId,
    _dependencies: Array<Place>,
    fn: ReactiveFunction,
    state: CreateUpdate,
  ): void {
    visitReactiveFunction(fn, this, state);
  }
}

export default function pruneInitializationDependencies(
  fn: ReactiveFunction,
): void {
  const [aliases, paths] = getAliases(fn);
  visitReactiveFunction(fn, new Visitor(fn.env, aliases, paths), 'Update');
}

function update(
  map: Map<IdentifierId, Map<PropertyLiteral, IdentifierId>>,
  key: IdentifierId,
  path: PropertyLiteral,
  value: IdentifierId,
): void {
  const inner = map.get(key) ?? new Map();
  inner.set(path, value);
  map.set(key, inner);
}

class AliasVisitor extends ReactiveFunctionVisitor {
  scopeIdentifiers: DisjointSet<IdentifierId> = new DisjointSet<IdentifierId>();
  scopePaths: Map<IdentifierId, Map<PropertyLiteral, IdentifierId>> = new Map();

  override visitInstruction(instr: ReactiveInstruction): void {
    if (
      instr.value.kind === 'StoreLocal' ||
      instr.value.kind === 'StoreContext'
    ) {
      this.scopeIdentifiers.union([
        instr.value.lvalue.place.identifier.id,
        instr.value.value.identifier.id,
      ]);
    } else if (
      instr.value.kind === 'LoadLocal' ||
      instr.value.kind === 'LoadContext'
    ) {
      instr.lvalue &&
        this.scopeIdentifiers.union([
          instr.lvalue.identifier.id,
          instr.value.place.identifier.id,
        ]);
    } else if (instr.value.kind === 'PropertyLoad') {
      instr.lvalue &&
        update(
          this.scopePaths,
          instr.value.object.identifier.id,
          instr.value.property,
          instr.lvalue.identifier.id,
        );
    } else if (instr.value.kind === 'PropertyStore') {
      update(
        this.scopePaths,
        instr.value.object.identifier.id,
        instr.value.property,
        instr.value.value.identifier.id,
      );
    }
  }
}

function getAliases(
  fn: ReactiveFunction,
): [
  DisjointSet<IdentifierId>,
  Map<IdentifierId, Map<PropertyLiteral, IdentifierId>>,
] {
  const visitor = new AliasVisitor();
  visitReactiveFunction(fn, visitor, null);
  let disjoint = visitor.scopeIdentifiers;
  let scopePaths = new Map<IdentifierId, Map<PropertyLiteral, IdentifierId>>();
  for (const [key, value] of visitor.scopePaths) {
    for (const [path, id] of value) {
      update(
        scopePaths,
        disjoint.find(key) ?? key,
        path,
        disjoint.find(id) ?? id,
      );
    }
  }
  return [disjoint, scopePaths];
}