import {CompilerError, SourceLocation} from '..';
import {ErrorCategory} from '../CompilerError';
import {
ArrayExpression,
BlockId,
FunctionExpression,
HIRFunction,
IdentifierId,
isSetStateType,
isUseEffectHookType,
} from '../HIR';
import {
eachInstructionValueOperand,
eachTerminalOperand,
} from '../HIR/visitors';
export function validateNoDerivedComputationsInEffects(fn: HIRFunction): void {
const candidateDependencies: Map<IdentifierId, ArrayExpression> = new Map();
const functions: Map<IdentifierId, FunctionExpression> = new Map();
const locals: Map<IdentifierId, IdentifierId> = new Map();
const errors = new CompilerError();
for (const block of fn.body.blocks.values()) {
for (const instr of block.instructions) {
const {lvalue, value} = instr;
if (value.kind === 'LoadLocal') {
locals.set(lvalue.identifier.id, value.place.identifier.id);
} else if (value.kind === 'ArrayExpression') {
candidateDependencies.set(lvalue.identifier.id, value);
} else if (value.kind === 'FunctionExpression') {
functions.set(lvalue.identifier.id, value);
} else if (
value.kind === 'CallExpression' ||
value.kind === 'MethodCall'
) {
const callee =
value.kind === 'CallExpression' ? value.callee : value.property;
if (
isUseEffectHookType(callee.identifier) &&
value.args.length === 2 &&
value.args[0].kind === 'Identifier' &&
value.args[1].kind === 'Identifier'
) {
const effectFunction = functions.get(value.args[0].identifier.id);
const deps = candidateDependencies.get(value.args[1].identifier.id);
if (
effectFunction != null &&
deps != null &&
deps.elements.length !== 0 &&
deps.elements.every(element => element.kind === 'Identifier')
) {
const dependencies: Array<IdentifierId> = deps.elements.map(dep => {
CompilerError.invariant(dep.kind === 'Identifier', {
reason: `Dependency is checked as a place above`,
loc: value.loc,
});
return locals.get(dep.identifier.id) ?? dep.identifier.id;
});
validateEffect(
effectFunction.loweredFunc.func,
dependencies,
errors,
);
}
}
}
}
}
if (errors.hasAnyErrors()) {
throw errors;
}
}
function validateEffect(
effectFunction: HIRFunction,
effectDeps: Array<IdentifierId>,
errors: CompilerError,
): void {
for (const operand of effectFunction.context) {
if (isSetStateType(operand.identifier)) {
continue;
} else if (effectDeps.find(dep => dep === operand.identifier.id) != null) {
continue;
} else {
return;
}
}
for (const dep of effectDeps) {
if (
effectFunction.context.find(operand => operand.identifier.id === dep) ==
null
) {
return;
}
}
const seenBlocks: Set<BlockId> = new Set();
const values: Map<IdentifierId, Array<IdentifierId>> = new Map();
for (const dep of effectDeps) {
values.set(dep, [dep]);
}
const setStateLocations: Array<SourceLocation> = [];
for (const block of effectFunction.body.blocks.values()) {
for (const pred of block.preds) {
if (!seenBlocks.has(pred)) {
return;
}
}
for (const phi of block.phis) {
const aggregateDeps: Set<IdentifierId> = new Set();
for (const operand of phi.operands.values()) {
const deps = values.get(operand.identifier.id);
if (deps != null) {
for (const dep of deps) {
aggregateDeps.add(dep);
}
}
}
if (aggregateDeps.size !== 0) {
values.set(phi.place.identifier.id, Array.from(aggregateDeps));
}
}
for (const instr of block.instructions) {
switch (instr.value.kind) {
case 'Primitive':
case 'JSXText':
case 'LoadGlobal': {
break;
}
case 'LoadLocal': {
const deps = values.get(instr.value.place.identifier.id);
if (deps != null) {
values.set(instr.lvalue.identifier.id, deps);
}
break;
}
case 'ComputedLoad':
case 'PropertyLoad':
case 'BinaryExpression':
case 'TemplateLiteral':
case 'CallExpression':
case 'MethodCall': {
const aggregateDeps: Set<IdentifierId> = new Set();
for (const operand of eachInstructionValueOperand(instr.value)) {
const deps = values.get(operand.identifier.id);
if (deps != null) {
for (const dep of deps) {
aggregateDeps.add(dep);
}
}
}
if (aggregateDeps.size !== 0) {
values.set(instr.lvalue.identifier.id, Array.from(aggregateDeps));
}
if (
instr.value.kind === 'CallExpression' &&
isSetStateType(instr.value.callee.identifier) &&
instr.value.args.length === 1 &&
instr.value.args[0].kind === 'Identifier'
) {
const deps = values.get(instr.value.args[0].identifier.id);
if (deps != null && new Set(deps).size === effectDeps.length) {
setStateLocations.push(instr.value.callee.loc);
} else {
return;
}
}
break;
}
default: {
return;
}
}
}
for (const operand of eachTerminalOperand(block.terminal)) {
if (values.has(operand.identifier.id)) {
return;
}
}
seenBlocks.add(block.id);
}
for (const loc of setStateLocations) {
errors.push({
category: ErrorCategory.EffectDerivationsOfState,
reason:
'Values derived from props and state should be calculated during render, not in an effect. (https://react.dev/learn/you-might-not-need-an-effect#updating-state-based-on-props-or-state)',
description: null,
loc,
suggestions: null,
});
}
}