lyra-engine/lyra-reflect/lyra-reflect-derive/src/enum_derive.rs

608 lines
17 KiB
Rust
Raw Normal View History

2023-12-30 23:55:05 +00:00
use proc_macro2::Ident;
use quote::quote;
use syn::{token::Enum, DataEnum, DeriveInput, Generics, GenericParam, parse_quote, Variant};
use crate::add_trait_bounds;
static ASCII_LOWER: [char; 26] = [
'a', 'b', 'c', 'd', 'e',
'f', 'g', 'h', 'i', 'j',
'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y',
'z',
];
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
enum VariantType {
Unit,
Struct,
Tuple,
}
impl From<&Variant> for VariantType {
fn from(value: &Variant) -> Self {
match value.fields {
syn::Fields::Named(_) => VariantType::Struct,
syn::Fields::Unnamed(_) => VariantType::Tuple,
syn::Fields::Unit => VariantType::Unit,
}
}
}
/// Generates the following different outputs:
///
/// ```rust
/// // for struct variants
/// TestEnum::Error { msg, code }
///
/// // for tuple variants
/// TestEnum::Middle(a, b)
///
/// // for unit variants
/// TestEnum::Start
/// ```
fn gen_variant_full_id(enum_id: &proc_macro2::Ident, variant: &Variant) -> proc_macro2::TokenStream {
let var_ty = VariantType::from(variant);
let var_id = &variant.ident;
let fields = variant.fields.iter().enumerate().map(|(idx, field)| {
match var_ty {
VariantType::Unit => {
return quote! { };
},
VariantType::Struct => {
let id = field.ident.as_ref().unwrap();
return quote! { #id }
},
VariantType::Tuple => {
let id = Ident::new(ASCII_LOWER[idx].to_string().as_str(),
proc_macro2::Span::call_site());
return quote! { #id }
}
}
});
let fields = match var_ty {
VariantType::Unit => quote! {},
VariantType::Struct => {
quote!{ { #( #fields ),* } }
},
VariantType::Tuple => {
quote!{ ( #( #fields ),* ) }
},
};
quote! {
#enum_id::#var_id #fields
}
}
/// Generates an if statement to check if `self` is `variant`. The body of the if statement is replaced by `if_body`
fn gen_variant_if(enum_id: &proc_macro2::Ident, variant: &Variant, if_body: proc_macro2::TokenStream, prepend_else: bool) -> proc_macro2::TokenStream {
let variant_check = gen_variant_full_id(enum_id, variant);
let optional_else = if prepend_else {
quote! { else }
} else {
quote! {}
};
quote! {
#optional_else if let #variant_check = self {
#if_body
}
}
}
/// Generates the following:
///
/// ```rust
/// /// generated one field here
/// if name == "msg" {
/// return Some(msg);
/// }
///
/// /// another field here
/// if name == "code" {
/// return Some(code);
/// }
/// ```
/// Continues to generate if statements until the variant runs out of fields.
fn gen_if_field_names(variant: &Variant) -> proc_macro2::TokenStream {
let field_ifs = variant.fields.iter().map(|field| {
let id = field.ident.as_ref().unwrap();
let id_str = id.span().source_text().unwrap();
quote! {
if name == #id_str {
return Some(#id);
}
}
});
quote! {
#( #field_ifs )*
}
}
/// Generates the following rust code:
///
/// ```rust
/// match name {
/// "msg" | "code" => true,
/// _ => false,
/// }
/// ```
/// More strings may be added to the true match arm depending on the enum struct variant fields
fn gen_match_names(variant: &Variant) -> proc_macro2::TokenStream {
let field_name_strs = variant.fields.iter().map(|field| {
let id = field.ident.as_ref()
.expect("Could not find identifier for enum field!");
id.span().source_text().unwrap()
});
quote! {
return match name {
#( #field_name_strs )|* => true,
_ => false,
};
}
}
/// Generates the following:
///
/// ```rust
/// /// generated one field here
/// if idx == 0 {
/// return Some(a);
/// }
///
/// /// another field here
/// if idx == 1 {
/// return Some(b);
/// }
/// ```
/// Continues to generate if statements until the variant runs out of fields.
fn gen_if_field_indices(variant: &Variant) -> proc_macro2::TokenStream {
let vty = VariantType::from(variant);
let field_ifs = variant.fields.iter().enumerate()
.map(|(idx, field)| {
let id = if vty == VariantType::Tuple {
Ident::new(ASCII_LOWER[idx].to_string().as_str(),
proc_macro2::Span::call_site())
} else {
field.ident.clone().unwrap()
};
quote! {
if idx == #idx {
return Some(#id);
}
}
});
quote! {
#( #field_ifs )*
}
}
/// Generates the following:
///
/// ```rust
/// /// generated one field here
/// if idx == 0 {
/// return Some("a");
/// }
///
/// /// another field here
/// if idx == 1 {
/// return Some("b");
/// }
/// ```
/// Continues to generate if statements until the variant runs out of fields.
fn gen_if_field_indices_names(variant: &Variant) -> proc_macro2::TokenStream {
let vty = VariantType::from(variant);
let field_ifs = variant.fields.iter().enumerate()
.map(|(idx, field)| {
let id_str = if vty == VariantType::Tuple {
ASCII_LOWER[idx].to_string()
} else {
field.ident.clone().unwrap().to_string()
};
quote! {
if idx == #idx {
return Some(#id_str.to_string());
}
}
});
quote! {
#( #field_ifs )*
}
}
/// Generates the following:
/// ```rust
/// /// when `by_index` is false:
///
/// if let TestEnum::Error{ msg, code} = self {
/// if name == "msg" {
/// return Some(msg);
/// }
///
/// if name == "code" {
/// return Some(code);
/// }
/// }
///
/// /// when `by_index` is true:
///
/// if let TestEnum::Something(a, b) = self {
/// if idx == 0 {
/// return Some(a);
/// }
///
/// if idx == 1 {
/// return Some(b);
/// }
/// }
///
/// if let TestEnum::Error{ msg, code} = self {
/// if idx == 0 {
/// return Some(msg);
/// }
///
/// if idx == 1 {
/// return Some(code);
/// }
/// }
/// ```
/// And so on, for each variant that the enum provided has.
///
/// Parameters
/// * `by_index`: Should the if statements be generated to check indices.
fn gen_enum_if_stmts(enum_id: &proc_macro2::Ident, data: &DataEnum, by_index: bool) -> proc_macro2::TokenStream {
let mut if_statement_count = 0;
let struct_vars = data.variants.iter().enumerate().map(|(idx, var)| {
let vty = VariantType::from(var);
let prepend_else = if_statement_count > 0;
match vty {
VariantType::Struct => if by_index {
if_statement_count += 1;
let if_body = gen_if_field_indices(var);
gen_variant_if(enum_id, var, if_body, prepend_else)
} else {
if_statement_count += 1;
let if_body = gen_if_field_names(var);
gen_variant_if(enum_id, var, if_body, prepend_else)
},
VariantType::Tuple => if by_index {
if_statement_count += 1;
let if_body = gen_if_field_indices(var);
gen_variant_if(enum_id, var, if_body, prepend_else)
} else {
quote! { }
},
_ => quote! { },
}
});
println!("====");
quote! {
#( #struct_vars )*
}
}
/// Generates the following rust code:
///
/// ```rust
/// if let TestEnum::Error { msg, code } = self {
/// return match name {
/// // expands for continuing struct fields
/// "msg" | "code" => true,
/// _ => false,
/// };
/// }
/// ```
/// And so on until the enum runs out of struct variants.
fn gen_enum_has_field(enum_id: &proc_macro2::Ident, data: &DataEnum) -> proc_macro2::TokenStream {
let struct_vars = data.variants.iter().map(|var| {
let vty = VariantType::from(var);
match vty {
VariantType::Struct => {
let match_name = gen_match_names(var);
gen_variant_if(enum_id, var, match_name, false)
},
_ => quote! { },
}
});
quote! {
#( #struct_vars )*
}
}
/// Generates the following code:
///
/// ```rust
/// match self {
/// TestEnum::Start => 0,
/// TestEnum::Middle(a, b) => 2,
/// TestEnum::Error { msg, code } => 2,
/// }
/// ```
/// and so on for each variant of the enum
fn gen_enum_fields_len(enum_id: &proc_macro2::Ident, data: &DataEnum) -> proc_macro2::TokenStream {
let variant_arms = data.variants.iter().map(|var| {
let variant_ident = gen_variant_full_id(enum_id, var);
let field_len = var.fields.len();
quote! {
#variant_ident => #field_len
}
});
quote! {
match self {
#( #variant_arms ),*
}
}
}
/// Generates the following code:
///
/// ```rust
/// if let TestEnum::Error { msg, code } = self {
/// if idx == 0 {
/// return Some("msg");
/// }
/// if idx == 1 {
/// return Some("code");
/// }
/// }
/// ```
/// and so on for each struct variant of the enum. The inner if statements expand for each
/// field of the variant.
fn gen_enum_field_name_at(enum_id: &proc_macro2::Ident, data: &DataEnum) -> proc_macro2::TokenStream {
let variant_ifs = data.variants.iter().map(|var| {
let vty = VariantType::from(var);
match vty {
VariantType::Struct => {
let match_name = gen_if_field_indices_names(var);
gen_variant_if(enum_id, var, match_name, false)
},
_ => quote! { },
}
});
quote! {
#( #variant_ifs )*
}
}
/// Generates the following code:
/// ```rust
/// match self {
/// TestEnum::Start => 0,
/// TestEnum::Middle(a, b) => 1,
/// TestEnum::Error { msg, code } => 2,
/// }
/// ```
/// The match arms will expand for each variant the enum has.
fn gen_enum_variant_name(enum_id: &proc_macro2::Ident, data: &DataEnum, gen_index: bool) -> proc_macro2::TokenStream {
let variant_arms = data.variants.iter().enumerate().map(|(idx, var)| {
let variant_ident = gen_variant_full_id(enum_id, var);
let var_name = var.ident.to_string();
let arm_result = if gen_index {
quote! {
#idx
}
} else {
quote! {
#var_name.to_string()
}
};
quote! {
#variant_ident => #arm_result
}
});
quote! {
match self {
#( #variant_arms ),*
}
}
}
/// Generates a match statement that returns the types of the variants of the enum.
///
/// Example:
/// ```rust
/// match self {
/// TestEnum::Start => EnumType::Unit,
/// TestEnum::Middle(a, b) => EnumType::Tuple,
/// TestEnum::Error { msg, code } => EnumType::Struct,
/// }
/// ```
/// Match arms will be added for each variant of the enum.
fn gen_enum_variant_type(enum_id: &proc_macro2::Ident, data: &DataEnum) -> proc_macro2::TokenStream {
let variant_arms = data.variants.iter().map(|var| {
let variant_ident = gen_variant_full_id(enum_id, var);
let vty = VariantType::from(var);
match vty {
VariantType::Struct => quote! { #variant_ident => EnumType::Struct },
VariantType::Tuple => quote! { #variant_ident => EnumType::Tuple },
VariantType::Unit => quote! { #variant_ident => EnumType::Unit },
}
});
quote! {
match self {
#( #variant_arms ),*
}
}
}
/// Create a reflect implementation for an enum
pub fn derive_reflect_enum(input: &DeriveInput, data_enum: &DataEnum) -> proc_macro2::TokenStream {
let type_path = &input.ident;
let name = type_path.span().source_text().unwrap();
//println!("Got type path: {}", type_path);
let variant_count = data_enum.variants.len();
/* let mut variants_iter = data_enum.variants.iter();
let variant = variants_iter.next().unwrap();
let variant_name = &variant.ident; */
let field_ifs = gen_enum_if_stmts(type_path, data_enum, false);
let field_mut_ifs = gen_enum_if_stmts(type_path, data_enum, false);
let field_at_ifs = gen_enum_if_stmts(type_path, data_enum, true);
let field_at_mut_ifs = gen_enum_if_stmts(type_path, data_enum, true);
let has_field = gen_enum_has_field(type_path, data_enum);
let field_len = gen_enum_fields_len(type_path, data_enum);
let field_name_at = gen_enum_field_name_at(type_path, data_enum);
let variant_name_match = gen_enum_variant_name(type_path, data_enum, false);
let variant_idx_match = gen_enum_variant_name(type_path, data_enum, true);
let variant_type = gen_enum_variant_type(type_path, data_enum);
let generics = add_trait_bounds(input.generics.clone(), vec![parse_quote!(Reflect), parse_quote!(Clone)]);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
return proc_macro2::TokenStream::from(quote! {
impl #impl_generics lyra_engine::reflect::Reflect for #type_path #ty_generics #where_clause {
fn name(&self) -> ::std::string::String {
#name.to_string()
}
fn type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<#type_path #ty_generics>()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn apply(&mut self, val: &dyn lyra_engine::reflect::Reflect) {
let val = val.as_any().downcast_ref::<Self>()
.expect("The type of `val` is not the same as `self`");
*self = val.clone();
}
fn clone_inner(&self) -> Box<dyn lyra_engine::reflect::Reflect> {
Box::new(self.clone())
}
fn reflect_ref(&self) -> lyra_engine::reflect::ReflectRef {
lyra_engine::reflect::ReflectRef::Enum(self)
}
fn reflect_mut(&mut self) -> lyra_engine::reflect::ReflectMut {
lyra_engine::reflect::ReflectMut::Enum(self)
}
fn reflect_val(&self) -> &dyn lyra_engine::reflect::Reflect {
self
}
fn reflect_val_mut(&mut self) -> &mut dyn lyra_engine::reflect::Reflect {
self
}
}
impl #impl_generics lyra_engine::reflect::Enum for #type_path #ty_generics #where_clause {
fn field(&self, name: &str) -> Option<&dyn lyra_engine::reflect::Reflect> {
let name = name.to_lowercase();
let name = name.as_str();
#field_ifs
None
}
fn field_mut(&mut self, name: &str) -> Option<&mut dyn lyra_engine::reflect::Reflect> {
let name = name.to_lowercase();
let name = name.as_str();
#field_mut_ifs
None
}
fn field_at(&self, idx: usize) -> Option<&dyn lyra_engine::reflect::Reflect> {
#field_at_ifs
None
}
fn field_at_mut(&mut self, idx: usize) -> Option<&mut dyn lyra_engine::reflect::Reflect> {
#field_at_mut_ifs
None
}
fn has_field(&self, name: &str) -> bool {
let name = name.to_lowercase();
let name = name.as_str();
#has_field
false
}
fn fields_len(&self) -> usize {
#field_len
}
fn variants_len(&self) -> usize {
#variant_count
}
fn field_name_at(&self, idx: usize) -> Option<String> {
#field_name_at
None
}
fn variant_name(&self) -> String {
#variant_name_match
}
fn variant_index(&self) -> usize {
#variant_idx_match
}
fn variant_type(&self) -> lyra_engine::reflect::EnumType {
#variant_type
}
fn is_variant_name(&self, name: &str) -> bool {
let name = name.to_lowercase();
self.variant_name().to_lowercase() == name
}
}
})
}