use std::collections::HashMap;
use react_diagnostics::Diagnostic;
use react_estree::{BinaryOperator, JsValue};
use react_hir::{
initialize_hir, merge_consecutive_blocks, BlockKind, Environment, Function, GotoKind,
GotoTerminal, IdentifierId, InstructionValue, LoadGlobal, Primitive, TerminalValue,
};
use react_ssa::eliminate_redundant_phis;
pub fn constant_propagation(env: &Environment, fun: &mut Function) -> Result<(), Diagnostic> {
let mut constants = Constants::default();
constant_propagation_impl(env, fun, &mut constants)
}
fn constant_propagation_impl(
env: &Environment,
fun: &mut Function,
constants: &mut Constants,
) -> Result<(), Diagnostic> {
loop {
let have_terminals_changed = apply_constant_propagation(env, fun, constants)?;
if !have_terminals_changed {
break;
}
initialize_hir(&mut fun.body).unwrap();
for block in fun.body.blocks.iter_mut() {
let predecessors = block.predecessors.clone();
for phi in block.phis.iter_mut() {
phi.operands
.retain(|predecessor, _| predecessors.contains(predecessor))
}
}
eliminate_redundant_phis(env, fun);
merge_consecutive_blocks(env, fun)?;
}
Ok(())
}
fn apply_constant_propagation(
env: &Environment,
fun: &mut Function,
constants: &mut Constants,
) -> Result<bool, Diagnostic> {
let mut has_changes = false;
for block in fun.body.blocks.iter_mut() {
for phi in block.phis.iter() {
let mut value: Option<Constant> = None;
for (_, operand) in &phi.operands {
if let Some(operand_value) = constants.get(operand.id) {
match &mut value {
Some(value) if value == operand_value => {
}
Some(_) => {
value = None;
break;
}
None => {
value = Some(operand_value.clone());
}
}
} else {
value = None;
break;
}
}
if let Some(value) = value {
constants.insert(phi.identifier.id, value);
}
}
for (ix, instr_ix) in block.instructions.iter().enumerate() {
if block.kind == BlockKind::Sequence && ix == block.instructions.len() - 1 {
continue;
}
let instr_ix = usize::from(*instr_ix);
let lvalue_id = fun.body.instructions[instr_ix].lvalue.identifier.id;
let mut value = std::mem::replace(
&mut fun.body.instructions[instr_ix].value,
InstructionValue::Tombstone,
);
let const_value = evaluate_instruction(env, &mut value, constants)?;
if let Some(const_value) = const_value {
constants.insert(lvalue_id, const_value);
}
fun.body.instructions[instr_ix].value = value;
}
if let TerminalValue::If(terminal) = &mut block.terminal.value {
if let Some(primitive) = constants.get_primitive(terminal.test.identifier.id) {
let target_block_id = if primitive.value.is_truthy() {
terminal.consequent
} else {
terminal.alternate
};
block.terminal.value = TerminalValue::Goto(GotoTerminal {
block: target_block_id,
kind: GotoKind::Break,
});
has_changes = true;
}
}
}
Ok(has_changes)
}
fn evaluate_instruction(
env: &Environment,
mut instr: &mut InstructionValue,
constants: &mut Constants,
) -> Result<Option<Constant>, Diagnostic> {
match &mut instr {
InstructionValue::Primitive(value) => Ok(Some(Constant::Primitive(value.clone()))),
InstructionValue::LoadGlobal(value) => Ok(Some(Constant::Global(value.clone()))),
InstructionValue::Binary(value) => {
let left = constants.get_primitive(value.left.identifier.id);
let right = constants.get_primitive(value.right.identifier.id);
match (left, right) {
(Some(left), Some(right)) => {
if let Some(result) =
apply_binary_operator(env, &left.value, value.operator, &right.value)
{
*instr = InstructionValue::Primitive(Primitive {
value: result.clone(),
});
Ok(Some(Constant::Primitive(Primitive { value: result })))
} else {
Ok(None)
}
}
_ => {
Ok(None)
}
}
}
InstructionValue::LoadLocal(value) => {
if let Some(const_value) = constants.get(value.place.identifier.id) {
*instr = const_value.into();
Ok(Some(const_value.clone()))
} else {
Ok(None)
}
}
InstructionValue::StoreLocal(value) => {
if let Some(const_value) = constants.get(value.value.identifier.id).cloned() {
constants.insert(value.lvalue.identifier.identifier.id, const_value.clone());
Ok(Some(const_value))
} else {
Ok(None)
}
}
InstructionValue::Function(value) => {
let mut inner_constants: Constants = value
.lowered_function
.context
.iter()
.filter_map(|id| {
let value = constants.get(id.identifier.id);
value.map(|value| (id.identifier.id, value.clone()))
})
.collect();
constant_propagation_impl(env, &mut value.lowered_function, &mut inner_constants)?;
Ok(None)
}
_ => {
Ok(None)
}
}
}
fn apply_binary_operator(
_env: &Environment,
left: &JsValue,
operator: BinaryOperator,
right: &JsValue,
) -> Option<JsValue> {
match (left, right) {
(JsValue::Number(left), JsValue::Number(right)) => match operator {
BinaryOperator::Add => Some(JsValue::Number(*left + *right)),
BinaryOperator::Subtract => Some(JsValue::Number(*left - *right)),
BinaryOperator::Multiply => Some(JsValue::Number(*left * *right)),
BinaryOperator::Divide => Some(JsValue::Number(*left / *right)),
BinaryOperator::LessThan => Some(JsValue::Boolean(*left < *right)),
BinaryOperator::LessThanOrEqual => Some(JsValue::Boolean(*left <= *right)),
BinaryOperator::GreaterThan => Some(JsValue::Boolean(*left > *right)),
BinaryOperator::GreaterThanOrEqual => Some(JsValue::Boolean(*left >= *right)),
BinaryOperator::Equals => Some(JsValue::Boolean(left.equals(*right))),
BinaryOperator::NotEquals => Some(JsValue::Boolean(left.not_equals(*right))),
BinaryOperator::StrictEquals => Some(JsValue::Boolean(left.equals(*right))),
BinaryOperator::NotStrictEquals => Some(JsValue::Boolean(left.not_equals(*right))),
_ => None,
},
(left, right) => match operator {
BinaryOperator::Equals => left.loosely_equals(right).map(JsValue::Boolean),
BinaryOperator::NotEquals => left.not_loosely_equals(right).map(JsValue::Boolean),
BinaryOperator::StrictEquals => Some(JsValue::Boolean(left.strictly_equals(right))),
BinaryOperator::NotStrictEquals => {
Some(JsValue::Boolean(left.not_strictly_equals(right)))
}
_ => None,
},
}
}
#[derive(Default)]
struct Constants {
data: HashMap<IdentifierId, Constant>,
}
impl Constants {
fn get_primitive(&self, id: IdentifierId) -> Option<&Primitive> {
if let Some(Constant::Primitive(primitive)) = &self.data.get(&id) {
Some(primitive)
} else {
None
}
}
fn get(&self, id: IdentifierId) -> Option<&Constant> {
self.data.get(&id)
}
fn insert(&mut self, id: IdentifierId, constant: Constant) {
self.data.insert(id, constant);
}
}
impl FromIterator<(IdentifierId, Constant)> for Constants {
fn from_iter<T: IntoIterator<Item = (IdentifierId, Constant)>>(iter: T) -> Self {
Self {
data: FromIterator::from_iter(iter),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Constant {
Global(LoadGlobal),
Primitive(Primitive),
}
impl From<&Constant> for InstructionValue {
fn from(value: &Constant) -> Self {
match value {
Constant::Global(value) => InstructionValue::LoadGlobal(value.clone()),
Constant::Primitive(value) => InstructionValue::Primitive(value.clone()),
}
}
}
impl From<Constant> for InstructionValue {
fn from(value: Constant) -> Self {
match value {
Constant::Global(value) => InstructionValue::LoadGlobal(value),
Constant::Primitive(value) => InstructionValue::Primitive(value),
}
}
}