import * as t from '@babel/types';
import {
CompilerError,
CompilerErrorDetail,
ErrorCategory,
} from '../CompilerError';
import {computeUnconditionalBlocks} from '../HIR/ComputeUnconditionalBlocks';
import {isHookName} from '../HIR/Environment';
import {
HIRFunction,
IdentifierId,
Place,
SourceLocation,
getHookKind,
} from '../HIR/HIR';
import {
eachInstructionLValue,
eachInstructionOperand,
eachTerminalOperand,
} from '../HIR/visitors';
import {assertExhaustive} from '../Utils/utils';
import {Result} from '../Utils/Result';
enum Kind {
Error = 'Error',
KnownHook = 'KnownHook',
PotentialHook = 'PotentialHook',
Global = 'Global',
Local = 'Local',
}
function joinKinds(a: Kind, b: Kind): Kind {
if (a === Kind.Error || b === Kind.Error) {
return Kind.Error;
} else if (a === Kind.KnownHook || b === Kind.KnownHook) {
return Kind.KnownHook;
} else if (a === Kind.PotentialHook || b === Kind.PotentialHook) {
return Kind.PotentialHook;
} else if (a === Kind.Global || b === Kind.Global) {
return Kind.Global;
} else {
return Kind.Local;
}
}
export function validateHooksUsage(
fn: HIRFunction,
): Result<void, CompilerError> {
const unconditionalBlocks = computeUnconditionalBlocks(fn);
const errors = new CompilerError();
const errorsByPlace = new Map<t.SourceLocation, CompilerErrorDetail>();
function recordError(
loc: SourceLocation,
errorDetail: CompilerErrorDetail,
): void {
if (typeof loc === 'symbol') {
errors.pushErrorDetail(errorDetail);
} else {
errorsByPlace.set(loc, errorDetail);
}
}
function recordConditionalHookError(place: Place): void {
setKind(place, Kind.Error);
const reason =
'Hooks must always be called in a consistent order, and may not be called conditionally. See the Rules of Hooks (https://react.dev/warnings/invalid-hook-call-warning)';
const previousError =
typeof place.loc !== 'symbol' ? errorsByPlace.get(place.loc) : undefined;
if (previousError === undefined || previousError.reason !== reason) {
recordError(
place.loc,
new CompilerErrorDetail({
category: ErrorCategory.Hooks,
description: null,
reason,
loc: place.loc,
suggestions: null,
}),
);
}
}
function recordInvalidHookUsageError(place: Place): void {
const previousError =
typeof place.loc !== 'symbol' ? errorsByPlace.get(place.loc) : undefined;
if (previousError === undefined) {
recordError(
place.loc,
new CompilerErrorDetail({
category: ErrorCategory.Hooks,
description: null,
reason:
'Hooks may not be referenced as normal values, they must be called. See https://react.dev/reference/rules/react-calls-components-and-hooks#never-pass-around-hooks-as-regular-values',
loc: place.loc,
suggestions: null,
}),
);
}
}
function recordDynamicHookUsageError(place: Place): void {
const previousError =
typeof place.loc !== 'symbol' ? errorsByPlace.get(place.loc) : undefined;
if (previousError === undefined) {
recordError(
place.loc,
new CompilerErrorDetail({
category: ErrorCategory.Hooks,
description: null,
reason:
'Hooks must be the same function on every render, but this value may change over time to a different function. See https://react.dev/reference/rules/react-calls-components-and-hooks#dont-dynamically-use-hooks',
loc: place.loc,
suggestions: null,
}),
);
}
}
const valueKinds = new Map<IdentifierId, Kind>();
function getKindForPlace(place: Place): Kind {
const knownKind = valueKinds.get(place.identifier.id);
if (
place.identifier.name !== null &&
isHookName(place.identifier.name.value)
) {
return joinKinds(knownKind ?? Kind.Local, Kind.PotentialHook);
} else {
return knownKind ?? Kind.Local;
}
}
function visitPlace(place: Place): void {
const kind = valueKinds.get(place.identifier.id);
if (kind === Kind.KnownHook) {
recordInvalidHookUsageError(place);
}
}
function setKind(place: Place, kind: Kind): void {
valueKinds.set(place.identifier.id, kind);
}
for (const param of fn.params) {
const place = param.kind === 'Identifier' ? param : param.place;
const kind = getKindForPlace(place);
setKind(place, kind);
}
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
let kind: Kind =
phi.place.identifier.name !== null &&
isHookName(phi.place.identifier.name.value)
? Kind.PotentialHook
: Kind.Local;
for (const [, operand] of phi.operands) {
const operandKind = valueKinds.get(operand.identifier.id);
if (operandKind !== undefined) {
kind = joinKinds(kind, operandKind);
}
}
valueKinds.set(phi.place.identifier.id, kind);
}
for (const instr of block.instructions) {
switch (instr.value.kind) {
case 'LoadGlobal': {
if (getHookKind(fn.env, instr.lvalue.identifier) != null) {
setKind(instr.lvalue, Kind.KnownHook);
} else {
setKind(instr.lvalue, Kind.Global);
}
break;
}
case 'LoadContext':
case 'LoadLocal': {
visitPlace(instr.value.place);
const kind = getKindForPlace(instr.value.place);
setKind(instr.lvalue, kind);
break;
}
case 'StoreLocal':
case 'StoreContext': {
visitPlace(instr.value.value);
const kind = joinKinds(
getKindForPlace(instr.value.value),
getKindForPlace(instr.value.lvalue.place),
);
setKind(instr.value.lvalue.place, kind);
setKind(instr.lvalue, kind);
break;
}
case 'ComputedLoad': {
visitPlace(instr.value.object);
const kind = getKindForPlace(instr.value.object);
setKind(instr.lvalue, joinKinds(getKindForPlace(instr.lvalue), kind));
break;
}
case 'PropertyLoad': {
const objectKind = getKindForPlace(instr.value.object);
const isHookProperty =
typeof instr.value.property === 'string' &&
isHookName(instr.value.property);
let kind: Kind;
switch (objectKind) {
case Kind.Error: {
kind = Kind.Error;
break;
}
case Kind.KnownHook: {
kind = isHookProperty ? Kind.KnownHook : Kind.Local;
break;
}
case Kind.PotentialHook: {
kind = Kind.PotentialHook;
break;
}
case Kind.Global: {
kind = isHookProperty ? Kind.KnownHook : Kind.Global;
break;
}
case Kind.Local: {
kind = isHookProperty ? Kind.PotentialHook : Kind.Local;
break;
}
default: {
assertExhaustive(objectKind, `Unexpected kind \`${objectKind}\``);
}
}
setKind(instr.lvalue, kind);
break;
}
case 'CallExpression': {
const calleeKind = getKindForPlace(instr.value.callee);
const isHookCallee =
calleeKind === Kind.KnownHook || calleeKind === Kind.PotentialHook;
if (isHookCallee && !unconditionalBlocks.has(block.id)) {
recordConditionalHookError(instr.value.callee);
} else if (calleeKind === Kind.PotentialHook) {
recordDynamicHookUsageError(instr.value.callee);
}
for (const operand of eachInstructionOperand(instr)) {
if (operand === instr.value.callee) {
continue;
}
visitPlace(operand);
}
break;
}
case 'MethodCall': {
const calleeKind = getKindForPlace(instr.value.property);
const isHookCallee =
calleeKind === Kind.KnownHook || calleeKind === Kind.PotentialHook;
if (isHookCallee && !unconditionalBlocks.has(block.id)) {
recordConditionalHookError(instr.value.property);
} else if (calleeKind === Kind.PotentialHook) {
recordDynamicHookUsageError(instr.value.property);
}
for (const operand of eachInstructionOperand(instr)) {
if (operand === instr.value.property) {
continue;
}
visitPlace(operand);
}
break;
}
case 'Destructure': {
visitPlace(instr.value.value);
const objectKind = getKindForPlace(instr.value.value);
for (const lvalue of eachInstructionLValue(instr)) {
const isHookProperty =
lvalue.identifier.name !== null &&
isHookName(lvalue.identifier.name.value);
let kind: Kind;
switch (objectKind) {
case Kind.Error: {
kind = Kind.Error;
break;
}
case Kind.KnownHook: {
kind = Kind.KnownHook;
break;
}
case Kind.PotentialHook: {
kind = Kind.PotentialHook;
break;
}
case Kind.Global: {
kind = isHookProperty ? Kind.KnownHook : Kind.Global;
break;
}
case Kind.Local: {
kind = isHookProperty ? Kind.PotentialHook : Kind.Local;
break;
}
default: {
assertExhaustive(
objectKind,
`Unexpected kind \`${objectKind}\``,
);
}
}
setKind(lvalue, kind);
}
break;
}
case 'ObjectMethod':
case 'FunctionExpression': {
visitFunctionExpression(errors, instr.value.loweredFunc.func);
break;
}
default: {
for (const operand of eachInstructionOperand(instr)) {
visitPlace(operand);
}
for (const lvalue of eachInstructionLValue(instr)) {
const kind = getKindForPlace(lvalue);
setKind(lvalue, kind);
}
}
}
}
for (const operand of eachTerminalOperand(block.terminal)) {
visitPlace(operand);
}
}
for (const [, error] of errorsByPlace) {
errors.pushErrorDetail(error);
}
return errors.asResult();
}
function visitFunctionExpression(errors: CompilerError, fn: HIRFunction): void {
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
switch (instr.value.kind) {
case 'ObjectMethod':
case 'FunctionExpression': {
visitFunctionExpression(errors, instr.value.loweredFunc.func);
break;
}
case 'MethodCall':
case 'CallExpression': {
const callee =
instr.value.kind === 'CallExpression'
? instr.value.callee
: instr.value.property;
const hookKind = getHookKind(fn.env, callee.identifier);
if (hookKind != null) {
errors.pushErrorDetail(
new CompilerErrorDetail({
category: ErrorCategory.Hooks,
reason:
'Hooks must be called at the top level in the body of a function component or custom hook, and may not be called within function expressions. See the Rules of Hooks (https://react.dev/warnings/invalid-hook-call-warning)',
loc: callee.loc,
description: `Cannot call ${hookKind === 'Custom' ? 'hook' : hookKind} within a function expression`,
suggestions: null,
}),
);
}
break;
}
}
}
}
}