diff --git a/lightning-background-processor/Cargo.toml b/lightning-background-processor/Cargo.toml index 80794ea3403..aa91378129f 100644 --- a/lightning-background-processor/Cargo.toml +++ b/lightning-background-processor/Cargo.toml @@ -15,7 +15,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] futures = [ ] -std = ["lightning/std", "bitcoin-io/std", "bitcoin_hashes/std"] +std = ["lightning/std", "lightning-liquidity/std", "bitcoin-io/std", "bitcoin_hashes/std"] default = ["std"] @@ -25,6 +25,7 @@ bitcoin_hashes = { version = "0.14.0", default-features = false } bitcoin-io = { version = "0.1.2", default-features = false } lightning = { version = "0.2.0", path = "../lightning", default-features = false } lightning-rapid-gossip-sync = { version = "0.2.0", path = "../lightning-rapid-gossip-sync", default-features = false } +lightning-liquidity = { version = "0.2.0", path = "../lightning-liquidity", default-features = false } [dev-dependencies] tokio = { version = "1.35", features = [ "macros", "rt", "rt-multi-thread", "sync", "time" ] } diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 46d990bb37e..1f7147f3203 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -42,6 +42,8 @@ use lightning::util::persist::Persister; use lightning::util::wakers::Sleeper; use lightning_rapid_gossip_sync::RapidGossipSync; +use lightning_liquidity::ALiquidityManager; + use core::ops::Deref; use core::time::Duration; @@ -417,7 +419,9 @@ macro_rules! define_run_body { log_trace!($logger, "Pruning and persisting network graph."); network_graph.remove_stale_channels_and_tracking_with_time(duration_since_epoch.as_secs()); } else { - log_warn!($logger, "Not pruning network graph, consider enabling `std` or doing so manually with remove_stale_channels_and_tracking_with_time."); + log_warn!($logger, + "Not pruning network graph, consider implementing the fetch_time argument or calling remove_stale_channels_and_tracking_with_time manually." + ); log_trace!($logger, "Persisting network graph."); } @@ -492,27 +496,31 @@ pub(crate) mod futures_util { A: Future + Unpin, B: Future + Unpin, C: Future + Unpin, - D: Future + Unpin, + D: Future + Unpin, + E: Future + Unpin, > { pub a: A, pub b: B, pub c: C, pub d: D, + pub e: E, } pub(crate) enum SelectorOutput { A, B, C, - D(bool), + D, + E(bool), } impl< A: Future + Unpin, B: Future + Unpin, C: Future + Unpin, - D: Future + Unpin, - > Future for Selector + D: Future + Unpin, + E: Future + Unpin, + > Future for Selector { type Output = SelectorOutput; fn poll( @@ -537,8 +545,14 @@ pub(crate) mod futures_util { Poll::Pending => {}, } match Pin::new(&mut self.d).poll(ctx) { + Poll::Ready(()) => { + return Poll::Ready(SelectorOutput::D); + }, + Poll::Pending => {}, + } + match Pin::new(&mut self.e).poll(ctx) { Poll::Ready(res) => { - return Poll::Ready(SelectorOutput::D(res)); + return Poll::Ready(SelectorOutput::E(res)); }, Poll::Pending => {}, } @@ -600,11 +614,6 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// /// See [`BackgroundProcessor::start`] for information on which actions this handles. /// -/// Requires the `futures` feature. Note that while this method is available without the `std` -/// feature, doing so will skip calling [`NetworkGraph::remove_stale_channels_and_tracking`], -/// you should call [`NetworkGraph::remove_stale_channels_and_tracking_with_time`] regularly -/// manually instead. -/// /// The `mobile_interruptable_platform` flag should be set if we're currently running on a /// mobile device, where we may need to check for interruption of the application regularly. If you /// are unsure, you should set the flag, as the performance impact of it is minimal unless there @@ -648,6 +657,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// # type P2PGossipSync
    = lightning::routing::gossip::P2PGossipSync, Arc
      , Arc>; /// # type ChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager, B, FE, Logger>; /// # type OnionMessenger = lightning::onion_message::messenger::OnionMessenger, Arc, Arc, Arc>, Arc, Arc, Arc>>, Arc>, lightning::ln::peer_handler::IgnoringMessageHandler, lightning::ln::peer_handler::IgnoringMessageHandler, lightning::ln::peer_handler::IgnoringMessageHandler>; +/// # type LiquidityManager = lightning_liquidity::LiquidityManager, Arc>, Arc>; /// # type Scorer = RwLock, Arc>>; /// # type PeerManager = lightning::ln::peer_handler::SimpleArcPeerManager, B, FE, Arc
        , Logger>; /// # @@ -661,6 +671,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// # event_handler: Arc, /// # channel_manager: Arc>, /// # onion_messenger: Arc>, +/// # liquidity_manager: Arc>, /// # chain_monitor: Arc>, /// # gossip_sync: Arc>, /// # persister: Arc, @@ -681,25 +692,34 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// let background_gossip_sync = GossipSync::p2p(Arc::clone(&node.gossip_sync)); /// let background_peer_man = Arc::clone(&node.peer_manager); /// let background_onion_messenger = Arc::clone(&node.onion_messenger); +/// let background_liquidity_manager = Arc::clone(&node.liquidity_manager); /// let background_logger = Arc::clone(&node.logger); /// let background_scorer = Arc::clone(&node.scorer); /// /// // Setup the sleeper. -/// let (stop_sender, stop_receiver) = tokio::sync::watch::channel(()); -/// +#[cfg_attr( + feature = "std", + doc = " let (stop_sender, stop_receiver) = tokio::sync::watch::channel(());" +)] +#[cfg_attr(feature = "std", doc = "")] /// let sleeper = move |d| { -/// let mut receiver = stop_receiver.clone(); +#[cfg_attr(feature = "std", doc = " let mut receiver = stop_receiver.clone();")] /// Box::pin(async move { /// tokio::select!{ /// _ = tokio::time::sleep(d) => false, -/// _ = receiver.changed() => true, +#[cfg_attr(feature = "std", doc = " _ = receiver.changed() => true,")] /// } /// }) /// }; /// /// let mobile_interruptable_platform = false; /// -/// let handle = tokio::spawn(async move { +#[cfg_attr(feature = "std", doc = " let handle = tokio::spawn(async move {")] +#[cfg_attr( + not(feature = "std"), + doc = " let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();" +)] +#[cfg_attr(not(feature = "std"), doc = " rt.block_on(async move {")] /// process_events_async( /// background_persister, /// |e| background_event_handler.handle_event(e), @@ -708,6 +728,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// Some(background_onion_messenger), /// background_gossip_sync, /// background_peer_man, +/// Some(background_liquidity_manager), /// background_logger, /// Some(background_scorer), /// sleeper, @@ -719,20 +740,20 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// }); /// /// // Stop the background processing. -/// stop_sender.send(()).unwrap(); -/// handle.await.unwrap(); +#[cfg_attr(feature = "std", doc = " stop_sender.send(()).unwrap();")] +#[cfg_attr(feature = "std", doc = " handle.await.unwrap()")] /// # } ///``` #[cfg(feature = "futures")] pub async fn process_events_async< 'a, - UL: 'static + Deref + Send + Sync, - CF: 'static + Deref + Send + Sync, - T: 'static + Deref + Send + Sync, - F: 'static + Deref + Send + Sync, - G: 'static + Deref> + Send + Sync, - L: 'static + Deref + Send + Sync, - P: 'static + Deref + Send + Sync, + UL: 'static + Deref, + CF: 'static + Deref, + T: 'static + Deref, + F: 'static + Deref, + G: 'static + Deref>, + L: 'static + Deref, + P: 'static + Deref, EventHandlerFuture: core::future::Future>, EventHandler: Fn(Event) -> EventHandlerFuture, PS: 'static + Deref + Send, @@ -740,11 +761,12 @@ pub async fn process_events_async< + Deref::Signer, CF, T, F, L, P>> + Send + Sync, - CM: 'static + Deref + Send + Sync, - OM: 'static + Deref + Send + Sync, - PGS: 'static + Deref> + Send + Sync, - RGS: 'static + Deref> + Send, - PM: 'static + Deref + Send + Sync, + CM: 'static + Deref, + OM: 'static + Deref, + PGS: 'static + Deref>, + RGS: 'static + Deref>, + PM: 'static + Deref, + LM: 'static + Deref, S: 'static + Deref + Send + Sync, SC: for<'b> WriteableScore<'b>, SleepFuture: core::future::Future + core::marker::Unpin, @@ -753,8 +775,8 @@ pub async fn process_events_async< >( persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM, onion_messenger: Option, gossip_sync: GossipSync, peer_manager: PM, - logger: L, scorer: Option, sleeper: Sleeper, mobile_interruptable_platform: bool, - fetch_time: FetchTime, + liquidity_manager: Option, logger: L, scorer: Option, sleeper: Sleeper, + mobile_interruptable_platform: bool, fetch_time: FetchTime, ) -> Result<(), lightning::io::Error> where UL::Target: 'static + UtxoLookup, @@ -764,9 +786,10 @@ where L::Target: 'static + Logger, P::Target: 'static + Persist<::Signer>, PS::Target: 'static + Persister<'a, CM, L, S>, - CM::Target: AChannelManager + Send + Sync, - OM::Target: AOnionMessenger + Send + Sync, - PM::Target: APeerManager + Send + Sync, + CM::Target: AChannelManager, + OM::Target: AOnionMessenger, + PM::Target: APeerManager, + LM::Target: ALiquidityManager, { let mut should_break = false; let async_event_handler = |event| { @@ -820,19 +843,26 @@ where } else { OptionalSelector { optional_future: None } }; + let lm_fut = if let Some(lm) = liquidity_manager.as_ref() { + let fut = lm.get_lm().get_pending_msgs_future(); + OptionalSelector { optional_future: Some(fut) } + } else { + OptionalSelector { optional_future: None } + }; let fut = Selector { a: channel_manager.get_cm().get_event_or_persistence_needed_future(), b: chain_monitor.get_update_future(), c: om_fut, - d: sleeper(if mobile_interruptable_platform { + d: lm_fut, + e: sleeper(if mobile_interruptable_platform { Duration::from_millis(100) } else { Duration::from_secs(FASTEST_TIMER) }), }; match fut.await { - SelectorOutput::A | SelectorOutput::B | SelectorOutput::C => {}, - SelectorOutput::D(exit) => { + SelectorOutput::A | SelectorOutput::B | SelectorOutput::C | SelectorOutput::D => {}, + SelectorOutput::E(exit) => { should_break = exit; }, } @@ -902,30 +932,31 @@ impl BackgroundProcessor { /// [`NetworkGraph::write`]: lightning::routing::gossip::NetworkGraph#impl-Writeable pub fn start< 'a, - UL: 'static + Deref + Send + Sync, - CF: 'static + Deref + Send + Sync, - T: 'static + Deref + Send + Sync, - F: 'static + Deref + Send + Sync, - G: 'static + Deref> + Send + Sync, - L: 'static + Deref + Send + Sync, - P: 'static + Deref + Send + Sync, + UL: 'static + Deref, + CF: 'static + Deref, + T: 'static + Deref, + F: 'static + Deref + Send, + G: 'static + Deref>, + L: 'static + Deref + Send, + P: 'static + Deref, EH: 'static + EventHandler + Send, PS: 'static + Deref + Send, M: 'static + Deref::Signer, CF, T, F, L, P>> + Send + Sync, - CM: 'static + Deref + Send + Sync, - OM: 'static + Deref + Send + Sync, - PGS: 'static + Deref> + Send + Sync, + CM: 'static + Deref + Send, + OM: 'static + Deref + Send, + PGS: 'static + Deref> + Send, RGS: 'static + Deref> + Send, - PM: 'static + Deref + Send + Sync, + PM: 'static + Deref + Send, + LM: 'static + Deref + Send, S: 'static + Deref + Send + Sync, SC: for<'b> WriteableScore<'b>, >( persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM, onion_messenger: Option, gossip_sync: GossipSync, peer_manager: PM, - logger: L, scorer: Option, + liquidity_manager: Option, logger: L, scorer: Option, ) -> Self where UL::Target: 'static + UtxoLookup, @@ -935,9 +966,10 @@ impl BackgroundProcessor { L::Target: 'static + Logger, P::Target: 'static + Persist<::Signer>, PS::Target: 'static + Persister<'a, CM, L, S>, - CM::Target: AChannelManager + Send + Sync, - OM::Target: AOnionMessenger + Send + Sync, - PM::Target: APeerManager + Send + Sync, + CM::Target: AChannelManager, + OM::Target: AOnionMessenger, + PM::Target: APeerManager, + LM::Target: ALiquidityManager, { let stop_thread = Arc::new(AtomicBool::new(false)); let stop_thread_clone = stop_thread.clone(); @@ -977,17 +1009,27 @@ impl BackgroundProcessor { scorer, stop_thread.load(Ordering::Acquire), { - let sleeper = if let Some(om) = onion_messenger.as_ref() { - Sleeper::from_three_futures( + let sleeper = match (onion_messenger.as_ref(), liquidity_manager.as_ref()) { + (Some(om), Some(lm)) => Sleeper::from_four_futures( &channel_manager.get_cm().get_event_or_persistence_needed_future(), &chain_monitor.get_update_future(), &om.get_om().get_update_future(), - ) - } else { - Sleeper::from_two_futures( + &lm.get_lm().get_pending_msgs_future(), + ), + (Some(om), None) => Sleeper::from_three_futures( + &channel_manager.get_cm().get_event_or_persistence_needed_future(), + &chain_monitor.get_update_future(), + &om.get_om().get_update_future(), + ), + (None, Some(lm)) => Sleeper::from_three_futures( &channel_manager.get_cm().get_event_or_persistence_needed_future(), &chain_monitor.get_update_future(), - ) + &lm.get_lm().get_pending_msgs_future(), + ), + (None, None) => Sleeper::from_two_futures( + &channel_manager.get_cm().get_event_or_persistence_needed_future(), + &chain_monitor.get_update_future(), + ), }; sleeper.wait_timeout(Duration::from_millis(100)); }, @@ -1100,6 +1142,7 @@ mod tests { use lightning::util::sweep::{OutputSpendStatus, OutputSweeper, PRUNE_DELAY_BLOCKS}; use lightning::util::test_utils; use lightning::{get_event, get_event_msg}; + use lightning_liquidity::LiquidityManager; use lightning_persister::fs_store::FilesystemStore; use lightning_rapid_gossip_sync::RapidGossipSync; use std::collections::VecDeque; @@ -1194,6 +1237,9 @@ mod tests { IgnoringMessageHandler, >; + type LM = + LiquidityManager, Arc, Arc>; + struct Node { node: Arc, messenger: Arc, @@ -1210,6 +1256,7 @@ mod tests { Arc, >, >, + liquidity_manager: Arc, chain_monitor: Arc, kv_store: Arc, tx_broadcaster: Arc, @@ -1629,11 +1676,20 @@ mod tests { logger.clone(), keys_manager.clone(), )); + let liquidity_manager = Arc::new(LiquidityManager::new( + Arc::clone(&keys_manager), + Arc::clone(&manager), + None, + None, + None, + None, + )); let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, + liquidity_manager, chain_monitor, kv_store, tx_broadcaster, @@ -1831,6 +1887,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -1924,6 +1981,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -1966,6 +2024,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -1998,6 +2057,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { @@ -2034,6 +2094,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2063,6 +2124,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2109,6 +2171,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2171,6 +2234,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2322,6 +2386,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2351,6 +2416,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2446,6 +2512,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2478,6 +2545,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { @@ -2640,6 +2708,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2690,6 +2759,7 @@ mod tests { Some(nodes[0].messenger.clone()), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), + Some(Arc::clone(&nodes[0].liquidity_manager)), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { diff --git a/lightning-invoice/Cargo.toml b/lightning-invoice/Cargo.toml index 7c49b2177a0..8e0c7587f4f 100644 --- a/lightning-invoice/Cargo.toml +++ b/lightning-invoice/Cargo.toml @@ -20,7 +20,7 @@ std = [] [dependencies] bech32 = { version = "0.11.0", default-features = false } lightning-types = { version = "0.3.0", path = "../lightning-types", default-features = false } -serde = { version = "1.0.118", optional = true } +serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } bitcoin = { version = "0.32.2", default-features = false, features = ["secp-recovery"] } [dev-dependencies] diff --git a/lightning-liquidity/Cargo.toml b/lightning-liquidity/Cargo.toml index 1cc0d988544..0733d387b15 100644 --- a/lightning-liquidity/Cargo.toml +++ b/lightning-liquidity/Cargo.toml @@ -27,7 +27,7 @@ bitcoin = { version = "0.32.2", default-features = false, features = ["serde"] } chrono = { version = "0.4", default-features = false, features = ["serde", "alloc"] } serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } -serde_json = "1.0" +serde_json = { version = "1.0", default-features = false, features = ["alloc"] } backtrace = { version = "0.3", optional = true } [dev-dependencies] diff --git a/lightning-liquidity/src/events.rs b/lightning-liquidity/src/events/event_queue.rs similarity index 63% rename from lightning-liquidity/src/events.rs rename to lightning-liquidity/src/events/event_queue.rs index 46308c7446c..a2589beb4e2 100644 --- a/lightning-liquidity/src/events.rs +++ b/lightning-liquidity/src/events/event_queue.rs @@ -1,23 +1,4 @@ -// This file is Copyright its original authors, visible in version control -// history. -// -// This file is licensed under the Apache License, Version 2.0 or the MIT license -// , at your option. -// You may not use this file except in accordance with one or both of these -// licenses. - -//! Events are surfaced by the library to indicate some action must be taken -//! by the end-user. -//! -//! Because we don't have a built-in runtime, it's up to the end-user to poll -//! [`LiquidityManager::get_and_clear_pending_events`] to receive events. -//! -//! [`LiquidityManager::get_and_clear_pending_events`]: crate::LiquidityManager::get_and_clear_pending_events - -use crate::lsps0; -use crate::lsps1; -use crate::lsps2; +use super::LiquidityEvent; use crate::sync::{Arc, Mutex}; use alloc::collections::VecDeque; @@ -33,37 +14,19 @@ pub(crate) struct EventQueue { queue: Arc>>, waker: Arc>>, #[cfg(feature = "std")] - condvar: crate::sync::Condvar, + condvar: Arc, } impl EventQueue { pub fn new() -> Self { let queue = Arc::new(Mutex::new(VecDeque::new())); let waker = Arc::new(Mutex::new(None)); - #[cfg(feature = "std")] - { - let condvar = crate::sync::Condvar::new(); - Self { queue, waker, condvar } + Self { + queue, + waker, + #[cfg(feature = "std")] + condvar: Arc::new(crate::sync::Condvar::new()), } - #[cfg(not(feature = "std"))] - Self { queue, waker } - } - - pub fn enqueue>(&self, event: E) { - { - let mut queue = self.queue.lock().unwrap(); - if queue.len() < MAX_EVENT_QUEUE_SIZE { - queue.push_back(event.into()); - } else { - return; - } - } - - if let Some(waker) = self.waker.lock().unwrap().take() { - waker.wake(); - } - #[cfg(feature = "std")] - self.condvar.notify_one(); } pub fn next_event(&self) -> Option { @@ -102,52 +65,40 @@ impl EventQueue { pub fn get_and_clear_pending_events(&self) -> Vec { self.queue.lock().unwrap().split_off(0).into() } -} -/// An event which you should probably take some action in response to. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum LiquidityEvent { - /// An LSPS0 client event. - LSPS0Client(lsps0::event::LSPS0ClientEvent), - /// An LSPS1 (Channel Request) client event. - LSPS1Client(lsps1::event::LSPS1ClientEvent), - /// An LSPS1 (Channel Request) server event. - #[cfg(lsps1_service)] - LSPS1Service(lsps1::event::LSPS1ServiceEvent), - /// An LSPS2 (JIT Channel) client event. - LSPS2Client(lsps2::event::LSPS2ClientEvent), - /// An LSPS2 (JIT Channel) server event. - LSPS2Service(lsps2::event::LSPS2ServiceEvent), -} - -impl From for LiquidityEvent { - fn from(event: lsps0::event::LSPS0ClientEvent) -> Self { - Self::LSPS0Client(event) + // Returns an [`EventQueueNotifierGuard`] that will notify about new event when dropped. + pub fn notifier(&self) -> EventQueueNotifierGuard { + EventQueueNotifierGuard(self) } } -impl From for LiquidityEvent { - fn from(event: lsps1::event::LSPS1ClientEvent) -> Self { - Self::LSPS1Client(event) - } -} +// A guard type that will notify about new events when dropped. +#[must_use] +pub(crate) struct EventQueueNotifierGuard<'a>(&'a EventQueue); -#[cfg(lsps1_service)] -impl From for LiquidityEvent { - fn from(event: lsps1::event::LSPS1ServiceEvent) -> Self { - Self::LSPS1Service(event) +impl<'a> EventQueueNotifierGuard<'a> { + pub fn enqueue>(&self, event: E) { + let mut queue = self.0.queue.lock().unwrap(); + if queue.len() < MAX_EVENT_QUEUE_SIZE { + queue.push_back(event.into()); + } else { + return; + } } } -impl From for LiquidityEvent { - fn from(event: lsps2::event::LSPS2ClientEvent) -> Self { - Self::LSPS2Client(event) - } -} +impl<'a> Drop for EventQueueNotifierGuard<'a> { + fn drop(&mut self) { + let should_notify = !self.0.queue.lock().unwrap().is_empty(); -impl From for LiquidityEvent { - fn from(event: lsps2::event::LSPS2ServiceEvent) -> Self { - Self::LSPS2Service(event) + if should_notify { + if let Some(waker) = self.0.waker.lock().unwrap().take() { + waker.wake(); + } + + #[cfg(feature = "std")] + self.0.condvar.notify_one(); + } } } @@ -195,7 +146,8 @@ mod tests { }); for _ in 0..3 { - event_queue.enqueue(expected_event.clone()); + let guard = event_queue.notifier(); + guard.enqueue(expected_event.clone()); } assert_eq!(event_queue.wait_next_event(), expected_event); @@ -220,14 +172,16 @@ mod tests { let mut delayed_enqueue = false; for _ in 0..25 { - event_queue.enqueue(expected_event.clone()); + let guard = event_queue.notifier(); + guard.enqueue(expected_event.clone()); enqueued_events.fetch_add(1, Ordering::SeqCst); } loop { tokio::select! { _ = tokio::time::sleep(Duration::from_millis(10)), if !delayed_enqueue => { - event_queue.enqueue(expected_event.clone()); + let guard = event_queue.notifier(); + guard.enqueue(expected_event.clone()); enqueued_events.fetch_add(1, Ordering::SeqCst); delayed_enqueue = true; } @@ -235,7 +189,8 @@ mod tests { assert_eq!(e, expected_event); received_events.fetch_add(1, Ordering::SeqCst); - event_queue.enqueue(expected_event.clone()); + let guard = event_queue.notifier(); + guard.enqueue(expected_event.clone()); enqueued_events.fetch_add(1, Ordering::SeqCst); } e = event_queue.next_event_async() => { @@ -267,8 +222,9 @@ mod tests { std::thread::spawn(move || { // Sleep a bit before we enqueue the events everybody is waiting for. std::thread::sleep(Duration::from_millis(20)); - thread_queue.enqueue(thread_event.clone()); - thread_queue.enqueue(thread_event.clone()); + let guard = thread_queue.notifier(); + guard.enqueue(thread_event.clone()); + guard.enqueue(thread_event.clone()); }); let e = event_queue.next_event_async().await; diff --git a/lightning-liquidity/src/events/mod.rs b/lightning-liquidity/src/events/mod.rs new file mode 100644 index 00000000000..506b91494c3 --- /dev/null +++ b/lightning-liquidity/src/events/mod.rs @@ -0,0 +1,72 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! Events are surfaced by the library to indicate some action must be taken +//! by the end-user. +//! +//! Because we don't have a built-in runtime, it's up to the end-user to poll +//! [`LiquidityManager::get_and_clear_pending_events`] to receive events. +//! +//! [`LiquidityManager::get_and_clear_pending_events`]: crate::LiquidityManager::get_and_clear_pending_events + +mod event_queue; + +pub(crate) use event_queue::EventQueue; +pub use event_queue::MAX_EVENT_QUEUE_SIZE; + +use crate::lsps0; +use crate::lsps1; +use crate::lsps2; + +/// An event which you should probably take some action in response to. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LiquidityEvent { + /// An LSPS0 client event. + LSPS0Client(lsps0::event::LSPS0ClientEvent), + /// An LSPS1 (Channel Request) client event. + LSPS1Client(lsps1::event::LSPS1ClientEvent), + /// An LSPS1 (Channel Request) server event. + #[cfg(lsps1_service)] + LSPS1Service(lsps1::event::LSPS1ServiceEvent), + /// An LSPS2 (JIT Channel) client event. + LSPS2Client(lsps2::event::LSPS2ClientEvent), + /// An LSPS2 (JIT Channel) server event. + LSPS2Service(lsps2::event::LSPS2ServiceEvent), +} + +impl From for LiquidityEvent { + fn from(event: lsps0::event::LSPS0ClientEvent) -> Self { + Self::LSPS0Client(event) + } +} + +impl From for LiquidityEvent { + fn from(event: lsps1::event::LSPS1ClientEvent) -> Self { + Self::LSPS1Client(event) + } +} + +#[cfg(lsps1_service)] +impl From for LiquidityEvent { + fn from(event: lsps1::event::LSPS1ServiceEvent) -> Self { + Self::LSPS1Service(event) + } +} + +impl From for LiquidityEvent { + fn from(event: lsps2::event::LSPS2ClientEvent) -> Self { + Self::LSPS2Client(event) + } +} + +impl From for LiquidityEvent { + fn from(event: lsps2::event::LSPS2ServiceEvent) -> Self { + Self::LSPS2Service(event) + } +} diff --git a/lightning-liquidity/src/lib.rs b/lightning-liquidity/src/lib.rs index 909590eac96..5fb59c319c8 100644 --- a/lightning-liquidity/src/lib.rs +++ b/lightning-liquidity/src/lib.rs @@ -68,4 +68,6 @@ mod sync; mod tests; mod utils; -pub use manager::{LiquidityClientConfig, LiquidityManager, LiquidityServiceConfig}; +pub use manager::{ + ALiquidityManager, LiquidityClientConfig, LiquidityManager, LiquidityServiceConfig, +}; diff --git a/lightning-liquidity/src/lsps0/client.rs b/lightning-liquidity/src/lsps0/client.rs index 7b049e65566..5ae73005e61 100644 --- a/lightning-liquidity/src/lsps0/client.rs +++ b/lightning-liquidity/src/lsps0/client.rs @@ -61,9 +61,11 @@ where fn handle_response( &self, response: LSPS0Response, counterparty_node_id: &PublicKey, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + match response { LSPS0Response::ListProtocols(LSPS0ListProtocolsResponse { protocols }) => { - self.pending_events.enqueue(LSPS0ClientEvent::ListProtocolsResponse { + event_queue_notifier.enqueue(LSPS0ClientEvent::ListProtocolsResponse { counterparty_node_id: *counterparty_node_id, protocols, }); diff --git a/lightning-liquidity/src/lsps1/client.rs b/lightning-liquidity/src/lsps1/client.rs index d0050abe4b1..b1b7b6a2493 100644 --- a/lightning-liquidity/src/lsps1/client.rs +++ b/lightning-liquidity/src/lsps1/client.rs @@ -110,8 +110,9 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, result: LSPS1GetInfoResponse, ) -> Result<(), LightningError> { - let outer_state_lock = self.per_peer_state.write().unwrap(); + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.write().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { let mut peer_state_lock = inner_state_lock.lock().unwrap(); @@ -126,7 +127,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::SupportedOptionsReady { + event_queue_notifier.enqueue(LSPS1ClientEvent::SupportedOptionsReady { counterparty_node_id: *counterparty_node_id, supported_options: result.options, request_id, @@ -147,6 +148,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, error: LSPSResponseError, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -162,7 +165,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::SupportedOptionsRequestFailed { + event_queue_notifier.enqueue(LSPS1ClientEvent::SupportedOptionsRequestFailed { request_id: request_id.clone(), counterparty_node_id: *counterparty_node_id, error: error.clone(), @@ -224,6 +227,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, response: LSPS1CreateOrderResponse, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -239,7 +244,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::OrderCreated { + event_queue_notifier.enqueue(LSPS1ClientEvent::OrderCreated { request_id, counterparty_node_id: *counterparty_node_id, order_id: response.order_id, @@ -266,6 +271,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, error: LSPSResponseError, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -281,7 +288,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::OrderRequestFailed { + event_queue_notifier.enqueue(LSPS1ClientEvent::OrderRequestFailed { request_id: request_id.clone(), counterparty_node_id: *counterparty_node_id, error: error.clone(), @@ -343,6 +350,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, response: LSPS1CreateOrderResponse, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -358,7 +367,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::OrderStatus { + event_queue_notifier.enqueue(LSPS1ClientEvent::OrderStatus { request_id, counterparty_node_id: *counterparty_node_id, order_id: response.order_id, @@ -385,6 +394,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, error: LSPSResponseError, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -400,7 +411,7 @@ where }); } - self.pending_events.enqueue(LSPS1ClientEvent::OrderRequestFailed { + event_queue_notifier.enqueue(LSPS1ClientEvent::OrderRequestFailed { request_id: request_id.clone(), counterparty_node_id: *counterparty_node_id, error: error.clone(), diff --git a/lightning-liquidity/src/lsps1/service.rs b/lightning-liquidity/src/lsps1/service.rs index 4b1cdcbf287..28fe72ca905 100644 --- a/lightning-liquidity/src/lsps1/service.rs +++ b/lightning-liquidity/src/lsps1/service.rs @@ -198,6 +198,7 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, params: LSPS1CreateOrderRequest, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); if !is_valid(¶ms.order, &self.config.supported_options.as_ref().unwrap()) { let response = LSPS1Response::CreateOrderError(LSPSResponseError { code: LSPS1_CREATE_ORDER_REQUEST_ORDER_MISMATCH_ERROR_CODE, @@ -231,7 +232,7 @@ where .insert(request_id.clone(), LSPS1Request::CreateOrder(params.clone())); } - self.pending_events.enqueue(LSPS1ServiceEvent::RequestForPaymentDetails { + event_queue_notifier.enqueue(LSPS1ServiceEvent::RequestForPaymentDetails { request_id, counterparty_node_id: *counterparty_node_id, order: params.order, @@ -315,6 +316,7 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, params: LSPS1GetOrderRequest, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -333,7 +335,7 @@ where if let Err(e) = outbound_channel.awaiting_payment() { peer_state_lock.outbound_channels_by_order_id.remove(¶ms.order_id); - self.pending_events.enqueue(LSPS1ServiceEvent::Refund { + event_queue_notifier.enqueue(LSPS1ServiceEvent::Refund { request_id, counterparty_node_id: *counterparty_node_id, order_id: params.order_id, @@ -345,7 +347,7 @@ where .pending_requests .insert(request_id.clone(), LSPS1Request::GetOrder(params.clone())); - self.pending_events.enqueue(LSPS1ServiceEvent::CheckPaymentConfirmation { + event_queue_notifier.enqueue(LSPS1ServiceEvent::CheckPaymentConfirmation { request_id, counterparty_node_id: *counterparty_node_id, order_id: params.order_id, diff --git a/lightning-liquidity/src/lsps2/client.rs b/lightning-liquidity/src/lsps2/client.rs index 6dc0d5350b6..3dabb83c954 100644 --- a/lightning-liquidity/src/lsps2/client.rs +++ b/lightning-liquidity/src/lsps2/client.rs @@ -191,6 +191,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, result: LSPS2GetInfoResponse, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -206,7 +208,7 @@ where }); } - self.pending_events.enqueue(LSPS2ClientEvent::OpeningParametersReady { + event_queue_notifier.enqueue(LSPS2ClientEvent::OpeningParametersReady { request_id, counterparty_node_id: *counterparty_node_id, opening_fee_params_menu: result.opening_fee_params_menu, @@ -257,6 +259,8 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, result: LSPS2BuyResponse, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); + let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { @@ -272,7 +276,7 @@ where })?; if let Ok(intercept_scid) = result.jit_channel_scid.to_scid() { - self.pending_events.enqueue(LSPS2ClientEvent::InvoiceParametersReady { + event_queue_notifier.enqueue(LSPS2ClientEvent::InvoiceParametersReady { request_id, counterparty_node_id: *counterparty_node_id, intercept_scid, diff --git a/lightning-liquidity/src/lsps2/payment_queue.rs b/lightning-liquidity/src/lsps2/payment_queue.rs index d956dfc9d81..30413537a9c 100644 --- a/lightning-liquidity/src/lsps2/payment_queue.rs +++ b/lightning-liquidity/src/lsps2/payment_queue.rs @@ -6,18 +6,11 @@ use lightning_types::payment::PaymentHash; /// Holds payments with the corresponding HTLCs until it is possible to pay the fee. /// When the fee is successfully paid with a forwarded payment, the queue should be consumed and the /// remaining payments forwarded. -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Default, PartialEq, Eq, Debug)] pub(crate) struct PaymentQueue { payments: Vec<(PaymentHash, Vec)>, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub(crate) struct InterceptedHTLC { - pub(crate) intercept_id: InterceptId, - pub(crate) expected_outbound_amount_msat: u64, - pub(crate) payment_hash: PaymentHash, -} - impl PaymentQueue { pub(crate) fn new() -> PaymentQueue { PaymentQueue { payments: Vec::new() } @@ -55,6 +48,13 @@ impl PaymentQueue { } } +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub(crate) struct InterceptedHTLC { + pub(crate) intercept_id: InterceptId, + pub(crate) expected_outbound_amount_msat: u64, + pub(crate) payment_hash: PaymentHash, +} + #[cfg(test)] mod tests { use super::*; diff --git a/lightning-liquidity/src/lsps2/service.rs b/lightning-liquidity/src/lsps2/service.rs index 9b511ad8d44..2f6318734b5 100644 --- a/lightning-liquidity/src/lsps2/service.rs +++ b/lightning-liquidity/src/lsps2/service.rs @@ -111,15 +111,15 @@ struct ForwardHTLCsAction(ChannelId, Vec); enum OutboundJITChannelState { /// The JIT channel SCID was created after a buy request, and we are awaiting an initial payment /// of sufficient size to open the channel. - PendingInitialPayment { payment_queue: Arc> }, + PendingInitialPayment { payment_queue: PaymentQueue }, /// An initial payment of sufficient size was intercepted to the JIT channel SCID, triggering the /// opening of the channel. We are awaiting the completion of the channel establishment. - PendingChannelOpen { payment_queue: Arc>, opening_fee_msat: u64 }, + PendingChannelOpen { payment_queue: PaymentQueue, opening_fee_msat: u64 }, /// The channel is open and a payment was forwarded while skimming the JIT channel fee. /// No further payments can be forwarded until the pending payment succeeds or fails, as we need /// to know whether the JIT channel fee needs to be skimmed from a next payment or not. PendingPaymentForward { - payment_queue: Arc>, + payment_queue: PaymentQueue, opening_fee_msat: u64, channel_id: ChannelId, }, @@ -127,11 +127,7 @@ enum OutboundJITChannelState { /// needs to be paid. This state can occur when the initial payment fails, e.g. due to a /// prepayment probe. We are awaiting a next payment of sufficient size to forward and skim the /// JIT channel fee. - PendingPayment { - payment_queue: Arc>, - opening_fee_msat: u64, - channel_id: ChannelId, - }, + PendingPayment { payment_queue: PaymentQueue, opening_fee_msat: u64, channel_id: ChannelId }, /// The channel is open and a payment was successfully forwarded while skimming the JIT channel /// fee. Any subsequent HTLCs can be forwarded without additional logic. PaymentForwarded { channel_id: ChannelId }, @@ -139,19 +135,16 @@ enum OutboundJITChannelState { impl OutboundJITChannelState { fn new() -> Self { - OutboundJITChannelState::PendingInitialPayment { - payment_queue: Arc::new(Mutex::new(PaymentQueue::new())), - } + OutboundJITChannelState::PendingInitialPayment { payment_queue: PaymentQueue::new() } } fn htlc_intercepted( &mut self, opening_fee_params: &LSPS2OpeningFeeParams, payment_size_msat: &Option, htlc: InterceptedHTLC, - ) -> Result<(Self, Option), ChannelStateError> { + ) -> Result, ChannelStateError> { match self { OutboundJITChannelState::PendingInitialPayment { payment_queue } => { - let (total_expected_outbound_amount_msat, num_htlcs) = - payment_queue.lock().unwrap().add_htlc(htlc); + let (total_expected_outbound_amount_msat, num_htlcs) = payment_queue.add_htlc(htlc); let (expected_payment_size_msat, mpp_mode) = if let Some(payment_size_msat) = payment_size_msat { @@ -186,8 +179,8 @@ impl OutboundJITChannelState { opening_fee_params.min_fee_msat, opening_fee_params.proportional, expected_payment_size_msat - ) - ))?; + )) + )?; let amt_to_forward_msat = expected_payment_size_msat.saturating_sub(opening_fee_msat); @@ -196,22 +189,21 @@ impl OutboundJITChannelState { if total_expected_outbound_amount_msat >= expected_payment_size_msat && amt_to_forward_msat > 0 { - let pending_channel_open = OutboundJITChannelState::PendingChannelOpen { - payment_queue: Arc::clone(&payment_queue), + *self = OutboundJITChannelState::PendingChannelOpen { + payment_queue: core::mem::take(payment_queue), opening_fee_msat, }; let open_channel = HTLCInterceptedAction::OpenChannel(OpenChannelParams { opening_fee_msat, amt_to_forward_msat, }); - Ok((pending_channel_open, Some(open_channel))) + Ok(Some(open_channel)) } else { if mpp_mode { - let pending_initial_payment = - OutboundJITChannelState::PendingInitialPayment { - payment_queue: Arc::clone(&payment_queue), - }; - Ok((pending_initial_payment, None)) + *self = OutboundJITChannelState::PendingInitialPayment { + payment_queue: core::mem::take(payment_queue), + }; + Ok(None) } else { Err(ChannelStateError( "Intercepted HTLC is too small to pay opening fee".to_string(), @@ -220,90 +212,88 @@ impl OutboundJITChannelState { } }, OutboundJITChannelState::PendingChannelOpen { payment_queue, opening_fee_msat } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); - payment_queue_lock.add_htlc(htlc); - let pending_channel_open = OutboundJITChannelState::PendingChannelOpen { - payment_queue: payment_queue.clone(), + let mut payment_queue = core::mem::take(payment_queue); + payment_queue.add_htlc(htlc); + *self = OutboundJITChannelState::PendingChannelOpen { + payment_queue, opening_fee_msat: *opening_fee_msat, }; - Ok((pending_channel_open, None)) + Ok(None) }, OutboundJITChannelState::PendingPaymentForward { payment_queue, opening_fee_msat, channel_id, } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); - payment_queue_lock.add_htlc(htlc); - let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward { - payment_queue: payment_queue.clone(), + let mut payment_queue = core::mem::take(payment_queue); + payment_queue.add_htlc(htlc); + *self = OutboundJITChannelState::PendingPaymentForward { + payment_queue, opening_fee_msat: *opening_fee_msat, channel_id: *channel_id, }; - Ok((pending_payment_forward, None)) + Ok(None) }, OutboundJITChannelState::PendingPayment { payment_queue, opening_fee_msat, channel_id, } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); - payment_queue_lock.add_htlc(htlc); + let mut payment_queue = core::mem::take(payment_queue); + payment_queue.add_htlc(htlc); if let Some((_payment_hash, htlcs)) = - payment_queue_lock.pop_greater_than_msat(*opening_fee_msat) + payment_queue.pop_greater_than_msat(*opening_fee_msat) { - let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward { - payment_queue: payment_queue.clone(), - opening_fee_msat: *opening_fee_msat, - channel_id: *channel_id, - }; let forward_payment = HTLCInterceptedAction::ForwardPayment( *channel_id, FeePayment { htlcs, opening_fee_msat: *opening_fee_msat }, ); - Ok((pending_payment_forward, Some(forward_payment))) + *self = OutboundJITChannelState::PendingPaymentForward { + payment_queue, + opening_fee_msat: *opening_fee_msat, + channel_id: *channel_id, + }; + Ok(Some(forward_payment)) } else { - let pending_payment = OutboundJITChannelState::PendingPayment { + *self = OutboundJITChannelState::PendingPayment { payment_queue: payment_queue.clone(), opening_fee_msat: *opening_fee_msat, channel_id: *channel_id, }; - Ok((pending_payment, None)) + Ok(None) } }, OutboundJITChannelState::PaymentForwarded { channel_id } => { - let payment_forwarded = - OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; let forward = HTLCInterceptedAction::ForwardHTLC(*channel_id); - Ok((payment_forwarded, Some(forward))) + *self = OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; + Ok(Some(forward)) }, } } fn channel_ready( - &self, channel_id: ChannelId, - ) -> Result<(Self, ForwardPaymentAction), ChannelStateError> { + &mut self, channel_id: ChannelId, + ) -> Result { match self { OutboundJITChannelState::PendingChannelOpen { payment_queue, opening_fee_msat } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); if let Some((_payment_hash, htlcs)) = - payment_queue_lock.pop_greater_than_msat(*opening_fee_msat) + payment_queue.pop_greater_than_msat(*opening_fee_msat) { - let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward { - payment_queue: Arc::clone(&payment_queue), - opening_fee_msat: *opening_fee_msat, - channel_id, - }; let forward_payment = ForwardPaymentAction( channel_id, FeePayment { opening_fee_msat: *opening_fee_msat, htlcs }, ); - Ok((pending_payment_forward, forward_payment)) + *self = OutboundJITChannelState::PendingPaymentForward { + payment_queue: core::mem::take(payment_queue), + opening_fee_msat: *opening_fee_msat, + channel_id, + }; + Ok(forward_payment) } else { - Err(ChannelStateError( + return Err(ChannelStateError( "No forwardable payment available when moving to channel ready." .to_string(), - )) + )); } }, state => Err(ChannelStateError(format!( @@ -313,36 +303,33 @@ impl OutboundJITChannelState { } } - fn htlc_handling_failed( - &mut self, - ) -> Result<(Self, Option), ChannelStateError> { + fn htlc_handling_failed(&mut self) -> Result, ChannelStateError> { match self { OutboundJITChannelState::PendingPaymentForward { payment_queue, opening_fee_msat, channel_id, } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); if let Some((_payment_hash, htlcs)) = - payment_queue_lock.pop_greater_than_msat(*opening_fee_msat) + payment_queue.pop_greater_than_msat(*opening_fee_msat) { - let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward { - payment_queue: payment_queue.clone(), - opening_fee_msat: *opening_fee_msat, - channel_id: *channel_id, - }; let forward_payment = ForwardPaymentAction( *channel_id, FeePayment { htlcs, opening_fee_msat: *opening_fee_msat }, ); - Ok((pending_payment_forward, Some(forward_payment))) + *self = OutboundJITChannelState::PendingPaymentForward { + payment_queue: core::mem::take(payment_queue), + opening_fee_msat: *opening_fee_msat, + channel_id: *channel_id, + }; + Ok(Some(forward_payment)) } else { - let pending_payment = OutboundJITChannelState::PendingPayment { - payment_queue: payment_queue.clone(), + *self = OutboundJITChannelState::PendingPayment { + payment_queue: core::mem::take(payment_queue), opening_fee_msat: *opening_fee_msat, channel_id: *channel_id, }; - Ok((pending_payment, None)) + Ok(None) } }, OutboundJITChannelState::PendingPayment { @@ -350,17 +337,16 @@ impl OutboundJITChannelState { opening_fee_msat, channel_id, } => { - let pending_payment = OutboundJITChannelState::PendingPayment { - payment_queue: payment_queue.clone(), + *self = OutboundJITChannelState::PendingPayment { + payment_queue: core::mem::take(payment_queue), opening_fee_msat: *opening_fee_msat, channel_id: *channel_id, }; - Ok((pending_payment, None)) + Ok(None) }, OutboundJITChannelState::PaymentForwarded { channel_id } => { - let payment_forwarded = - OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; - Ok((payment_forwarded, None)) + *self = OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; + Ok(None) }, state => Err(ChannelStateError(format!( "HTLC handling failed when JIT Channel was in state: {:?}", @@ -369,23 +355,19 @@ impl OutboundJITChannelState { } } - fn payment_forwarded( - &mut self, - ) -> Result<(Self, Option), ChannelStateError> { + fn payment_forwarded(&mut self) -> Result, ChannelStateError> { match self { OutboundJITChannelState::PendingPaymentForward { payment_queue, channel_id, .. } => { - let mut payment_queue_lock = payment_queue.lock().unwrap(); - let payment_forwarded = - OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; - let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear()); - Ok((payment_forwarded, Some(forward_htlcs))) + let mut payment_queue = core::mem::take(payment_queue); + let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue.clear()); + *self = OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; + Ok(Some(forward_htlcs)) }, OutboundJITChannelState::PaymentForwarded { channel_id } => { - let payment_forwarded = - OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; - Ok((payment_forwarded, None)) + *self = OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id }; + Ok(None) }, state => Err(ChannelStateError(format!( "Payment forwarded when JIT Channel was in state: {:?}", @@ -418,29 +400,25 @@ impl OutboundJITChannel { fn htlc_intercepted( &mut self, htlc: InterceptedHTLC, ) -> Result, LightningError> { - let (new_state, action) = + let action = self.state.htlc_intercepted(&self.opening_fee_params, &self.payment_size_msat, htlc)?; - self.state = new_state; Ok(action) } fn htlc_handling_failed(&mut self) -> Result, LightningError> { - let (new_state, action) = self.state.htlc_handling_failed()?; - self.state = new_state; + let action = self.state.htlc_handling_failed()?; Ok(action) } fn channel_ready( &mut self, channel_id: ChannelId, ) -> Result { - let (new_state, action) = self.state.channel_ready(channel_id)?; - self.state = new_state; + let action = self.state.channel_ready(channel_id)?; Ok(action) } fn payment_forwarded(&mut self) -> Result, LightningError> { - let (new_state, action) = self.state.payment_forwarded()?; - self.state = new_state; + let action = self.state.payment_forwarded()?; Ok(action) } @@ -799,6 +777,8 @@ where &self, intercept_scid: u64, intercept_id: InterceptId, expected_outbound_amount_msat: u64, payment_hash: PaymentHash, ) -> Result<(), APIError> { + let event_queue_notifier = self.pending_events.notifier(); + let peer_by_intercept_scid = self.peer_by_intercept_scid.read().unwrap(); if let Some(counterparty_node_id) = peer_by_intercept_scid.get(&intercept_scid) { let outer_state_lock = self.per_peer_state.read().unwrap(); @@ -822,7 +802,7 @@ where user_channel_id: jit_channel.user_channel_id, intercept_scid, }; - self.pending_events.enqueue(event); + event_queue_notifier.enqueue(event); }, Ok(Some(HTLCInterceptedAction::ForwardHTLC(channel_id))) => { self.channel_manager.get_cm().forward_intercepted_htlc( @@ -1088,6 +1068,7 @@ where &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, params: LSPS2GetInfoRequest, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); let (result, response) = { let mut outer_state_lock = self.per_peer_state.write().unwrap(); let inner_state_lock = @@ -1106,8 +1087,7 @@ where counterparty_node_id: *counterparty_node_id, token: params.token, }; - self.pending_events.enqueue(event); - + event_queue_notifier.enqueue(event); (Ok(()), msg) }, (e, msg) => (e, msg), @@ -1124,6 +1104,7 @@ where fn handle_buy_request( &self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey, params: LSPS2BuyRequest, ) -> Result<(), LightningError> { + let event_queue_notifier = self.pending_events.notifier(); if let Some(payment_size_msat) = params.payment_size_msat { if payment_size_msat < params.opening_fee_params.min_payment_size_msat { let response = LSPS2Response::BuyError(LSPSResponseError { @@ -1226,7 +1207,7 @@ where opening_fee_params: params.opening_fee_params, payment_size_msat: params.payment_size_msat, }; - self.pending_events.enqueue(event); + event_queue_notifier.enqueue(event); (Ok(()), msg) }, @@ -1538,7 +1519,7 @@ mod tests { let mut state = OutboundJITChannelState::new(); // Intercepts the first HTLC of a multipart payment A. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1549,13 +1530,12 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingInitialPayment { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingInitialPayment { .. })); assert!(action.is_none()); - state = new_state; } // Intercepts the first HTLC of a different multipart payment B. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1566,14 +1546,13 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingInitialPayment { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingInitialPayment { .. })); assert!(action.is_none()); - state = new_state; } // Intercepts the second HTLC of multipart payment A, completing the expected payment and // opening the channel. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1584,13 +1563,12 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingChannelOpen { .. })); assert!(matches!(action, Some(HTLCInterceptedAction::OpenChannel(_)))); - state = new_state; } // Channel opens, becomes ready, and multipart payment A gets forwarded. { - let (new_state, ForwardPaymentAction(channel_id, payment)) = + let ForwardPaymentAction(channel_id, payment) = state.channel_ready(ChannelId([200; 32])).unwrap(); assert_eq!(channel_id, ChannelId([200; 32])); assert_eq!(payment.opening_fee_msat, 10_000_000); @@ -1609,11 +1587,10 @@ mod tests { }, ] ); - state = new_state; } // Intercepts the first HTLC of a different payment C. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1624,21 +1601,19 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingPaymentForward { .. })); assert!(action.is_none()); - state = new_state; } // Payment A fails. { - let (new_state, action) = state.htlc_handling_failed().unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingPayment { .. })); + let action = state.htlc_handling_failed().unwrap(); + assert!(matches!(state, OutboundJITChannelState::PendingPayment { .. })); // No payments have received sufficient HTLCs yet. assert!(action.is_none()); - state = new_state; } // Additional HTLC of payment B arrives, completing the expectd payment. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1649,7 +1624,7 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingPaymentForward { .. })); match action { Some(HTLCInterceptedAction::ForwardPayment(channel_id, payment)) => { assert_eq!(channel_id, ChannelId([200; 32])); @@ -1672,12 +1647,11 @@ mod tests { }, _ => panic!("Unexpected action when intercepted HTLC."), } - state = new_state; } // Payment completes, queued payments get forwarded. { - let (new_state, action) = state.payment_forwarded().unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); + let action = state.payment_forwarded().unwrap(); + assert!(matches!(state, OutboundJITChannelState::PaymentForwarded { .. })); match action { Some(ForwardHTLCsAction(channel_id, htlcs)) => { assert_eq!(channel_id, ChannelId([200; 32])); @@ -1692,11 +1666,10 @@ mod tests { }, _ => panic!("Unexpected action when forwarded payment."), } - state = new_state; } // Any new HTLC gets automatically forwarded. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1707,7 +1680,7 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); + assert!(matches!(state, OutboundJITChannelState::PaymentForwarded { .. })); assert!( matches!(action, Some(HTLCInterceptedAction::ForwardHTLC(channel_id)) if channel_id == ChannelId([200; 32])) ); @@ -1730,7 +1703,7 @@ mod tests { let mut state = OutboundJITChannelState::new(); // Intercepts payment A, opening the channel. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1741,13 +1714,12 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingChannelOpen { .. })); assert!(matches!(action, Some(HTLCInterceptedAction::OpenChannel(_)))); - state = new_state; } // Intercepts payment B. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1758,13 +1730,12 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingChannelOpen { .. })); assert!(action.is_none()); - state = new_state; } // Channel opens, becomes ready, and payment A gets forwarded. { - let (new_state, ForwardPaymentAction(channel_id, payment)) = + let ForwardPaymentAction(channel_id, payment) = state.channel_ready(ChannelId([200; 32])).unwrap(); assert_eq!(channel_id, ChannelId([200; 32])); assert_eq!(payment.opening_fee_msat, 10_000_000); @@ -1776,11 +1747,10 @@ mod tests { payment_hash: PaymentHash([100; 32]), },] ); - state = new_state; } // Intercepts payment C. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1791,14 +1761,13 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. })); + assert!(matches!(state, OutboundJITChannelState::PendingPaymentForward { .. })); assert!(action.is_none()); - state = new_state; } // Payment A fails, and payment B is forwarded. { - let (new_state, action) = state.htlc_handling_failed().unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. })); + let action = state.htlc_handling_failed().unwrap(); + assert!(matches!(state, OutboundJITChannelState::PendingPaymentForward { .. })); match action { Some(ForwardPaymentAction(channel_id, payment)) => { assert_eq!(channel_id, ChannelId([200; 32])); @@ -1813,12 +1782,11 @@ mod tests { }, _ => panic!("Unexpected action when HTLC handling failed."), } - state = new_state; } // Payment completes, queued payments get forwarded. { - let (new_state, action) = state.payment_forwarded().unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); + let action = state.payment_forwarded().unwrap(); + assert!(matches!(state, OutboundJITChannelState::PaymentForwarded { .. })); match action { Some(ForwardHTLCsAction(channel_id, htlcs)) => { assert_eq!(channel_id, ChannelId([200; 32])); @@ -1833,11 +1801,10 @@ mod tests { }, _ => panic!("Unexpected action when forwarded payment."), } - state = new_state; } // Any new HTLC gets automatically forwarded. { - let (new_state, action) = state + let action = state .htlc_intercepted( &opening_fee_params, &payment_size_msat, @@ -1848,7 +1815,7 @@ mod tests { }, ) .unwrap(); - assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. })); + assert!(matches!(state, OutboundJITChannelState::PaymentForwarded { .. })); assert!( matches!(action, Some(HTLCInterceptedAction::ForwardHTLC(channel_id)) if channel_id == ChannelId([200; 32])) ); diff --git a/lightning-liquidity/src/manager.rs b/lightning-liquidity/src/manager.rs index eec9a71d632..651cb4b74a6 100644 --- a/lightning-liquidity/src/manager.rs +++ b/lightning-liquidity/src/manager.rs @@ -1,4 +1,3 @@ -use alloc::boxed::Box; use alloc::string::ToString; use alloc::vec::Vec; @@ -11,7 +10,7 @@ use crate::lsps0::ser::{ LSPS_MESSAGE_TYPE_ID, }; use crate::lsps0::service::LSPS0ServiceHandler; -use crate::message_queue::{MessageQueue, ProcessMessagesCallback}; +use crate::message_queue::MessageQueue; use crate::lsps1::client::{LSPS1ClientConfig, LSPS1ClientHandler}; use crate::lsps1::msgs::LSPS1Message; @@ -32,6 +31,7 @@ use lightning::ln::wire::CustomMessageReader; use lightning::sign::EntropySource; use lightning::util::logger::Level; use lightning::util::ser::{LengthLimitedRead, LengthReadable}; +use lightning::util::wakers::Future; use lightning_types::features::{InitFeatures, NodeFeatures}; @@ -68,14 +68,49 @@ pub struct LiquidityClientConfig { pub lsps2_client_config: Option, } +/// A trivial trait which describes any [`LiquidityManager`]. +/// +/// This is not exported to bindings users as general cover traits aren't useful in other +/// languages. +pub trait ALiquidityManager { + /// A type implementing [`EntropySource`] + type EntropySource: EntropySource + ?Sized; + /// A type that may be dereferenced to [`Self::EntropySource`]. + type ES: Deref + Clone; + /// A type implementing [`AChannelManager`] + type AChannelManager: AChannelManager + ?Sized; + /// A type that may be dereferenced to [`Self::AChannelManager`]. + type CM: Deref + Clone; + /// A type implementing [`Filter`]. + type Filter: Filter + ?Sized; + /// A type that may be dereferenced to [`Self::Filter`]. + type C: Deref + Clone; + /// Returns a reference to the actual [`LiquidityManager`] object. + fn get_lm(&self) -> &LiquidityManager; +} + +impl ALiquidityManager + for LiquidityManager +where + ES::Target: EntropySource, + CM::Target: AChannelManager, + C::Target: Filter, +{ + type EntropySource = ES::Target; + type ES = ES; + type AChannelManager = CM::Target; + type CM = CM; + type Filter = C::Target; + type C = C; + fn get_lm(&self) -> &LiquidityManager { + self + } +} + /// The main interface into LSP functionality. /// /// Should be used as a [`CustomMessageHandler`] for your [`PeerManager`]'s [`MessageHandler`]. /// -/// Users should provide a callback to process queued messages via -/// [`LiquidityManager::set_process_msgs_callback`] post construction. This allows the -/// [`LiquidityManager`] to wake the [`PeerManager`] when there are pending messages to be sent. -/// /// Users need to continually poll [`LiquidityManager::get_and_clear_pending_events`] in order to surface /// [`LiquidityEvent`]'s that likely need to be handled. /// @@ -264,63 +299,13 @@ where { self.lsps2_service_handler.as_ref() } - /// Allows to set a callback that will be called after new messages are pushed to the message - /// queue. - /// - /// Usually, you'll want to use this to call [`PeerManager::process_events`] to clear the - /// message queue. For example: - /// - /// ``` - /// # use lightning::io; - /// # use lightning_liquidity::LiquidityManager; - /// # use std::sync::{Arc, RwLock}; - /// # use std::sync::atomic::{AtomicBool, Ordering}; - /// # use std::time::SystemTime; - /// # struct MyStore {} - /// # impl lightning::util::persist::KVStore for MyStore { - /// # fn read(&self, primary_namespace: &str, secondary_namespace: &str, key: &str) -> io::Result> { Ok(Vec::new()) } - /// # fn write(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> { Ok(()) } - /// # fn remove(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool) -> io::Result<()> { Ok(()) } - /// # fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { Ok(Vec::new()) } - /// # } - /// # struct MyEntropySource {} - /// # impl lightning::sign::EntropySource for MyEntropySource { - /// # fn get_secure_random_bytes(&self) -> [u8; 32] { [0u8; 32] } - /// # } - /// # struct MyEventHandler {} - /// # impl MyEventHandler { - /// # async fn handle_event(&self, _: lightning::events::Event) {} - /// # } - /// # #[derive(Eq, PartialEq, Clone, Hash)] - /// # struct MySocketDescriptor {} - /// # impl lightning::ln::peer_handler::SocketDescriptor for MySocketDescriptor { - /// # fn send_data(&mut self, _data: &[u8], _resume_read: bool) -> usize { 0 } - /// # fn disconnect_socket(&mut self) {} - /// # } - /// # type MyBroadcaster = dyn lightning::chain::chaininterface::BroadcasterInterface + Send + Sync; - /// # type MyFeeEstimator = dyn lightning::chain::chaininterface::FeeEstimator + Send + Sync; - /// # type MyNodeSigner = dyn lightning::sign::NodeSigner + Send + Sync; - /// # type MyUtxoLookup = dyn lightning::routing::utxo::UtxoLookup + Send + Sync; - /// # type MyFilter = dyn lightning::chain::Filter + Send + Sync; - /// # type MyLogger = dyn lightning::util::logger::Logger + Send + Sync; - /// # type MyChainMonitor = lightning::chain::chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc>; - /// # type MyPeerManager = lightning::ln::peer_handler::SimpleArcPeerManager, MyLogger>; - /// # type MyNetworkGraph = lightning::routing::gossip::NetworkGraph>; - /// # type MyGossipSync = lightning::routing::gossip::P2PGossipSync, Arc, Arc>; - /// # type MyChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager; - /// # type MyScorer = RwLock, Arc>>; - /// # type MyLiquidityManager = LiquidityManager, Arc, Arc>; - /// # fn setup_background_processing(my_persister: Arc, my_event_handler: Arc, my_chain_monitor: Arc, my_channel_manager: Arc, my_logger: Arc, my_peer_manager: Arc, my_liquidity_manager: Arc) { - /// let process_msgs_pm = Arc::clone(&my_peer_manager); - /// let process_msgs_callback = move || process_msgs_pm.process_events(); - /// - /// my_liquidity_manager.set_process_msgs_callback(process_msgs_callback); - /// # } - /// ``` + /// Returns a [`Future`] that will complete when the next batch of pending messages is ready to + /// be processed. /// - /// [`PeerManager::process_events`]: lightning::ln::peer_handler::PeerManager::process_events - pub fn set_process_msgs_callback(&self, callback: F) { - self.pending_messages.set_process_msgs_callback(Box::new(callback)); + /// Note that callbacks registered on the [`Future`] MUST NOT call back into this + /// [`LiquidityManager`] and should instead register actions to be taken later. + pub fn get_pending_msgs_future(&self) -> Future { + self.pending_messages.get_pending_msgs_future() } /// Blocks the current thread until next event is ready and returns it. diff --git a/lightning-liquidity/src/message_queue.rs b/lightning-liquidity/src/message_queue.rs index 49a98ecfa68..58060862f07 100644 --- a/lightning-liquidity/src/message_queue.rs +++ b/lightning-liquidity/src/message_queue.rs @@ -1,11 +1,12 @@ //! Holds types and traits used to implement message queues for [`LSPSMessage`]s. -use alloc::boxed::Box; use alloc::collections::VecDeque; use alloc::vec::Vec; use crate::lsps0::ser::LSPSMessage; -use crate::sync::{Mutex, RwLock}; +use crate::sync::Mutex; + +use lightning::util::wakers::{Future, Notifier}; use bitcoin::secp256k1::PublicKey; @@ -14,53 +15,29 @@ use bitcoin::secp256k1::PublicKey; /// [`LiquidityManager`]: crate::LiquidityManager pub struct MessageQueue { queue: Mutex>, - process_msgs_callback: RwLock>>, + pending_msgs_notifier: Notifier, } impl MessageQueue { pub(crate) fn new() -> Self { let queue = Mutex::new(VecDeque::new()); - let process_msgs_callback = RwLock::new(None); - Self { queue, process_msgs_callback } - } - - pub(crate) fn set_process_msgs_callback(&self, callback: Box) { - *self.process_msgs_callback.write().unwrap() = Some(callback); + let pending_msgs_notifier = Notifier::new(); + Self { queue, pending_msgs_notifier } } pub(crate) fn get_and_clear_pending_msgs(&self) -> Vec<(PublicKey, LSPSMessage)> { self.queue.lock().unwrap().drain(..).collect() } + pub(crate) fn get_pending_msgs_future(&self) -> Future { + self.pending_msgs_notifier.get_future() + } + pub(crate) fn enqueue(&self, counterparty_node_id: &PublicKey, msg: LSPSMessage) { { let mut queue = self.queue.lock().unwrap(); queue.push_back((*counterparty_node_id, msg)); } - - if let Some(process_msgs_callback) = self.process_msgs_callback.read().unwrap().as_ref() { - process_msgs_callback.call() - } + self.pending_msgs_notifier.notify(); } } - -macro_rules! define_callback { ($($bounds: path),*) => { -/// A callback which will be called to trigger network message processing. -/// -/// Usually, this should call [`PeerManager::process_events`]. -/// -/// [`PeerManager::process_events`]: lightning::ln::peer_handler::PeerManager::process_events -pub trait ProcessMessagesCallback : $($bounds +)* { - /// The method which is called. - fn call(&self); -} - -impl ProcessMessagesCallback for F { - fn call(&self) { (self)(); } -} -} } - -#[cfg(feature = "std")] -define_callback!(Send, Sync); -#[cfg(not(feature = "std"))] -define_callback!(); diff --git a/lightning-liquidity/tests/common/mod.rs b/lightning-liquidity/tests/common/mod.rs index f114f7b9c89..2259d1eae06 100644 --- a/lightning-liquidity/tests/common/mod.rs +++ b/lightning-liquidity/tests/common/mod.rs @@ -39,7 +39,7 @@ use lightning_persister::fs_store::FilesystemStore; use std::collections::{HashMap, VecDeque}; use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::AtomicBool; use std::sync::mpsc::SyncSender; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -134,7 +134,6 @@ pub(crate) struct Node { >, pub(crate) liquidity_manager: Arc, Arc, Arc>>, - pub(crate) check_msgs_processed: Arc, pub(crate) chain_monitor: Arc, pub(crate) kv_store: Arc, pub(crate) tx_broadcaster: Arc, @@ -472,21 +471,12 @@ pub(crate) fn create_liquidity_node( let peer_manager = Arc::new(PeerManager::new(msg_handler, 0, &seed, logger.clone(), keys_manager.clone())); - // Rather than registering PeerManager's process_events, we handle messages manually and use a - // bool to check whether PeerManager would have been called as expected. - let check_msgs_processed = Arc::new(AtomicBool::new(false)); - - let process_msgs_flag = Arc::clone(&check_msgs_processed); - let process_msgs_callback = move || process_msgs_flag.store(true, Ordering::Release); - liquidity_manager.set_process_msgs_callback(process_msgs_callback); - Node { channel_manager, keys_manager, p2p_gossip_sync, peer_manager, liquidity_manager, - check_msgs_processed, chain_monitor, kv_store, tx_broadcaster, @@ -634,8 +624,6 @@ pub(crate) use handle_funding_generation_ready; macro_rules! get_lsps_message { ($node: expr, $expected_target_node_id: expr) => {{ - use std::sync::atomic::Ordering; - assert!($node.check_msgs_processed.swap(false, Ordering::AcqRel)); let msgs = $node.liquidity_manager.get_and_clear_pending_msg(); assert_eq!(msgs.len(), 1); let (target_node_id, message) = msgs.into_iter().next().unwrap(); diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index a23e866ec18..ce504f63224 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -30,17 +30,29 @@ use core::pin::Pin; use core::task::{Context, Poll}; /// Used to signal to one of many waiters that the condition they're waiting on has happened. -pub(crate) struct Notifier { +/// +/// This is usually used by LDK objects such as [`ChannelManager`] or [`PeerManager`] to signal to +/// the background processor that it should wake up and process pending events. +/// +/// [`ChannelManager`]: crate::ln::channelmanager::ChannelManager +/// [`PeerManager`]: crate::ln::peer_handler::PeerManager +pub struct Notifier { notify_pending: Mutex<(bool, Option>>)>, } impl Notifier { - pub(crate) fn new() -> Self { + /// Constructs a new notifier. + pub fn new() -> Self { Self { notify_pending: Mutex::new((false, None)) } } /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters. - pub(crate) fn notify(&self) { + /// + /// We deem the notification successful either directly after any callbacks were made, or after + /// the user [`poll`]ed a previously-completed future. + /// + /// [`poll`]: core::future::Future::poll + pub fn notify(&self) { let mut lock = self.notify_pending.lock().unwrap(); if let Some(future_state) = &lock.1 { if complete_future(future_state) { @@ -52,7 +64,7 @@ impl Notifier { } /// Gets a [`Future`] that will get woken up with any waiters - pub(crate) fn get_future(&self) -> Future { + pub fn get_future(&self) -> Future { let mut lock = self.notify_pending.lock().unwrap(); let mut self_idx = 0; if let Some(existing_state) = &lock.1 { @@ -254,6 +266,21 @@ impl Sleeper { vec![Arc::clone(&fut_a.state), Arc::clone(&fut_b.state), Arc::clone(&fut_c.state)]; Self { notifiers } } + /// Constructs a new sleeper from four futures, allowing blocking on all four at once. + /// + // Note that this is another common case - a ChannelManager, a ChainMonitor, an + // OnionMessenger, and a LiquidityManager. + pub fn from_four_futures( + fut_a: &Future, fut_b: &Future, fut_c: &Future, fut_d: &Future, + ) -> Self { + let notifiers = vec![ + Arc::clone(&fut_a.state), + Arc::clone(&fut_b.state), + Arc::clone(&fut_c.state), + Arc::clone(&fut_d.state), + ]; + Self { notifiers } + } /// Constructs a new sleeper on many futures, allowing blocking on all at once. pub fn new(futures: Vec) -> Self { Self { notifiers: futures.into_iter().map(|f| Arc::clone(&f.state)).collect() }