import {CompilerError, CompilerErrorDetailOptions, SourceLocation} from '..';
import {
ArrayExpression,
CallExpression,
Effect,
Environment,
FunctionExpression,
GeneratedSource,
HIRFunction,
Identifier,
IdentifierId,
Instruction,
InstructionId,
InstructionKind,
InstructionValue,
isUseEffectHookType,
LoadLocal,
makeInstructionId,
NonLocalImportSpecifier,
Place,
promoteTemporary,
} from '../HIR';
import {createTemporaryPlace, markInstructionIds} from '../HIR/HIRBuilder';
import {getOrInsertWith} from '../Utils/utils';
import {
BuiltInFireFunctionId,
BuiltInFireId,
DefaultNonmutatingHook,
} from '../HIR/ObjectShape';
import {eachInstructionOperand} from '../HIR/visitors';
import {printSourceLocationLine} from '../HIR/PrintHIR';
import {USE_FIRE_FUNCTION_NAME} from '../HIR/Environment';
import {ErrorCategory} from '../CompilerError';
const CANNOT_COMPILE_FIRE = 'Cannot compile `fire`';
export function transformFire(fn: HIRFunction): void {
const context = new Context(fn.env);
replaceFireFunctions(fn, context);
if (!context.hasErrors()) {
ensureNoMoreFireUses(fn, context);
}
context.throwIfErrorsFound();
}
function replaceFireFunctions(fn: HIRFunction, context: Context): void {
let importedUseFire: NonLocalImportSpecifier | null = null;
let hasRewrite = false;
for (const [, block] of fn.body.blocks) {
const rewriteInstrs = new Map<InstructionId, Array<Instruction>>();
const deleteInstrs = new Set<InstructionId>();
for (const instr of block.instructions) {
const {value, lvalue} = instr;
if (
value.kind === 'CallExpression' &&
isUseEffectHookType(value.callee.identifier) &&
value.args.length > 0 &&
value.args[0].kind === 'Identifier'
) {
const lambda = context.getFunctionExpression(
value.args[0].identifier.id,
);
if (lambda != null) {
const capturedCallees =
visitFunctionExpressionAndPropagateFireDependencies(
lambda,
context,
true,
);
const newInstrs = [];
for (const [
fireCalleePlace,
fireCalleeInfo,
] of capturedCallees.entries()) {
if (!context.hasCalleeWithInsertedFire(fireCalleePlace)) {
context.addCalleeWithInsertedFire(fireCalleePlace);
importedUseFire ??= fn.env.programContext.addImportSpecifier({
source: fn.env.programContext.reactRuntimeModule,
importSpecifierName: USE_FIRE_FUNCTION_NAME,
});
const loadUseFireInstr = makeLoadUseFireInstruction(
fn.env,
importedUseFire,
);
const loadFireCalleeInstr = makeLoadFireCalleeInstruction(
fn.env,
fireCalleeInfo.capturedCalleeIdentifier,
);
const callUseFireInstr = makeCallUseFireInstruction(
fn.env,
loadUseFireInstr.lvalue,
loadFireCalleeInstr.lvalue,
);
const storeUseFireInstr = makeStoreUseFireInstruction(
fn.env,
callUseFireInstr.lvalue,
fireCalleeInfo.fireFunctionBinding,
);
newInstrs.push(
loadUseFireInstr,
loadFireCalleeInstr,
callUseFireInstr,
storeUseFireInstr,
);
const loadUseEffectInstrId = context.getLoadGlobalInstrId(
value.callee.identifier.id,
);
if (loadUseEffectInstrId == null) {
context.pushError({
loc: value.loc,
description: null,
category: ErrorCategory.Invariant,
reason: '[InsertFire] No LoadGlobal found for useEffect call',
suggestions: null,
});
continue;
}
rewriteInstrs.set(loadUseEffectInstrId, newInstrs);
}
}
ensureNoRemainingCalleeCaptures(
lambda.loweredFunc.func,
context,
capturedCallees,
);
if (
value.args.length > 1 &&
value.args[1] != null &&
value.args[1].kind === 'Identifier'
) {
const depArray = value.args[1];
const depArrayExpression = context.getArrayExpression(
depArray.identifier.id,
);
if (depArrayExpression != null) {
for (const dependency of depArrayExpression.elements) {
if (dependency.kind === 'Identifier') {
const loadOfDependency = context.getLoadLocalInstr(
dependency.identifier.id,
);
if (loadOfDependency != null) {
const replacedDepArrayItem = capturedCallees.get(
loadOfDependency.place.identifier.id,
);
if (replacedDepArrayItem != null) {
loadOfDependency.place =
replacedDepArrayItem.fireFunctionBinding;
}
}
}
}
} else {
context.pushError({
loc: value.args[1].loc,
description:
'You must use an array literal for an effect dependency array when that effect uses `fire()`',
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
} else if (value.args.length > 1 && value.args[1].kind === 'Spread') {
context.pushError({
loc: value.args[1].place.loc,
description:
'You must use an array literal for an effect dependency array when that effect uses `fire()`',
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
}
} else if (
value.kind === 'CallExpression' &&
value.callee.identifier.type.kind === 'Function' &&
value.callee.identifier.type.shapeId === BuiltInFireId &&
context.inUseEffectLambda()
) {
if (value.args.length === 1 && value.args[0].kind === 'Identifier') {
const callExpr = context.getCallExpression(
value.args[0].identifier.id,
);
if (callExpr != null) {
const calleeId = callExpr.callee.identifier.id;
const loadLocal = context.getLoadLocalInstr(calleeId);
if (loadLocal == null) {
context.pushError({
loc: value.loc,
description: null,
category: ErrorCategory.Invariant,
reason:
'[InsertFire] No loadLocal found for fire call argument',
suggestions: null,
});
continue;
}
const fireFunctionBinding =
context.getOrGenerateFireFunctionBinding(
loadLocal.place,
value.loc,
);
loadLocal.place = {...fireFunctionBinding};
deleteInstrs.add(instr.id);
} else {
context.pushError({
loc: value.loc,
description:
'`fire()` can only receive a function call such as `fire(fn(a,b)). Method calls and other expressions are not allowed',
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
} else {
let description: string =
'fire() can only take in a single call expression as an argument';
if (value.args.length === 0) {
description += ' but received none';
} else if (value.args.length > 1) {
description += ' but received multiple arguments';
} else if (value.args[0].kind === 'Spread') {
description += ' but received a spread argument';
}
context.pushError({
loc: value.loc,
description,
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
} else if (value.kind === 'CallExpression') {
context.addCallExpression(lvalue.identifier.id, value);
} else if (
value.kind === 'FunctionExpression' &&
context.inUseEffectLambda()
) {
visitFunctionExpressionAndPropagateFireDependencies(
value,
context,
false,
);
} else if (value.kind === 'FunctionExpression') {
context.addFunctionExpression(lvalue.identifier.id, value);
} else if (value.kind === 'LoadLocal') {
context.addLoadLocalInstr(lvalue.identifier.id, value);
} else if (
value.kind === 'LoadGlobal' &&
value.binding.kind === 'ImportSpecifier' &&
value.binding.module === 'react' &&
value.binding.imported === 'fire' &&
context.inUseEffectLambda()
) {
deleteInstrs.add(instr.id);
} else if (value.kind === 'LoadGlobal') {
context.addLoadGlobalInstrId(lvalue.identifier.id, instr.id);
} else if (value.kind === 'ArrayExpression') {
context.addArrayExpression(lvalue.identifier.id, value);
}
}
block.instructions = rewriteInstructions(rewriteInstrs, block.instructions);
block.instructions = deleteInstructions(deleteInstrs, block.instructions);
if (rewriteInstrs.size > 0 || deleteInstrs.size > 0) {
hasRewrite = true;
fn.env.hasFireRewrite = true;
}
}
if (hasRewrite) {
markInstructionIds(fn.body);
}
}
function visitFunctionExpressionAndPropagateFireDependencies(
fnExpr: FunctionExpression,
context: Context,
enteringUseEffect: boolean,
): FireCalleesToFireFunctionBinding {
let withScope = enteringUseEffect
? context.withUseEffectLambdaScope.bind(context)
: context.withFunctionScope.bind(context);
const calleesCapturedByFnExpression = withScope(() =>
replaceFireFunctions(fnExpr.loweredFunc.func, context),
);
for (
let contextIdx = 0;
contextIdx < fnExpr.loweredFunc.func.context.length;
contextIdx++
) {
const contextItem = fnExpr.loweredFunc.func.context[contextIdx];
const replacedCallee = calleesCapturedByFnExpression.get(
contextItem.identifier.id,
);
if (replacedCallee != null) {
fnExpr.loweredFunc.func.context[contextIdx] = {
...replacedCallee.fireFunctionBinding,
};
}
}
context.mergeCalleesFromInnerScope(calleesCapturedByFnExpression);
return calleesCapturedByFnExpression;
}
function* eachReachablePlace(fn: HIRFunction): Iterable<Place> {
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
if (
instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod'
) {
yield* eachReachablePlace(instr.value.loweredFunc.func);
} else {
yield* eachInstructionOperand(instr);
}
}
}
}
function ensureNoRemainingCalleeCaptures(
fn: HIRFunction,
context: Context,
capturedCallees: FireCalleesToFireFunctionBinding,
): void {
for (const place of eachReachablePlace(fn)) {
const calleeInfo = capturedCallees.get(place.identifier.id);
if (calleeInfo != null) {
const calleeName =
calleeInfo.capturedCalleeIdentifier.name?.kind === 'named'
? calleeInfo.capturedCalleeIdentifier.name.value
: '<unknown>';
context.pushError({
loc: place.loc,
description: `All uses of ${calleeName} must be either used with a fire() call in \
this effect or not used with a fire() call at all. ${calleeName} was used with fire() on line \
${printSourceLocationLine(calleeInfo.fireLoc)} in this effect`,
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
}
}
function ensureNoMoreFireUses(fn: HIRFunction, context: Context): void {
for (const place of eachReachablePlace(fn)) {
if (
place.identifier.type.kind === 'Function' &&
place.identifier.type.shapeId === BuiltInFireId
) {
context.pushError({
loc: place.identifier.loc,
description: 'Cannot use `fire` outside of a useEffect function',
category: ErrorCategory.Fire,
reason: CANNOT_COMPILE_FIRE,
suggestions: null,
});
}
}
}
function makeLoadUseFireInstruction(
env: Environment,
importedLoadUseFire: NonLocalImportSpecifier,
): Instruction {
const useFirePlace = createTemporaryPlace(env, GeneratedSource);
useFirePlace.effect = Effect.Read;
useFirePlace.identifier.type = DefaultNonmutatingHook;
const instrValue: InstructionValue = {
kind: 'LoadGlobal',
binding: {...importedLoadUseFire},
loc: GeneratedSource,
};
return {
id: makeInstructionId(0),
value: instrValue,
lvalue: {...useFirePlace},
loc: GeneratedSource,
effects: null,
};
}
function makeLoadFireCalleeInstruction(
env: Environment,
fireCalleeIdentifier: Identifier,
): Instruction {
const loadedFireCallee = createTemporaryPlace(env, GeneratedSource);
const fireCallee: Place = {
kind: 'Identifier',
identifier: fireCalleeIdentifier,
reactive: false,
effect: Effect.Unknown,
loc: fireCalleeIdentifier.loc,
};
return {
id: makeInstructionId(0),
value: {
kind: 'LoadLocal',
loc: GeneratedSource,
place: {...fireCallee},
},
lvalue: {...loadedFireCallee},
loc: GeneratedSource,
effects: null,
};
}
function makeCallUseFireInstruction(
env: Environment,
useFirePlace: Place,
argPlace: Place,
): Instruction {
const useFireCallResultPlace = createTemporaryPlace(env, GeneratedSource);
useFireCallResultPlace.effect = Effect.Read;
const useFireCall: CallExpression = {
kind: 'CallExpression',
callee: {...useFirePlace},
args: [argPlace],
loc: GeneratedSource,
};
return {
id: makeInstructionId(0),
value: useFireCall,
lvalue: {...useFireCallResultPlace},
loc: GeneratedSource,
effects: null,
};
}
function makeStoreUseFireInstruction(
env: Environment,
useFireCallResultPlace: Place,
fireFunctionBindingPlace: Place,
): Instruction {
promoteTemporary(fireFunctionBindingPlace.identifier);
const fireFunctionBindingLValuePlace = createTemporaryPlace(
env,
GeneratedSource,
);
return {
id: makeInstructionId(0),
value: {
kind: 'StoreLocal',
lvalue: {
kind: InstructionKind.Const,
place: {...fireFunctionBindingPlace},
},
value: {...useFireCallResultPlace},
type: null,
loc: GeneratedSource,
},
lvalue: fireFunctionBindingLValuePlace,
loc: GeneratedSource,
effects: null,
};
}
type FireCalleesToFireFunctionBinding = Map<
IdentifierId,
{
fireFunctionBinding: Place;
capturedCalleeIdentifier: Identifier;
fireLoc: SourceLocation;
}
>;
class Context {
#env: Environment;
#errors: CompilerError = new CompilerError();
#callExpressions = new Map<IdentifierId, CallExpression>();
#functionExpressions = new Map<IdentifierId, FunctionExpression>();
#loadLocals = new Map<IdentifierId, LoadLocal>();
#fireCalleesToFireFunctions: Map<IdentifierId, Place> = new Map();
#calleesWithInsertedFire = new Set<IdentifierId>();
#capturedCalleeIdentifierIds: FireCalleesToFireFunctionBinding = new Map();
#inUseEffectLambda = false;
#loadGlobalInstructionIds = new Map<IdentifierId, InstructionId>();
constructor(env: Environment) {
this.#env = env;
}
#arrayExpressions = new Map<IdentifierId, ArrayExpression>();
pushError(error: CompilerErrorDetailOptions): void {
this.#errors.push(error);
}
withFunctionScope(fn: () => void): FireCalleesToFireFunctionBinding {
fn();
return this.#capturedCalleeIdentifierIds;
}
withUseEffectLambdaScope(fn: () => void): FireCalleesToFireFunctionBinding {
const capturedCalleeIdentifierIds = this.#capturedCalleeIdentifierIds;
const inUseEffectLambda = this.#inUseEffectLambda;
this.#capturedCalleeIdentifierIds = new Map();
this.#inUseEffectLambda = true;
const resultCapturedCalleeIdentifierIds = this.withFunctionScope(fn);
this.#capturedCalleeIdentifierIds = capturedCalleeIdentifierIds;
this.#inUseEffectLambda = inUseEffectLambda;
return resultCapturedCalleeIdentifierIds;
}
addCallExpression(id: IdentifierId, callExpr: CallExpression): void {
this.#callExpressions.set(id, callExpr);
}
getCallExpression(id: IdentifierId): CallExpression | undefined {
return this.#callExpressions.get(id);
}
addLoadLocalInstr(id: IdentifierId, loadLocal: LoadLocal): void {
this.#loadLocals.set(id, loadLocal);
}
getLoadLocalInstr(id: IdentifierId): LoadLocal | undefined {
return this.#loadLocals.get(id);
}
getOrGenerateFireFunctionBinding(
callee: Place,
fireLoc: SourceLocation,
): Place {
const fireFunctionBinding = getOrInsertWith(
this.#fireCalleesToFireFunctions,
callee.identifier.id,
() => createTemporaryPlace(this.#env, GeneratedSource),
);
fireFunctionBinding.identifier.type = {
kind: 'Function',
shapeId: BuiltInFireFunctionId,
return: {kind: 'Poly'},
isConstructor: false,
};
this.#capturedCalleeIdentifierIds.set(callee.identifier.id, {
fireFunctionBinding,
capturedCalleeIdentifier: callee.identifier,
fireLoc,
});
return fireFunctionBinding;
}
mergeCalleesFromInnerScope(
innerCallees: FireCalleesToFireFunctionBinding,
): void {
for (const [id, calleeInfo] of innerCallees.entries()) {
this.#capturedCalleeIdentifierIds.set(id, calleeInfo);
}
}
addCalleeWithInsertedFire(id: IdentifierId): void {
this.#calleesWithInsertedFire.add(id);
}
hasCalleeWithInsertedFire(id: IdentifierId): boolean {
return this.#calleesWithInsertedFire.has(id);
}
inUseEffectLambda(): boolean {
return this.#inUseEffectLambda;
}
addFunctionExpression(id: IdentifierId, fn: FunctionExpression): void {
this.#functionExpressions.set(id, fn);
}
getFunctionExpression(id: IdentifierId): FunctionExpression | undefined {
return this.#functionExpressions.get(id);
}
addLoadGlobalInstrId(id: IdentifierId, instrId: InstructionId): void {
this.#loadGlobalInstructionIds.set(id, instrId);
}
getLoadGlobalInstrId(id: IdentifierId): InstructionId | undefined {
return this.#loadGlobalInstructionIds.get(id);
}
addArrayExpression(id: IdentifierId, array: ArrayExpression): void {
this.#arrayExpressions.set(id, array);
}
getArrayExpression(id: IdentifierId): ArrayExpression | undefined {
return this.#arrayExpressions.get(id);
}
hasErrors(): boolean {
return this.#errors.hasAnyErrors();
}
throwIfErrorsFound(): void {
if (this.hasErrors()) throw this.#errors;
}
}
function deleteInstructions(
deleteInstrs: Set<InstructionId>,
instructions: Array<Instruction>,
): Array<Instruction> {
if (deleteInstrs.size > 0) {
const newInstrs = instructions.filter(instr => !deleteInstrs.has(instr.id));
return newInstrs;
}
return instructions;
}
function rewriteInstructions(
rewriteInstrs: Map<InstructionId, Array<Instruction>>,
instructions: Array<Instruction>,
): Array<Instruction> {
if (rewriteInstrs.size > 0) {
const newInstrs = [];
for (const instr of instructions) {
const newInstrsAtId = rewriteInstrs.get(instr.id);
if (newInstrsAtId != null) {
newInstrs.push(...newInstrsAtId, instr);
} else {
newInstrs.push(instr);
}
}
return newInstrs;
}
return instructions;
}