Skip to content

Commit 9cbb029

Browse files
committed
Add LuaNativeFn/LuaNativeFnMut/LuaNativeAsyncFn traits for using in Function::wrap
1 parent 8274b5f commit 9cbb029

9 files changed

+292
-52
lines changed

src/function.rs

+56-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::{mem, ptr, slice};
55
use crate::error::{Error, Result};
66
use crate::state::Lua;
77
use crate::table::Table;
8+
use crate::traits::{LuaNativeFn, LuaNativeFnMut};
89
use crate::types::{Callback, LuaType, MaybeSend, ValueRef};
910
use crate::util::{
1011
assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
@@ -13,6 +14,7 @@ use crate::value::{FromLuaMulti, IntoLua, IntoLuaMulti, Value};
1314

1415
#[cfg(feature = "async")]
1516
use {
17+
crate::traits::LuaNativeAsyncFn,
1618
crate::types::AsyncCallback,
1719
std::future::{self, Future},
1820
};
@@ -522,55 +524,97 @@ impl Function {
522524
/// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`]
523525
/// trait.
524526
#[inline]
525-
pub fn wrap<A, R, F>(func: F) -> impl IntoLua
527+
pub fn wrap<F, A, R>(func: F) -> impl IntoLua
526528
where
529+
F: LuaNativeFn<A, Output = Result<R>> + MaybeSend + 'static,
527530
A: FromLuaMulti,
528531
R: IntoLuaMulti,
529-
F: Fn(&Lua, A) -> Result<R> + MaybeSend + 'static,
530532
{
531533
WrappedFunction(Box::new(move |lua, nargs| unsafe {
532534
let args = A::from_stack_args(nargs, 1, None, lua)?;
533-
func(lua.lua(), args)?.push_into_stack_multi(lua)
535+
func.call(args)?.push_into_stack_multi(lua)
534536
}))
535537
}
536538

537539
/// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait.
538-
#[inline]
539-
pub fn wrap_mut<A, R, F>(func: F) -> impl IntoLua
540+
pub fn wrap_mut<F, A, R>(func: F) -> impl IntoLua
540541
where
542+
F: LuaNativeFnMut<A, Output = Result<R>> + MaybeSend + 'static,
541543
A: FromLuaMulti,
542544
R: IntoLuaMulti,
543-
F: FnMut(&Lua, A) -> Result<R> + MaybeSend + 'static,
544545
{
545546
let func = RefCell::new(func);
546547
WrappedFunction(Box::new(move |lua, nargs| unsafe {
547548
let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?;
548549
let args = A::from_stack_args(nargs, 1, None, lua)?;
549-
func(lua.lua(), args)?.push_into_stack_multi(lua)
550+
func.call(args)?.push_into_stack_multi(lua)
551+
}))
552+
}
553+
554+
#[inline]
555+
pub fn wrap_raw<F, A>(func: F) -> impl IntoLua
556+
where
557+
F: LuaNativeFn<A> + MaybeSend + 'static,
558+
A: FromLuaMulti,
559+
{
560+
WrappedFunction(Box::new(move |lua, nargs| unsafe {
561+
let args = A::from_stack_args(nargs, 1, None, lua)?;
562+
func.call(args).push_into_stack_multi(lua)
563+
}))
564+
}
565+
566+
#[inline]
567+
pub fn wrap_raw_mut<F, A>(func: F) -> impl IntoLua
568+
where
569+
F: LuaNativeFnMut<A> + MaybeSend + 'static,
570+
A: FromLuaMulti,
571+
{
572+
let func = RefCell::new(func);
573+
WrappedFunction(Box::new(move |lua, nargs| unsafe {
574+
let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?;
575+
let args = A::from_stack_args(nargs, 1, None, lua)?;
576+
func.call(args).push_into_stack_multi(lua)
550577
}))
551578
}
552579

553580
/// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`]
554581
/// trait.
555582
#[cfg(feature = "async")]
556583
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
557-
pub fn wrap_async<A, R, F, FR>(func: F) -> impl IntoLua
584+
pub fn wrap_async<F, A, R>(func: F) -> impl IntoLua
558585
where
586+
F: LuaNativeAsyncFn<A, Output = Result<R>> + MaybeSend + 'static,
559587
A: FromLuaMulti,
560588
R: IntoLuaMulti,
561-
F: Fn(Lua, A) -> FR + MaybeSend + 'static,
562-
FR: Future<Output = Result<R>> + MaybeSend + 'static,
563589
{
564590
WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe {
565591
let args = match A::from_stack_args(nargs, 1, None, rawlua) {
566592
Ok(args) => args,
567593
Err(e) => return Box::pin(future::ready(Err(e))),
568594
};
569-
let lua = rawlua.lua().clone();
570-
let fut = func(lua.clone(), args);
595+
let lua = rawlua.lua();
596+
let fut = func.call(args);
571597
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
572598
}))
573599
}
600+
601+
#[cfg(feature = "async")]
602+
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
603+
pub fn wrap_raw_async<F, A>(func: F) -> impl IntoLua
604+
where
605+
F: LuaNativeAsyncFn<A> + MaybeSend + 'static,
606+
A: FromLuaMulti,
607+
{
608+
WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe {
609+
let args = match A::from_stack_args(nargs, 1, None, rawlua) {
610+
Ok(args) => args,
611+
Err(e) => return Box::pin(future::ready(Err(e))),
612+
};
613+
let lua = rawlua.lua();
614+
let fut = func.call(args);
615+
Box::pin(async move { fut.await.push_into_stack_multi(lua.raw_lua()) })
616+
}))
617+
}
574618
}
575619

576620
impl IntoLua for WrappedFunction {

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ pub use crate::stdlib::StdLib;
115115
pub use crate::string::{BorrowedBytes, BorrowedStr, String};
116116
pub use crate::table::{Table, TablePairs, TableSequence};
117117
pub use crate::thread::{Thread, ThreadStatus};
118-
pub use crate::traits::ObjectLike;
118+
pub use crate::traits::{LuaNativeAsyncFn, LuaNativeFn, LuaNativeFnMut, ObjectLike};
119119
pub use crate::types::{
120120
AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
121121
};

src/prelude.rs

+10-9
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ pub use crate::{
55
AnyUserData as LuaAnyUserData, Chunk as LuaChunk, Error as LuaError, ErrorContext as LuaErrorContext,
66
ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti,
77
Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger,
8-
IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaOptions, MetaMethod as LuaMetaMethod,
9-
MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, ObjectLike as LuaObjectLike,
10-
RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, String as LuaString,
11-
Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, Thread as LuaThread,
12-
ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields,
13-
UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods,
14-
UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut,
15-
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState,
8+
IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions,
9+
MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber,
10+
ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib,
11+
String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence,
12+
Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData,
13+
UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable,
14+
UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef,
15+
UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue,
16+
VmState as LuaVmState,
1617
};
1718

1819
#[cfg(not(feature = "luau"))]
@@ -25,7 +26,7 @@ pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};
2526

2627
#[cfg(feature = "async")]
2728
#[doc(no_inline)]
28-
pub use crate::AsyncThread as LuaAsyncThread;
29+
pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn};
2930

3031
#[cfg(feature = "serialize")]
3132
#[doc(no_inline)]

src/state.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,12 @@ impl Lua {
15531553
T::from_lua(value, self)
15541554
}
15551555

1556+
/// Converts a value that implements `IntoLua` into a `FromLua` variant.
1557+
#[inline]
1558+
pub fn convert<U: FromLua>(&self, value: impl IntoLua) -> Result<U> {
1559+
U::from_lua(value.into_lua(self)?, self)
1560+
}
1561+
15561562
/// Converts a value that implements `IntoLuaMulti` into a `MultiValue` instance.
15571563
#[inline]
15581564
pub fn pack_multi(&self, t: impl IntoLuaMulti) -> Result<MultiValue> {

src/traits.rs

+92
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::string::String as StdString;
22

33
use crate::error::Result;
44
use crate::private::Sealed;
5+
use crate::types::MaybeSend;
56
use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti};
67

78
#[cfg(feature = "async")]
@@ -76,3 +77,94 @@ pub trait ObjectLike: Sealed {
7677
/// This might invoke the `__tostring` metamethod.
7778
fn to_string(&self) -> Result<StdString>;
7879
}
80+
81+
/// A trait for types that can be used as Lua functions.
82+
pub trait LuaNativeFn<A: FromLuaMulti> {
83+
type Output: IntoLuaMulti;
84+
85+
fn call(&self, args: A) -> Self::Output;
86+
}
87+
88+
/// A trait for types with mutable state that can be used as Lua functions.
89+
pub trait LuaNativeFnMut<A: FromLuaMulti> {
90+
type Output: IntoLuaMulti;
91+
92+
fn call(&mut self, args: A) -> Self::Output;
93+
}
94+
95+
/// A trait for types that returns a future and can be used as Lua functions.
96+
#[cfg(feature = "async")]
97+
pub trait LuaNativeAsyncFn<A: FromLuaMulti> {
98+
type Output: IntoLuaMulti;
99+
100+
fn call(&self, args: A) -> impl Future<Output = Self::Output> + MaybeSend + 'static;
101+
}
102+
103+
macro_rules! impl_lua_native_fn {
104+
($($A:ident),*) => {
105+
impl<FN, $($A,)* R> LuaNativeFn<($($A,)*)> for FN
106+
where
107+
FN: Fn($($A,)*) -> R + MaybeSend + 'static,
108+
($($A,)*): FromLuaMulti,
109+
R: IntoLuaMulti,
110+
{
111+
type Output = R;
112+
113+
#[allow(non_snake_case)]
114+
fn call(&self, args: ($($A,)*)) -> Self::Output {
115+
let ($($A,)*) = args;
116+
self($($A,)*)
117+
}
118+
}
119+
120+
impl<FN, $($A,)* R> LuaNativeFnMut<($($A,)*)> for FN
121+
where
122+
FN: FnMut($($A,)*) -> R + MaybeSend + 'static,
123+
($($A,)*): FromLuaMulti,
124+
R: IntoLuaMulti,
125+
{
126+
type Output = R;
127+
128+
#[allow(non_snake_case)]
129+
fn call(&mut self, args: ($($A,)*)) -> Self::Output {
130+
let ($($A,)*) = args;
131+
self($($A,)*)
132+
}
133+
}
134+
135+
#[cfg(feature = "async")]
136+
impl<FN, $($A,)* Fut, R> LuaNativeAsyncFn<($($A,)*)> for FN
137+
where
138+
FN: Fn($($A,)*) -> Fut + MaybeSend + 'static,
139+
($($A,)*): FromLuaMulti,
140+
Fut: Future<Output = R> + MaybeSend + 'static,
141+
R: IntoLuaMulti,
142+
{
143+
type Output = R;
144+
145+
#[allow(non_snake_case)]
146+
fn call(&self, args: ($($A,)*)) -> impl Future<Output = Self::Output> + MaybeSend + 'static {
147+
let ($($A,)*) = args;
148+
self($($A,)*)
149+
}
150+
}
151+
};
152+
}
153+
154+
impl_lua_native_fn!();
155+
impl_lua_native_fn!(A);
156+
impl_lua_native_fn!(A, B);
157+
impl_lua_native_fn!(A, B, C);
158+
impl_lua_native_fn!(A, B, C, D);
159+
impl_lua_native_fn!(A, B, C, D, E);
160+
impl_lua_native_fn!(A, B, C, D, E, F);
161+
impl_lua_native_fn!(A, B, C, D, E, F, G);
162+
impl_lua_native_fn!(A, B, C, D, E, F, G, H);
163+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I);
164+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J);
165+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K);
166+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L);
167+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M);
168+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
169+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
170+
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);

tests/async.rs

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![cfg(feature = "async")]
22

3+
use std::string::String as StdString;
34
use std::sync::Arc;
45
use std::time::Duration;
56

@@ -39,12 +40,38 @@ async fn test_async_function() -> Result<()> {
3940
async fn test_async_function_wrap() -> Result<()> {
4041
let lua = Lua::new();
4142

42-
let f = Function::wrap_async(|_, s: String| async move { Ok(s) });
43+
let f = Function::wrap_async(|s: StdString| async move {
44+
tokio::task::yield_now().await;
45+
Ok(s)
46+
});
4347
lua.globals().set("f", f)?;
48+
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
49+
assert_eq!(res, "hello");
4450

51+
Ok(())
52+
}
53+
54+
#[tokio::test]
55+
async fn test_async_function_wrap_raw() -> Result<()> {
56+
let lua = Lua::new();
57+
58+
let f = Function::wrap_raw_async(|s: StdString| async move {
59+
tokio::task::yield_now().await;
60+
s
61+
});
62+
lua.globals().set("f", f)?;
4563
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
4664
assert_eq!(res, "hello");
4765

66+
// Return error
67+
let ferr = Function::wrap_raw_async(|| async move {
68+
tokio::task::yield_now().await;
69+
Err::<(), _>("some error")
70+
});
71+
lua.globals().set("ferr", ferr)?;
72+
let (_, err): (Value, String) = lua.load(r#"ferr()"#).eval_async().await?;
73+
assert_eq!(err, "some error");
74+
4875
Ok(())
4976
}
5077

tests/chunk.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ fn test_chunk_macro() -> Result<()> {
4242
data.raw_set("num", 1)?;
4343

4444
let ud = mlua::AnyUserData::wrap("hello");
45-
let f = mlua::Function::wrap(|_lua, ()| Ok(()));
45+
let f = mlua::Function::wrap(|| Ok(()));
4646

4747
lua.globals().set("g", 123)?;
4848

0 commit comments

Comments
 (0)