Eng 432 refactor mdns service/state (#604)

clean up mdns state and refactor advertisement system
This commit is contained in:
Oscar Beaumont 2023-03-13 18:05:46 +08:00 committed by GitHub
parent 0c25239c53
commit 1978e2ff48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 155 additions and 142 deletions

View file

@ -227,7 +227,7 @@ impl P2PManager {
let mut len_buf = len.to_le_bytes();
debug_assert_eq!(len_buf.len(), 4);
head_buf.extend_from_slice(&mut len_buf);
head_buf.extend_from_slice(&len_buf);
head_buf.append(&mut buf);
self.manager.broadcast(head_buf).await;

View file

@ -1,25 +1,21 @@
use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
sync::{atomic::AtomicBool, Arc},
};
use std::{collections::HashSet, net::SocketAddr, sync::Arc};
use libp2p::{core::muxing::StreamMuxerBox, quic, Swarm, Transport};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, warn};
use crate::{
spacetime::{SpaceTime, UnicastStream},
AsyncFn, DiscoveredPeer, Keypair, ManagerStream, ManagerStreamAction, Mdns, Metadata, PeerId,
AsyncFn, DiscoveredPeer, Keypair, ManagerStream, ManagerStreamAction, Mdns, MdnsState,
Metadata, PeerId,
};
/// Is the core component of the P2P system that holds the state and delegates actions to the other components
#[derive(Debug)]
pub struct Manager<TMetadata: Metadata> {
pub(crate) mdns_state: Arc<MdnsState<TMetadata>>,
pub(crate) peer_id: PeerId,
pub(crate) listen_addrs: RwLock<HashSet<SocketAddr>>,
pub(crate) discovered: RwLock<HashMap<PeerId, DiscoveredPeer<TMetadata>>>,
pub(crate) application_name: &'static [u8],
event_stream_tx: mpsc::Sender<ManagerStreamAction<TMetadata>>,
}
@ -40,17 +36,19 @@ impl<TMetadata: Metadata> Manager<TMetadata> {
.then_some(())
.ok_or(ManagerError::InvalidAppName)?;
let peer_id = PeerId(keypair.public().to_peer_id());
let (event_stream_tx, event_stream_rx) = mpsc::channel(1024);
let (mdns, mdns_state) = Mdns::new(application_name, peer_id, fn_get_metadata).unwrap();
let this = Arc::new(Self {
mdns_state,
// Look this is bad but it's hard to avoid. Technically a memory leak but it's a small amount of memory and is should done on startup on the P2P system.
application_name: Box::leak(Box::new(
format!("/{}/spacetime/1.0.0", application_name)
.as_bytes()
.to_vec(),
)),
peer_id: PeerId(keypair.public().to_peer_id()),
listen_addrs: RwLock::new(Default::default()),
discovered: RwLock::new(Default::default()),
peer_id,
event_stream_tx,
});
@ -77,11 +75,11 @@ impl<TMetadata: Metadata> Manager<TMetadata> {
Ok((
this.clone(),
ManagerStream {
manager: this.clone(),
manager: this,
event_stream_rx,
swarm,
mdns: Mdns::new(this, application_name, fn_get_metadata).unwrap(),
is_advertisement_queued: AtomicBool::new(false),
mdns,
queued_events: Default::default(),
},
))
}
@ -98,11 +96,17 @@ impl<TMetadata: Metadata> Manager<TMetadata> {
}
pub async fn listen_addrs(&self) -> HashSet<SocketAddr> {
self.listen_addrs.read().await.clone()
self.mdns_state.listen_addrs.read().await.clone()
}
pub async fn get_discovered_peers(&self) -> Vec<DiscoveredPeer<TMetadata>> {
self.discovered.read().await.values().cloned().collect()
self.mdns_state
.discovered
.read()
.await
.values()
.cloned()
.collect()
}
pub async fn get_connected_peers(&self) -> Result<Vec<PeerId>, ()> {

View file

@ -1,11 +1,4 @@
use std::{
fmt,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use std::{collections::VecDeque, fmt, net::SocketAddr, sync::Arc};
use libp2p::{
futures::StreamExt,
@ -13,7 +6,7 @@ use libp2p::{
dial_opts::{DialOpts, PeerCondition},
NetworkBehaviourAction, NotifyHandler, SwarmEvent,
},
Multiaddr, Swarm,
Swarm,
};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, warn};
@ -63,7 +56,7 @@ where
pub(crate) event_stream_rx: mpsc::Receiver<ManagerStreamAction<TMetadata>>,
pub(crate) swarm: Swarm<SpaceTime<TMetadata>>,
pub(crate) mdns: Mdns<TMetadata, TMetadataFn>,
pub(crate) is_advertisement_queued: AtomicBool,
pub(crate) queued_events: VecDeque<Event<TMetadata>>,
}
impl<TMetadata, TMetadataFn> ManagerStream<TMetadata, TMetadataFn>
@ -75,8 +68,12 @@ where
pub async fn next(&mut self) -> Option<Event<TMetadata>> {
// We loop polling internal services until an event comes in that needs to be sent to the parent application.
loop {
if let Some(event) = self.queued_events.pop_front() {
return Some(event);
}
tokio::select! {
event = self.mdns.poll() => {
event = self.mdns.poll(&self.manager) => {
if let Some(event) = event {
return Some(event);
}
@ -101,16 +98,11 @@ where
SwarmEvent::IncomingConnectionError { local_addr, error, .. } => warn!("handshake error with incoming connection from '{}': {}", local_addr, error),
SwarmEvent::OutgoingConnectionError { peer_id, error } => warn!("error establishing connection with '{:?}': {}", peer_id, error),
SwarmEvent::BannedPeer { peer_id, .. } => warn!("banned peer '{}' attempted to connection and was rejected", peer_id),
SwarmEvent::NewListenAddr{ address, .. } => {
SwarmEvent::NewListenAddr { address, .. } => {
match quic_multiaddr_to_socketaddr(address) {
Ok(addr) => {
debug!("listen address added: {}", addr);
self.manager.listen_addrs.write().await.insert(addr);
if !self.is_advertisement_queued.load(Ordering::Relaxed) {
self.is_advertisement_queued.store(true, Ordering::Relaxed);
self.mdns.advertise();
}
self.mdns.advertise();
self.mdns.register_addr(addr).await;
return Some(Event::AddListenAddr(addr));
},
Err(err) => {
@ -120,8 +112,12 @@ where
}
},
SwarmEvent::ExpiredListenAddr { address, .. } => {
match Self::unregister_addr(&self.manager, &self.mdns, &self.is_advertisement_queued, address).await {
Ok(_) => {},
match quic_multiaddr_to_socketaddr(address) {
Ok(addr) => {
debug!("listen address added: {}", addr);
self.mdns.unregister_addr(&addr).await;
return Some(Event::RemoveListenAddr(addr));
},
Err(err) => {
warn!("error passing listen address: {}", err);
continue;
@ -131,14 +127,21 @@ where
SwarmEvent::ListenerClosed { listener_id, addresses, reason } => {
debug!("listener '{:?}' was closed due to: {:?}", listener_id, reason);
for address in addresses {
match Self::unregister_addr(&self.manager, &self.mdns, &self.is_advertisement_queued, address).await {
Ok(_) => {},
match quic_multiaddr_to_socketaddr(address) {
Ok(addr) => {
debug!("listen address added: {}", addr);
self.mdns.unregister_addr(&addr).await;
self.queued_events.push_back(Event::RemoveListenAddr(addr));
},
Err(err) => {
warn!("error passing listen address: {}", err);
continue;
}
}
}
// The `loop` will restart and begin returning the events from `queued_events`.
}
SwarmEvent::ListenerError { listener_id, error } => warn!("listener '{:?}' reported a non-fatal error: {}", listener_id, error),
SwarmEvent::Dialing(_peer_id) => {},
@ -191,13 +194,13 @@ where
);
}
ManagerStreamAction::BroadcastData(data) => {
let connected_peers = self.swarm.connected_peers().map(|v| *v).collect::<Vec<_>>();
let connected_peers = self.swarm.connected_peers().copied().collect::<Vec<_>>();
let behaviour = self.swarm.behaviour_mut();
for peer_id in connected_peers {
behaviour
.pending_events
.push_back(NetworkBehaviourAction::NotifyHandler {
peer_id: peer_id,
peer_id,
handler: NotifyHandler::Any,
event: OutboundRequest::Broadcast(data.clone()),
});
@ -205,28 +208,6 @@ where
}
}
return None;
}
// TODO: Move into mdns
async fn unregister_addr(
manager: &Arc<Manager<TMetadata>>,
mdns: &Mdns<TMetadata, TMetadataFn>,
is_advertisement_queued: &AtomicBool,
address: Multiaddr,
) -> Result<Event<TMetadata>, String> {
match quic_multiaddr_to_socketaddr(address) {
Ok(addr) => {
debug!("listen address removed: {}", addr);
manager.listen_addrs.write().await.remove(&addr);
let _ = mdns.unregister_mdns();
if !is_advertisement_queued.load(Ordering::Relaxed) {
is_advertisement_queued.store(true, Ordering::Relaxed);
mdns.advertise();
}
Ok(Event::RemoveListenAddr(addr))
}
Err(err) => Err(err),
}
None
}
}

View file

@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr},
pin::Pin,
str::FromStr,
@ -8,7 +8,10 @@ use std::{
};
use mdns_sd::{ServiceDaemon, ServiceEvent, ServiceInfo};
use tokio::time::{sleep_until, Instant, Sleep};
use tokio::{
sync::RwLock,
time::{sleep_until, Instant, Sleep},
};
use tracing::{debug, error, warn};
use crate::{AsyncFn, DiscoveredPeer, Event, Manager, Metadata, PeerId};
@ -16,18 +19,27 @@ use crate::{AsyncFn, DiscoveredPeer, Event, Manager, Metadata, PeerId};
/// TODO
const MDNS_READVERTISEMENT_INTERVAL: Duration = Duration::from_secs(60); // Every minute re-advertise
/// TODO
#[derive(Debug)]
pub struct MdnsState<TMetadata: Metadata> {
pub discovered: RwLock<HashMap<PeerId, DiscoveredPeer<TMetadata>>>,
pub listen_addrs: RwLock<HashSet<SocketAddr>>,
}
/// TODO
pub struct Mdns<TMetadata, TMetadataFn>
where
TMetadata: Metadata,
TMetadataFn: AsyncFn<Output = TMetadata>,
{
manager: Arc<Manager<TMetadata>>,
// used to ignore events from our own mdns advertisement
peer_id: PeerId,
fn_get_metadata: TMetadataFn,
mdns_daemon: ServiceDaemon,
mdns_service_receiver: flume::Receiver<ServiceEvent>,
service_name: String,
next_mdns_advertisement: Pin<Box<Sleep>>,
state: Arc<MdnsState<TMetadata>>,
}
impl<TMetadata, TMetadataFn> Mdns<TMetadata, TMetadataFn>
@ -36,10 +48,10 @@ where
TMetadataFn: AsyncFn<Output = TMetadata>,
{
pub fn new(
manager: Arc<Manager<TMetadata>>,
application_name: &'static str,
peer_id: PeerId,
fn_get_metadata: TMetadataFn,
) -> Result<Self, mdns_sd::Error>
) -> Result<(Self, Arc<MdnsState<TMetadata>>), mdns_sd::Error>
where
TMetadataFn: AsyncFn<Output = TMetadata>,
{
@ -47,88 +59,81 @@ where
let service_name = format!("_{}._udp.local.", application_name);
let mdns_service_receiver = mdns_daemon.browse(&service_name)?;
let this = Self {
manager,
fn_get_metadata,
mdns_daemon,
mdns_service_receiver,
service_name,
next_mdns_advertisement: Box::pin(sleep_until(
Instant::now() + MDNS_READVERTISEMENT_INTERVAL,
)),
};
this.advertise();
Ok(this)
let state = Arc::new(MdnsState {
discovered: RwLock::new(Default::default()),
listen_addrs: RwLock::new(Default::default()),
});
Ok((
Self {
peer_id,
fn_get_metadata,
mdns_daemon,
mdns_service_receiver,
service_name,
next_mdns_advertisement: Box::pin(sleep_until(Instant::now())), // Trigger an advertisement immediately
state: state.clone(),
},
state,
))
}
pub fn unregister_mdns(&self) -> mdns_sd::Result<mdns_sd::Receiver<mdns_sd::UnregisterStatus>> {
self.mdns_daemon
.unregister(&format!("{}.{}", self.manager.peer_id, self.service_name))
.unregister(&format!("{}.{}", self.peer_id, self.service_name))
}
/// Do an mdns advertisement to the network
pub fn advertise(&self) {
// TODO: Instead of spawning maybe do this as part of the polling loop to avoid needing persitent reference to manager.
let manager = self.manager.clone();
let service_name = self.service_name.clone();
// let fn_get_metadata = self.fn_get_metadata.clone();
let mdns_daemon = self.mdns_daemon.clone();
/// Do an mdns advertisement to the network.
async fn advertise(&mut self) {
let metadata = (self.fn_get_metadata)().await.to_hashmap();
let metadata_fut = (self.fn_get_metadata)();
// This is in simple terms converts from `Vec<(ip, port)>` to `Vec<(Vec<Ip>, port)>`
let mut services = HashMap::<u16, ServiceInfo>::new();
for addr in self.state.listen_addrs.read().await.iter() {
let addr = match addr {
SocketAddr::V4(addr) => addr,
// TODO: Our mdns library doesn't support Ipv6. This code has the infra to support it so once this issue is fixed upstream we can just flip it on.
// Refer to issue: https://github.com/keepsimple1/mdns-sd/issues/61
SocketAddr::V6(_) => continue,
};
tokio::spawn(async move {
let metadata = metadata_fut.await.to_hashmap();
let peer_id = manager.peer_id.0.to_base58();
// This is in simple terms converts from `Vec<(ip, port)>` to `Vec<(Vec<Ip>, port)>`
let mut services = HashMap::<u16, ServiceInfo>::new();
for addr in manager.listen_addrs.read().await.iter() {
let addr = match addr {
SocketAddr::V4(addr) => addr,
// TODO: Our mdns library doesn't support Ipv6. This code has the infra to support it so once this issue is fixed upstream we can just flip it on.
// Refer to issue: https://github.com/keepsimple1/mdns-sd/issues/61
SocketAddr::V6(_) => continue,
if let Some(mut service) = services.remove(&addr.port()) {
service.insert_ipv4addr(*addr.ip());
services.insert(addr.port(), service);
} else {
let service = match ServiceInfo::new(
&self.service_name,
&self.peer_id.to_string(),
&format!("{}.", self.peer_id),
*addr.ip(),
addr.port(),
Some(metadata.clone()), // TODO: Prevent the user defining a value that overflows a DNS record
) {
Ok(service) => service,
Err(err) => {
warn!("error creating mdns service info: {}", err);
continue;
}
};
if let Some(mut service) = services.remove(&addr.port()) {
service.insert_ipv4addr(*addr.ip());
services.insert(addr.port(), service);
} else {
let service = match ServiceInfo::new(
&service_name,
&peer_id,
&format!("{}.", peer_id),
*addr.ip(),
addr.port(),
Some(metadata.clone()), // TODO: Prevent the user defining a value that overflows a DNS record
) {
Ok(service) => service,
Err(err) => {
warn!("error creating mdns service info: {}", err);
continue;
}
};
services.insert(addr.port(), service);
}
services.insert(addr.port(), service);
}
}
for (_, service) in services.into_iter() {
debug!("advertising mdns service: {:?}", service);
match mdns_daemon.register(service) {
Ok(_) => {}
Err(err) => warn!("error registering mdns service: {}", err),
}
for (_, service) in services.into_iter() {
debug!("advertising mdns service: {:?}", service);
match self.mdns_daemon.register(service) {
Ok(_) => {}
Err(err) => warn!("error registering mdns service: {}", err),
}
});
}
self.next_mdns_advertisement =
Box::pin(sleep_until(Instant::now() + MDNS_READVERTISEMENT_INTERVAL));
}
// TODO: if the channel's sender is dropped will this cause the `tokio::select` in the `manager.rs` to infinitely loop?
pub async fn poll(&mut self) -> Option<Event<TMetadata>> {
pub async fn poll(&mut self, manager: &Arc<Manager<TMetadata>>) -> Option<Event<TMetadata>> {
tokio::select! {
_ = &mut self.next_mdns_advertisement => {
self.advertise();
self.next_mdns_advertisement = Box::pin(sleep_until(Instant::now() + MDNS_READVERTISEMENT_INTERVAL));
}
_ = &mut self.next_mdns_advertisement => self.advertise().await,
event = self.mdns_service_receiver.recv_async() => {
let event = event.unwrap(); // TODO: Error handling
match event {
@ -142,7 +147,7 @@ where
match PeerId::from_str(&raw_peer_id) {
Ok(peer_id) => {
// Prevent discovery of the current peer.
if peer_id == self.manager.peer_id {
if peer_id == self.peer_id {
return None;
}
@ -156,14 +161,14 @@ where
Ok(metadata) => {
let peer = {
let mut discovered_peers =
self.manager.discovered.write().await;
self.state.discovered.write().await;
let peer = if let Some(peer) = discovered_peers.remove(&peer_id) {
peer
} else {
DiscoveredPeer {
manager: self.manager.clone(),
manager: manager.clone(),
peer_id,
metadata,
addresses: info
@ -201,13 +206,13 @@ where
match PeerId::from_str(&raw_peer_id) {
Ok(peer_id) => {
// Prevent discovery of the current peer.
if peer_id == self.manager.peer_id {
if peer_id == self.peer_id {
return None;
}
{
let mut discovered_peers =
self.manager.discovered.write().await;
self.state.discovered.write().await;
let peer = discovered_peers.remove(&peer_id);
return Some(Event::PeerExpired {
@ -229,4 +234,26 @@ where
None
}
pub async fn register_addr(&mut self, addr: SocketAddr) {
self.state.listen_addrs.write().await.insert(addr);
// If the next mdns advertisement is more than 250ms away, then we should queue one closer to now.
// This acts as a debounce for advertisements when many addresses are discovered close to each other (Eg. at startup)
if self.next_mdns_advertisement.deadline() > (Instant::now() + Duration::from_millis(250)) {
self.next_mdns_advertisement =
Box::pin(sleep_until(Instant::now() + Duration::from_millis(200)));
}
}
pub async fn unregister_addr(&mut self, addr: &SocketAddr) {
self.state.listen_addrs.write().await.remove(addr);
// If the next mdns advertisement is more than 250ms away, then we should queue one closer to now.
// This acts as a debounce for advertisements when many addresses are discovered close to each other (Eg. at startup)
if self.next_mdns_advertisement.deadline() > (Instant::now() + Duration::from_millis(250)) {
self.next_mdns_advertisement =
Box::pin(sleep_until(Instant::now() + Duration::from_millis(200)));
}
}
}

View file

@ -36,8 +36,9 @@ impl SpaceTimeStream {
Self::Broadcast(mut stream) => {
if let Some(stream) = stream.0.take() {
BroadcastStream::close_inner(stream).await
} else if cfg!(debug_assertions) {
panic!("'BroadcastStream' should never be 'None' here!");
} else {
debug_assert!(true, "'BroadcastStream' should never be 'None' here!");
error!("'BroadcastStream' should never be 'None' here!");
Ok(())
}