import {CompilerError, ErrorSeverity} from '../CompilerError';
import {
BlockId,
HIRFunction,
IdentifierId,
Place,
SourceLocation,
getHookKindForType,
isRefValueType,
isUseRefType,
} from '../HIR';
import {
eachInstructionOperand,
eachInstructionValueOperand,
eachPatternOperand,
eachTerminalOperand,
} from '../HIR/visitors';
import {Err, Ok, Result} from '../Utils/Result';
const opaqueRefId = Symbol();
type RefId = number & {[opaqueRefId]: 'RefId'};
function makeRefId(id: number): RefId {
CompilerError.invariant(id >= 0 && Number.isInteger(id), {
reason: 'Expected identifier id to be a non-negative integer',
description: null,
loc: null,
suggestions: null,
});
return id as RefId;
}
let _refId = 0;
function nextRefId(): RefId {
return makeRefId(_refId++);
}
type RefAccessType =
| {kind: 'None'}
| {kind: 'Nullable'}
| {kind: 'Guard'; refId: RefId}
| RefAccessRefType;
type RefAccessRefType =
| {kind: 'Ref'; refId: RefId}
| {kind: 'RefValue'; loc?: SourceLocation; refId?: RefId}
| {kind: 'Structure'; value: null | RefAccessRefType; fn: null | RefFnType};
type RefFnType = {readRefEffect: boolean; returnType: RefAccessType};
class Env extends Map<IdentifierId, RefAccessType> {
#changed = false;
resetChanged(): void {
this.#changed = false;
}
hasChanged(): boolean {
return this.#changed;
}
override set(key: IdentifierId, value: RefAccessType): this {
const cur = this.get(key);
const widenedValue = joinRefAccessTypes(value, cur ?? {kind: 'None'});
if (
!(cur == null && widenedValue.kind === 'None') &&
(cur == null || !tyEqual(cur, widenedValue))
) {
this.#changed = true;
}
return super.set(key, widenedValue);
}
}
export function validateNoRefAccessInRender(
fn: HIRFunction,
): Result<void, CompilerError> {
const env = new Env();
return validateNoRefAccessInRenderImpl(fn, env).map(_ => undefined);
}
function refTypeOfType(place: Place): RefAccessType {
if (isRefValueType(place.identifier)) {
return {kind: 'RefValue'};
} else if (isUseRefType(place.identifier)) {
return {kind: 'Ref', refId: nextRefId()};
} else {
return {kind: 'None'};
}
}
function tyEqual(a: RefAccessType, b: RefAccessType): boolean {
if (a.kind !== b.kind) {
return false;
}
switch (a.kind) {
case 'None':
return true;
case 'Ref':
return true;
case 'Nullable':
return true;
case 'Guard':
CompilerError.invariant(b.kind === 'Guard', {
reason: 'Expected ref value',
loc: null,
});
return a.refId === b.refId;
case 'RefValue':
CompilerError.invariant(b.kind === 'RefValue', {
reason: 'Expected ref value',
loc: null,
});
return a.loc == b.loc;
case 'Structure': {
CompilerError.invariant(b.kind === 'Structure', {
reason: 'Expected structure',
loc: null,
});
const fnTypesEqual =
(a.fn === null && b.fn === null) ||
(a.fn !== null &&
b.fn !== null &&
a.fn.readRefEffect === b.fn.readRefEffect &&
tyEqual(a.fn.returnType, b.fn.returnType));
return (
fnTypesEqual &&
(a.value === b.value ||
(a.value !== null && b.value !== null && tyEqual(a.value, b.value)))
);
}
}
}
function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
function joinRefAccessRefTypes(
a: RefAccessRefType,
b: RefAccessRefType,
): RefAccessRefType {
if (a.kind === 'RefValue') {
if (b.kind === 'RefValue' && a.refId === b.refId) {
return a;
}
return {kind: 'RefValue'};
} else if (b.kind === 'RefValue') {
return b;
} else if (a.kind === 'Ref' || b.kind === 'Ref') {
if (a.kind === 'Ref' && b.kind === 'Ref' && a.refId === b.refId) {
return a;
}
return {kind: 'Ref', refId: nextRefId()};
} else {
CompilerError.invariant(
a.kind === 'Structure' && b.kind === 'Structure',
{
reason: 'Expected structure',
loc: null,
},
);
const fn =
a.fn === null
? b.fn
: b.fn === null
? a.fn
: {
readRefEffect: a.fn.readRefEffect || b.fn.readRefEffect,
returnType: joinRefAccessTypes(
a.fn.returnType,
b.fn.returnType,
),
};
const value =
a.value === null
? b.value
: b.value === null
? a.value
: joinRefAccessRefTypes(a.value, b.value);
return {
kind: 'Structure',
fn,
value,
};
}
}
return types.reduce(
(a, b) => {
if (a.kind === 'None') {
return b;
} else if (b.kind === 'None') {
return a;
} else if (a.kind === 'Guard') {
if (b.kind === 'Guard' && a.refId === b.refId) {
return a;
} else if (b.kind === 'Nullable' || b.kind === 'Guard') {
return {kind: 'None'};
} else {
return b;
}
} else if (b.kind === 'Guard') {
if (a.kind === 'Nullable') {
return {kind: 'None'};
} else {
return b;
}
} else if (a.kind === 'Nullable') {
return b;
} else if (b.kind === 'Nullable') {
return a;
} else {
return joinRefAccessRefTypes(a, b);
}
},
{kind: 'None'},
);
}
function validateNoRefAccessInRenderImpl(
fn: HIRFunction,
env: Env,
): Result<RefAccessType, CompilerError> {
let returnValues: Array<undefined | RefAccessType> = [];
let place;
for (const param of fn.params) {
if (param.kind === 'Identifier') {
place = param;
} else {
place = param.place;
}
const type = refTypeOfType(place);
env.set(place.identifier.id, type);
}
for (let i = 0; (i == 0 || env.hasChanged()) && i < 10; i++) {
env.resetChanged();
returnValues = [];
const safeBlocks = new Map<BlockId, RefId>();
const errors = new CompilerError();
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
env.set(
phi.place.identifier.id,
joinRefAccessTypes(
...Array(...phi.operands.values()).map(
operand =>
env.get(operand.identifier.id) ?? ({kind: 'None'} as const),
),
),
);
}
for (const instr of block.instructions) {
switch (instr.value.kind) {
case 'JsxExpression':
case 'JsxFragment': {
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoDirectRefValueAccess(errors, operand, env);
}
break;
}
case 'ComputedLoad':
case 'PropertyLoad': {
if (instr.value.kind === 'ComputedLoad') {
validateNoDirectRefValueAccess(errors, instr.value.property, env);
}
const objType = env.get(instr.value.object.identifier.id);
let lookupType: null | RefAccessType = null;
if (objType?.kind === 'Structure') {
lookupType = objType.value;
} else if (objType?.kind === 'Ref') {
lookupType = {
kind: 'RefValue',
loc: instr.loc,
refId: objType.refId,
};
}
env.set(
instr.lvalue.identifier.id,
lookupType ?? refTypeOfType(instr.lvalue),
);
break;
}
case 'TypeCastExpression': {
env.set(
instr.lvalue.identifier.id,
env.get(instr.value.value.identifier.id) ??
refTypeOfType(instr.lvalue),
);
break;
}
case 'LoadContext':
case 'LoadLocal': {
env.set(
instr.lvalue.identifier.id,
env.get(instr.value.place.identifier.id) ??
refTypeOfType(instr.lvalue),
);
break;
}
case 'StoreContext':
case 'StoreLocal': {
env.set(
instr.value.lvalue.place.identifier.id,
env.get(instr.value.value.identifier.id) ??
refTypeOfType(instr.value.lvalue.place),
);
env.set(
instr.lvalue.identifier.id,
env.get(instr.value.value.identifier.id) ??
refTypeOfType(instr.lvalue),
);
break;
}
case 'Destructure': {
const objType = env.get(instr.value.value.identifier.id);
let lookupType = null;
if (objType?.kind === 'Structure') {
lookupType = objType.value;
}
env.set(
instr.lvalue.identifier.id,
lookupType ?? refTypeOfType(instr.lvalue),
);
for (const lval of eachPatternOperand(instr.value.lvalue.pattern)) {
env.set(lval.identifier.id, lookupType ?? refTypeOfType(lval));
}
break;
}
case 'ObjectMethod':
case 'FunctionExpression': {
let returnType: RefAccessType = {kind: 'None'};
let readRefEffect = false;
const result = validateNoRefAccessInRenderImpl(
instr.value.loweredFunc.func,
env,
);
if (result.isOk()) {
returnType = result.unwrap();
} else if (result.isErr()) {
readRefEffect = true;
}
env.set(instr.lvalue.identifier.id, {
kind: 'Structure',
fn: {
readRefEffect,
returnType,
},
value: null,
});
break;
}
case 'MethodCall':
case 'CallExpression': {
const callee =
instr.value.kind === 'CallExpression'
? instr.value.callee
: instr.value.property;
const hookKind = getHookKindForType(fn.env, callee.identifier.type);
let returnType: RefAccessType = {kind: 'None'};
const fnType = env.get(callee.identifier.id);
if (fnType?.kind === 'Structure' && fnType.fn !== null) {
returnType = fnType.fn.returnType;
if (fnType.fn.readRefEffect) {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'This function accesses a ref value (the `current` property), which may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: callee.loc,
description:
callee.identifier.name !== null &&
callee.identifier.name.kind === 'named'
? `Function \`${callee.identifier.name.value}\` accesses a ref`
: null,
suggestions: null,
});
}
}
for (const operand of eachInstructionValueOperand(instr.value)) {
if (hookKind != null) {
validateNoDirectRefValueAccess(errors, operand, env);
} else {
validateNoRefAccess(errors, env, operand, operand.loc);
}
}
env.set(instr.lvalue.identifier.id, returnType);
break;
}
case 'ObjectExpression':
case 'ArrayExpression': {
const types: Array<RefAccessType> = [];
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoDirectRefValueAccess(errors, operand, env);
types.push(env.get(operand.identifier.id) ?? {kind: 'None'});
}
const value = joinRefAccessTypes(...types);
if (
value.kind === 'None' ||
value.kind === 'Guard' ||
value.kind === 'Nullable'
) {
env.set(instr.lvalue.identifier.id, {kind: 'None'});
} else {
env.set(instr.lvalue.identifier.id, {
kind: 'Structure',
value,
fn: null,
});
}
break;
}
case 'PropertyDelete':
case 'PropertyStore':
case 'ComputedDelete':
case 'ComputedStore': {
const safe = safeBlocks.get(block.id);
const target = env.get(instr.value.object.identifier.id);
if (
instr.value.kind === 'PropertyStore' &&
safe != null &&
target?.kind === 'Ref' &&
target.refId === safe
) {
safeBlocks.delete(block.id);
} else {
validateNoRefAccess(errors, env, instr.value.object, instr.loc);
}
for (const operand of eachInstructionValueOperand(instr.value)) {
if (operand === instr.value.object) {
continue;
}
validateNoRefValueAccess(errors, env, operand);
}
break;
}
case 'StartMemoize':
case 'FinishMemoize':
break;
case 'Primitive': {
if (instr.value.value == null) {
env.set(instr.lvalue.identifier.id, {kind: 'Nullable'});
}
break;
}
case 'BinaryExpression': {
const left = env.get(instr.value.left.identifier.id);
const right = env.get(instr.value.right.identifier.id);
let nullish: boolean = false;
let refId: RefId | null = null;
if (left?.kind === 'RefValue' && left.refId != null) {
refId = left.refId;
} else if (right?.kind === 'RefValue' && right.refId != null) {
refId = right.refId;
}
if (left?.kind === 'Nullable') {
nullish = true;
} else if (right?.kind === 'Nullable') {
nullish = true;
}
if (refId !== null && nullish) {
env.set(instr.lvalue.identifier.id, {kind: 'Guard', refId});
} else {
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoRefValueAccess(errors, env, operand);
}
}
break;
}
default: {
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoRefValueAccess(errors, env, operand);
}
break;
}
}
for (const operand of eachInstructionOperand(instr)) {
guardCheck(errors, operand, env);
}
if (
isUseRefType(instr.lvalue.identifier) &&
env.get(instr.lvalue.identifier.id)?.kind !== 'Ref'
) {
env.set(
instr.lvalue.identifier.id,
joinRefAccessTypes(
env.get(instr.lvalue.identifier.id) ?? {kind: 'None'},
{kind: 'Ref', refId: nextRefId()},
),
);
}
if (
isRefValueType(instr.lvalue.identifier) &&
env.get(instr.lvalue.identifier.id)?.kind !== 'RefValue'
) {
env.set(
instr.lvalue.identifier.id,
joinRefAccessTypes(
env.get(instr.lvalue.identifier.id) ?? {kind: 'None'},
{kind: 'RefValue', loc: instr.loc},
),
);
}
}
if (block.terminal.kind === 'if') {
const test = env.get(block.terminal.test.identifier.id);
if (test?.kind === 'Guard') {
safeBlocks.set(block.terminal.consequent, test.refId);
}
}
for (const operand of eachTerminalOperand(block.terminal)) {
if (block.terminal.kind !== 'return') {
validateNoRefValueAccess(errors, env, operand);
if (block.terminal.kind !== 'if') {
guardCheck(errors, operand, env);
}
} else {
validateNoDirectRefValueAccess(errors, operand, env);
guardCheck(errors, operand, env);
returnValues.push(env.get(operand.identifier.id));
}
}
}
if (errors.hasErrors()) {
return Err(errors);
}
}
CompilerError.invariant(!env.hasChanged(), {
reason: 'Ref type environment did not converge',
loc: null,
});
return Ok(
joinRefAccessTypes(
...returnValues.filter((env): env is RefAccessType => env !== undefined),
),
);
}
function destructure(
type: RefAccessType | undefined,
): RefAccessType | undefined {
if (type?.kind === 'Structure' && type.value !== null) {
return destructure(type.value);
}
return type;
}
function guardCheck(errors: CompilerError, operand: Place, env: Env): void {
if (env.get(operand.identifier.id)?.kind === 'Guard') {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: operand.loc,
description:
operand.identifier.name !== null &&
operand.identifier.name.kind === 'named'
? `Cannot access ref value \`${operand.identifier.name.value}\``
: null,
suggestions: null,
});
}
}
function validateNoRefValueAccess(
errors: CompilerError,
env: Env,
operand: Place,
): void {
const type = destructure(env.get(operand.identifier.id));
if (
type?.kind === 'RefValue' ||
(type?.kind === 'Structure' && type.fn?.readRefEffect)
) {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: (type.kind === 'RefValue' && type.loc) || operand.loc,
description:
operand.identifier.name !== null &&
operand.identifier.name.kind === 'named'
? `Cannot access ref value \`${operand.identifier.name.value}\``
: null,
suggestions: null,
});
}
}
function validateNoRefAccess(
errors: CompilerError,
env: Env,
operand: Place,
loc: SourceLocation,
): void {
const type = destructure(env.get(operand.identifier.id));
if (
type?.kind === 'Ref' ||
type?.kind === 'RefValue' ||
(type?.kind === 'Structure' && type.fn?.readRefEffect)
) {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: (type.kind === 'RefValue' && type.loc) || loc,
description:
operand.identifier.name !== null &&
operand.identifier.name.kind === 'named'
? `Cannot access ref value \`${operand.identifier.name.value}\``
: null,
suggestions: null,
});
}
}
function validateNoDirectRefValueAccess(
errors: CompilerError,
operand: Place,
env: Env,
): void {
const type = destructure(env.get(operand.identifier.id));
if (type?.kind === 'RefValue') {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: type.loc ?? operand.loc,
description:
operand.identifier.name !== null &&
operand.identifier.name.kind === 'named'
? `Cannot access ref value \`${operand.identifier.name.value}\``
: null,
suggestions: null,
});
}
}