diff --git a/Cargo.toml b/Cargo.toml index 49cdcc5..84794ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tower-github-webhook" -version = "0.1.2" +version = "0.2.0" edition = "2021" authors = ["Sebastian Rollén "] license = "MIT" @@ -17,17 +17,18 @@ hex = "0.4.3" hmac = "0.12.1" http = "1.0.0" http-body = "1.0.0" -pin-project = "1.1.3" +http-body-util = "0.1.0" +pin-project-lite = "0.2.14" sha2 = "0.10.8" -tower = { version = "0.4.13", features = ["util"] } +tower-layer = "0.3.2" +tower-service = "0.3.2" tracing = "0.1.40" [dev-dependencies] axum = { version = "0.7.4", features = ["macros"] } http-body-util = "0.1.0" -hyper = "1.1.0" octocrab = "0.33.3" serde = { version = "1.0.196", features = ["derive"] } tokio = { version = "1.35.1", features = ["full"] } -tracing = "0.1.40" tracing-subscriber = "0.3.18" +tower = { version = "0.4.13", features = ["util"] } diff --git a/README.md b/README.md index 9e5efda..99990d4 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,5 @@ # tower-github-webhook -## WORK IN PROGRESS - -This crate does not currently work as intended—the middleware empties the request body completely rather than passing the body on to the inner service. - -I would not recommend using this crate for anything meaningful until I have had time to fix the issue - `tower-github-webhook` is a crate that simplifies validating webhooks received from GitHub. [![Crates.io](https://img.shields.io/crates/v/tower-github-webhook)](https://crates.io/crates/tower-github-webhook) diff --git a/examples/simple.rs b/examples/simple.rs index 7d0f511..9bcb30d 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,9 +1,19 @@ +//! # Example +//! +//! This is a simple example of how to implement a webhook handler to handle incoming GitHub +//! events. +//! It uses `octocrab` for definitions of the various webhooks that are sent from GitHub, `axum` as +//! a server and, of course, `tower-github-webhook` to handle authenticatition of the incoming +//! webhook. +//! +//! The `Event` struct has implements the `FromRequest` axum trait so that it can be used as a +//! parameter in the axum handler. use axum::async_trait; use axum::body::Bytes; use axum::debug_handler; use axum::extract::{FromRequest, Request}; use axum::response::{IntoResponse, Response}; -use axum::{extract::Json, routing::post, Router}; +use axum::{routing::post, Router}; use octocrab::models::{ webhook_events::{WebhookEvent, WebhookEventPayload, WebhookEventType}, Author, Repository, @@ -11,6 +21,8 @@ use octocrab::models::{ use serde::{Deserialize, Serialize}; use tower_github_webhook::ValidateGitHubWebhookLayer; +const WEBHOOK_SECRET: &'static str = "my little secret"; + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct Event { pub kind: WebhookEventType, @@ -74,11 +86,11 @@ fn app() -> Router { // Build route service Router::new().route( "/github/events", - post(print_body).layer(ValidateGitHubWebhookLayer::new("123")), + post(print_body).layer(ValidateGitHubWebhookLayer::new(WEBHOOK_SECRET)), ) } #[debug_handler] -async fn print_body(Json(event): Json) { +async fn print_body(event: Event) { println!("{:#?}", event); } diff --git a/src/future.rs b/src/future.rs index b3af550..2a284d2 100644 --- a/src/future.rs +++ b/src/future.rs @@ -1,179 +1,155 @@ -use bytes::Buf; +use bytes::{Buf, Bytes, BytesMut}; use hmac::{Hmac, Mac}; -use http::{Request, Response, StatusCode}; +use http::{request::Parts, Request, Response, StatusCode}; use http_body::Body; -use pin_project::pin_project; +use http_body_util::{Either, Empty, Full}; +use pin_project_lite::pin_project; use sha2::Sha256; -use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use tower::Service; +use tower_service::Service; -#[pin_project] -pub struct ValidateGitHubWebhookFuture< - S: Service, Response = Response>, - ReqBody, - ResBody, -> { - req: Option>, - signature: Option>, - inner: S, - hmac: Option>, - #[pin] - state: ValidateGitHubWebhookFutureState, +type FutureResponse = Result>>, Error>; + +pin_project! { + pub struct Future>, Response = Response>, ReqBody, ResBody> { + // We use Option here and for `hmac` to make it easy to move these fields out of the future + // later. + parts: Option, + buffer: BytesMut, + inner: S, + hmac: Option>, + #[pin] + body: ReqBody, + #[pin] + state: State, + } } -impl ValidateGitHubWebhookFuture +impl Future where - S: Service, Response = Response>, + S: Service>, Response = Response>, + ReqBody: Body, { pub fn new(req: Request, hmac: Hmac, inner: S) -> Self { + let (parts, body) = req.into_parts(); + let body_size = body.size_hint().lower().try_into().unwrap_or(0); + let buffer = BytesMut::with_capacity(body_size); Self { - req: Some(req), - signature: None, + parts: Some(parts), + body, + buffer, inner, hmac: Some(hmac), - state: ValidateGitHubWebhookFutureState::ExtractSignature, + state: State::new(), } } } -impl Future for ValidateGitHubWebhookFuture +pin_project! { + #[project = StateProj] + enum State { + ExtractSignature, + ExtractBody { + signature: Vec, + }, + Inner { + #[pin] + fut: F, + }, + } +} + +impl State { + pub fn new() -> Self { + Self::ExtractSignature + } +} + +impl std::future::Future for Future where - S: Service, Response = Response, Future = F>, - F: Future, S::Error>>, - ReqBody: Body + Unpin, - ResBody: Body + Default, + S: Service>, Response = Response, Future = F>, + F: std::future::Future, S::Error>>, + ReqBody: Body, { - type Output = Result, S::Error>; + type Output = FutureResponse; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut().project(); let mut curr_state = this.state; match curr_state.as_mut().project() { - ValidateGitHubProj::ExtractSignature => { - tracing::trace!( - "[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractSignature" - ); - let req = this.req.take().unwrap(); - let signature = match req.headers().get("x-hub-signature-256") { - Some(sig) => { - let Some(sig) = sig.as_bytes().splitn(2, |x| x == &b'=').nth(1) else { - tracing::debug!("[tower-github-webhook] Invalid header format"); - curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); - cx.waker().wake_by_ref(); - return Poll::Pending; - }; - match hex::decode(sig) { - Ok(sig) => sig, - Err(_) => { - tracing::debug!("[tower-github-webhook] Invalid header format"); - curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); - cx.waker().wake_by_ref(); - return Poll::Pending; - } - } - } - None => { - tracing::debug!( - "[tower-github-webhook] Missing X-HUB-SIGNATURE-256 header" - ); - curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); - cx.waker().wake_by_ref(); - return Poll::Pending; - } + StateProj::ExtractSignature => { + let parts = this + .parts + .take() + .expect("Parts is either reset at the end of this method, or we bail"); + let Some(signature) = parts.headers.get("x-hub-signature-256") else { + return bail("Missing X-HUB-SIGNATURE-256 header"); + }; + let Some(signature) = signature.as_bytes().splitn(2, |x| x == &b'=').nth(1) else { + return bail("Invalid header format"); }; - curr_state.set(ValidateGitHubWebhookFutureState::ExtractBody); - *this.signature = Some(signature); - *this.req = Some(req); - cx.waker().wake_by_ref(); - Poll::Pending + let Ok(signature) = hex::decode(signature) else { + return bail("Invalid header format"); + }; + *this.parts = Some(parts); + curr_state.set(State::ExtractBody { signature }); + rewake(cx) } - ValidateGitHubProj::ExtractBody => { - tracing::trace!( - "[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractBody" - ); - let mut req = this.req.take().unwrap(); - let body = Pin::new(req.body_mut()); - if body.is_end_stream() { - curr_state.set(ValidateGitHubWebhookFutureState::ValidateSignature); + StateProj::ExtractBody { signature } => { + if this.body.is_end_stream() { + // We're done updating the HMAC, so we can now move it out + let hmac = this + .hmac + .take() + .expect("HMAC is only moved out of the option once, here"); + if hmac.verify_slice(signature).is_ok() { + let parts = this.parts.take().unwrap(); + let body = Full::new(this.buffer.split().freeze()); + let req = Request::from_parts(parts, body); + let fut = this.inner.call(req); + curr_state.set(State::Inner { fut }); + rewake(cx) + } else { + bail("Invalid signature") + } } else { - let frame = match Pin::new(req.body_mut()).poll_frame(cx) { - Poll::Pending => { - *this.req = Some(req); - return Poll::Pending; - } - Poll::Ready(frame) => frame, + let Poll::Ready(maybe_frame) = this.body.poll_frame(cx) else { + return Poll::Pending; }; - - if let Some(Ok(frame)) = frame { + if let Some(Ok(frame)) = maybe_frame { if let Ok(data) = frame.into_data() { - let mut hmac = this.hmac.take().unwrap(); - hmac.update(data.chunk()); - *this.hmac = Some(hmac); + let bytes = data.chunk(); + this.buffer.extend(bytes); + let Some(h) = this.hmac.as_mut() else { + unreachable!() + }; + h.update(bytes); } } + rewake(cx) } - *this.req = Some(req); - cx.waker().wake_by_ref(); - Poll::Pending } - ValidateGitHubProj::ValidateSignature => { - tracing::trace!( - "[tower-github-webhook] ValidateGitHubWebhookFutureState::ValidateSignature" - ); - let signature = this.signature.take().unwrap(); - let hmac = this.hmac.take().unwrap(); - if hmac.verify_slice(&signature).is_ok() { - tracing::debug!("[tower-github-webhook] Valid signature"); - curr_state.set(ValidateGitHubWebhookFutureState::InnerBefore); - } else { - tracing::debug!("[tower-github-webhook] Invalid signature"); - curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); - } - cx.waker().wake_by_ref(); - Poll::Pending - } - ValidateGitHubProj::InnerBefore => { - tracing::trace!( - "[tower-github-webhook] ValidateGitHubWebhookFutureState::InnerBefore" - ); - let req = this.req.take().unwrap(); - let fut = this.inner.call(req); - curr_state.set(ValidateGitHubWebhookFutureState::Inner { fut }); - cx.waker().wake_by_ref(); - Poll::Pending - } - ValidateGitHubProj::Inner { fut } => { - tracing::trace!("[tower-github-webhook] ValidateGitHubWebhookFutureState::Inner"); - fut.poll(cx) - } - ValidateGitHubProj::Unauthorized => { - tracing::trace!( - "[tower-github-webhook] ValidateGitHubWebhookFutureState::Unauthorized" - ); - tracing::warn!("[tower-github-webhook] Request not authorized"); - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - Poll::Ready(Ok(res)) + StateProj::Inner { fut } => { + let Poll::Ready(response) = fut.poll(cx) else { + return Poll::Pending; + }; + let response = response?; + Poll::Ready(Ok(response.map(|b| Either::Left(b)))) } } } } -#[pin_project(project = ValidateGitHubProj)] -pub(crate) enum ValidateGitHubWebhookFutureState< - ReqBody, - ResBody, - S: Service, Response = http::Response>, -> { - ExtractSignature, - ExtractBody, - ValidateSignature, - InnerBefore, - Inner { - #[pin] - fut: S::Future, - }, - Unauthorized, +fn bail(debug_message: &str) -> Poll> { + tracing::debug!("[tower-github-webhook] {debug_message}"); + tracing::warn!("[tower-github-webhook] Request not authorized"); + let mut res = Response::new(Either::Right(Empty::new())); + *res.status_mut() = StatusCode::UNAUTHORIZED; + Poll::Ready(Ok(res)) +} + +fn rewake(cx: &mut Context<'_>) -> Poll { + cx.waker().wake_by_ref(); + Poll::Pending } diff --git a/src/layer.rs b/src/layer.rs index 77d8d72..f30bb62 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -1,5 +1,5 @@ use crate::ValidateGitHubWebhook; -use tower::Layer; +use tower_layer::Layer; /// Layer that applies the [ValidateGitHubWebhook] middleware which authorizes all requests using /// the `X-Hub-Signature-256` header. @@ -11,6 +11,10 @@ pub struct ValidateGitHubWebhookLayer { impl ValidateGitHubWebhookLayer { /// Authorize requests using the `X-Hub-Signature-256` header. If the signature specified in /// that header is not signed using the `webhook_secret` secret, the request will fail. + /// + /// The `webhook_secret` parameter can be any type that implements `AsRef<[u8]>` such as + /// `String`. However, using `secrecy::SecretString` is recommended to prevent the secret from + /// being printed in any logs. pub fn new(webhook_secret: Secret) -> Self { Self { webhook_secret } } diff --git a/src/lib.rs b/src/lib.rs index d1367fe..7458c34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,14 @@ //! # Overview //! //! `tower-github-webhook` is a crate for verifying signed webhooks received from GitHub. +//! +//! The crate exports two structs: `ValidateGitHubWebhookLayer` and `ValidateGitHubWebhook`. These +//! structs implement `tower_layer::Layer` and `tower_service::Service`, respectively, and so can +//! be used as middleware for any servers that build on top of the Tower ecosystem. mod future; mod layer; mod service; #[cfg(test)] -mod test_helpers; -#[cfg(test)] mod tests; pub use layer::ValidateGitHubWebhookLayer; diff --git a/src/service.rs b/src/service.rs index b67e7bf..643df3e 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,10 +1,12 @@ -use crate::future::ValidateGitHubWebhookFuture; +use crate::future::Future; +use bytes::Bytes; use hmac::{Hmac, Mac}; use http::{Request, Response}; use http_body::Body; +use http_body_util::{Either, Empty, Full}; use sha2::Sha256; use std::task::{Context, Poll}; -use tower::Service; +use tower_service::Service; /// Middleware that authorizes all requests using the X-Hub-Signature-256 header. #[derive(Clone)] @@ -23,13 +25,12 @@ impl ValidateGitHubWebhook { impl Service> for ValidateGitHubWebhook where - S: Service, Response = Response> + Clone, - ReqBody: Body + Unpin, - ResBody: Body + Default, + S: Service>, Response = Response> + Clone, + ReqBody: Body, { - type Response = Response; + type Response = Response>>; type Error = S::Error; - type Future = ValidateGitHubWebhookFuture; + type Future = Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -38,6 +39,6 @@ where fn call(&mut self, req: Request) -> Self::Future { let inner = self.inner.clone(); let hmac = self.hmac.clone(); - ValidateGitHubWebhookFuture::new(req, hmac, inner) + Future::new(req, hmac, inner) } } diff --git a/src/test_helpers.rs b/src/test_helpers.rs deleted file mode 100644 index 9b989af..0000000 --- a/src/test_helpers.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use http_body::Frame; -use http_body_util::BodyExt; -use tower::BoxError; - -type BoxBody = http_body_util::combinators::UnsyncBoxBody; - -#[derive(Debug)] -pub(crate) struct Body(BoxBody); - -impl Body { - pub(crate) fn new(body: B) -> Self - where - B: http_body::Body + Send + 'static, - B::Error: Into, - { - Self(body.map_err(Into::into).boxed_unsync()) - } - - pub(crate) fn empty() -> Self { - Self::new(http_body_util::Empty::new()) - } -} - -impl Default for Body { - fn default() -> Self { - Self::empty() - } -} - -macro_rules! body_from_impl { - ($ty:ty) => { - impl From<$ty> for Body { - fn from(buf: $ty) -> Self { - Self::new(http_body_util::Full::from(buf)) - } - } - }; -} - -body_from_impl!(&'static [u8]); -body_from_impl!(std::borrow::Cow<'static, [u8]>); -body_from_impl!(Vec); - -body_from_impl!(&'static str); -body_from_impl!(std::borrow::Cow<'static, str>); -body_from_impl!(String); - -body_from_impl!(Bytes); - -impl http_body::Body for Body { - type Data = Bytes; - type Error = BoxError; - - fn poll_frame( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - Pin::new(&mut self.0).poll_frame(cx) - } - - fn size_hint(&self) -> http_body::SizeHint { - self.0.size_hint() - } - - fn is_end_stream(&self) -> bool { - self.0.is_end_stream() - } -} diff --git a/src/tests.rs b/src/tests.rs index b388786..56400f4 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,19 +1,22 @@ -use crate::test_helpers::Body; use crate::ValidateGitHubWebhookLayer; +use bytes::Bytes; use hmac::{Hmac, Mac}; use http::{Request, Response, StatusCode}; +use http_body_util::Full; use sha2::Sha256; use tower::{service_fn, util::ServiceExt, BoxError, Layer}; -async fn echo(req: Request) -> Result, BoxError> { +async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } +type EmptyBody = http_body_util::Empty; + #[tokio::test] async fn gives_unauthorized_error_when_no_header() { let svc_fun = service_fn(echo); let svc = ValidateGitHubWebhookLayer::new("123").layer(svc_fun); - let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + let res = svc.oneshot(Request::new(EmptyBody::new())).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED) } @@ -25,7 +28,7 @@ async fn gives_unauthorized_error_when_wrong_signature() { .oneshot( Request::builder() .header("x-hub-signature-256", "sha256=fake") - .body(Body::empty()) + .body(EmptyBody::new()) .unwrap(), ) .await @@ -35,20 +38,27 @@ async fn gives_unauthorized_error_when_wrong_signature() { #[tokio::test] async fn gives_ok_when_correct_signature() { + use http_body_util::BodyExt; + let svc_fun = service_fn(echo); let svc = ValidateGitHubWebhookLayer::new("123").layer(svc_fun); - let hmac = + let mut hmac = Hmac::::new_from_slice("123".as_bytes()).expect("Failed to parse webhook secret"); + hmac.update(b"hello world"); let signature = format!("sha256={}", hex::encode(hmac.finalize().into_bytes())); let res = svc .oneshot( Request::builder() .header("x-hub-signature-256", signature) - .body(Body::empty()) + .body(Full::new(Bytes::from_static(b"hello world"))) .unwrap(), ) .await .unwrap(); - assert_eq!(res.status(), StatusCode::OK) + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.into_body().collect().await.unwrap().to_bytes(), + Bytes::from("hello world") + ); }