diff --git a/src/pubsub/kafka.rs b/src/pubsub/kafka.rs index 3bf07f33..62b54429 100644 --- a/src/pubsub/kafka.rs +++ b/src/pubsub/kafka.rs @@ -150,35 +150,20 @@ pub fn launch_subscribers( async fn subscriber_task(client: Arc, topics: Vec) { PUBSUB_SUBSCRIBE.increment(); + let sub_topics: Vec<&str> = topics.iter().map(AsRef::as_ref).collect(); + if client.subscribe(&sub_topics).is_ok() { PUBSUB_SUBSCRIBER_CURR.add(1); PUBSUB_SUBSCRIBE_OK.increment(); - let msg_stamp = MessageValidator::new(); + + let validator = MessageValidator::new(); + while RUNNING.load(Ordering::Relaxed) { match client.recv().await { - Ok(m) => match m.payload_view::<[u8]>() { - Some(Ok(m)) => { - let mut v = m.to_owned(); - match msg_stamp.validate_msg(&mut v) { - MessageValidationResult::Unexpected => { - error!("pubsub: invalid message received"); - RESPONSE_EX.increment(); - PUBSUB_RECEIVE_INVALID.increment(); - continue; - } - MessageValidationResult::Corrupted => { - error!("pubsub: corrupt message received"); - PUBSUB_RECEIVE.increment(); - PUBSUB_RECEIVE_CORRUPT.increment(); - continue; - } - MessageValidationResult::Validated(latency) => { - let _ = PUBSUB_LATENCY.increment(latency); - PUBSUB_RECEIVE.increment(); - PUBSUB_RECEIVE_OK.increment(); - } - } + Ok(message) => match message.payload_view::<[u8]>() { + Some(Ok(message)) => { + let _ = validator.validate(&mut message.to_owned()); } Some(Err(e)) => { error!("Error in deserializing the message:{:?}", e); @@ -211,7 +196,9 @@ pub fn launch_publishers(runtime: &mut Runtime, config: Config, work_receiver: R let _guard = runtime.enter(); Arc::new(get_kafka_producer(&config)) }; + PUBSUB_PUBLISHER_CONNECT.increment(); + for _ in 0..config.pubsub().unwrap().publisher_concurrency() { runtime.spawn(publisher_task(client.clone(), work_receiver.clone())); } @@ -223,14 +210,19 @@ async fn publisher_task( work_receiver: Receiver, ) -> Result<()> { PUBSUB_PUBLISHER_CURR.add(1); - let msg_stamp = MessageValidator::new(); + + let validator = MessageValidator::new(); + while RUNNING.load(Ordering::Relaxed) { let work_item = work_receiver .recv() .await .map_err(|_| Error::new(ErrorKind::Other, "channel closed"))?; + REQUEST.increment(); + let start = Instant::now(); + let result = match work_item { WorkItem::Publish { topic, @@ -238,7 +230,7 @@ async fn publisher_task( key, mut message, } => { - let timestamp = msg_stamp.stamp_msg(&mut message); + let timestamp = validator.stamp(&mut message); PUBSUB_PUBLISH.increment(); client .send( @@ -255,7 +247,9 @@ async fn publisher_task( .await } }; + let stop = Instant::now(); + match result { Ok(_) => { let latency = stop.duration_since(start).as_nanos(); @@ -268,6 +262,8 @@ async fn publisher_task( } } } + PUBSUB_PUBLISHER_CURR.sub(1); + Ok(()) } diff --git a/src/pubsub/mod.rs b/src/pubsub/mod.rs index 301a2261..68aef414 100644 --- a/src/pubsub/mod.rs +++ b/src/pubsub/mod.rs @@ -10,72 +10,81 @@ use tokio::runtime::Runtime; mod kafka; mod momento; +pub fn hasher() -> RandomState { + RandomState::with_seeds( + 0xd5b96f9126d61cee, + 0x50af85c9d1b6de70, + 0xbd7bdf2fee6d15b2, + 0x3dbe88bb183ac6f4, + ) +} + struct MessageValidator { hash_builder: RandomState, } -pub enum MessageValidationResult { - // u64 is the end-to-end latency in nanosecond) - Validated(u64), + +pub enum ValidationError { Unexpected, Corrupted, } + impl MessageValidator { - // Deterministic seeds are used so that multiple MessageStamp can stamp and validate messages + /// Deterministic seeds are used so that multiple validators can stamp and + /// validate messages produced by other instances. pub fn new() -> Self { MessageValidator { - hash_builder: RandomState::with_seeds( - 0xd5b96f9126d61cee, - 0x50af85c9d1b6de70, - 0xbd7bdf2fee6d15b2, - 0x3dbe88bb183ac6f4, - ), + hash_builder: hasher(), } } - pub fn stamp_msg(&self, message: &mut [u8]) -> u64 { + + /// Sets the checksum and timestamp in the message. Returns the timestamp. + pub fn stamp(&self, message: &mut [u8]) -> u64 { let timestamp = (UnixInstant::now() - UnixInstant::from_nanos(0)).as_nanos(); let ts = timestamp.to_be_bytes(); + // write the current unix time into the message - [ - message[16], - message[17], - message[18], - message[19], - message[20], - message[21], - message[22], - message[23], - ] = ts; + message[16..24].copy_from_slice(&ts[0..8]); // todo, write a sequence number into the message // checksum the message and put the checksum into the message - [ - message[8], - message[9], - message[10], - message[11], - message[12], - message[13], - message[14], - message[15], - ] = self.hash_builder.hash_one(&message).to_be_bytes(); + let checksum = self.hash_builder.hash_one(&message).to_be_bytes(); + message[8..16].copy_from_slice(&checksum); + timestamp } - pub fn validate_msg(&self, v: &mut Vec) -> MessageValidationResult { + + /// Validate the message checksum and returns a validation result. + pub fn validate(&self, v: &mut Vec) -> std::result::Result { let now_unix = UnixInstant::now(); - if [v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]] - != [0x54, 0x45, 0x53, 0x54, 0x49, 0x4E, 0x47, 0x21] - { - return MessageValidationResult::Unexpected; + + // check if the magic bytes match + if v[0..8] != [0x54, 0x45, 0x53, 0x54, 0x49, 0x4E, 0x47, 0x21] { + error!("pubsub: unexpected/invalid message received"); + RESPONSE_EX.increment(); + PUBSUB_RECEIVE_INVALID.increment(); + return Err(ValidationError::Unexpected); } - let csum = [v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]]; - [v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]] = [0; 8]; + + // validate the checksum + let csum = v[8..16].to_owned(); + v[8..16].copy_from_slice(&[0; 8]); if csum != self.hash_builder.hash_one(&v).to_be_bytes() { - return MessageValidationResult::Corrupted; + error!("pubsub: corrupt message received"); + PUBSUB_RECEIVE.increment(); + PUBSUB_RECEIVE_CORRUPT.increment(); + return Err(ValidationError::Corrupted); } + + // calculate and return the end to end latency let ts = u64::from_be_bytes([v[16], v[17], v[18], v[19], v[20], v[21], v[22], v[23]]); let latency = now_unix - UnixInstant::from_nanos(ts); - MessageValidationResult::Validated(latency.as_nanos()) + + let _ = PUBSUB_LATENCY.increment(latency.as_nanos()); + PUBSUB_RECEIVE.increment(); + PUBSUB_RECEIVE_OK.increment(); + + Ok(latency.as_nanos()) } } @@ -89,6 +98,7 @@ impl PubsubRuntimes { if let Some(rt) = self.publisher_rt.take() { rt.shutdown_timeout(duration); } + if let Some(rt) = self.subscriber_rt.take() { rt.shutdown_timeout(duration); } diff --git a/src/pubsub/momento.rs b/src/pubsub/momento.rs index d15e04a4..f696e70d 100644 --- a/src/pubsub/momento.rs +++ b/src/pubsub/momento.rs @@ -76,31 +76,13 @@ async fn subscriber_task(client: Arc, cache_name: String, topic: St PUBSUB_SUBSCRIBER_CURR.add(1); PUBSUB_SUBSCRIBE_OK.increment(); - let msg_stamp = MessageValidator::new(); + let validator = MessageValidator::new(); while RUNNING.load(Ordering::Relaxed) { match subscription.next().await { Some(SubscriptionItem::Value(v)) => { if let ValueKind::Binary(mut v) = v.kind { - match msg_stamp.validate_msg(&mut v) { - MessageValidationResult::Unexpected => { - error!("pubsub: invalid message received"); - RESPONSE_EX.increment(); - PUBSUB_RECEIVE_INVALID.increment(); - continue; - } - MessageValidationResult::Corrupted => { - error!("pubsub: corrupt message received"); - PUBSUB_RECEIVE.increment(); - PUBSUB_RECEIVE_CORRUPT.increment(); - continue; - } - MessageValidationResult::Validated(latency) => { - let _ = PUBSUB_LATENCY.increment(latency); - PUBSUB_RECEIVE.increment(); - PUBSUB_RECEIVE_OK.increment(); - } - } + let _ = validator.validate(&mut v); } else { error!("there was a string in the topic"); // unexpected message @@ -184,7 +166,7 @@ async fn publisher_task( }) .to_string(); - let msg_stamp = MessageValidator::new(); + let validator = MessageValidator::new(); while RUNNING.load(Ordering::Relaxed) { let work_item = work_receiver @@ -201,7 +183,7 @@ async fn publisher_task( partition: _, key: _, } => { - msg_stamp.stamp_msg(&mut message); + validator.stamp(&mut message); PUBSUB_PUBLISH.increment(); match timeout(