Skip to content

Commit 30e75da

Browse files
committed
Improve/fix scoped UserData drop
1 parent 00b80a4 commit 30e75da

File tree

3 files changed

+153
-62
lines changed

3 files changed

+153
-62
lines changed

src/scope.rs

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
186186
let state = u.lua.state;
187187
assert_stack(state, 2);
188188
u.lua.push_ref(&u);
189+
190+
// Clear uservalue
191+
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
192+
ffi::lua_pushnil(state);
193+
#[cfg(any(feature = "lua51", feature = "luajit"))]
194+
ffi::lua_newtable(state);
195+
ffi::lua_setuservalue(state, -2);
196+
189197
// We know the destructor has not run yet because we hold a reference to the
190198
// userdata.
191199
vec![Box::new(take_userdata::<UserDataCell<T>>(state))]
@@ -244,28 +252,28 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
244252
let check_ud_type = move |lua: &'callback Lua, value| {
245253
if let Some(Value::UserData(ud)) = value {
246254
unsafe {
247-
assert_stack(lua.state, 1);
255+
let _sg = StackGuard::new(lua.state);
256+
assert_stack(lua.state, 3);
248257
lua.push_ref(&ud.0);
249-
ffi::lua_getuservalue(lua.state, -1);
250-
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
251-
{
252-
ffi::lua_rawgeti(lua.state, -1, 1);
253-
ffi::lua_remove(lua.state, -2);
258+
if ffi::lua_getmetatable(lua.state, -1) == 0 {
259+
return Err(Error::UserDataTypeMismatch);
260+
}
261+
ffi::lua_pushstring(lua.state, cstr!("__mlua"));
262+
if ffi::lua_rawget(lua.state, -2) == ffi::LUA_TLIGHTUSERDATA {
263+
let ud_ptr = ffi::lua_touserdata(lua.state, -1);
264+
if ud_ptr == check_data.as_ptr() as *mut c_void {
265+
return Ok(());
266+
}
254267
}
255-
return ffi::lua_touserdata(lua.state, -1)
256-
== check_data.as_ptr() as *mut c_void;
257268
}
258-
}
259-
260-
false
269+
};
270+
Err(Error::UserDataTypeMismatch)
261271
};
262272

263273
match method {
264274
NonStaticMethod::Method(method) => {
265275
let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
266-
if !check_ud_type(lua, args.pop_front()) {
267-
return Err(Error::UserDataTypeMismatch);
268-
}
276+
check_ud_type(lua, args.pop_front())?;
269277
let data = data
270278
.try_borrow()
271279
.map(|cell| Ref::map(cell, AsRef::as_ref))
@@ -277,9 +285,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
277285
NonStaticMethod::MethodMut(method) => {
278286
let method = RefCell::new(method);
279287
let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
280-
if !check_ud_type(lua, args.pop_front()) {
281-
return Err(Error::UserDataTypeMismatch);
282-
}
288+
check_ud_type(lua, args.pop_front())?;
283289
let mut method = method
284290
.try_borrow_mut()
285291
.map_err(|_| Error::RecursiveMutCallback)?;
@@ -314,23 +320,18 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
314320
unsafe {
315321
let lua = self.lua;
316322
let _sg = StackGuard::new(lua.state);
317-
assert_stack(lua.state, 6);
323+
assert_stack(lua.state, 13);
318324

319325
push_userdata(lua.state, ())?;
320-
#[cfg(any(feature = "lua54", feature = "lua53"))]
321-
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
322-
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
323-
protect_lua_closure(lua.state, 0, 1, |state| {
324-
// Lua 5.2/5.1 allows to store only table. Then we will wrap the value.
325-
ffi::lua_createtable(state, 1, 0);
326-
ffi::lua_pushlightuserdata(state, data.as_ptr() as *mut c_void);
327-
ffi::lua_rawseti(state, -2, 1);
328-
})?;
329-
ffi::lua_setuservalue(lua.state, -2);
330326

331327
// Prepare metatable, add meta methods first and then meta fields
332-
protect_lua_closure(lua.state, 0, 1, move |state| {
328+
protect_lua_closure(lua.state, 0, 1, |state| {
333329
ffi::lua_newtable(state);
330+
331+
// Add internal metamethod to store reference to the data
332+
ffi::lua_pushstring(state, cstr!("__mlua"));
333+
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
334+
ffi::lua_rawset(state, -3);
334335
})?;
335336
for (k, m) in ud_methods.meta_methods {
336337
push_string(lua.state, k.validate()?.name())?;
@@ -413,8 +414,26 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
413414
ffi::lua_pop(lua.state, count);
414415

415416
ffi::lua_setmetatable(lua.state, -2);
417+
let ud = AnyUserData(lua.pop_ref());
418+
419+
self.destructors.borrow_mut().push((ud.0.clone(), |u| {
420+
let state = u.lua.state;
421+
assert_stack(state, 2);
422+
u.lua.push_ref(&u);
423+
424+
// Clear uservalue
425+
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
426+
ffi::lua_pushnil(state);
427+
#[cfg(any(feature = "lua51", feature = "luajit"))]
428+
ffi::lua_newtable(state);
429+
ffi::lua_setuservalue(state, -2);
430+
431+
// We know the destructor has not run yet because we hold a reference to the
432+
// userdata.
433+
vec![Box::new(take_userdata::<()>(state))]
434+
}));
416435

417-
Ok(AnyUserData(lua.pop_ref()))
436+
Ok(ud)
418437
}
419438
}
420439

src/userdata.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ impl MetaMethod {
186186
MetaMethod::Custom(name) if name == "__metatable" => {
187187
Err(Error::MetaMethodRestricted(name))
188188
}
189+
MetaMethod::Custom(name) if name == "__mlua" => Err(Error::MetaMethodRestricted(name)),
189190
_ => Ok(self),
190191
}
191192
}

tests/scope.rs

Lines changed: 103 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ extern "system" {}
1212

1313
use std::cell::Cell;
1414
use std::rc::Rc;
15+
use std::sync::Arc;
1516

16-
use mlua::{Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataMethods, UserDataFields};
17+
use mlua::{
18+
AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields,
19+
UserDataMethods,
20+
};
1721

1822
#[test]
1923
fn scope_func() -> Result<()> {
@@ -35,42 +39,16 @@ fn scope_func() -> Result<()> {
3539
assert_eq!(Rc::strong_count(&rc), 1);
3640

3741
match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) {
38-
Err(Error::CallbackError { .. }) => {}
42+
Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
43+
Error::CallbackDestructed => {}
44+
ref err => panic!("wrong error type {:?}", err),
45+
},
3946
r => panic!("improper return for destructed function: {:?}", r),
4047
};
4148

4249
Ok(())
4350
}
4451

45-
#[test]
46-
fn scope_drop() -> Result<()> {
47-
let lua = Lua::new();
48-
49-
struct MyUserdata(Rc<()>);
50-
impl UserData for MyUserdata {
51-
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
52-
methods.add_method("method", |_, _, ()| Ok(()));
53-
}
54-
}
55-
56-
let rc = Rc::new(());
57-
58-
lua.scope(|scope| {
59-
lua.globals()
60-
.set("test", scope.create_userdata(MyUserdata(rc.clone()))?)?;
61-
assert_eq!(Rc::strong_count(&rc), 2);
62-
Ok(())
63-
})?;
64-
assert_eq!(Rc::strong_count(&rc), 1);
65-
66-
match lua.load("test:method()").exec() {
67-
Err(Error::CallbackError { .. }) => {}
68-
r => panic!("improper return for destructed userdata: {:?}", r),
69-
};
70-
71-
Ok(())
72-
}
73-
7452
#[test]
7553
fn scope_capture() -> Result<()> {
7654
let lua = Lua::new();
@@ -90,7 +68,7 @@ fn scope_capture() -> Result<()> {
9068
}
9169

9270
#[test]
93-
fn outer_lua_access() -> Result<()> {
71+
fn scope_outer_lua_access() -> Result<()> {
9472
let lua = Lua::new();
9573

9674
let table = lua.create_table()?;
@@ -273,3 +251,96 @@ fn scope_userdata_mismatch() -> Result<()> {
273251

274252
Ok(())
275253
}
254+
255+
#[test]
256+
fn scope_userdata_drop() -> Result<()> {
257+
let lua = Lua::new();
258+
259+
struct MyUserData(Rc<()>);
260+
impl UserData for MyUserData {
261+
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
262+
methods.add_method("method", |_, _, ()| Ok(()));
263+
}
264+
}
265+
266+
struct MyUserDataArc(Arc<()>);
267+
impl UserData for MyUserDataArc {}
268+
269+
let rc = Rc::new(());
270+
let arc = Arc::new(());
271+
lua.scope(|scope| {
272+
let ud = scope.create_userdata(MyUserData(rc.clone()))?;
273+
ud.set_user_value(MyUserDataArc(arc.clone()))?;
274+
lua.globals().set("ud", ud)?;
275+
assert_eq!(Rc::strong_count(&rc), 2);
276+
assert_eq!(Arc::strong_count(&arc), 2);
277+
Ok(())
278+
})?;
279+
280+
lua.gc_collect()?;
281+
assert_eq!(Rc::strong_count(&rc), 1);
282+
assert_eq!(Arc::strong_count(&arc), 1);
283+
284+
match lua.load("ud:method()").exec() {
285+
Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
286+
Error::CallbackDestructed => {}
287+
ref err => panic!("wrong error type {:?}", err),
288+
},
289+
r => panic!("improper return for destructed userdata: {:?}", r),
290+
};
291+
292+
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
293+
match ud.borrow::<MyUserData>() {
294+
Ok(_) => panic!("succesfull borrow for destructed userdata"),
295+
Err(Error::UserDataDestructed) => {}
296+
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
297+
}
298+
299+
Ok(())
300+
}
301+
302+
#[test]
303+
fn scope_nonstatic_userdata_drop() -> Result<()> {
304+
let lua = Lua::new();
305+
306+
struct MyUserData<'a>(&'a i64);
307+
impl<'a> UserData for MyUserData<'a> {
308+
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
309+
methods.add_method("method", |_, _, ()| Ok(()));
310+
}
311+
}
312+
313+
struct MyUserDataArc(Arc<()>);
314+
impl UserData for MyUserDataArc {}
315+
316+
let i = 422;
317+
let arc = Arc::new(());
318+
lua.scope(|scope| {
319+
let ud = scope.create_nonstatic_userdata(MyUserData(&i))?;
320+
ud.set_user_value(MyUserDataArc(arc.clone()))?;
321+
lua.globals().set("ud", ud)?;
322+
lua.load("ud:method()").exec()?;
323+
assert_eq!(Arc::strong_count(&arc), 2);
324+
Ok(())
325+
})?;
326+
327+
lua.gc_collect()?;
328+
assert_eq!(Arc::strong_count(&arc), 1);
329+
330+
match lua.load("ud:method()").exec() {
331+
Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
332+
Error::CallbackDestructed => {}
333+
ref err => panic!("wrong error type {:?}", err),
334+
},
335+
r => panic!("improper return for destructed userdata: {:?}", r),
336+
};
337+
338+
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
339+
match ud.borrow::<MyUserData>() {
340+
Ok(_) => panic!("succesfull borrow for destructed userdata"),
341+
Err(Error::UserDataDestructed) => {}
342+
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
343+
}
344+
345+
Ok(())
346+
}

0 commit comments

Comments
 (0)