use std::collections::{HashMap, HashSet};
use react_compiler_ast::common::BaseNode;
use react_compiler_ast::declarations::{
ImportDeclaration, ImportKind, ImportSpecifier, ImportSpecifierData, ModuleExportName,
};
use react_compiler_ast::expressions::{CallExpression, Expression, Identifier};
use react_compiler_ast::literals::StringLiteral;
use react_compiler_ast::patterns::{ObjectPattern, ObjectPatternProp, ObjectPatternProperty, PatternLike};
use react_compiler_ast::scope::ScopeInfo;
use react_compiler_ast::statements::{
Statement, VariableDeclaration, VariableDeclarationKind, VariableDeclarator,
};
use react_compiler_ast::{Program, SourceType};
use react_compiler_diagnostics::{CompilerError, CompilerErrorDetail, ErrorCategory, Position, SourceLocation};
use super::compile_result::{DebugLogEntry, LoggerEvent, OrderedLogItem};
use super::plugin_options::{CompilerTarget, PluginOptions};
use super::suppression::SuppressionRange;
use crate::timing::TimingData;
#[derive(Debug, Clone)]
pub struct NonLocalImportSpecifier {
pub name: String,
pub module: String,
pub imported: String,
}
pub struct ProgramContext {
pub opts: PluginOptions,
pub filename: Option<String>,
source_filename: Option<String>,
pub code: Option<String>,
pub react_runtime_module: String,
pub suppressions: Vec<SuppressionRange>,
pub has_module_scope_opt_out: bool,
pub events: Vec<LoggerEvent>,
pub ordered_log: Vec<OrderedLogItem>,
pub instrument_fn_name: Option<String>,
pub instrument_gating_name: Option<String>,
pub hook_guard_name: Option<String>,
pub renames: Vec<react_compiler_hir::environment::BindingRename>,
pub timing: TimingData,
pub debug_enabled: bool,
already_compiled: HashSet<u32>,
known_referenced_names: HashSet<String>,
imports: HashMap<String, HashMap<String, NonLocalImportSpecifier>>,
}
impl ProgramContext {
pub fn new(
opts: PluginOptions,
filename: Option<String>,
code: Option<String>,
suppressions: Vec<SuppressionRange>,
has_module_scope_opt_out: bool,
) -> Self {
let react_runtime_module = get_react_compiler_runtime_module(&opts.target);
let profiling = opts.profiling;
let debug_enabled = opts.debug;
Self {
opts,
filename,
source_filename: None,
code,
react_runtime_module,
suppressions,
has_module_scope_opt_out,
events: Vec::new(),
ordered_log: Vec::new(),
instrument_fn_name: None,
instrument_gating_name: None,
hook_guard_name: None,
renames: Vec::new(),
timing: TimingData::new(profiling),
debug_enabled,
already_compiled: HashSet::new(),
known_referenced_names: HashSet::new(),
imports: HashMap::new(),
}
}
pub fn set_source_filename(&mut self, filename: Option<String>) {
if self.source_filename.is_none() {
self.source_filename = filename;
}
}
pub fn source_filename(&self) -> Option<String> {
self.source_filename.clone()
}
pub fn is_already_compiled(&self, start: u32) -> bool {
self.already_compiled.contains(&start)
}
pub fn mark_compiled(&mut self, start: u32) {
self.already_compiled.insert(start);
}
pub fn init_from_scope(&mut self, scope: &ScopeInfo) {
for binding in &scope.bindings {
self.known_referenced_names.insert(binding.name.clone());
}
}
pub fn has_reference(&self, name: &str) -> bool {
self.known_referenced_names.contains(name)
}
pub fn new_uid(&mut self, name: &str) -> String {
if is_hook_name(name) {
let mut uid = name.to_string();
let mut i = 0;
while self.has_reference(&uid) {
uid = format!("{}_{}", name, i);
i += 1;
}
self.known_referenced_names.insert(uid.clone());
uid
} else if !self.has_reference(name) {
self.known_referenced_names.insert(name.to_string());
name.to_string()
} else {
let base = name.trim_start_matches('_');
let mut uid = format!("_{}", base);
let mut i = 2;
while self.has_reference(&uid) {
uid = format!("_{}{}", base, i);
i += 1;
}
self.known_referenced_names.insert(uid.clone());
uid
}
}
pub fn add_memo_cache_import(&mut self) -> NonLocalImportSpecifier {
let module = self.react_runtime_module.clone();
self.add_import_specifier(&module, "c", Some("_c"))
}
pub fn add_import_specifier(
&mut self,
module: &str,
specifier: &str,
name_hint: Option<&str>,
) -> NonLocalImportSpecifier {
if let Some(module_imports) = self.imports.get(module) {
if let Some(existing) = module_imports.get(specifier) {
return existing.clone();
}
}
let name = self.new_uid(name_hint.unwrap_or(specifier));
let binding = NonLocalImportSpecifier {
name,
module: module.to_string(),
imported: specifier.to_string(),
};
self.imports
.entry(module.to_string())
.or_default()
.insert(specifier.to_string(), binding.clone());
binding
}
pub fn add_new_reference(&mut self, name: String) {
self.known_referenced_names.insert(name);
}
pub fn known_referenced_names(&self) -> &HashSet<String> {
&self.known_referenced_names
}
pub fn merge_uid_known_names(&mut self, names: &HashSet<String>) {
self.known_referenced_names.extend(names.iter().cloned());
}
pub fn log_event(&mut self, event: LoggerEvent) {
self.ordered_log.push(OrderedLogItem::Event { event: event.clone() });
self.events.push(event);
}
pub fn log_debug(&mut self, entry: DebugLogEntry) {
self.ordered_log.push(OrderedLogItem::Debug { entry });
}
pub fn has_pending_imports(&self) -> bool {
!self.imports.is_empty()
}
pub fn imports(&self) -> &HashMap<String, HashMap<String, NonLocalImportSpecifier>> {
&self.imports
}
}
pub fn validate_restricted_imports(
program: &Program,
blocklisted: &Option<Vec<String>>,
) -> Option<CompilerError> {
let blocklisted = match blocklisted {
Some(b) if !b.is_empty() => b,
_ => return None,
};
let restricted: HashSet<&str> = blocklisted.iter().map(|s| s.as_str()).collect();
let mut error = CompilerError::new();
for stmt in &program.body {
if let Statement::ImportDeclaration(import) = stmt {
if restricted.contains(import.source.value.as_str()) {
let mut detail = CompilerErrorDetail::new(
ErrorCategory::Todo,
"Bailing out due to blocklisted import",
)
.with_description(format!("Import from module {}", import.source.value));
detail.loc = import.base.loc.as_ref().map(|loc| SourceLocation {
start: Position { line: loc.start.line, column: loc.start.column, index: loc.start.index },
end: Position { line: loc.end.line, column: loc.end.column, index: loc.end.index },
});
error.push_error_detail(detail);
}
}
}
if error.has_any_errors() {
Some(error)
} else {
None
}
}
pub fn add_imports_to_program(program: &mut Program, context: &ProgramContext) {
if context.imports.is_empty() {
return;
}
let existing_import_indices: HashMap<String, usize> = program
.body
.iter()
.enumerate()
.filter_map(|(idx, stmt)| {
if let Statement::ImportDeclaration(import) = stmt {
if is_non_namespaced_import(import) {
return Some((import.source.value.clone(), idx));
}
}
None
})
.collect();
let mut stmts: Vec<Statement> = Vec::new();
let mut sorted_modules: Vec<_> = context.imports.iter().collect();
sorted_modules.sort_by(|(a, _), (b, _)| a.to_lowercase().cmp(&b.to_lowercase()));
for (module_name, imports_map) in sorted_modules {
let sorted_imports = {
let mut sorted: Vec<_> = imports_map.values().collect();
sorted.sort_by_key(|s| &s.imported);
sorted
};
let import_specifiers: Vec<ImportSpecifier> = sorted_imports
.iter()
.map(|spec| make_import_specifier(spec))
.collect();
if let Some(&idx) = existing_import_indices.get(module_name.as_str()) {
if let Statement::ImportDeclaration(ref mut import) = program.body[idx] {
import.specifiers.extend(import_specifiers);
}
} else if matches!(program.source_type, SourceType::Module) {
stmts.push(Statement::ImportDeclaration(ImportDeclaration {
base: BaseNode::typed("ImportDeclaration"),
specifiers: import_specifiers,
source: StringLiteral {
base: BaseNode::typed("StringLiteral"),
value: module_name.clone(),
},
import_kind: None,
assertions: None,
attributes: None,
}));
} else {
let properties: Vec<ObjectPatternProperty> = sorted_imports
.iter()
.map(|spec| {
ObjectPatternProperty::ObjectProperty(ObjectPatternProp {
base: BaseNode::typed("ObjectProperty"),
key: Box::new(Expression::Identifier(Identifier {
base: BaseNode::typed("Identifier"),
name: spec.imported.clone(),
type_annotation: None,
optional: None,
decorators: None,
})),
value: Box::new(PatternLike::Identifier(Identifier {
base: BaseNode::typed("Identifier"),
name: spec.name.clone(),
type_annotation: None,
optional: None,
decorators: None,
})),
computed: false,
shorthand: false,
decorators: None,
method: None,
})
})
.collect();
stmts.push(Statement::VariableDeclaration(VariableDeclaration {
base: BaseNode::typed("VariableDeclaration"),
kind: VariableDeclarationKind::Const,
declarations: vec![VariableDeclarator {
base: BaseNode::typed("VariableDeclarator"),
id: PatternLike::ObjectPattern(ObjectPattern {
base: BaseNode::typed("ObjectPattern"),
properties,
type_annotation: None,
decorators: None,
}),
init: Some(Box::new(Expression::CallExpression(CallExpression {
base: BaseNode::typed("CallExpression"),
callee: Box::new(Expression::Identifier(Identifier {
base: BaseNode::typed("Identifier"),
name: "require".to_string(),
type_annotation: None,
optional: None,
decorators: None,
})),
arguments: vec![Expression::StringLiteral(StringLiteral {
base: BaseNode::typed("StringLiteral"),
value: module_name.clone(),
})],
type_parameters: None,
type_arguments: None,
optional: None,
}))),
definite: None,
}],
declare: None,
}));
}
}
if !stmts.is_empty() {
let mut new_body = stmts;
new_body.append(&mut program.body);
program.body = new_body;
}
}
fn make_import_specifier(spec: &NonLocalImportSpecifier) -> ImportSpecifier {
ImportSpecifier::ImportSpecifier(ImportSpecifierData {
base: BaseNode::typed("ImportSpecifier"),
local: Identifier {
base: BaseNode::typed("Identifier"),
name: spec.name.clone(),
type_annotation: None,
optional: None,
decorators: None,
},
imported: ModuleExportName::Identifier(Identifier {
base: BaseNode::typed("Identifier"),
name: spec.imported.clone(),
type_annotation: None,
optional: None,
decorators: None,
}),
import_kind: None,
})
}
fn is_non_namespaced_import(import: &ImportDeclaration) -> bool {
import
.specifiers
.iter()
.all(|s| matches!(s, ImportSpecifier::ImportSpecifier(_)))
&& import
.import_kind
.as_ref()
.map_or(true, |k| matches!(k, ImportKind::Value))
}
fn is_hook_name(name: &str) -> bool {
let bytes = name.as_bytes();
bytes.len() >= 4
&& bytes[0] == b'u'
&& bytes[1] == b's'
&& bytes[2] == b'e'
&& bytes
.get(3)
.map_or(false, |c| c.is_ascii_uppercase() || c.is_ascii_digit())
}
pub fn get_react_compiler_runtime_module(target: &CompilerTarget) -> String {
match target {
CompilerTarget::Version(v) if v == "19" => "react/compiler-runtime".to_string(),
CompilerTarget::Version(v) if v == "17" || v == "18" => {
"react-compiler-runtime".to_string()
}
CompilerTarget::MetaInternal { runtime_module, .. } => runtime_module.clone(),
CompilerTarget::Version(_) => "react/compiler-runtime".to_string(),
}
}