/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use react_diagnostics::Diagnostic;
use crate::{
initialize_hir, BasicBlock, BlockRewriter, BlockRewriterAction, DeclareLocal, Environment,
Function, GotoKind, GotoTerminal, Identifier, IdentifierData, IdentifierId, IdentifierOperand,
InstrIx, Instruction, InstructionKind, InstructionValue, LValue, LabelTerminal, LoadLocal,
MutableRange, PlaceOrSpread, ReturnTerminal, StoreLocal, Terminal, TerminalValue, Type,
};
/// Inlines `useMemo()` calls, rewriting so that the lambda body becomes part of the
/// outer block's instructions. To account for complex control flow, the inlining works
/// as follows:
/// * First, block ids are guaranteed to be unique for all blocks within a function and
/// its recursive function expressions. Thus, the function expression's blocks can be
/// directly moved into the outer function's `blocks` map.
/// * To account for complex control flow, we create a "label" terminal just prior to
/// the useMemo call, with the useMemo function's entry block as the body of the
/// label terminal. The code following the useMemo call becomes the fallthrough.
/// All returns within the useMemo are translated to instead:
/// * Assign to a temporary identifier representing the useMemo result
/// * Break to the label's fallthrough.
///
/// ## Example
///
/// Input:
/// ```javascript
/// foo();
/// const x = useMemo(() => {
/// if (a) {
/// return b;
/// }
/// return c;
/// })
/// x;
/// ```
///
/// HIR after translation:
/// ```hir
/// bb0:
/// [ 1] #1 = LoadLocal 'foo'
/// [ 2] #2 = Call #1()
/// // label to allow substituting return -> goto in the lambda body
/// [ 3] Label body=bb1 fallthrough=bb4
/// bb1:
/// [ 4] #3 = LoadLocal 'a'
/// [ 5] If test=#3 consequent=bb2 alternate=bb3
/// bb2:
/// [ 6] #4 = LoadLocal 'b'
/// [ 7] StoreLocal '<tmp>', #4
/// [ 8] Goto bb4
/// bb3:
/// [ 9] #5 = LoadLocal 'c'
/// [10] StoreLocal '<tmp>' #5
/// [11] Goto bb4
/// bb4:
/// // code after the useMemo. save the temporary
/// [12] #6 = LoadLocal '<tmp>'
/// [13] StoreLocal 'x', #6
/// ```
///
pub fn inline_use_memo(env: &Environment, fun: &mut Function) -> Result<(), Diagnostic> {
let mut use_memo_globals: HashSet<IdentifierId> = Default::default();
let mut functions: HashMap<IdentifierId, InstrIx> = Default::default();
let blocks = &mut fun.body.blocks;
let instructions = &mut fun.body.instructions;
let mut rewriter = BlockRewriter::new(blocks, fun.body.entry);
let mut inlined = Vec::new();
rewriter.try_each_block(|mut block, rewriter| {
for (i, instr_ix) in block.instructions.iter().cloned().enumerate() {
let instr = &mut instructions[usize::from(instr_ix)];
match &mut instr.value {
InstructionValue::LoadGlobal(value) => {
if value.name.as_str() == "useMemo" {
use_memo_globals.insert(instr.lvalue.identifier.id);
}
}
InstructionValue::Function(_) => {
functions.insert(instr.lvalue.identifier.id, instr_ix);
}
InstructionValue::Call(value) => {
if !use_memo_globals.contains(&value.callee.identifier.id) {
continue;
}
// Skip useMemo calls where the argument is a spread element
let lambda_id = match &value.arguments.get(0) {
Some(PlaceOrSpread::Place(place)) => place.identifier.id,
_ => continue,
};
let lambda_ix = match functions.get(&lambda_id) {
Some(ix) => *ix,
// Skip useMemo calls where the argument is not a function expression
_ => continue,
};
let instr_id = instr.id;
// Create a temporary variable to store the useMemo result into
let temporary_id = env.next_identifier_id();
let temporary = Identifier {
id: temporary_id,
// NOTE: for memoization to work correctly this variable has to be named
name: Some("t".to_string()),
data: Rc::new(RefCell::new(IdentifierData {
mutable_range: MutableRange::new(),
scope: None,
type_: Type::Var(env.next_type_var_id()),
})),
};
// Replace the call with a load of the temporary
// this is convenient since consumers of the useMemo call
// already point to this instruction id, so by reusing the
// instruction we don't have to update the consumer(s) to
// look at a different instruction
instr.value = InstructionValue::LoadLocal(LoadLocal {
place: IdentifierOperand {
identifier: temporary.clone(),
effect: None,
},
});
// Move the function expression out of its instruction so that we own
// the value and can modify and inline its contents into the outer
// function. We replace with a tombstone value that we can filter out later
let lambda = std::mem::replace(
&mut instructions[usize::from(lambda_ix)].value,
InstructionValue::Tombstone,
);
let mut lambda = if let InstructionValue::Function(lambda) = lambda {
lambda
} else {
unreachable!("Must be a function, checked above")
};
// Additional validation
// TODO: this should be part of a separate validation pass
if !lambda.lowered_function.params.is_empty() {
return Err(Diagnostic::invalid_react(
"useMemo callbacks may not accept any arguments",
None,
));
}
if lambda.lowered_function.is_async || lambda.lowered_function.is_generator {
return Err(Diagnostic::invalid_react(
"useMemo callbacks may not be async or generator functions",
None,
));
}
// Set aside a BlockId for the code that follows the useMemo call
let continuation_block_id = env.next_block_id();
// Rewrite the body of the lambda to replace any return terminals
// with an assignment to the useMemo temporary followed by a break
// to the continuation block
for block in lambda.lowered_function.body.blocks.iter_mut() {
if let TerminalValue::Return(ReturnTerminal { value }) =
&mut block.terminal.value
{
let store_ix = InstrIx::new(
lambda.lowered_function.body.instructions.len() as u32,
);
lambda.lowered_function.body.instructions.push(Instruction {
id: instr_id,
lvalue: IdentifierOperand {
identifier: env.new_temporary(),
effect: None,
},
value: InstructionValue::StoreLocal(StoreLocal {
lvalue: LValue {
identifier: IdentifierOperand {
identifier: temporary.clone(),
effect: None,
},
kind: InstructionKind::Reassign,
},
value: value.clone(),
}),
});
block.instructions.push(store_ix);
block.terminal.value = TerminalValue::Goto(GotoTerminal {
block: continuation_block_id,
kind: GotoKind::Break,
});
}
}
// Extract the block's original terminal, which we will move to the
// continuation block. Replace it with a label terminal, necessary to
// allow the goto statements to have a target.
let terminal_id = block.terminal.id;
let terminal = std::mem::replace(
&mut block.terminal,
Terminal {
id: terminal_id,
value: TerminalValue::Label(LabelTerminal {
block: lambda.lowered_function.body.entry,
fallthrough: Some(continuation_block_id),
}),
},
);
// Extract the instructions for the continuation block
let continuation_instructions = block.instructions.split_off(i);
// Declare the temporary variable at the end of the block preceding
// the useMemo invocation
let declare_ix = InstrIx::new(instructions.len() as u32);
instructions.push(Instruction {
id: instr_id,
lvalue: IdentifierOperand {
identifier: env.new_temporary(),
effect: None,
},
value: InstructionValue::DeclareLocal(DeclareLocal {
lvalue: LValue {
identifier: IdentifierOperand {
identifier: temporary.clone(),
effect: None,
},
kind: InstructionKind::Let,
},
}),
});
block.instructions.push(declare_ix);
// Add the continuation block
let continuation_block = Box::new(BasicBlock {
id: continuation_block_id,
instructions: continuation_instructions,
kind: block.kind,
phis: Default::default(),
predecessors: Default::default(),
terminal,
});
rewriter.add_block(continuation_block);
inlined.push(lambda);
break;
}
_ => {}
}
}
Ok(BlockRewriterAction::Keep(block))
})?;
if !inlined.is_empty() {
for lambda in inlined {
fun.body.inline(lambda);
}
initialize_hir(&mut fun.body)?;
}
Ok(())
}