fix seg fault caused by garbage collection of `ClosureData`s

If the `ClosureData`s are freed on `lua_closed`, the references to the State they may have could become invalid. This was fixed by storing a StatePtr inside LuaRef's instead of a reference
This commit is contained in:
SeanOMik 2024-02-29 19:30:56 -05:00
parent feb93f2b4e
commit 54c9926a04
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
4 changed files with 61 additions and 45 deletions

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::{marker::PhantomData, sync::Arc};
use crate::{Error, PushToLuaStack, Result, State}; use crate::{Error, PushToLuaStack, Result, State, StatePtr};
use mlua_sys as lua; use mlua_sys as lua;
@ -11,15 +11,16 @@ use mlua_sys as lua;
/// the inner Arc detects a single strong count. /// the inner Arc detects a single strong count.
#[derive(Clone)] #[derive(Clone)]
pub struct LuaRef<'a> { pub struct LuaRef<'a> {
lref: Arc<i32>, pub(crate) lref: Arc<i32>,
pub(crate) state: &'a State, pub(crate) state: StatePtr,
_marker: PhantomData<&'a State>,
} }
impl<'a> Drop for LuaRef<'a> { impl<'a> Drop for LuaRef<'a> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
if Arc::strong_count(&self.lref) == 1 { if Arc::strong_count(&self.lref) == 1 {
let s = self.state.state_ptr(); let s = self.state.lua.as_ptr();
lua::luaL_unref(s, lua::LUA_REGISTRYINDEX, *self.lref); lua::luaL_unref(s, lua::LUA_REGISTRYINDEX, *self.lref);
} }
} }
@ -27,10 +28,19 @@ impl<'a> Drop for LuaRef<'a> {
} }
impl<'a> LuaRef<'a> { impl<'a> LuaRef<'a> {
pub fn new(lua_ref: i32, state: &'a State) -> Self { pub(crate) fn new(lua_ref: i32, state: &'a State) -> Self {
Self { Self {
lref: Arc::new(lua_ref), lref: Arc::new(lua_ref),
state, state: state.ptr.clone(),
_marker: PhantomData
}
}
pub(crate) fn from_arc(lua_ref: Arc<i32>, state: &'a State) -> Self {
Self {
lref: lua_ref,
state: state.ptr.clone(),
_marker: PhantomData
} }
} }

View File

@ -101,12 +101,12 @@ struct ClosureData<'a> {
} }
//#[derive(Default)] //#[derive(Default)]
pub struct ExtraSpace<'a> { pub struct ExtraSpace {
pub userdata_metatables: HashMap<TypeId, LuaRef<'a>>, pub userdata_metatables: HashMap<TypeId, Arc<i32>>,
pub state_ptr: StatePtr pub state_ptr: StatePtr,
} }
impl<'a> ExtraSpace<'a> { impl ExtraSpace {
pub fn new(state: &State) -> Self { pub fn new(state: &State) -> Self {
Self { Self {
userdata_metatables: Default::default(), userdata_metatables: Default::default(),
@ -117,11 +117,11 @@ impl<'a> ExtraSpace<'a> {
#[derive(Clone)] #[derive(Clone)]
pub struct StatePtr { pub struct StatePtr {
lua: Arc<NonNull<lua::lua_State>>, pub lua: Arc<NonNull<lua::lua_State>>,
} }
pub struct State { pub struct State {
ptr: StatePtr, pub(crate) ptr: StatePtr,
} }
impl Default for State { impl Default for State {
@ -133,28 +133,31 @@ impl Default for State {
impl Drop for State { impl Drop for State {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
println!("State count: {}", Arc::strong_count(&self.ptr.lua)); let extra = self.get_extra_space_ptr();
if Arc::strong_count(&self.ptr.lua) == 2 { // this owned arc, and the one in extra
let extra = self.get_extra_space_ptr();
{ {
// clear the refs to anything in lua before we close it and // clear the refs to anything in lua before we close it and
// attempt to drop extra after // attempt to drop extra after
let extra = &mut *extra;
extra.userdata_metatables.clear(); let s = self.state_ptr();
let extra = &mut *extra;
for (_, lref) in extra.userdata_metatables.drain() {
lua::luaL_unref(s, lua::LUA_REGISTRYINDEX, *lref);
} }
lua::lua_close(self.state_ptr()); extra.userdata_metatables.clear();
extra.drop_in_place();
// must be dealloced since it wasn't memory created from lua (i.e. userdata)
alloc::dealloc(extra.cast(), Layout::new::<ExtraSpace>());
} }
lua::lua_close(self.state_ptr());
extra.drop_in_place();
// must be dealloced since it wasn't memory created from lua
alloc::dealloc(extra.cast(), Layout::new::<ExtraSpace>());
} }
} }
} }
unsafe fn extra_space<'a>(state: *mut lua::lua_State) -> *mut ExtraSpace<'a> { unsafe fn extra_space(state: *mut lua::lua_State) -> *mut ExtraSpace {
let extra = lua::lua_getextraspace(state) let extra = lua::lua_getextraspace(state)
.cast::<*mut ExtraSpace>(); .cast::<*mut ExtraSpace>();
*extra *extra
@ -432,15 +435,14 @@ impl State {
} }
}; };
let data = ClosureData {
wrapper_fn: Box::new(wrapper_fn),
};
let s = self.state_ptr();
unsafe { unsafe {
let s = self.state_ptr();
let _g = StackGuard::new(self); let _g = StackGuard::new(self);
self.ensure_stack(3)?; self.ensure_stack(4)?;
let data = ClosureData {
wrapper_fn: Box::new(wrapper_fn),
};
let ptr = lua::lua_newuserdata(s, mem::size_of::<ClosureData>()); let ptr = lua::lua_newuserdata(s, mem::size_of::<ClosureData>());
let ptr = ptr.cast(); let ptr = ptr.cast();
@ -480,8 +482,11 @@ impl State {
let mt = self.create_userdata_metatable::<T>()?; let mt = self.create_userdata_metatable::<T>()?;
mt.push_to_lua_stack(self)?; mt.push_to_lua_stack(self)?;
} else { } else {
let mt = udmts.get(&TypeId::of::<T>()).unwrap(); let mt = self.get_userdata_metatable::<T>();
mt.push_to_lua_stack(self)?; mt.push_to_lua_stack(self)?;
/* let mt = udmts.get(&TypeId::of::<T>()).unwrap();
mt.push_to_lua_stack(self)?; */
} }
lua::lua_setmetatable(s, -2); lua::lua_setmetatable(s, -2);
@ -494,7 +499,7 @@ impl State {
let extra = self.get_extra_space(); let extra = self.get_extra_space();
let mt = extra.userdata_metatables.get(&TypeId::of::<T>()); let mt = extra.userdata_metatables.get(&TypeId::of::<T>());
mt.map(|r| Table::with_ref(self, r.clone(), true).unwrap()) mt.map(|r| Table::with_ref(self, LuaRef::from_arc(r.clone(), self), true).unwrap())
} }
pub(crate) fn create_userdata_metatable<'a, T: Userdata + 'static>(&'a self) -> Result<Table<'a>> { pub(crate) fn create_userdata_metatable<'a, T: Userdata + 'static>(&'a self) -> Result<Table<'a>> {
@ -577,7 +582,7 @@ impl State {
} }
let extra = self.get_extra_space(); let extra = self.get_extra_space();
extra.userdata_metatables.insert(TypeId::of::<T>(), mt.lref.clone()); extra.userdata_metatables.insert(TypeId::of::<T>(), mt.lref.lref.clone());
Ok(mt) Ok(mt)
} }

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData; use std::{marker::PhantomData, mem};
use crate::{FromLua, FromLuaStack, LuaRef, PushToLuaStack, StackGuard, Value}; use crate::{FromLua, FromLuaStack, LuaRef, PushToLuaStack, StackGuard, State, Value};
use mlua_sys as lua; use mlua_sys as lua;
@ -25,15 +25,16 @@ impl<'a, T: FromLua<'a>> Iterator for TableIter<'a, T> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
unsafe { unsafe {
let state = self.lref.state; let state: &State = mem::transmute(&self.lref.state.lua);
let s = state.state_ptr(); let s = state.state_ptr();
if let Err(e) = state.ensure_stack(1) { if let Err(e) = state.ensure_stack(1) {
return Some(Err(e)); return Some(Err(e));
} }
let _g = StackGuard::new(state); let _g = StackGuard::new(&state);
if let Err(e) = self.lref.push_to_lua_stack(state) { if let Err(e) = self.lref.push_to_lua_stack(&state) {
return Some(Err(e)); return Some(Err(e));
} }

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData; use std::{marker::PhantomData, mem};
use crate::{AsLua, FromLua, FromLuaStack, LuaRef, PushToLuaStack, StackGuard, Value}; use crate::{AsLua, FromLua, FromLuaStack, LuaRef, PushToLuaStack, StackGuard, State, Value};
use mlua_sys as lua; use mlua_sys as lua;
@ -20,7 +20,7 @@ impl<'a, K: FromLua<'a>, V: FromLua<'a>> TablePairs<'a, K, V> {
} }
unsafe fn get_item(&mut self) -> crate::Result<Option<(K, V)>> { unsafe fn get_item(&mut self) -> crate::Result<Option<(K, V)>> {
let state = self.lref.state; let state: &State = mem::transmute(&self.lref.state.lua);
let s = state.state_ptr(); let s = state.state_ptr();
let _g = StackGuard::new(state); let _g = StackGuard::new(state);