import {NodePath} from '@babel/core';
import * as t from '@babel/types';
import {CompilerError} from '../CompilerError';
import {GeneratedSource} from '../HIR';
import {ProgramContext} from './Imports';
import {ExternalFunction} from '..';
function insertAdditionalFunctionDeclaration(
fnPath: NodePath<t.FunctionDeclaration>,
compiled: t.FunctionDeclaration,
programContext: ProgramContext,
gatingFunctionIdentifierName: string,
): void {
const originalFnName = fnPath.node.id;
const originalFnParams = fnPath.node.params;
const compiledParams = fnPath.node.params;
CompilerError.invariant(originalFnName != null && compiled.id != null, {
reason:
'Expected function declarations that are referenced elsewhere to have a named identifier',
loc: fnPath.node.loc ?? GeneratedSource,
});
CompilerError.invariant(originalFnParams.length === compiledParams.length, {
reason:
'Expected React Compiler optimized function declarations to have the same number of parameters as source',
loc: fnPath.node.loc ?? GeneratedSource,
});
const gatingCondition = t.identifier(
programContext.newUid(`${gatingFunctionIdentifierName}_result`),
);
const unoptimizedFnName = t.identifier(
programContext.newUid(`${originalFnName.name}_unoptimized`),
);
const optimizedFnName = t.identifier(
programContext.newUid(`${originalFnName.name}_optimized`),
);
compiled.id.name = optimizedFnName.name;
fnPath.get('id').replaceInline(unoptimizedFnName);
const newParams: Array<t.Identifier | t.RestElement> = [];
const genNewArgs: Array<() => t.Identifier | t.SpreadElement> = [];
for (let i = 0; i < originalFnParams.length; i++) {
const argName = `arg${i}`;
if (originalFnParams[i].type === 'RestElement') {
newParams.push(t.restElement(t.identifier(argName)));
genNewArgs.push(() => t.spreadElement(t.identifier(argName)));
} else {
newParams.push(t.identifier(argName));
genNewArgs.push(() => t.identifier(argName));
}
}
fnPath.insertAfter(
t.functionDeclaration(
originalFnName,
newParams,
t.blockStatement([
t.ifStatement(
gatingCondition,
t.returnStatement(
t.callExpression(
compiled.id,
genNewArgs.map(fn => fn()),
),
),
t.returnStatement(
t.callExpression(
unoptimizedFnName,
genNewArgs.map(fn => fn()),
),
),
),
]),
),
);
fnPath.insertBefore(
t.variableDeclaration('const', [
t.variableDeclarator(
gatingCondition,
t.callExpression(t.identifier(gatingFunctionIdentifierName), []),
),
]),
);
fnPath.insertBefore(compiled);
}
export function insertGatedFunctionDeclaration(
fnPath: NodePath<
t.FunctionDeclaration | t.ArrowFunctionExpression | t.FunctionExpression
>,
compiled:
| t.FunctionDeclaration
| t.ArrowFunctionExpression
| t.FunctionExpression,
programContext: ProgramContext,
gating: ExternalFunction,
referencedBeforeDeclaration: boolean,
): void {
const gatingImportedName = programContext.addImportSpecifier(gating).name;
if (referencedBeforeDeclaration && fnPath.isFunctionDeclaration()) {
CompilerError.invariant(compiled.type === 'FunctionDeclaration', {
reason: 'Expected compiled node type to match input type',
description: `Got ${compiled.type} but expected FunctionDeclaration`,
loc: fnPath.node.loc ?? GeneratedSource,
});
insertAdditionalFunctionDeclaration(
fnPath,
compiled,
programContext,
gatingImportedName,
);
} else {
const gatingExpression = t.conditionalExpression(
t.callExpression(t.identifier(gatingImportedName), []),
buildFunctionExpression(compiled),
buildFunctionExpression(fnPath.node),
);
if (
fnPath.parentPath.node.type !== 'ExportDefaultDeclaration' &&
fnPath.node.type === 'FunctionDeclaration' &&
fnPath.node.id != null
) {
fnPath.replaceWith(
t.variableDeclaration('const', [
t.variableDeclarator(fnPath.node.id, gatingExpression),
]),
);
} else if (
fnPath.parentPath.node.type === 'ExportDefaultDeclaration' &&
fnPath.node.type !== 'ArrowFunctionExpression' &&
fnPath.node.id != null
) {
fnPath.insertAfter(
t.exportDefaultDeclaration(t.identifier(fnPath.node.id.name)),
);
fnPath.parentPath.replaceWith(
t.variableDeclaration('const', [
t.variableDeclarator(
t.identifier(fnPath.node.id.name),
gatingExpression,
),
]),
);
} else {
fnPath.replaceWith(gatingExpression);
}
}
}
function buildFunctionExpression(
node:
| t.FunctionDeclaration
| t.ArrowFunctionExpression
| t.FunctionExpression,
): t.ArrowFunctionExpression | t.FunctionExpression {
if (
node.type === 'ArrowFunctionExpression' ||
node.type === 'FunctionExpression'
) {
return node;
} else {
const fn: t.FunctionExpression = {
type: 'FunctionExpression',
async: node.async,
generator: node.generator,
loc: node.loc ?? null,
id: node.id ?? null,
params: node.params,
body: node.body,
};
return fn;
}
}