import {isValidIdentifier} from '@babel/types';
import {CompilerError} from '../CompilerError';
import {
GotoVariant,
HIRFunction,
IdentifierId,
Instruction,
InstructionValue,
LoadGlobal,
Phi,
Place,
Primitive,
assertConsistentIdentifiers,
assertTerminalSuccessorsExist,
makePropertyLiteral,
markInstructionIds,
markPredecessors,
mergeConsecutiveBlocks,
reversePostorderBlocks,
} from '../HIR';
import {
removeDeadDoWhileStatements,
removeUnnecessaryTryCatch,
removeUnreachableForUpdates,
} from '../HIR/HIRBuilder';
import {eliminateRedundantPhi} from '../SSA';
export function constantPropagation(fn: HIRFunction): void {
const constants: Constants = new Map();
constantPropagationImpl(fn, constants);
}
function constantPropagationImpl(fn: HIRFunction, constants: Constants): void {
while (true) {
const haveTerminalsChanged = applyConstantPropagation(fn, constants);
if (!haveTerminalsChanged) {
break;
}
reversePostorderBlocks(fn.body);
removeUnreachableForUpdates(fn.body);
removeDeadDoWhileStatements(fn.body);
removeUnnecessaryTryCatch(fn.body);
markInstructionIds(fn.body);
markPredecessors(fn.body);
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
for (const [predecessor] of phi.operands) {
if (!block.preds.has(predecessor)) {
phi.operands.delete(predecessor);
}
}
}
}
eliminateRedundantPhi(fn);
mergeConsecutiveBlocks(fn);
assertConsistentIdentifiers(fn);
assertTerminalSuccessorsExist(fn);
}
}
function applyConstantPropagation(
fn: HIRFunction,
constants: Constants,
): boolean {
let hasChanges = false;
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
let value = evaluatePhi(phi, constants);
if (value !== null) {
constants.set(phi.place.identifier.id, value);
}
}
for (let i = 0; i < block.instructions.length; i++) {
if (block.kind === 'sequence' && i === block.instructions.length - 1) {
continue;
}
const instr = block.instructions[i]!;
const value = evaluateInstruction(constants, instr);
if (value !== null) {
constants.set(instr.lvalue.identifier.id, value);
}
}
const terminal = block.terminal;
switch (terminal.kind) {
case 'if': {
const testValue = read(constants, terminal.test);
if (testValue !== null && testValue.kind === 'Primitive') {
hasChanges = true;
const targetBlockId = testValue.value
? terminal.consequent
: terminal.alternate;
block.terminal = {
kind: 'goto',
variant: GotoVariant.Break,
block: targetBlockId,
id: terminal.id,
loc: terminal.loc,
};
}
break;
}
default: {
}
}
}
return hasChanges;
}
function evaluatePhi(phi: Phi, constants: Constants): Constant | null {
let value: Constant | null = null;
for (const [, operand] of phi.operands) {
const operandValue = constants.get(operand.identifier.id) ?? null;
if (operandValue === null) {
return null;
}
if (value === null) {
value = operandValue;
continue;
}
if (operandValue.kind !== value.kind) {
return null;
}
switch (operandValue.kind) {
case 'Primitive': {
CompilerError.invariant(value.kind === 'Primitive', {
reason: 'value kind expected to be Primitive',
loc: null,
suggestions: null,
});
if (operandValue.value !== value.value) {
return null;
}
break;
}
case 'LoadGlobal': {
CompilerError.invariant(value.kind === 'LoadGlobal', {
reason: 'value kind expected to be LoadGlobal',
loc: null,
suggestions: null,
});
if (operandValue.binding.name !== value.binding.name) {
return null;
}
break;
}
default:
return null;
}
}
return value;
}
function evaluateInstruction(
constants: Constants,
instr: Instruction,
): Constant | null {
const value = instr.value;
switch (value.kind) {
case 'Primitive': {
return value;
}
case 'LoadGlobal': {
return value;
}
case 'ComputedLoad': {
const property = read(constants, value.property);
if (
property !== null &&
property.kind === 'Primitive' &&
((typeof property.value === 'string' &&
isValidIdentifier(property.value)) ||
typeof property.value === 'number')
) {
const nextValue: InstructionValue = {
kind: 'PropertyLoad',
loc: value.loc,
property: makePropertyLiteral(property.value),
object: value.object,
};
instr.value = nextValue;
}
return null;
}
case 'ComputedStore': {
const property = read(constants, value.property);
if (
property !== null &&
property.kind === 'Primitive' &&
((typeof property.value === 'string' &&
isValidIdentifier(property.value)) ||
typeof property.value === 'number')
) {
const nextValue: InstructionValue = {
kind: 'PropertyStore',
loc: value.loc,
property: makePropertyLiteral(property.value),
object: value.object,
value: value.value,
};
instr.value = nextValue;
}
return null;
}
case 'PostfixUpdate': {
const previous = read(constants, value.value);
if (
previous !== null &&
previous.kind === 'Primitive' &&
typeof previous.value === 'number'
) {
const next =
value.operation === '++' ? previous.value + 1 : previous.value - 1;
constants.set(value.lvalue.identifier.id, {
kind: 'Primitive',
value: next,
loc: value.loc,
});
return previous;
}
return null;
}
case 'PrefixUpdate': {
const previous = read(constants, value.value);
if (
previous !== null &&
previous.kind === 'Primitive' &&
typeof previous.value === 'number'
) {
const next: Primitive = {
kind: 'Primitive',
value:
value.operation === '++' ? previous.value + 1 : previous.value - 1,
loc: value.loc,
};
constants.set(value.lvalue.identifier.id, next);
return next;
}
return null;
}
case 'UnaryExpression': {
switch (value.operator) {
case '!': {
const operand = read(constants, value.value);
if (operand !== null && operand.kind === 'Primitive') {
const result: Primitive = {
kind: 'Primitive',
value: !operand.value,
loc: value.loc,
};
instr.value = result;
return result;
}
return null;
}
case '-': {
const operand = read(constants, value.value);
if (
operand !== null &&
operand.kind === 'Primitive' &&
typeof operand.value === 'number'
) {
const result: Primitive = {
kind: 'Primitive',
value: operand.value * -1,
loc: value.loc,
};
instr.value = result;
return result;
}
return null;
}
default:
return null;
}
}
case 'BinaryExpression': {
const lhsValue = read(constants, value.left);
const rhsValue = read(constants, value.right);
if (
lhsValue !== null &&
rhsValue !== null &&
lhsValue.kind === 'Primitive' &&
rhsValue.kind === 'Primitive'
) {
const lhs = lhsValue.value;
const rhs = rhsValue.value;
let result: Primitive | null = null;
switch (value.operator) {
case '+': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs + rhs, loc: value.loc};
} else if (typeof lhs === 'string' && typeof rhs === 'string') {
result = {kind: 'Primitive', value: lhs + rhs, loc: value.loc};
}
break;
}
case '-': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs - rhs, loc: value.loc};
}
break;
}
case '*': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs * rhs, loc: value.loc};
}
break;
}
case '/': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs / rhs, loc: value.loc};
}
break;
}
case '|': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs | rhs, loc: value.loc};
}
break;
}
case '&': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs & rhs, loc: value.loc};
}
break;
}
case '^': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs ^ rhs, loc: value.loc};
}
break;
}
case '<<': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs << rhs, loc: value.loc};
}
break;
}
case '>>': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs >> rhs, loc: value.loc};
}
break;
}
case '>>>': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {
kind: 'Primitive',
value: lhs >>> rhs,
loc: value.loc,
};
}
break;
}
case '%': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs % rhs, loc: value.loc};
}
break;
}
case '**': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs ** rhs, loc: value.loc};
}
break;
}
case '<': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs < rhs, loc: value.loc};
}
break;
}
case '<=': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs <= rhs, loc: value.loc};
}
break;
}
case '>': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs > rhs, loc: value.loc};
}
break;
}
case '>=': {
if (typeof lhs === 'number' && typeof rhs === 'number') {
result = {kind: 'Primitive', value: lhs >= rhs, loc: value.loc};
}
break;
}
case '==': {
result = {kind: 'Primitive', value: lhs == rhs, loc: value.loc};
break;
}
case '===': {
result = {kind: 'Primitive', value: lhs === rhs, loc: value.loc};
break;
}
case '!=': {
result = {kind: 'Primitive', value: lhs != rhs, loc: value.loc};
break;
}
case '!==': {
result = {kind: 'Primitive', value: lhs !== rhs, loc: value.loc};
break;
}
default: {
break;
}
}
if (result !== null) {
instr.value = result;
return result;
}
}
return null;
}
case 'PropertyLoad': {
const objectValue = read(constants, value.object);
if (objectValue !== null) {
if (
objectValue.kind === 'Primitive' &&
typeof objectValue.value === 'string' &&
value.property === 'length'
) {
const result: InstructionValue = {
kind: 'Primitive',
value: objectValue.value.length,
loc: value.loc,
};
instr.value = result;
return result;
}
}
return null;
}
case 'TemplateLiteral': {
if (value.subexprs.length === 0) {
const result: InstructionValue = {
kind: 'Primitive',
value: value.quasis.map(q => q.cooked).join(''),
loc: value.loc,
};
instr.value = result;
return result;
}
if (value.subexprs.length !== value.quasis.length - 1) {
return null;
}
if (value.quasis.some(q => q.cooked === undefined)) {
return null;
}
let quasiIndex = 0;
let resultString = value.quasis[quasiIndex].cooked as string;
++quasiIndex;
for (const subExpr of value.subexprs) {
const subExprValue = read(constants, subExpr);
if (!subExprValue || subExprValue.kind !== 'Primitive') {
return null;
}
const expressionValue = subExprValue.value;
if (
typeof expressionValue !== 'number' &&
typeof expressionValue !== 'string' &&
typeof expressionValue !== 'boolean' &&
!(typeof expressionValue === 'object' && expressionValue === null)
) {
return null;
}
const suffix = value.quasis[quasiIndex].cooked;
++quasiIndex;
if (suffix === undefined) {
return null;
}
resultString = resultString.concat(expressionValue as string, suffix);
}
const result: InstructionValue = {
kind: 'Primitive',
value: resultString,
loc: value.loc,
};
instr.value = result;
return result;
}
case 'LoadLocal': {
const placeValue = read(constants, value.place);
if (placeValue !== null) {
instr.value = placeValue;
}
return placeValue;
}
case 'StoreLocal': {
const placeValue = read(constants, value.value);
if (placeValue !== null) {
constants.set(value.lvalue.place.identifier.id, placeValue);
}
return placeValue;
}
case 'ObjectMethod':
case 'FunctionExpression': {
constantPropagationImpl(value.loweredFunc.func, constants);
return null;
}
default: {
return null;
}
}
}
function read(constants: Constants, place: Place): Constant | null {
return constants.get(place.identifier.id) ?? null;
}
type Constant = Primitive | LoadGlobal;
type Constants = Map<IdentifierId, Constant>;