use std::cell::RefCell;
use std::rc::Rc;
use indexmap::{IndexMap, IndexSet};
use react_diagnostics::Diagnostic;
use react_hir::{
BasicBlock, BlockId, BlockRewriter, BlockRewriterAction, Blocks, Environment, Function,
Identifier, IdentifierData, IdentifierId, IdentifierOperand, InstructionValue, MutableRange,
Phi, HIR,
};
pub fn enter_ssa(env: &Environment, fun: &mut Function) -> Result<(), Diagnostic> {
assert!(fun.context.is_empty());
enter_ssa_impl(env, fun, None)
}
pub fn enter_ssa_impl(
env: &Environment,
fun: &mut Function,
context_defs: Option<IndexMap<IdentifierId, Identifier>>,
) -> Result<(), Diagnostic> {
let mut builder = Builder::new(env, fun.body.entry, &fun.body.blocks);
if let Some(context_defs) = context_defs {
builder.initialize_context(context_defs);
}
for param in &mut fun.params {
builder.visit_param(param);
}
visit_instructions(env, &mut builder, &mut fun.body)?;
let mut states = builder.complete();
for block in fun.body.blocks.iter_mut() {
let state = states.remove(&block.id).unwrap();
block.phis = state.phis;
}
Ok(())
}
fn visit_instructions(
env: &Environment,
builder: &mut Builder<'_>,
hir: &mut HIR,
) -> Result<(), Diagnostic> {
let instructions = &mut hir.instructions;
let blocks = &mut hir.blocks;
let mut rewriter = BlockRewriter::new(blocks, hir.entry);
rewriter.try_each_block(|mut block, _rewriter| {
builder.start_block(&block);
for instr_ix in &block.instructions {
let instr = &mut instructions[usize::from(*instr_ix)];
instr.each_rvalue(|rvalue| builder.visit_load(rvalue));
instr.try_each_lvalue(|lvalue| builder.visit_store(lvalue))?;
if let InstructionValue::Function(fun) = &mut instr.value {
let context_defs: IndexMap<IdentifierId, Identifier> = fun
.lowered_function
.context
.iter_mut()
.map(|identifier| {
let old_id = identifier.identifier.id;
builder.visit_load(identifier);
(old_id, identifier.identifier.clone())
})
.collect();
enter_ssa_impl(env, &mut fun.lowered_function, Some(context_defs))?;
}
}
block
.terminal
.value
.each_operand(|load| builder.visit_load(load));
builder.close_block(&block);
Ok(BlockRewriterAction::Keep(block))
})
}
#[derive(Debug)]
struct Builder<'e> {
env: &'e Environment,
predecessors: IndexMap<BlockId, IndexSet<BlockId>>,
states: IndexMap<BlockId, BlockState>,
current: BlockId,
unsealed_predecessors: IndexMap<BlockId, usize>,
unknown: IndexSet<IdentifierId>,
context: IndexSet<IdentifierId>,
visited: IndexSet<BlockId>,
}
#[derive(Debug)]
struct BlockState {
defs: IndexMap<IdentifierId, Identifier>,
incomplete_phis: Vec<IncompletePhi>,
phis: Vec<Phi>,
}
impl BlockState {
fn new() -> Self {
Self {
defs: Default::default(),
incomplete_phis: Default::default(),
phis: Default::default(),
}
}
}
#[derive(Debug)]
struct IncompletePhi {
old_id: Identifier,
new_id: Identifier,
}
impl<'e> Builder<'e> {
fn new(env: &'e Environment, entry: BlockId, blocks: &Blocks) -> Self {
let states = blocks
.block_ids()
.into_iter()
.map(|block_id| (block_id, BlockState::new()))
.collect();
let predecessors = blocks
.iter()
.map(|block| (block.id, block.predecessors.clone()))
.collect();
Self {
env,
predecessors,
states,
current: entry,
unsealed_predecessors: Default::default(),
unknown: Default::default(),
context: Default::default(),
visited: Default::default(),
}
}
fn initialize_context(&mut self, defs: IndexMap<IdentifierId, Identifier>) {
let state = self.states.get_mut(&self.current).unwrap();
state.defs = defs;
}
fn complete(self) -> IndexMap<BlockId, BlockState> {
self.states
}
fn next_ssa_id(&self) -> IdentifierId {
self.env.next_identifier_id()
}
fn visit_store(&mut self, lvalue: &mut IdentifierOperand) -> Result<(), Diagnostic> {
let old_identifier = &lvalue.identifier;
if self.unknown.contains(&old_identifier.id) {
return Err(Diagnostic::invariant(
"EnterSSA: Expected identifier to be defined before being used",
None,
));
}
if self.context.contains(&old_identifier.id) {
let new_identifier = self.get_id_at(self.current, old_identifier);
lvalue.identifier = new_identifier;
return Ok(());
}
let new_identifier = self.make_identifier(old_identifier);
let state = self.states.get_mut(&self.current).unwrap();
state.defs.insert(old_identifier.id, new_identifier.clone());
lvalue.identifier = new_identifier;
Ok(())
}
fn visit_param(&mut self, param: &mut IdentifierOperand) {
let old_identifier = ¶m.identifier;
let new_identifier = self.make_identifier(old_identifier);
let state = self.states.get_mut(&self.current).unwrap();
state.defs.insert(old_identifier.id, new_identifier.clone());
param.identifier = new_identifier;
}
fn visit_load(&mut self, local: &mut IdentifierOperand) {
let new_identifier = self.get_id_at(self.current, &local.identifier);
local.identifier = new_identifier;
}
fn get_id_at(&mut self, block_id: BlockId, old_identifier: &Identifier) -> Identifier {
let state = self.states.get(&block_id).unwrap();
if let Some(identifier) = state.defs.get(&old_identifier.id) {
return identifier.clone();
}
let predecessors = self.predecessors.get(&block_id).unwrap();
if predecessors.is_empty() {
panic!("Unable to find previous id for {old_identifier:?}");
}
if self.unsealed_predecessors.get(&block_id).cloned().unwrap() > 0 {
let new_identifier = self.make_identifier(old_identifier);
let state = self.states.get_mut(&block_id).unwrap();
state.incomplete_phis.push(IncompletePhi {
old_id: old_identifier.clone(),
new_id: new_identifier.clone(),
});
state.defs.insert(old_identifier.id, new_identifier.clone());
return new_identifier;
}
if predecessors.len() == 1 {
let predecessor = predecessors.first().unwrap();
let new_identifier = self.get_id_at(*predecessor, old_identifier);
let state = self.states.get_mut(&block_id).unwrap();
state.defs.insert(old_identifier.id, new_identifier.clone());
return new_identifier;
}
let new_identifier = self.make_identifier(old_identifier);
let state = self.states.get_mut(&block_id).unwrap();
state.defs.insert(old_identifier.id, new_identifier.clone());
self.add_phi(block_id, old_identifier, new_identifier)
}
fn add_phi(
&mut self,
block_id: BlockId,
old_identifier: &Identifier,
new_identifier: Identifier,
) -> Identifier {
let mut phi = Phi {
identifier: new_identifier.clone(),
operands: Default::default(),
};
let predecessors = self.predecessors.get(&block_id).unwrap().clone();
for pred_block_id in predecessors {
let pred_id = self.get_id_at(pred_block_id, old_identifier);
phi.operands.insert(pred_block_id, pred_id);
}
let state = self.states.get_mut(&block_id).unwrap();
state.phis.push(phi);
new_identifier
}
fn make_identifier(&self, old_identifier: &Identifier) -> Identifier {
let old_data = old_identifier.data.borrow();
Identifier {
id: self.next_ssa_id(),
name: old_identifier.name.clone(),
data: Rc::new(RefCell::new(IdentifierData {
mutable_range: MutableRange::new(),
scope: None,
type_: old_data.type_.clone(),
})),
}
}
fn fix_incomplete_phis(&mut self, block_id: BlockId) {
let state = self.states.get_mut(&block_id).unwrap();
let incomplete_phis = std::mem::take(&mut state.incomplete_phis);
for phi in incomplete_phis {
self.add_phi(block_id, &phi.old_id, phi.new_id);
}
}
fn start_block(&mut self, block: &BasicBlock) {
self.current = block.id;
self.visited.insert(block.id);
}
fn close_block(&mut self, block: &BasicBlock) {
let successors = block.terminal.value.successors();
for successor in successors {
let preds = &self.predecessors.get(&successor).unwrap();
let count = self
.unsealed_predecessors
.entry(successor)
.or_insert(preds.len());
*count -= 1;
if *count == 0 && self.visited.contains(&successor) {
self.fix_incomplete_phis(successor)
}
}
}
}