Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poc: Refactor Kad Engine to poll based implementation #306

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
597 changes: 437 additions & 160 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ parking_lot = "0.12.3"
pin-project = "1.1.0"
prost = "0.12.6"
quinn = { version = "0.9.3", default-features = false, features = ["tls-rustls", "runtime-tokio"], optional = true }
rand = { version = "0.8.0", features = ["getrandom"] }
rand = { version = "0.8.5", features = ["getrandom"] }
rcgen = "0.10.0"
ring = "0.16.20"
serde = "1.0.158"
sha2 = "0.10.8"
simple-dns = "0.7.0"
simple-dns = "0.9.1"
smallvec = "1.13.2"
snow = { version = "0.9.3", features = ["ring-resolver"], default-features = false }
socket2 = { version = "0.5.7", features = ["all"] }
str0m = { version = "0.6.2", optional = true }
thiserror = "1.0.61"
tokio-stream = "0.1.12"
tokio-tungstenite = { version = "0.20.0", features = ["rustls-tls-native-roots"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat", "io", "codec"] }
tokio = { version = "1.26.0", features = ["rt", "net", "io-util", "time", "macros", "sync", "parking_lot"] }
thiserror = "2.0.7"
tokio-stream = "0.1.17"
tokio-tungstenite = { version = "0.25.0", features = ["rustls-tls-native-roots", "url"], optional = true }
tokio-util = { version = "0.7.13", features = ["compat", "io", "codec"] }
tokio = { version = "1.42.0", features = ["rt", "net", "io-util", "time", "macros", "sync", "parking_lot"] }
tracing = { version = "0.1.40", features = ["log"] }
hickory-resolver = "0.24.2"
uint = "0.9.5"
unsigned-varint = { version = "0.8.0", features = ["codec"] }
url = "2.4.0"
url = "2.5.4"
webpki = { version = "0.22.4", optional = true }
x25519-dalek = "2.0.0"
x509-parser = "0.16.0"
Expand Down
87 changes: 55 additions & 32 deletions src/protocol/libp2p/kademlia/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//! [`/ipfs/kad/1.0.0`](https://github.com/libp2p/specs/blob/master/kad-dht/README.md) implementation.

use crate::{
addresses,
error::{Error, ImmediateDialError, SubstreamError},
protocol::{
libp2p::kademlia::{
Expand All @@ -41,8 +42,9 @@ use crate::{
};

use bytes::{Bytes, BytesMut};
use futures::StreamExt;
use futures::{sink::Close, StreamExt};
use multiaddr::Multiaddr;
use rustls::client;
use tokio::sync::mpsc::{Receiver, Sender};

use std::{
Expand Down Expand Up @@ -245,7 +247,7 @@ impl Kademlia {
context.add_pending_action(substream_id, action);
}
Err(error) => {
tracing::debug!(
tracing::error!(
target: LOG_TARGET,
?peer,
?action,
Expand Down Expand Up @@ -319,7 +321,7 @@ impl Kademlia {

match pending_action.take() {
None => {
tracing::trace!(
tracing::warn!(
target: LOG_TARGET,
?peer,
?substream_id,
Expand Down Expand Up @@ -410,6 +412,25 @@ impl Kademlia {
}
}

fn closest_peers<K: Clone>(&mut self, target: &Key<K>) -> Vec<KademliaPeer> {
// Find closest peers from kademlia.
let mut closest_peers = self.routing_table.closest(target, self.replication_factor);

// Get the true addresses of the peers.
let mut peer_to_addresses =
self.service.peer_addresses(closest_peers.iter().map(|p| p.peer));

// Update the addresses of the peers.
for closest in closest_peers.iter_mut() {
if let Some(addresses) = peer_to_addresses.remove(&closest.peer) {
closest.addresses = addresses;
} else {
closest.addresses = Vec::new();
}
}
closest_peers
}

/// Handle received message.
async fn on_message_received(
&mut self,
Expand Down Expand Up @@ -448,11 +469,8 @@ impl Kademlia {
"handle `FIND_NODE` request",
);

let message = KademliaMessage::find_node_response(
&target,
self.routing_table
.closest(&Key::new(target.as_ref()), self.replication_factor),
);
let peers = self.closest_peers(&Key::new(target.as_ref()));
let message = KademliaMessage::find_node_response(&target, peers);
self.executor.send_message(peer, message.into(), substream);
}
}
Expand Down Expand Up @@ -500,9 +518,7 @@ impl Kademlia {
);

let value = self.store.get(&key).cloned();
let closest_peers = self
.routing_table
.closest(&Key::new(key.as_ref()), self.replication_factor);
let closest_peers = self.closest_peers(&Key::new(key.as_ref()));

let message =
KademliaMessage::get_value_response(key, closest_peers, value);
Expand Down Expand Up @@ -612,9 +628,7 @@ impl Kademlia {
p.addresses = self.service.public_addresses().get_addresses();
});

let closer_peers = self
.routing_table
.closest(&Key::new(key.as_ref()), self.replication_factor);
let closer_peers = self.closest_peers(&Key::new(key.as_ref()));

let message =
KademliaMessage::get_providers_response(providers, &closer_peers);
Expand Down Expand Up @@ -667,7 +681,7 @@ impl Kademlia {
}

/// Handle dial failure.
fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr) {
fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr, reason: String) {
tracing::trace!(target: LOG_TARGET, ?peer, ?address, "failed to dial peer");

let Some(actions) = self.pending_dials.remove(&peer) else {
Expand All @@ -681,6 +695,7 @@ impl Kademlia {
?peer,
query = ?query_id,
?address,
?reason,
"report failure for pending query",
);

Expand Down Expand Up @@ -877,14 +892,22 @@ impl Kademlia {
tracing::debug!(target: LOG_TARGET, "starting kademlia event loop");

loop {
// poll `QueryEngine` for next actions.
while let Some(action) = self.engine.next_action() {
if let Err((query, peer)) = self.on_query_action(action).await {
self.disconnect_peer(peer, Some(query)).await;
}
}
// // poll `QueryEngine` for next actions.
// while let Some(action) = self.engine.next_action() {
// if let Err((query, peer)) = self.on_query_action(action).await {
// self.disconnect_peer(peer, Some(query)).await;
// }
// }

tokio::select! {
action = self.engine.next() => {
if let Some(action) = action {
if let Err((query, peer)) = self.on_query_action(action).await {
self.disconnect_peer(peer, Some(query)).await;
}
}
},

event = self.service.next() => match event {
Some(TransportEvent::ConnectionEstablished { peer, .. }) => {
if let Err(error) = self.on_connection_established(peer) {
Expand Down Expand Up @@ -920,8 +943,8 @@ impl Kademlia {
Some(TransportEvent::SubstreamOpenFailure { substream, error }) => {
self.on_substream_open_failure(substream, error).await;
}
Some(TransportEvent::DialFailure { peer, address, .. }) =>
self.on_dial_failure(peer, address),
Some(TransportEvent::DialFailure { peer, address, reason }) =>
self.on_dial_failure(peer, address, reason),
None => return Err(Error::EssentialTaskClosed),
},
context = self.executor.next() => {
Expand Down Expand Up @@ -966,7 +989,7 @@ impl Kademlia {
"failed to read message from substream",
);

self.disconnect_peer(peer, query_id).await;
// self.disconnect_peer(peer, query_id).await;
}
}
},
Expand All @@ -980,12 +1003,11 @@ impl Kademlia {
"starting `FIND_NODE` query",
);

let closest = self.closest_peers(&Key::from(peer));
self.engine.start_find_node(
query_id,
peer,
self.routing_table
.closest(&Key::from(peer), self.replication_factor)
.into()
closest.into(),
);
}
Some(KademliaCommand::PutRecord { mut record, query_id }) => {
Expand All @@ -1009,10 +1031,11 @@ impl Kademlia {

self.store.put(record.clone());

let closest = self.closest_peers(&key);
self.engine.start_put_record(
query_id,
record,
self.routing_table.closest(&key, self.replication_factor).into(),
closest.into(),
);
}
Some(KademliaCommand::PutRecordToPeers {
Expand Down Expand Up @@ -1075,14 +1098,14 @@ impl Kademlia {
};

self.store.put_provider(key.clone(), provider.clone());
let key_saved = key.clone();
let closest = self.closest_peers(&Key::new(key));

self.engine.start_add_provider(
query_id,
key.clone(),
key_saved,
provider,
self.routing_table
.closest(&Key::new(key), self.replication_factor)
.into(),
closest.into(),
);
}
Some(KademliaCommand::StopProviding {
Expand Down
81 changes: 59 additions & 22 deletions src/protocol/libp2p/kademlia/query/find_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// DEALINGS IN THE SOFTWARE.

use bytes::Bytes;
use futures::Stream;

use crate::{
protocol::libp2p::kademlia::{
Expand All @@ -29,7 +30,11 @@ use crate::{
PeerId,
};

use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::{
collections::{BTreeMap, HashMap, HashSet, VecDeque},
pin::Pin,
task::{Context, Poll},
};

/// Logging target for the file.
const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::find_node";
Expand Down Expand Up @@ -91,6 +96,11 @@ pub struct FindNodeContext<T: Clone + Into<Vec<u8>>> {
/// These represent the number of peers added to the `Self::pending` minus the number of peers
/// that have failed to respond within the `Self::peer_timeout`
pending_responses: usize,

start_time: std::time::Instant,

is_done: bool,
waker: Option<std::task::Waker>,
}

impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {
Expand All @@ -116,11 +126,18 @@ impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {

peer_timeout: DEFAULT_PEER_TIMEOUT,
pending_responses: 0,

is_done: false,
waker: None,

start_time: std::time::Instant::now(),
}
}

/// Register response failure for `peer`.
pub fn register_response_failure(&mut self, peer: PeerId) {
tracing::warn!(target: LOG_TARGET, query = ?self.config.query, ?peer, "peer failed to respond");

let Some((peer, instant)) = self.pending.remove(&peer) else {
tracing::debug!(target: LOG_TARGET, query = ?self.config.query, ?peer, "pending peer doesn't exist during response failure");
return;
Expand All @@ -129,7 +146,8 @@ impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {

tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, elapsed = ?instant.elapsed(), "peer failed to respond");

self.queried.insert(peer.peer);
// Add a retry mechanism for failure responses.
// self.queried.insert(peer.peer);
}

/// Register `FIND_NODE` response from `peer`.
Expand All @@ -149,25 +167,7 @@ impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {

// always mark the peer as queried to prevent it getting queried again
self.queried.insert(peer.peer);

if self.responses.len() < self.config.replication_factor {
self.responses.insert(distance, peer);
} else {
// Update the furthest peer if this response is closer.
// Find the furthest distance.
let furthest_distance =
self.responses.last_entry().map(|entry| *entry.key()).unwrap_or(distance);

// The response received from the peer is closer than the furthest response.
if distance < furthest_distance {
self.responses.insert(distance, peer);

// Remove the furthest entry.
if self.responses.len() > self.config.replication_factor {
self.responses.pop_last();
}
}
}
self.responses.insert(distance, peer);

let to_query_candidate = peers.into_iter().filter_map(|peer| {
// Peer already produced a response.
Expand Down Expand Up @@ -230,6 +230,18 @@ impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {

/// Get next action for a `FIND_NODE` query.
pub fn next_action(&mut self) -> Option<QueryAction> {
// if self.start_time.elapsed() > std::time::Duration::from_secs(10) {
// return if self.responses.is_empty() {
// Some(QueryAction::QueryFailed {
// query: self.config.query,
// })
// } else {
// Some(QueryAction::QuerySucceeded {
// query: self.config.query,
// })
// };
// }

// If we cannot make progress, return the final result.
// A query failed when we are not able to identify one single peer.
if self.is_done() {
Expand Down Expand Up @@ -298,6 +310,30 @@ impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {
}
}

impl<T: Clone + Into<Vec<u8>> + Unpin> Stream for FindNodeContext<T> {
type Item = QueryAction;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_done {
return Poll::Ready(None);
}

let action = self.next_action();
match action {
Some(QueryAction::QueryFailed { .. }) | Some(QueryAction::QuerySucceeded { .. }) => {
self.is_done = true;
}
None => {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
}
_ => (),
};

Poll::Ready(action)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -460,7 +496,8 @@ mod tests {
let in_peers_set: HashSet<_> = [peer_a, peer_b, peer_c].into_iter().collect();
assert_eq!(in_peers_set.len(), 3);

let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect();
let in_peers: VecDeque<KademliaPeer> =
[peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect();
let mut context = FindNodeContext::new(config, in_peers);

// Schedule peer queries.
Expand Down
Loading
Loading