use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, Data, DataEnum,
DeriveInput, Expr, ExprLit, ExprRange, Ident, Lit, RangeLimits, Result, Variant,
};
#[proc_macro_derive(ClassifyBytes, attributes(bytes, bytes_range, fallback))]
pub fn classify_bytes_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let Data::Enum(DataEnum { variants, .. }) = &ast.data else {
return syn::Error::new_spanned(
&ast.ident,
"ClassifyBytes can only be derived on an enum.",
)
.to_compile_error()
.into();
};
let enum_name = &ast.ident;
let mut byte_map: [Option<Ident>; 256] = [const { None }; 256];
let mut fallback_variant: Option<Ident> = None;
for variant in variants {
let variant_ident = &variant.ident;
if has_fallback_attr(variant) {
if fallback_variant.is_some() {
let err = syn::Error::new_spanned(
variant_ident,
"Multiple variants have #[fallback]. Only one allowed.",
);
return err.to_compile_error().into();
}
fallback_variant = Some(variant_ident.clone());
}
let single_bytes = get_bytes_attrs(&variant.attrs);
let range_bytes = get_bytes_range_attrs(&variant.attrs);
let all_bytes = single_bytes
.into_iter()
.chain(range_bytes)
.collect::<Vec<_>>();
for b in all_bytes {
byte_map[b as usize] = Some(variant_ident.clone());
}
}
let fallback_ident = fallback_variant.expect("A variant marked with #[fallback] is missing");
let fill = byte_map
.clone()
.into_iter()
.map(|variant_opt| match variant_opt {
Some(ident) => quote!(#enum_name::#ident),
None => quote!(#enum_name::#fallback_ident),
});
let expanded = quote! {
impl #enum_name {
pub const TABLE: [#enum_name; 256] = [
#(#fill),*
];
}
impl From<u8> for #enum_name {
fn from(byte: u8) -> Self {
#enum_name::TABLE[byte as usize]
}
}
impl From<&u8> for #enum_name {
fn from(byte: &u8) -> Self {
#enum_name::TABLE[*byte as usize]
}
}
};
TokenStream::from(expanded)
}
fn has_fallback_attr(variant: &Variant) -> bool {
variant
.attrs
.iter()
.any(|attr| attr.path().is_ident("fallback"))
}
fn get_bytes_attrs(attrs: &[Attribute]) -> Vec<u8> {
let mut assigned = Vec::new();
for attr in attrs {
if attr.path().is_ident("bytes") {
match parse_bytes_attr(attr) {
Ok(list) => assigned.extend(list),
Err(e) => panic!("Error parsing #[bytes(...)]: {}", e),
}
}
}
assigned
}
fn parse_bytes_attr(attr: &Attribute) -> Result<Vec<u8>> {
let items: Punctuated<Lit, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut out = Vec::new();
for lit in items {
match lit {
Lit::Byte(lb) => out.push(lb.value()),
_ => {
return Err(syn::Error::new_spanned(
lit,
"Expected a byte literal like b'a'",
))
}
}
}
Ok(out)
}
fn get_bytes_range_attrs(attrs: &[Attribute]) -> Vec<u8> {
let mut assigned = Vec::new();
for attr in attrs {
if attr.path().is_ident("bytes_range") {
match parse_bytes_range_attr(attr) {
Ok(list) => assigned.extend(list),
Err(e) => panic!("Error parsing #[bytes_range(...)]: {}", e),
}
}
}
assigned
}
fn parse_bytes_range_attr(attr: &Attribute) -> Result<Vec<u8>> {
let exprs: Punctuated<Expr, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut out = Vec::new();
for expr in exprs {
if let Expr::Range(ExprRange {
start: Some(start),
end: Some(end),
limits,
..
}) = expr
{
let from = extract_byte_literal(&start)?;
let to = extract_byte_literal(&end)?;
match limits {
RangeLimits::Closed(_) => {
if from <= to {
out.extend(from..=to);
}
}
RangeLimits::HalfOpen(_) => {
if from < to {
out.extend(from..to);
}
}
}
} else {
return Err(syn::Error::new_spanned(
expr,
"Expected a byte range like b'a'..=b'z'",
));
}
}
Ok(out)
}
fn extract_byte_literal(expr: &Expr) -> Result<u8> {
if let Expr::Lit(ExprLit { lit, .. }) = expr {
match lit {
Lit::Byte(lb) => Ok(lb.value()),
Lit::Int(li) => {
let value = li.base10_parse::<u64>()?;
if value <= 255 {
Ok(value as u8)
} else {
Err(syn::Error::new_spanned(
li,
format!("Integer literal {} out of range for a byte (0..255)", value),
))
}
}
_ => Err(syn::Error::new_spanned(
lit,
"Expected b'...' or an integer literal in range 0..=255",
)),
}
} else {
Err(syn::Error::new_spanned(
expr,
"Expected a literal expression like b'a' or 0x80",
))
}
}