Implement calling rust functions from lua, make the Table struct support metatables

This commit is contained in:
SeanOMik 2024-01-25 23:47:30 -05:00
parent 19ee453172
commit 71199bc905
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
5 changed files with 415 additions and 45 deletions

View File

@ -1,6 +1,6 @@
use std::ffi::CStr;
use crate::{LuaRef, FromLuaStack, State, PushToLuaStack};
use crate::{FromLuaStack, LuaRef, PushToLuaStack, StackGuard, State};
use mlua_sys as lua;
@ -25,12 +25,21 @@ impl<'a> PushToLuaStack for Function<'a> {
}
impl<'a> Function<'a> {
pub fn from_ref(state: &'a State, lref: LuaRef) -> Self {
Self {
state,
lref,
}
}
pub fn exec<A, R>(&self, args: A) -> crate::Result<R>
where
A: FunctionArgs,
R: FunctionResults<'a>,
A: PushToLuaStackMulti<'a>,
R: FromLuaStackMulti<'a>,
{
unsafe {
let _g = StackGuard::new(self.state);
self.push_to_lua_stack(self.state)?;
args.push_args_to_lua_stack(self.state)?;
@ -50,12 +59,12 @@ impl<'a> Function<'a> {
}
}
pub trait FunctionArgs {
pub trait PushToLuaStackMulti<'a> {
fn len(&self) -> usize;
fn push_args_to_lua_stack(&self, state: &State) -> crate::Result<()>;
}
impl<'a, T> FunctionArgs for T
impl<'a, T> PushToLuaStackMulti<'a> for T
where
T: PushToLuaStack,
{
@ -72,14 +81,14 @@ where
}
}
pub trait FunctionResults<'a>: Sized {
pub trait FromLuaStackMulti<'a>: Sized {
fn len() -> usize;
fn results_from_lua_stack(state: &'a State) -> crate::Result<Self>;
}
impl<'a> FunctionResults<'a> for () {
impl<'a> FromLuaStackMulti<'a> for () {
fn len() -> usize {
0
1
}
fn results_from_lua_stack(_state: &'a State) -> crate::Result<Self> {
@ -87,7 +96,7 @@ impl<'a> FunctionResults<'a> for () {
}
}
impl<'a, T: FromLuaStack<'a>> FunctionResults<'a> for T {
impl<'a, T: FromLuaStack<'a>> FromLuaStackMulti<'a> for T {
fn len() -> usize {
1
}
@ -100,7 +109,7 @@ impl<'a, T: FromLuaStack<'a>> FunctionResults<'a> for T {
macro_rules! impl_function_arg_tuple {
( $count: expr, $first: tt, $( $name: tt ),+ ) => (
#[allow(non_snake_case)]
impl<$first: PushToLuaStack, $($name: PushToLuaStack,)+> FunctionArgs for ($first, $($name,)+) {
impl<'a, $first: PushToLuaStack, $($name: PushToLuaStack,)+> PushToLuaStackMulti<'a> for ($first, $($name,)+) {
fn len(&self) -> usize {
// this will end up generating $count - 1 - 1 - 1... hopefully the compiler will
// optimize that out
@ -119,7 +128,7 @@ macro_rules! impl_function_arg_tuple {
}
}
impl<'a, $first: FromLuaStack<'a>, $($name: FromLuaStack<'a>,)+> FunctionResults<'a> for ($first, $($name,)+) {
impl<'a, $first: FromLuaStack<'a>, $($name: FromLuaStack<'a>,)+> FromLuaStackMulti<'a> for ($first, $($name,)+) {
fn len() -> usize {
$count
}
@ -132,10 +141,10 @@ macro_rules! impl_function_arg_tuple {
impl_function_arg_tuple!( $count - 1, $( $name ),+ );
);
// implements FunctionArgs and FunctionResults for a tuple with a single element
// implements PushToLuaStackMulti and FromLuaStackMulti for a tuple with a single element
( $count: expr, $only: tt ) => {
#[allow(non_snake_case)]
impl<$only: PushToLuaStack> FunctionArgs for ($only,) {
impl<'a, $only: PushToLuaStack> PushToLuaStackMulti<'a> for ($only,) {
fn len(&self) -> usize {
1
}
@ -151,7 +160,7 @@ macro_rules! impl_function_arg_tuple {
}
}
impl<'a, $only: FromLuaStack<'a>> FunctionResults<'a> for ($only,) {
impl<'a, $only: FromLuaStack<'a>> FromLuaStackMulti<'a> for ($only,) {
fn len() -> usize {
1
}

46
src/guard.rs Normal file
View File

@ -0,0 +1,46 @@
use crate::State;
use mlua_sys as lua;
/// A stack guard will protect the LuaStack from leaks.
///
/// When its first created, it keeps note of how large the stack is and when its dropped,
/// it will pop all new values from the top of the stack.
pub struct StackGuard<'a> {
state: &'a State,
top: i32,
}
impl<'a> Drop for StackGuard<'a> {
fn drop(&mut self) {
unsafe {
let s = self.state.state_ptr();
let now_top = lua::lua_gettop(s);
if now_top > self.top {
lua::lua_pop(s, now_top - self.top);
}
}
}
}
impl<'a> StackGuard<'a> {
pub fn new(state: &'a State) -> Self {
let top = unsafe { lua::lua_gettop(state.state_ptr()) };
Self {
state,
top
}
}
}
pub unsafe fn lua_error_guard<F, R>(lua: &State, func: F) -> R
where
F: Fn() -> crate::Result<R>
{
match func() {
Ok(v) => v,
Err(e) => e.throw_lua(lua.state_ptr())
}
}

View File

@ -1,4 +1,4 @@
use std::{sync::Arc, ffi::CStr};
use std::{ffi::CStr, str::Utf8Error, sync::Arc};
use lua::{lua_typename, lua_type};
use mlua_sys as lua;
@ -15,14 +15,73 @@ use function::*;
pub mod value;
use value::*;
pub mod guard;
use guard::*;
/* struct RustFn {
} */
/* struct RustFnUpvalue {
} */
///
pub fn ptr_to_string(ptr: *const i8) -> std::result::Result<String, Utf8Error> {
let c = unsafe { CStr::from_ptr(ptr) };
let s= c.to_str()?;
Ok(s.to_string())
}
fn main() -> Result<()> {
let lua = State::new();
lua.expose_libraries(&[StdLibrary::Debug]);
let globals = lua.globals()?;
let a = |lua: &State, (num,): (i32,)| -> Result<i32> {
println!("Rust got number from lua: {}", num);
Ok(999)
};
let f = lua.create_function(a)?;
globals.set("native_test", f)?;
//let tbl = lua.create_table()?;
let vec2_add = lua.create_function(|lua: &State, (a, b): (Table, Table)| -> Result<Table> {
let ax: i32 = a.get("x")?;
let ay: i32 = a.get("y")?;
let bx: i32 = b.get("x")?;
let by: i32 = b.get("y")?;
let rx = ax + bx;
let ry = ay + by;
let mt = lua.create_meta_table("Vec2")?;
mt.set("x", rx)?;
mt.set("y", ry)?;
Ok(mt)
})?;
let mt = lua.create_meta_table("Vec2")?;
mt.set("x", 50)?;
mt.set("y", 50)?;
mt.set_meta("__add", vec2_add)?;
globals.set("pos1", mt)?;
let mt = lua.create_meta_table("Vec2")?;
mt.set("x", 25)?;
mt.set("y", 25)?;
globals.set("pos2", mt)?;
let tbl = lua.create_table()?;
tbl.set("x", 10)?;
let globals = lua.globals()?;
//let globals = lua.globals()?;
globals.set("X", tbl)?;
lua.execute(r#"
@ -47,21 +106,28 @@ fn main() -> Result<()> {
end
end]]--
print("x is " .. X.x)
cool_num = 50
function say_number(num)
print("I'm lua and I said " .. num)
end
function multiply_print(a, b)
print(a .. " * " .. b .. " = " .. a * b)
print(a .. " * " .. b .. " = " .. a*b)
end
function multiply_ret(a, b)
return a * b
end
function say_number(a)
print("Lua says " .. a)
end
cool_num = 50
local res = native_test(50)
print("Lua got " .. res .. " back from rust!")
print("Pos1 is (" .. pos1.x .. ", " .. pos1.y .. ")")
print("Pos2 is (" .. pos2.x .. ", " .. pos2.y .. ")")
local add_pos = pos1 + pos2
print("Pos1 + pos2 is (" .. add_pos.x .. ", " .. add_pos.y .. ")")
"#).unwrap();
let num = globals.get::<_, i32>("cool_num")?;
@ -99,8 +165,10 @@ impl LuaRef {
/// Creates a reference to what is at the top of the stack.
pub unsafe fn from_stack(state: &State) -> Result<Self> {
let s = state.state_ptr();
let t = lua::lua_gettop(s);
let r = lua::luaL_ref(s, lua::LUA_REGISTRYINDEX);
let t = lua::lua_gettop(s);
if r == lua::LUA_REFNIL {
Err(Error::Nil)
} else {
@ -115,7 +183,13 @@ impl PushToLuaStack for LuaRef {
unsafe {
state.ensure_stack(1)?;
lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, *self.0 as i64);
let top = lua::lua_gettop(s);
let ty = lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, *self.0 as i64);
let new_top = lua::lua_gettop(s);
if ty == lua::LUA_TNIL || ty == lua::LUA_TNONE || top == new_top {
return Err(Error::Nil);
}
}
Ok(())
@ -124,16 +198,26 @@ impl PushToLuaStack for LuaRef {
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Lua runtime error: {0}")]
/// An error returned from lua
#[error("Lua runtime error: {0}")]
Runtime(String),
#[error("Ran out of memory when attempting to use `lua_checkstack`")]
/// Ran into a not enough memory error when trying to grow the lua stack.
#[error("Ran out of memory when attempting to use `lua_checkstack`")]
Oom,
#[error("Ran into a nill value on the stack")]
Nil,
#[error("Unexpected type, expected {0} but got {1}")]
UnexpectedType(String, String)
UnexpectedType(String, String),
#[error("Bad argument provided to {func:?}! Argument #{arg_index} (name: {arg_name:?}), cause: {error}")]
BadArgument {
func: Option<String>,
arg_index: i32,
arg_name: Option<String>,
/// the error that describes what was wrong for this argument
error: Arc<Error>
},
#[error("There is already a registry entry with the key {0}")]
RegistryConflict(String)
}
impl Error {
@ -144,6 +228,16 @@ impl Error {
pub fn unexpected_type(expected: &str, got: &str) -> Self {
Self::UnexpectedType(expected.to_string(), got.to_string())
}
/// Throw the error in lua.
///
/// This method never returns
pub unsafe fn throw_lua(self, lua: *mut lua::lua_State) -> ! {
let msg = format!("{}", self);
let msg_c = msg.as_ptr() as *const i8;
lua::luaL_error(lua, msg_c);
panic!("never gets here");
}
}
/// A result for use with lua functions
@ -157,6 +251,12 @@ pub trait FromLuaStack<'a>: Sized {
unsafe fn from_lua_stack(state: &'a State) -> Result<Self>;
}
/* impl<'a> FromLuaStack<'a> for () {
unsafe fn from_lua_stack(state: &'a State) -> Result<Self> {
Ok(())
}
} */
/// Implements PushToLuaStack for a number
macro_rules! impl_push_to_lua_stack_number {
($ty: ident) => {
@ -176,7 +276,7 @@ impl<'a> FromLuaStack<'a> for i32 {
if lua::lua_isnumber(s, -1) == 1 {
let v = lua::lua_tonumber(s, -1) as i32;
lua::lua_pop(s, -1);
lua::lua_pop(s, 1);
Ok(v)
} else {
let lua_ty = lua_type(s, -1);

View File

@ -1,9 +1,9 @@
use core::ffi;
use std::{ptr::NonNull, ffi::{CString, CStr}};
use std::{ffi::{CString, CStr}, mem, ptr::{self, NonNull}};
use mlua_sys as lua;
use crate::{Table, Result, Error, LuaRef};
use crate::{lua_error_guard, Error, FromLuaStack, FromLuaStackMulti, Function, LuaRef, PushToLuaStack, PushToLuaStackMulti, Result, Table};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum StdLibrary {
@ -92,6 +92,12 @@ impl State {
}
}
pub fn from_ptr(lua: *mut lua::lua_State) -> Self {
Self {
lua: unsafe { NonNull::new_unchecked(lua) }
}
}
pub fn state_ptr(&self) -> *mut lua::lua_State {
self.lua.as_ptr()
}
@ -142,6 +148,10 @@ impl State {
Table::new(self)
}
pub fn create_meta_table<'a>(&'a self, name: &str) -> Result<Table<'a>> {
Table::new_meta_table(self, name)
}
/// This called `lua_checkstack` and returns a result.
///
/// The result will be Ok if the stack has space for `n` values.
@ -157,7 +167,105 @@ impl State {
unsafe {
let s = self.state_ptr();
lua::lua_rawgeti(s, lua::LUA_REGISTRYINDEX, lua::LUA_RIDX_GLOBALS);
Table::with_ref(self, LuaRef::from_stack(self)?)
Table::with_ref(self, LuaRef::from_stack(self)?, None)
}
}
pub(crate) unsafe fn print_stack(state: &State) {
let s = state.state_ptr();
let t = lua::lua_gettop(s);
for i in 1..(t+1) {
let ty = lua::lua_type(s, i);
let tyname = lua::lua_typename(s, ty);
let tyname = CStr::from_ptr(tyname);
let tyname = tyname.to_str().unwrap();
println!("{}: {}", i, tyname);
}
}
pub fn create_function<'a, A, R, F>(&'a self, f: F) -> Result<Function<'a>>
where
A: FromLuaStackMulti<'a>,
R: PushToLuaStackMulti<'a>,
F: Fn(&'a State, A) -> Result<R> + 'static,
{
unsafe extern "C-unwind" fn rust_closure(s: *mut lua::lua_State) -> i32 {
// ensure validity of data
let upv_idx = lua::lua_upvalueindex(1);
let ltype = lua::lua_type(s, upv_idx);
match ltype {
lua::LUA_TUSERDATA => {
let ud_ptr = lua::lua_touserdata(s, upv_idx);
if ud_ptr.is_null() {
crate::Error::runtime("null upvalue provided to luacclosure")
.throw_lua(s);
}
let data_ptr = ud_ptr as *mut ClosureData;
let top = lua::lua_gettop(s);
let wrap = &(*data_ptr).wrapper_fn;
let s = (*data_ptr).state;
wrap(s, top)
},
_ => {
let name = CStr::from_ptr(lua::lua_typename(s, ltype));
if let Ok(n_str) = name.to_str() {
crate::Error::Runtime(format!("unexpected type ({}) provided to luacclosure", n_str))
.throw_lua(s);
} else {
crate::Error::runtime("unexpected type provided to luacclosure")
.throw_lua(s);
}
}
}
}
struct ClosureData<'a> {
wrapper_fn: Box<dyn Fn(&'a State, i32) -> i32>,
state: &'a State,
}
let wrapper_fn = move |lua: &State, narg: i32| -> i32 {
unsafe {
let lua: &State = mem::transmute(lua); // transmute lifetimes
if narg != A::len() as i32 {
Error::Runtime(format!("incorrect number of arguments provided to lua function, expected {}", A::len()))
.throw_lua(lua.state_ptr());
}
lua_error_guard(lua, || {
let args = A::results_from_lua_stack(lua)?;
let r = f(lua, args)?;
r.push_args_to_lua_stack(lua)?;
Ok(r.len() as i32)
})
}
};
let data = ClosureData {
wrapper_fn: Box::new(wrapper_fn),
state: self,
};
let s = self.state_ptr();
unsafe {
let ptr = lua::lua_newuserdata(s, mem::size_of::<ClosureData>());
let ptr = ptr.cast();
ptr::write(ptr, data);
lua::lua_pushcclosure(s, rust_closure, 1);
let lref = LuaRef::from_stack(self)?;
Ok(Function::from_ref(self, lref))
}
}
}

View File

@ -1,12 +1,20 @@
use std::{ffi::CStr, ops::Deref, sync::Arc};
use mlua_sys as lua;
use crate::{State, Result, PushToLuaStack, LuaRef, FromLuaStack};
use crate::{FromLuaStack, LuaRef, PushToLuaStack, Result, StackGuard, State};
pub(crate) struct MetaTableInfo {
name: Option<String>,
lref: LuaRef,
}
#[derive(Clone)]
pub struct Table<'a> {
state: &'a State,
lref: LuaRef,
meta: Option<LuaRef>, // Some if this table is a metatable
}
impl<'a> Table<'a> {
@ -23,24 +31,62 @@ impl<'a> Table<'a> {
Ok(Self {
state,
lref,
meta: None,
})
}
pub fn new_meta_table(state: &'a State, name: &str) -> Result<Self> {
let name_term = format!("{}\0", name);
let name_term_c = name_term.as_str().as_ptr() as *const i8;
let name_term_arc = Arc::new(name_term_c);
let (lref, meta_ref) = unsafe {
let _g = StackGuard::new(state);
state.ensure_stack(2)?;
let s = state.state_ptr();
if lua::luaL_newmetatable(s, name_term_c) == 0 {
// lua::luaL_getmetatable does not return the type that was
// retrieved from the registry
lua::lua_pushstring(s, name_term_c);
let ty = lua::lua_rawget(s, lua::LUA_REGISTRYINDEX);
if ty != lua::LUA_TTABLE {
return Err(crate::Error::RegistryConflict(name.to_string()));
}
}
let meta = LuaRef::from_stack(state)?;
let s = state.state_ptr();
lua::lua_newtable(s);
(LuaRef::from_stack(state)?, meta)
};
Ok(Self {
state,
lref,
meta: Some(meta_ref)
})
}
/// Construct a table with a lua reference to one.
pub fn with_ref(state: &'a State, lua_ref: LuaRef) -> Result<Self> {
pub fn with_ref(state: &'a State, lua_ref: LuaRef, name: Option<String>) -> Result<Self> {
let s = state.state_ptr();
unsafe {
let _g = StackGuard::new(state);
lua_ref.push_to_lua_stack(state)?;
if lua::lua_istable(s, -1) == 0 {
panic!("Index is not a table")
}
lua::lua_pop(s, -1);
}
Ok(Self {
state,
lref: lua_ref,
meta: None,
})
}
@ -54,14 +100,24 @@ impl<'a> Table<'a> {
{
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(3)?;
let _g = StackGuard::new(self.state);
if let Some(_) = &self.meta {
self.state.ensure_stack(4)?;
} else {
self.state.ensure_stack(3)?;
}
self.lref.push_to_lua_stack(self.state)?;
key.push_to_lua_stack(self.state)?;
val.push_to_lua_stack(self.state)?;
lua::lua_settable(s, -3);
lua::lua_pop(self.state.state_ptr(), -1);
if let Some(mt) = &self.meta {
mt.push_to_lua_stack(self.state)?;
lua::lua_setmetatable(s, -2);
}
}
Ok(())
@ -77,14 +133,14 @@ impl<'a> Table<'a> {
{
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(2)?;
let _g = StackGuard::new(self.state);
self.lref.push_to_lua_stack(self.state)?;
key.push_to_lua_stack(self.state)?;
lua::lua_gettable(s, -2); // table[key] is at top of stack
let top = lua::lua_gettop(s);
let val = V::from_lua_stack(self.state)?;
let new_top = lua::lua_gettop(s);
debug_assert!(new_top < top, "V::from_lua_stack did not remove anything from the stack!");
Ok(val)
}
@ -96,11 +152,13 @@ impl<'a> Table<'a> {
pub fn len(&self) -> Result<u64> {
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(1)?;
let _g = StackGuard::new(self.state);
self.lref.push_to_lua_stack(self.state)?;
lua::lua_len(s, -1);
let len = lua::lua_tonumber(s, -1);
lua::lua_pop(self.state.state_ptr(), -1);
Ok(len as u64)
}
@ -115,6 +173,7 @@ impl<'a> Table<'a> {
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(3)?;
let _g = StackGuard::new(self.state);
self.lref.push_to_lua_stack(self.state)?;
key.push_to_lua_stack(self.state)?;
@ -135,6 +194,9 @@ impl<'a> Table<'a> {
{
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(2)?;
let _g = StackGuard::new(self.state);
self.lref.push_to_lua_stack(self.state)?;
key.push_to_lua_stack(self.state)?;
lua::lua_rawget(s, -2); // table[key] is at top of stack
@ -146,13 +208,39 @@ impl<'a> Table<'a> {
pub fn raw_len(&self) -> Result<u64> {
let s = self.state.state_ptr();
unsafe {
self.state.ensure_stack(1)?;
let _g = StackGuard::new(self.state);
self.lref.push_to_lua_stack(self.state)?;
let len = lua::lua_rawlen(s, -1);
lua::lua_pop(s, -1); // pop table
lua::lua_pop(s, 1); // pop table
Ok(len as u64)
}
}
pub fn set_meta<K, V>(&self, key: K, val: V) -> Result<()>
where
K: PushToLuaStack,
V: PushToLuaStack
{
let mt = self.meta.as_ref()
.expect("this table is not a meta table!");
unsafe {
let s = self.state.state_ptr();
self.state.ensure_stack(3)?;
let _g = StackGuard::new(self.state);
//lua::luaL_getmetatable(s, **cname);
mt.push_to_lua_stack(self.state)?;
key.push_to_lua_stack(self.state)?;
val.push_to_lua_stack(self.state)?;
lua::lua_settable(s, -3);
}
Ok(())
}
}
impl<'a> PushToLuaStack for Table<'a> {
@ -163,3 +251,22 @@ impl<'a> PushToLuaStack for Table<'a> {
Ok(())
}
}
impl<'a> FromLuaStack<'a> for Table<'a> {
unsafe fn from_lua_stack(state: &'a State) -> Result<Self> {
let s = state.state_ptr();
let ty = lua::lua_type(s, -1);
if ty == lua::LUA_TTABLE {
let t = Table::with_ref(state, LuaRef::from_stack(state)?, None);
t
} else {
let tyname = lua::lua_typename(s, ty);
let cstr = CStr::from_ptr(tyname);
let s = cstr.to_str()
.expect("Lua type has invalid bytes!");
Err(crate::Error::UnexpectedType("Table".to_string(), s.to_string()))
}
}
}