Store userdata pointers as RefCells to avoid multiple mutable borrows at once

This commit is contained in:
SeanOMik 2024-01-29 11:03:49 -05:00
parent a3dbe82613
commit eebb93a9a6
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
3 changed files with 25 additions and 25 deletions

View File

@ -1,4 +1,4 @@
use std::{marker::PhantomData, sync::Arc}; use std::{cell::Ref, marker::PhantomData, sync::Arc};
use mlua_sys as lua; use mlua_sys as lua;
@ -392,7 +392,6 @@ pub struct Vec2 {
} }
impl Userdata for Vec2 { impl Userdata for Vec2 {
fn build<'a>(builder: &mut UserdataBuilder<'a, Vec2>) -> crate::Result<()> { fn build<'a>(builder: &mut UserdataBuilder<'a, Vec2>) -> crate::Result<()> {
builder builder
.field_getter("x", |_, this| Ok(this.x)) .field_getter("x", |_, this| Ok(this.x))
@ -406,7 +405,7 @@ impl Userdata for Vec2 {
}) })
// method test // method test
.method("add", |lua, lhs: &Vec2, (rhs,): (&Vec3,)| { .method("add", |lua, lhs: &Vec2, (rhs,): (Ref<Vec3>,)| {
let lx = lhs.x; let lx = lhs.x;
let ly = lhs.y; let ly = lhs.y;
@ -416,7 +415,7 @@ impl Userdata for Vec2 {
lua.create_userdata(Vec2 { x: lx + rx, y: ly + ry, }) lua.create_userdata(Vec2 { x: lx + rx, y: ly + ry, })
}) })
.meta_method(MetaMethod::Add, |lua, lhs: &Vec2, (rhs,): (&Vec2,)| { .meta_method(MetaMethod::Add, |lua, lhs: &Vec2, (rhs,): (Ref<Vec2>,)| {
let lx = lhs.x; let lx = lhs.x;
let ly = lhs.y; let ly = lhs.y;
@ -450,7 +449,6 @@ impl<'a, T: Userdata> AsLua<'a> for UserdataProxy<T> {
impl<T: Userdata> Userdata for UserdataProxy<T> { impl<T: Userdata> Userdata for UserdataProxy<T> {
fn build<'a>(builder: &mut UserdataBuilder<'a, Self>) -> crate::Result<()> { fn build<'a>(builder: &mut UserdataBuilder<'a, Self>) -> crate::Result<()> {
let mut other = UserdataBuilder::<T>::new(); let mut other = UserdataBuilder::<T>::new();
T::build(&mut other)?; T::build(&mut other)?;

View File

@ -1,5 +1,5 @@
use core::ffi; use core::ffi;
use std::{alloc::{self, Layout}, any::TypeId, collections::HashMap, ffi::{CStr, CString}, mem, ptr::{self, NonNull}, str::Utf8Error, sync::Arc}; use std::{alloc::{self, Layout}, any::TypeId, cell::RefCell, collections::HashMap, ffi::{CStr, CString}, mem, ptr::{self, NonNull}, str::Utf8Error, sync::Arc};
use mlua_sys as lua; use mlua_sys as lua;
@ -402,8 +402,8 @@ impl State {
let _g = StackGuard::new(self); let _g = StackGuard::new(self);
let s = self.state_ptr(); let s = self.state_ptr();
let ptr = lua::lua_newuserdata(s, mem::size_of::<T>()).cast::<T>(); let ptr = lua::lua_newuserdata(s, mem::size_of::<RefCell<T>>()).cast::<RefCell<T>>();
ptr::write(ptr, data); ptr::write(ptr, RefCell::new(data));
// get the current metatable, or create a new one and push it to the stack. // get the current metatable, or create a new one and push it to the stack.
let udmts = &mut self.get_extra_space().userdata_metatables; let udmts = &mut self.get_extra_space().userdata_metatables;

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, ffi::CStr, marker::PhantomData}; use std::{borrow::Borrow, cell::{Ref, RefCell, RefMut}, collections::HashMap, ffi::CStr, marker::PhantomData, ops::{Deref, DerefMut}};
use crate::{ensure_type, AsLua, FromLua, FromLuaStack, FromLuaVec, LuaRef, PushToLuaStack, StackGuard, State, Value, ValueVec}; use crate::{ensure_type, AsLua, FromLua, FromLuaStack, FromLuaVec, LuaRef, PushToLuaStack, StackGuard, State, Value, ValueVec};
@ -133,7 +133,7 @@ impl<'a, T: Userdata> UserdataBuilder<'a, T> {
let val = val.pop_front().unwrap(); let val = val.pop_front().unwrap();
let this = val.as_userdata().unwrap(); // if this panics, its a bug let this = val.as_userdata().unwrap(); // if this panics, its a bug
let this = this.as_ref::<T>()?; let this = this.as_ref::<T>()?;
f(lua, this).and_then(|r| r.as_lua(lua)) f(lua, &*this).and_then(|r| r.as_lua(lua))
}; };
self.field_getters.insert(name.to_string(), Box::new(wrap)); self.field_getters.insert(name.to_string(), Box::new(wrap));
@ -149,12 +149,12 @@ impl<'a, T: Userdata> UserdataBuilder<'a, T> {
let wrap = move |lua: &'a State, mut val: ValueVec<'a>| { let wrap = move |lua: &'a State, mut val: ValueVec<'a>| {
let lua_val = val.pop_front().unwrap(); let lua_val = val.pop_front().unwrap();
let this = lua_val.as_userdata().unwrap(); // if this panics, its a bug let this = lua_val.as_userdata().unwrap(); // if this panics, its a bug
let this = this.as_mut::<T>()?; let mut this = this.as_mut::<T>()?;
let lua_val = val.pop_front().unwrap(); let lua_val = val.pop_front().unwrap();
let v_arg = V::from_lua(lua, lua_val)?; let v_arg = V::from_lua(lua, lua_val)?;
f(lua, this, v_arg).as_lua(lua) f(lua, this.deref_mut(), v_arg).as_lua(lua)
}; };
self.field_setters.insert(name.to_string(), Box::new(wrap)); self.field_setters.insert(name.to_string(), Box::new(wrap));
@ -211,7 +211,7 @@ impl<'a, T: Userdata> UserdataBuilder<'a, T> {
&this_name, &fn_name &this_name, &fn_name
)?; )?;
f(lua, this, args).and_then(|r| r.as_lua(lua)) f(lua, &*this, args).and_then(|r| r.as_lua(lua))
}; };
self.functions.insert(name.to_string(), Box::new(wrap)); self.functions.insert(name.to_string(), Box::new(wrap));
@ -238,7 +238,7 @@ impl<'a, T: Userdata> UserdataBuilder<'a, T> {
&this_name, &fn_name &this_name, &fn_name
)?; )?;
f(lua, this, args).and_then(|r| r.as_lua(lua)) f(lua, &*this, args).and_then(|r| r.as_lua(lua))
}; };
self.meta_methods.insert(name.as_ref().to_string(), Box::new(wrap)); self.meta_methods.insert(name.as_ref().to_string(), Box::new(wrap));
@ -268,7 +268,7 @@ impl<'a> AnyUserdata<'a> {
} }
/// Returns a borrow to the userdata. /// Returns a borrow to the userdata.
pub fn as_ref<T: Userdata + 'static>(&self) -> crate::Result<&'a T> { pub fn as_ref<T: Userdata + 'static>(&self) -> crate::Result<Ref<'a, T>> {
unsafe { unsafe {
self.state.ensure_stack(3)?; self.state.ensure_stack(3)?;
let _g = StackGuard::new(self.state); let _g = StackGuard::new(self.state);
@ -285,7 +285,8 @@ impl<'a> AnyUserdata<'a> {
if lua::lua_rawequal(s, -2, -1) == 1 { if lua::lua_rawequal(s, -2, -1) == 1 {
let cptr = lua::lua_touserdata(s, -3); let cptr = lua::lua_touserdata(s, -3);
Ok(cptr.cast::<T>().as_ref().unwrap()) // TODO: Ensure this userdata matches the type of T let cell = &*cptr.cast::<RefCell<T>>();
Ok(cell.borrow())
} else { } else {
return Err(crate::Error::UserdataMismatch); return Err(crate::Error::UserdataMismatch);
} }
@ -293,7 +294,7 @@ impl<'a> AnyUserdata<'a> {
} }
/// Returns a mutable reference to the userdata. /// Returns a mutable reference to the userdata.
pub fn as_mut<T: Userdata + 'static>(&self) -> crate::Result<&'a mut T> { pub fn as_mut<T: Userdata + 'static>(&self) -> crate::Result<RefMut<'a, T>> {
unsafe { unsafe {
self.state.ensure_stack(3)?; self.state.ensure_stack(3)?;
let _g = StackGuard::new(self.state); let _g = StackGuard::new(self.state);
@ -310,14 +311,15 @@ impl<'a> AnyUserdata<'a> {
if lua::lua_rawequal(s, -2, -1) == 1 { if lua::lua_rawequal(s, -2, -1) == 1 {
let cptr = lua::lua_touserdata(s, -3); let cptr = lua::lua_touserdata(s, -3);
Ok(cptr.cast::<T>().as_mut().unwrap()) let cell = &*cptr.cast::<RefCell<T>>();
Ok(cell.borrow_mut())
} else { } else {
Err(crate::Error::UserdataMismatch) Err(crate::Error::UserdataMismatch)
} }
} }
} }
/// Returns a mutable pointer to the userdata **WITHOUT verifying the type of it**. /// Returns a mutable pointer of the [`RefCell`] of userdata **WITHOUT verifying the type of it**.
/// ///
/// # Safety /// # Safety
/// * You must be certain that the type `T` is the same type that this userdata has a handle to. /// * You must be certain that the type `T` is the same type that this userdata has a handle to.
@ -325,7 +327,7 @@ impl<'a> AnyUserdata<'a> {
/// ///
/// If there is a possibility that these types do not match, use [`AnyUserdata::as_ptr`] /// If there is a possibility that these types do not match, use [`AnyUserdata::as_ptr`]
/// which does verify the types. /// which does verify the types.
pub unsafe fn as_ptr_unchecked<T: Userdata + 'static>(&self) -> crate::Result<*mut T> { pub unsafe fn as_ptr_unchecked<T: Userdata + 'static>(&self) -> crate::Result<*mut RefCell<T>> {
self.state.ensure_stack(1)?; self.state.ensure_stack(1)?;
let _g = StackGuard::new(self.state); let _g = StackGuard::new(self.state);
let s = self.state.state_ptr(); let s = self.state.state_ptr();
@ -336,11 +338,11 @@ impl<'a> AnyUserdata<'a> {
Ok(cptr.cast()) Ok(cptr.cast())
} }
/// Returns a mutable pointer to the userdata. /// Returns a mutable pointer of the [`RefCell`] storing the userdata.
/// ///
/// This function ensures that the type of the userdata this struct has a handle to is the /// This function ensures that the type of the userdata this struct has a handle to is the
/// same as `T`. If it isn't, a `UserdataMismatch` error will be returned. /// same as `T`. If it isn't, a `UserdataMismatch` error will be returned.
pub unsafe fn as_ptr<T: Userdata + 'static>(&self) -> crate::Result<*mut T> { pub unsafe fn as_ptr<T: Userdata + 'static>(&self) -> crate::Result<*mut RefCell<T>> {
let _g = StackGuard::new(self.state); let _g = StackGuard::new(self.state);
let s = self.state.state_ptr(); let s = self.state.state_ptr();
@ -355,7 +357,7 @@ impl<'a> AnyUserdata<'a> {
if lua::lua_rawequal(s, -2, -1) == 1 { if lua::lua_rawequal(s, -2, -1) == 1 {
let cptr = lua::lua_touserdata(s, -3); let cptr = lua::lua_touserdata(s, -3);
Ok(cptr.cast::<T>()) Ok(cptr.cast())
} else { } else {
Err(crate::Error::UserdataMismatch) Err(crate::Error::UserdataMismatch)
} }
@ -418,14 +420,14 @@ impl<'a> FromLua<'a> for AnyUserdata<'a> {
} }
} }
impl<'a, T: Userdata + 'static> FromLua<'a> for &'a T { impl<'a, T: Userdata + 'static> FromLua<'a> for Ref<'a, T> {
fn from_lua(_lua: &'a State, val: Value<'a>) -> crate::Result<Self> { fn from_lua(_lua: &'a State, val: Value<'a>) -> crate::Result<Self> {
let ud = val.into_userdata()?; let ud = val.into_userdata()?;
ud.as_ref::<T>() ud.as_ref::<T>()
} }
} }
impl<'a, T: Userdata + 'static> FromLua<'a> for &'a mut T { impl<'a, T: Userdata + 'static> FromLua<'a> for RefMut<'a, T> {
fn from_lua(_lua: &'a State, val: Value<'a>) -> crate::Result<Self> { fn from_lua(_lua: &'a State, val: Value<'a>) -> crate::Result<Self> {
let ud = val.into_userdata()?; let ud = val.into_userdata()?;
ud.as_mut::<T>() ud.as_mut::<T>()