From 6465ed2e0b0d008b9e6dcdadc1c9a46011d6498c Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Sat, 20 Jan 2024 11:46:42 -0500 Subject: [PATCH] scripting: support 'overloaded' like methods in wrapper macro --- .../lyra-scripting-derive/src/lib.rs | 382 ++++++++++-------- lyra-scripting/src/lua/wrappers/math.rs | 58 +-- 2 files changed, 225 insertions(+), 215 deletions(-) diff --git a/lyra-scripting/lyra-scripting-derive/src/lib.rs b/lyra-scripting/lyra-scripting-derive/src/lib.rs index ed6e7b1..beda599 100644 --- a/lyra-scripting/lyra-scripting-derive/src/lib.rs +++ b/lyra-scripting/lyra-scripting-derive/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{parse_macro_input, Path, Token, token, parenthesized, punctuated::Punctuated, braced, bracketed}; +use syn::{parse_macro_input, Path, Token, token, parenthesized, punctuated::Punctuated, braced}; mod mat_wrapper; use mat_wrapper::MatWrapper; @@ -31,6 +31,221 @@ impl syn::parse::Parse for MetaMethod { } } +impl MetaMethod { + /// Returns a boolean if an identifier is a lua wrapper, and therefore also userdata + fn is_arg_wrapper(ident: &Ident) -> bool { + let s = ident.to_string(); + s.starts_with("Lua") + } + + /// returns the tokens of the body of the metamethod + /// + /// Parameters + /// * `metamethod` - The ident of the metamethod that is being implemented. + /// * `other` - The tokens of the argument used in the metamethod. + fn get_method_body(metamethod: &Ident, other: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let mm_str = metamethod.to_string(); + let mm_str = mm_str.as_str(); + match mm_str { + "Add" | "Sub" | "Div" | "Mul" | "Mod" => { + let symbol = match mm_str { + "Add" => quote!(+), + "Sub" => quote!(-), + "Div" => quote!(/), + "Mul" => quote!(*), + "Mod" => quote!(%), + _ => unreachable!(), // the string was just checked to be one of these + }; + + quote! { + Ok(Self(this.0 #symbol #other)) + } + }, + "Unm" => { + quote! { + Ok(Self(-this.0)) + } + }, + "Eq" => { + quote! { + Ok(this.0 == #other) + } + }, + "Shl" => { + quote! { + Ok(Self(this.0 << #other)) + } + } + "Shr" => { + quote! { + Ok(Self(this.0 >> #other)) + } + }, + "BAnd" | "BOr" | "BXor" => { + let symbol = match mm_str { + "BAnd" => { + quote!(&) + }, + "BOr" => { + quote!(|) + }, + "BXor" => { + quote!(^) + }, + _ => unreachable!() // the string was just checked to be one of these + }; + + quote! { + Ok(Self(this.0 #symbol #other)) + } + }, + "BNot" => { + quote! { + Ok(Self(!this.0)) + } + }, + "ToString" => { + quote! { + Ok(format!("{:?}", this.0)) + } + }, + _ => syn::Error::new_spanned(metamethod, + "unsupported auto implementation of metamethod").to_compile_error(), + } + } + + fn get_body_for_arg(mt_ident: &Ident, arg_ident: &Ident, arg_param: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let other: proc_macro2::TokenStream = if Self::is_arg_wrapper(arg_ident) { + // Lua wrappers must be dereferenced + quote! { + #arg_param.0 + } + } else { + quote! { + #arg_param + } + }; + Self::get_method_body(&mt_ident, other) + } + + pub fn to_tokens(&self, wrapper_ident: &Ident) -> proc_macro2::TokenStream { + let wrapped_str = &wrapper_ident.to_string()[3..]; // removes starting 'Lua' from name + let mt_ident = &self.ident; + let mt_lua_name = mt_ident.to_string().to_lowercase(); + + if self.mods.is_empty() { + let other = quote! { + v.0 + }; + let body = Self::get_method_body(&self.ident, other); + + quote! { + methods.add_meta_method(mlua::MetaMethod::#mt_ident, |_, this, (v,): (#wrapper_ident,)| { + #body + }); + } + } else if self.mods.len() == 1 { + let first = self.mods.iter().next().unwrap(); + let body = Self::get_body_for_arg(&self.ident, first, quote!(v)); + + quote! { + methods.add_meta_method(mlua::MetaMethod::#mt_ident, |_, this, (v,): (#first,)| { + #body + }); + } + } else { + // an optional match arm that matches mlua::Value:Number + let number_arm = { + let num_ident = self.mods.iter().find(|i| { + let is = i.to_string(); + let is = is.as_str(); + match is { + "u8" | "u16" | "u32" | "u64" | "u128" + | "i8" | "i16" | "i32" | "i64" | "i128" + | "f32" | "f64" => true, + _ => false, + } + }); + + if let Some(num_ident) = num_ident { + let body = Self::get_body_for_arg(&self.ident, num_ident, quote!(n as #num_ident)); + + quote! { + mlua::Value::Number(n) => { + #body + }, + } + } else { quote!() } + }; + + let userdata_arm = { + let wrappers: Vec<&Ident> = self.mods.iter() + .filter(|i| Self::is_arg_wrapper(i)) + .collect(); + + let if_statements = wrappers.iter().map(|i| { + let body = Self::get_method_body(&self.ident, quote!(other.0)); + + quote! { + if let Ok(other) = ud.borrow::<#i>() { + #body + } + } + + }); + + quote! { + mlua::Value::UserData(ud) => { + #(#if_statements else)* + // this is the body of the else statement + { + // try to get the name of the userdata for the error message + if let Ok(mt) = ud.get_metatable() { + if let Ok(name) = mt.get::("__name") { + return Err(mlua::Error::BadArgument { + to: Some(format!("{}.__{}", #wrapped_str, #mt_lua_name)), + pos: 2, + name: Some("rhs".to_string()), + cause: Arc::new(mlua::Error::RuntimeError( + format!("cannot multiply with unknown userdata named {}", name) + )) + }); + } + } + + Err(mlua::Error::BadArgument { + to: Some(format!("{}.__{}", #wrapped_str, #mt_lua_name)), + pos: 2, + name: Some("rhs".to_string()), + cause: Arc::new( + mlua::Error::runtime("cannot multiply with unknown userdata") + ) + }) + } + }, + } + }; + + quote! { + methods.add_meta_method(mlua::MetaMethod::#mt_ident, |_, this, (v,): (mlua::Value,)| { + match v { + #number_arm + #userdata_arm + _ => Err(mlua::Error::BadArgument { + to: Some(format!("{}.__{}", #wrapped_str, #mt_lua_name)), + pos: 2, + name: Some("rhs".to_string()), + cause: Arc::new( + mlua::Error::RuntimeError(format!("cannot multiply with {}", v.type_name())) + ) + }) + } + }); + } + } + } +} + pub(crate) struct VecWrapper { } @@ -341,6 +556,7 @@ pub fn wrap_math_vec_copy(input: proc_macro::TokenStream) -> proc_macro::TokenSt None } }; + // TODO: fix this so it doesn't cause a stack overflow /* let vec_wrapper_fields = vec_wrapper.as_ref().map(|vec| vec.to_field_tokens(&path, &wrapper_typename)); */ let vec_wrapper_fields: Option = None; @@ -389,169 +605,7 @@ pub fn wrap_math_vec_copy(input: proc_macro::TokenStream) -> proc_macro::TokenSt let meta_method_idents = { let idents = input.meta_method_idents.iter().map(|metamethod| { - let metamethod_ident = &metamethod.ident; - let mm_str = metamethod.ident.to_string(); - let mm_str = mm_str.as_str(); - match mm_str { - "Add" | "Sub" | "Div" | "Mul" | "Mod" => { - let symbol = match mm_str { - "Add" => quote!(+), - "Sub" => quote!(-), - "Div" => quote!(/), - "Mul" => quote!(*), - "Mod" => quote!(%), - _ => unreachable!(), - }; - - // create a temporary vec to chain with metamethod.mods. If no parameters - // were provided, add the wrapper to the list of parameters. - let t = if metamethod.mods.is_empty() { - vec![wrapper_typename.clone()] - } else { vec![] }; - - let mods = metamethod.mods.iter().chain(t.iter()).map(|param| { - let other = if param.to_string().starts_with("Lua") { - quote!(other.0) - } else { - quote!(other) - }; - - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, - |_, this, (other,): (#param,)| { - Ok(#wrapper_typename(this.0 #symbol #other)) - }); - } - }); - - quote! { - #(#mods)* - } - }, - "Unm" => { - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, |_, this, ()| { - Ok(#wrapper_typename(-this.0)) - }); - } - }, - // Eq meta method has a different implementation than the above methods. - "Eq" => { - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, - |_, this, (other,): (#wrapper_typename,)| { - Ok(this.0 == other.0) - }); - } - }, - "Shl" => { - // create a temporary vec to chain with metamethod.mods. If no parameters - // were provided, add the wrapper to the list of parameters. - let t = if metamethod.mods.is_empty() { - vec![wrapper_typename.clone()] - } else { vec![] }; - - let mods = metamethod.mods.iter().chain(t.iter()).map(|param| { - let other = if param.to_string().starts_with("Lua") { - quote!(other.0) - } else { - quote!(other) - }; - - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, - |_, this, (other,): (#param,)| { - Ok(#wrapper_typename(this.0 << #other)) - }); - } - }); - - quote! { - #(#mods)* - } - } - "Shr" => { - // create a temporary vec to chain with metamethod.mods. If no parameters - // were provided, add the wrapper to the list of parameters. - let t = if metamethod.mods.is_empty() { - vec![wrapper_typename.clone()] - } else { vec![] }; - - let mods = metamethod.mods.iter().chain(t.iter()).map(|param| { - let other = if param.to_string().starts_with("Lua") { - quote!(other.0) - } else { - quote!(other) - }; - - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, - |_, this, (other,): (#param,)| { - Ok(#wrapper_typename(this.0 >> #other)) - }); - } - }); - - quote! { - #(#mods)* - } - }, - "BAnd" | "BOr" | "BXor" => { - let symbol = match mm_str { - "BAnd" => { - quote!(&) - }, - "BOr" => { - quote!(|) - }, - "BXor" => { - quote!(^) - }, - _ => unreachable!() // the string was just checked to be one of these - }; - - // create a temporary vec to chain with metamethod.mods. If no parameters - // were provided, add the wrapper to the list of parameters. - let t = if metamethod.mods.is_empty() { - vec![wrapper_typename.clone()] - } else { vec![] }; - - let mods = metamethod.mods.iter().chain(t.iter()).map(|param| { - let other = if param.to_string().starts_with("Lua") { - quote!(other.0) - } else { - quote!(other) - }; - - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, - |_, this, (other,): (#param,)| { - Ok(#wrapper_typename(this.0 #symbol #other)) - }); - } - }); - - quote! { - #(#mods)* - } - }, - "BNot" => { - quote! { - methods.add_meta_method(mlua::MetaMethod::#metamethod_ident, |_, this, ()| { - Ok(#wrapper_typename(!this.0)) - }); - } - }, - "ToString" => { - quote! { - methods.add_meta_method(mlua::MetaMethod::ToString, |_, this, ()| { - Ok(format!("{:?}", this.0)) - }); - } - }, - _ => syn::Error::new_spanned(metamethod_ident, - "unsupported auto implementation of metamethod").to_compile_error(), - } + metamethod.to_tokens(&wrapper_typename) }); quote! { diff --git a/lyra-scripting/src/lua/wrappers/math.rs b/lyra-scripting/src/lua/wrappers/math.rs index a439fa1..58d8f4a 100644 --- a/lyra-scripting/src/lua/wrappers/math.rs +++ b/lyra-scripting/src/lua/wrappers/math.rs @@ -17,8 +17,7 @@ wrap_math_vec_copy!( Div(LuaVec2, f32), Mul(LuaVec2, f32), Mod(LuaVec2, f32), - Eq, - Unm + Eq, Unm, ToString ) ); wrap_math_vec_copy!( @@ -26,58 +25,15 @@ wrap_math_vec_copy!( derives(PartialEq), fields(x, y, z), metamethods( - Add(LuaVec3), + Add(LuaVec3, f32), Sub(LuaVec3, f32), Div(LuaVec3, f32), - //Mul(LuaVec3, f32), + Mul(LuaVec3, f32), Mod(LuaVec3, f32), - Eq, Unm, ToString, - ), - custom_methods { - methods.add_meta_method(mlua::MetaMethod::Mul, |_, this, (v,): (mlua::Value,)| { - match v { - mlua::Value::Number(n) => { - Ok(Self(this.0 * (n as f32))) - }, - mlua::Value::UserData(ud) => { - if let Ok(other_this) = ud.borrow::() { - Ok(Self(this.0 * other_this.0)) - } else { - if let Ok(mt) = ud.get_metatable() { - if let Ok(name) = mt.get::("__name") { - return Err(mlua::Error::BadArgument { - to: Some("LuaVec3.__mul".to_string()), - pos: 2, - name: Some("rhs".to_string()), - cause: Arc::new(mlua::Error::RuntimeError( - format!("cannot multiply with unknown userdata named {}", name) - )) - }); - } - } - - Err(mlua::Error::BadArgument { - to: Some("LuaVec3.__mul".to_string()), - pos: 2, - name: Some("rhs".to_string()), - cause: Arc::new( - mlua::Error::runtime("cannot multiply with unknown userdata") - ) - }) - } - }, - _ => Err(mlua::Error::BadArgument { - to: Some("LuaVec3.__mul".to_string()), - pos: 2, - name: Some("rhs".to_string()), - cause: Arc::new( - mlua::Error::RuntimeError(format!("cannot multiply with {}", v.type_name())) - ) - }) - } - }); - } + Eq, Unm, ToString + ) ); + /* wrap_math_vec_copy!( math::Vec3A, derives(PartialEq), @@ -492,7 +448,7 @@ wrap_math_vec_copy!( wrap_math_vec_copy!( math::Transform, - //derives(PartialEq), + derives(PartialEq), no_new, metamethods(ToString, Eq), custom_fields {