From 71199bc90513d0aa67ac6de6d329674ff88509fe Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Thu, 25 Jan 2024 23:47:30 -0500 Subject: [PATCH] Implement calling rust functions from lua, make the Table struct support metatables --- src/function.rs | 37 +++++++++----- src/guard.rs | 46 +++++++++++++++++ src/main.rs | 132 ++++++++++++++++++++++++++++++++++++++++++------ src/state.rs | 114 +++++++++++++++++++++++++++++++++++++++-- src/table.rs | 131 ++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 415 insertions(+), 45 deletions(-) create mode 100644 src/guard.rs diff --git a/src/function.rs b/src/function.rs index 1bdc7d4..d83eff1 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,6 +1,6 @@ use std::ffi::CStr; -use crate::{LuaRef, FromLuaStack, State, PushToLuaStack}; +use crate::{FromLuaStack, LuaRef, PushToLuaStack, StackGuard, State}; use mlua_sys as lua; @@ -25,12 +25,21 @@ impl<'a> PushToLuaStack for Function<'a> { } impl<'a> Function<'a> { + pub fn from_ref(state: &'a State, lref: LuaRef) -> Self { + Self { + state, + lref, + } + } + pub fn exec(&self, args: A) -> crate::Result where - A: FunctionArgs, - R: FunctionResults<'a>, + A: PushToLuaStackMulti<'a>, + R: FromLuaStackMulti<'a>, { unsafe { + let _g = StackGuard::new(self.state); + self.push_to_lua_stack(self.state)?; args.push_args_to_lua_stack(self.state)?; @@ -50,12 +59,12 @@ impl<'a> Function<'a> { } } -pub trait FunctionArgs { +pub trait PushToLuaStackMulti<'a> { fn len(&self) -> usize; fn push_args_to_lua_stack(&self, state: &State) -> crate::Result<()>; } -impl<'a, T> FunctionArgs for T +impl<'a, T> PushToLuaStackMulti<'a> for T where T: PushToLuaStack, { @@ -72,14 +81,14 @@ where } } -pub trait FunctionResults<'a>: Sized { +pub trait FromLuaStackMulti<'a>: Sized { fn len() -> usize; fn results_from_lua_stack(state: &'a State) -> crate::Result; } -impl<'a> FunctionResults<'a> for () { +impl<'a> FromLuaStackMulti<'a> for () { fn len() -> usize { - 0 + 1 } fn results_from_lua_stack(_state: &'a State) -> crate::Result { @@ -87,7 +96,7 @@ impl<'a> FunctionResults<'a> for () { } } -impl<'a, T: FromLuaStack<'a>> FunctionResults<'a> for T { +impl<'a, T: FromLuaStack<'a>> FromLuaStackMulti<'a> for T { fn len() -> usize { 1 } @@ -100,7 +109,7 @@ impl<'a, T: FromLuaStack<'a>> FunctionResults<'a> for T { macro_rules! impl_function_arg_tuple { ( $count: expr, $first: tt, $( $name: tt ),+ ) => ( #[allow(non_snake_case)] - impl<$first: PushToLuaStack, $($name: PushToLuaStack,)+> FunctionArgs for ($first, $($name,)+) { + impl<'a, $first: PushToLuaStack, $($name: PushToLuaStack,)+> PushToLuaStackMulti<'a> for ($first, $($name,)+) { fn len(&self) -> usize { // this will end up generating $count - 1 - 1 - 1... hopefully the compiler will // optimize that out @@ -119,7 +128,7 @@ macro_rules! impl_function_arg_tuple { } } - impl<'a, $first: FromLuaStack<'a>, $($name: FromLuaStack<'a>,)+> FunctionResults<'a> for ($first, $($name,)+) { + impl<'a, $first: FromLuaStack<'a>, $($name: FromLuaStack<'a>,)+> FromLuaStackMulti<'a> for ($first, $($name,)+) { fn len() -> usize { $count } @@ -132,10 +141,10 @@ macro_rules! impl_function_arg_tuple { impl_function_arg_tuple!( $count - 1, $( $name ),+ ); ); - // implements FunctionArgs and FunctionResults for a tuple with a single element + // implements PushToLuaStackMulti and FromLuaStackMulti for a tuple with a single element ( $count: expr, $only: tt ) => { #[allow(non_snake_case)] - impl<$only: PushToLuaStack> FunctionArgs for ($only,) { + impl<'a, $only: PushToLuaStack> PushToLuaStackMulti<'a> for ($only,) { fn len(&self) -> usize { 1 } @@ -151,7 +160,7 @@ macro_rules! impl_function_arg_tuple { } } - impl<'a, $only: FromLuaStack<'a>> FunctionResults<'a> for ($only,) { + impl<'a, $only: FromLuaStack<'a>> FromLuaStackMulti<'a> for ($only,) { fn len() -> usize { 1 } diff --git a/src/guard.rs b/src/guard.rs new file mode 100644 index 0000000..291bbf4 --- /dev/null +++ b/src/guard.rs @@ -0,0 +1,46 @@ +use crate::State; + +use mlua_sys as lua; + +/// A stack guard will protect the LuaStack from leaks. +/// +/// When its first created, it keeps note of how large the stack is and when its dropped, +/// it will pop all new values from the top of the stack. +pub struct StackGuard<'a> { + state: &'a State, + top: i32, +} + +impl<'a> Drop for StackGuard<'a> { + fn drop(&mut self) { + unsafe { + let s = self.state.state_ptr(); + let now_top = lua::lua_gettop(s); + + if now_top > self.top { + lua::lua_pop(s, now_top - self.top); + } + } + } +} + +impl<'a> StackGuard<'a> { + pub fn new(state: &'a State) -> Self { + let top = unsafe { lua::lua_gettop(state.state_ptr()) }; + + Self { + state, + top + } + } +} + +pub unsafe fn lua_error_guard(lua: &State, func: F) -> R +where + F: Fn() -> crate::Result +{ + match func() { + Ok(v) => v, + Err(e) => e.throw_lua(lua.state_ptr()) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index a6f8fc1..78ed43a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, ffi::CStr}; +use std::{ffi::CStr, str::Utf8Error, sync::Arc}; use lua::{lua_typename, lua_type}; use mlua_sys as lua; @@ -15,14 +15,73 @@ use function::*; pub mod value; use value::*; +pub mod guard; +use guard::*; + +/* struct RustFn { + +} */ + +/* struct RustFnUpvalue { + +} */ + +/// +pub fn ptr_to_string(ptr: *const i8) -> std::result::Result { + let c = unsafe { CStr::from_ptr(ptr) }; + let s= c.to_str()?; + Ok(s.to_string()) +} + + fn main() -> Result<()> { let lua = State::new(); lua.expose_libraries(&[StdLibrary::Debug]); + let globals = lua.globals()?; + + let a = |lua: &State, (num,): (i32,)| -> Result { + println!("Rust got number from lua: {}", num); + Ok(999) + }; + + let f = lua.create_function(a)?; + globals.set("native_test", f)?; + + //let tbl = lua.create_table()?; + + let vec2_add = lua.create_function(|lua: &State, (a, b): (Table, Table)| -> Result { + let ax: i32 = a.get("x")?; + let ay: i32 = a.get("y")?; + + let bx: i32 = b.get("x")?; + let by: i32 = b.get("y")?; + + let rx = ax + bx; + let ry = ay + by; + + let mt = lua.create_meta_table("Vec2")?; + mt.set("x", rx)?; + mt.set("y", ry)?; + Ok(mt) + })?; + + + let mt = lua.create_meta_table("Vec2")?; + mt.set("x", 50)?; + mt.set("y", 50)?; + mt.set_meta("__add", vec2_add)?; + globals.set("pos1", mt)?; + + let mt = lua.create_meta_table("Vec2")?; + mt.set("x", 25)?; + mt.set("y", 25)?; + globals.set("pos2", mt)?; + let tbl = lua.create_table()?; tbl.set("x", 10)?; - let globals = lua.globals()?; + //let globals = lua.globals()?; globals.set("X", tbl)?; lua.execute(r#" @@ -47,21 +106,28 @@ fn main() -> Result<()> { end end]]-- - print("x is " .. X.x) - - cool_num = 50 - - function say_number(num) - print("I'm lua and I said " .. num) - end - function multiply_print(a, b) - print(a .. " * " .. b .. " = " .. a * b) + print(a .. " * " .. b .. " = " .. a*b) end function multiply_ret(a, b) return a * b end + + function say_number(a) + print("Lua says " .. a) + end + + cool_num = 50 + + local res = native_test(50) + print("Lua got " .. res .. " back from rust!") + + print("Pos1 is (" .. pos1.x .. ", " .. pos1.y .. ")") + print("Pos2 is (" .. pos2.x .. ", " .. pos2.y .. ")") + + local add_pos = pos1 + pos2 + print("Pos1 + pos2 is (" .. add_pos.x .. ", " .. add_pos.y .. ")") "#).unwrap(); let num = globals.get::<_, i32>("cool_num")?; @@ -99,8 +165,10 @@ impl LuaRef { /// Creates a reference to what is at the top of the stack. pub unsafe fn from_stack(state: &State) -> Result { let s = state.state_ptr(); + let t = lua::lua_gettop(s); let r = lua::luaL_ref(s, lua::LUA_REGISTRYINDEX); + let t = lua::lua_gettop(s); if r == lua::LUA_REFNIL { Err(Error::Nil) } else { @@ -115,7 +183,13 @@ impl PushToLuaStack for LuaRef { unsafe { state.ensure_stack(1)?; - lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, *self.0 as i64); + let top = lua::lua_gettop(s); + let ty = lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, *self.0 as i64); + let new_top = lua::lua_gettop(s); + + if ty == lua::LUA_TNIL || ty == lua::LUA_TNONE || top == new_top { + return Err(Error::Nil); + } } Ok(()) @@ -124,16 +198,26 @@ impl PushToLuaStack for LuaRef { #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Lua runtime error: {0}")] /// An error returned from lua + #[error("Lua runtime error: {0}")] Runtime(String), - #[error("Ran out of memory when attempting to use `lua_checkstack`")] /// Ran into a not enough memory error when trying to grow the lua stack. + #[error("Ran out of memory when attempting to use `lua_checkstack`")] Oom, #[error("Ran into a nill value on the stack")] Nil, #[error("Unexpected type, expected {0} but got {1}")] - UnexpectedType(String, String) + UnexpectedType(String, String), + #[error("Bad argument provided to {func:?}! Argument #{arg_index} (name: {arg_name:?}), cause: {error}")] + BadArgument { + func: Option, + arg_index: i32, + arg_name: Option, + /// the error that describes what was wrong for this argument + error: Arc + }, + #[error("There is already a registry entry with the key {0}")] + RegistryConflict(String) } impl Error { @@ -144,6 +228,16 @@ impl Error { pub fn unexpected_type(expected: &str, got: &str) -> Self { Self::UnexpectedType(expected.to_string(), got.to_string()) } + + /// Throw the error in lua. + /// + /// This method never returns + pub unsafe fn throw_lua(self, lua: *mut lua::lua_State) -> ! { + let msg = format!("{}", self); + let msg_c = msg.as_ptr() as *const i8; + lua::luaL_error(lua, msg_c); + panic!("never gets here"); + } } /// A result for use with lua functions @@ -157,6 +251,12 @@ pub trait FromLuaStack<'a>: Sized { unsafe fn from_lua_stack(state: &'a State) -> Result; } +/* impl<'a> FromLuaStack<'a> for () { + unsafe fn from_lua_stack(state: &'a State) -> Result { + Ok(()) + } +} */ + /// Implements PushToLuaStack for a number macro_rules! impl_push_to_lua_stack_number { ($ty: ident) => { @@ -176,7 +276,7 @@ impl<'a> FromLuaStack<'a> for i32 { if lua::lua_isnumber(s, -1) == 1 { let v = lua::lua_tonumber(s, -1) as i32; - lua::lua_pop(s, -1); + lua::lua_pop(s, 1); Ok(v) } else { let lua_ty = lua_type(s, -1); diff --git a/src/state.rs b/src/state.rs index 0facc9c..08c1562 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,9 +1,9 @@ use core::ffi; -use std::{ptr::NonNull, ffi::{CString, CStr}}; +use std::{ffi::{CString, CStr}, mem, ptr::{self, NonNull}}; use mlua_sys as lua; -use crate::{Table, Result, Error, LuaRef}; +use crate::{lua_error_guard, Error, FromLuaStack, FromLuaStackMulti, Function, LuaRef, PushToLuaStack, PushToLuaStackMulti, Result, Table}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum StdLibrary { @@ -92,6 +92,12 @@ impl State { } } + pub fn from_ptr(lua: *mut lua::lua_State) -> Self { + Self { + lua: unsafe { NonNull::new_unchecked(lua) } + } + } + pub fn state_ptr(&self) -> *mut lua::lua_State { self.lua.as_ptr() } @@ -142,6 +148,10 @@ impl State { Table::new(self) } + pub fn create_meta_table<'a>(&'a self, name: &str) -> Result> { + Table::new_meta_table(self, name) + } + /// This called `lua_checkstack` and returns a result. /// /// The result will be Ok if the stack has space for `n` values. @@ -157,7 +167,105 @@ impl State { unsafe { let s = self.state_ptr(); lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, lua::LUA_RIDX_GLOBALS); - Table::with_ref(self, LuaRef::from_stack(self)?) + Table::with_ref(self, LuaRef::from_stack(self)?, None) } } + + pub(crate) unsafe fn print_stack(state: &State) { + let s = state.state_ptr(); + let t = lua::lua_gettop(s); + + for i in 1..(t+1) { + let ty = lua::lua_type(s, i); + let tyname = lua::lua_typename(s, ty); + let tyname = CStr::from_ptr(tyname); + let tyname = tyname.to_str().unwrap(); + println!("{}: {}", i, tyname); + } + } + + pub fn create_function<'a, A, R, F>(&'a self, f: F) -> Result> + where + A: FromLuaStackMulti<'a>, + R: PushToLuaStackMulti<'a>, + F: Fn(&'a State, A) -> Result + 'static, + { + unsafe extern "C-unwind" fn rust_closure(s: *mut lua::lua_State) -> i32 { + // ensure validity of data + let upv_idx = lua::lua_upvalueindex(1); + let ltype = lua::lua_type(s, upv_idx); + match ltype { + lua::LUA_TUSERDATA => { + let ud_ptr = lua::lua_touserdata(s, upv_idx); + + if ud_ptr.is_null() { + crate::Error::runtime("null upvalue provided to luacclosure") + .throw_lua(s); + } + + let data_ptr = ud_ptr as *mut ClosureData; + + let top = lua::lua_gettop(s); + let wrap = &(*data_ptr).wrapper_fn; + let s = (*data_ptr).state; + + wrap(s, top) + }, + _ => { + let name = CStr::from_ptr(lua::lua_typename(s, ltype)); + if let Ok(n_str) = name.to_str() { + crate::Error::Runtime(format!("unexpected type ({}) provided to luacclosure", n_str)) + .throw_lua(s); + } else { + crate::Error::runtime("unexpected type provided to luacclosure") + .throw_lua(s); + } + + } + } + } + + struct ClosureData<'a> { + wrapper_fn: Box i32>, + state: &'a State, + } + + let wrapper_fn = move |lua: &State, narg: i32| -> i32 { + unsafe { + let lua: &State = mem::transmute(lua); // transmute lifetimes + + if narg != A::len() as i32 { + Error::Runtime(format!("incorrect number of arguments provided to lua function, expected {}", A::len())) + .throw_lua(lua.state_ptr()); + } + + lua_error_guard(lua, || { + let args = A::results_from_lua_stack(lua)?; + + let r = f(lua, args)?; + r.push_args_to_lua_stack(lua)?; + Ok(r.len() as i32) + }) + } + }; + + let data = ClosureData { + wrapper_fn: Box::new(wrapper_fn), + state: self, + }; + + let s = self.state_ptr(); + + unsafe { + let ptr = lua::lua_newuserdata(s, mem::size_of::()); + let ptr = ptr.cast(); + ptr::write(ptr, data); + + lua::lua_pushcclosure(s, rust_closure, 1); + let lref = LuaRef::from_stack(self)?; + + Ok(Function::from_ref(self, lref)) + } + } + } \ No newline at end of file diff --git a/src/table.rs b/src/table.rs index e116d71..e050537 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,12 +1,20 @@ +use std::{ffi::CStr, ops::Deref, sync::Arc}; + use mlua_sys as lua; -use crate::{State, Result, PushToLuaStack, LuaRef, FromLuaStack}; +use crate::{FromLuaStack, LuaRef, PushToLuaStack, Result, StackGuard, State}; +pub(crate) struct MetaTableInfo { + name: Option, + lref: LuaRef, +} +#[derive(Clone)] pub struct Table<'a> { state: &'a State, lref: LuaRef, + meta: Option, // Some if this table is a metatable } impl<'a> Table<'a> { @@ -23,24 +31,62 @@ impl<'a> Table<'a> { Ok(Self { state, lref, + meta: None, + }) + } + + pub fn new_meta_table(state: &'a State, name: &str) -> Result { + let name_term = format!("{}\0", name); + let name_term_c = name_term.as_str().as_ptr() as *const i8; + let name_term_arc = Arc::new(name_term_c); + + let (lref, meta_ref) = unsafe { + let _g = StackGuard::new(state); + state.ensure_stack(2)?; + + let s = state.state_ptr(); + if lua::luaL_newmetatable(s, name_term_c) == 0 { + // lua::luaL_getmetatable does not return the type that was + // retrieved from the registry + lua::lua_pushstring(s, name_term_c); + let ty = lua::lua_rawget(s, lua::LUA_REGISTRYINDEX); + + if ty != lua::LUA_TTABLE { + return Err(crate::Error::RegistryConflict(name.to_string())); + } + } + let meta = LuaRef::from_stack(state)?; + + + let s = state.state_ptr(); + lua::lua_newtable(s); + (LuaRef::from_stack(state)?, meta) + }; + + Ok(Self { + state, + lref, + meta: Some(meta_ref) }) } /// Construct a table with a lua reference to one. - pub fn with_ref(state: &'a State, lua_ref: LuaRef) -> Result { + pub fn with_ref(state: &'a State, lua_ref: LuaRef, name: Option) -> Result { let s = state.state_ptr(); unsafe { + let _g = StackGuard::new(state); + lua_ref.push_to_lua_stack(state)?; if lua::lua_istable(s, -1) == 0 { panic!("Index is not a table") } - lua::lua_pop(s, -1); } - + Ok(Self { state, lref: lua_ref, + meta: None, }) } @@ -54,14 +100,24 @@ impl<'a> Table<'a> { { let s = self.state.state_ptr(); unsafe { - self.state.ensure_stack(3)?; - + let _g = StackGuard::new(self.state); + + if let Some(_) = &self.meta { + self.state.ensure_stack(4)?; + } else { + self.state.ensure_stack(3)?; + } + self.lref.push_to_lua_stack(self.state)?; key.push_to_lua_stack(self.state)?; val.push_to_lua_stack(self.state)?; lua::lua_settable(s, -3); - lua::lua_pop(self.state.state_ptr(), -1); + + if let Some(mt) = &self.meta { + mt.push_to_lua_stack(self.state)?; + lua::lua_setmetatable(s, -2); + } } Ok(()) @@ -77,14 +133,14 @@ impl<'a> Table<'a> { { let s = self.state.state_ptr(); unsafe { + self.state.ensure_stack(2)?; + let _g = StackGuard::new(self.state); + self.lref.push_to_lua_stack(self.state)?; key.push_to_lua_stack(self.state)?; lua::lua_gettable(s, -2); // table[key] is at top of stack - let top = lua::lua_gettop(s); let val = V::from_lua_stack(self.state)?; - let new_top = lua::lua_gettop(s); - debug_assert!(new_top < top, "V::from_lua_stack did not remove anything from the stack!"); Ok(val) } @@ -96,11 +152,13 @@ impl<'a> Table<'a> { pub fn len(&self) -> Result { let s = self.state.state_ptr(); unsafe { + self.state.ensure_stack(1)?; + let _g = StackGuard::new(self.state); + self.lref.push_to_lua_stack(self.state)?; lua::lua_len(s, -1); let len = lua::lua_tonumber(s, -1); - lua::lua_pop(self.state.state_ptr(), -1); Ok(len as u64) } @@ -115,6 +173,7 @@ impl<'a> Table<'a> { let s = self.state.state_ptr(); unsafe { self.state.ensure_stack(3)?; + let _g = StackGuard::new(self.state); self.lref.push_to_lua_stack(self.state)?; key.push_to_lua_stack(self.state)?; @@ -135,6 +194,9 @@ impl<'a> Table<'a> { { let s = self.state.state_ptr(); unsafe { + self.state.ensure_stack(2)?; + let _g = StackGuard::new(self.state); + self.lref.push_to_lua_stack(self.state)?; key.push_to_lua_stack(self.state)?; lua::lua_rawget(s, -2); // table[key] is at top of stack @@ -146,13 +208,39 @@ impl<'a> Table<'a> { pub fn raw_len(&self) -> Result { let s = self.state.state_ptr(); unsafe { + self.state.ensure_stack(1)?; + let _g = StackGuard::new(self.state); + self.lref.push_to_lua_stack(self.state)?; let len = lua::lua_rawlen(s, -1); - lua::lua_pop(s, -1); // pop table + lua::lua_pop(s, 1); // pop table Ok(len as u64) } } + + pub fn set_meta(&self, key: K, val: V) -> Result<()> + where + K: PushToLuaStack, + V: PushToLuaStack + { + let mt = self.meta.as_ref() + .expect("this table is not a meta table!"); + + unsafe { + let s = self.state.state_ptr(); + self.state.ensure_stack(3)?; + let _g = StackGuard::new(self.state); + + //lua::luaL_getmetatable(s, **cname); + mt.push_to_lua_stack(self.state)?; + key.push_to_lua_stack(self.state)?; + val.push_to_lua_stack(self.state)?; + lua::lua_settable(s, -3); + } + + Ok(()) + } } impl<'a> PushToLuaStack for Table<'a> { @@ -162,4 +250,23 @@ impl<'a> PushToLuaStack for Table<'a> { Ok(()) } +} + +impl<'a> FromLuaStack<'a> for Table<'a> { + unsafe fn from_lua_stack(state: &'a State) -> Result { + let s = state.state_ptr(); + + let ty = lua::lua_type(s, -1); + if ty == lua::LUA_TTABLE { + let t = Table::with_ref(state, LuaRef::from_stack(state)?, None); + + t + } else { + let tyname = lua::lua_typename(s, ty); + let cstr = CStr::from_ptr(tyname); + let s = cstr.to_str() + .expect("Lua type has invalid bytes!"); + Err(crate::Error::UnexpectedType("Table".to_string(), s.to_string())) + } + } } \ No newline at end of file