diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5eed895 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +<<<<<<< HEAD +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..600bd70 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "tower-github-webhook" +version = "0.1.0" +edition = "2021" +authors = ["Sebastian Rollén "] +license = "MIT" +repository = "https://github.com/SebRollen/tower-github-webhook" +description = "tower-github-webhook is a crate that simplifies validating webhooks received from GitHub " +keywords = ["tower", "layer", "service", "github", "webhook"] +categories = ["authentication", "web-programming"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.5.0" +hex = "0.4.3" +hmac = "0.12.1" +http = "1.0.0" +http-body = "1.0.0" +pin-project = "1.1.3" +sha2 = "0.10.8" +tower = { version = "0.4.13", features = ["util"] } +tracing = "0.1.40" + +[dev-dependencies] +axum = { version = "0.7.4", features = ["macros"] } +http-body-util = "0.1.0" +hyper = "1.1.0" +octocrab = "0.33.3" +serde = { version = "1.0.196", features = ["derive"] } +tokio = { version = "1.35.1", features = ["full"] } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..026d676 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Sebastian Rollén + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cfa479b --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# tower-github-webhook + +`tower-github-webhook` is a crate that simplifies validating webhooks received from GitHub. diff --git a/examples/simple.rs b/examples/simple.rs new file mode 100644 index 0000000..7d0f511 --- /dev/null +++ b/examples/simple.rs @@ -0,0 +1,84 @@ +use axum::async_trait; +use axum::body::Bytes; +use axum::debug_handler; +use axum::extract::{FromRequest, Request}; +use axum::response::{IntoResponse, Response}; +use axum::{extract::Json, routing::post, Router}; +use octocrab::models::{ + webhook_events::{WebhookEvent, WebhookEventPayload, WebhookEventType}, + Author, Repository, +}; +use serde::{Deserialize, Serialize}; +use tower_github_webhook::ValidateGitHubWebhookLayer; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Event { + pub kind: WebhookEventType, + pub sender: Option, + pub repository: Option, + pub payload: WebhookEventPayload, +} + +impl From for Event { + fn from(e: WebhookEvent) -> Self { + Self { + kind: e.kind, + sender: e.sender, + repository: e.repository, + payload: e.specific, + } + } +} + +#[async_trait] +impl FromRequest for Event +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + let headers = req.headers().clone(); + let header = headers + .get("x-github-event") + .map(|x| x.to_str()) + .unwrap() + .map_err(|_| { + "Failed to convert header to string" + .to_string() + .into_response() + })?; + let bytes = Bytes::from_request(req, state) + .await + .map_err(IntoResponse::into_response)?; + let webhook_event = WebhookEvent::try_from_header_and_body(header, &bytes).unwrap(); + Ok(Self::from(webhook_event)) + } +} + +#[tokio::main] +async fn main() { + // Setup tracing + tracing_subscriber::fmt::init(); + + // Run our service + let addr = "0.0.0.0:3000"; + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + tracing::info!("Listening on {}", addr); + axum::serve(listener, app().into_make_service()) + .await + .unwrap(); +} + +fn app() -> Router { + // Build route service + Router::new().route( + "/github/events", + post(print_body).layer(ValidateGitHubWebhookLayer::new("123")), + ) +} + +#[debug_handler] +async fn print_body(Json(event): Json) { + println!("{:#?}", event); +} diff --git a/src/future.rs b/src/future.rs new file mode 100644 index 0000000..6d63638 --- /dev/null +++ b/src/future.rs @@ -0,0 +1,172 @@ +use bytes::Buf; +use hmac::{Hmac, Mac}; +use http::{Request, Response, StatusCode}; +use http_body::Body; +use pin_project::pin_project; +use sha2::Sha256; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tower::Service; + +#[pin_project] +pub struct ValidateGitHubWebhookFuture< + S: Service, Response = Response>, + ReqBody, + ResBody, +> { + req: Option>, + signature: Option>, + inner: S, + hmac: Option>, + #[pin] + state: ValidateGitHubWebhookFutureState, +} + +impl ValidateGitHubWebhookFuture +where + S: Service, Response = Response>, +{ + pub fn new(req: Request, hmac: Hmac, inner: S) -> Self { + Self { + req: Some(req), + signature: None, + inner, + hmac: Some(hmac), + state: ValidateGitHubWebhookFutureState::ExtractSignature, + } + } +} + +impl Future for ValidateGitHubWebhookFuture +where + S: Service, Response = Response, Future = F>, + F: Future, S::Error>>, + ReqBody: Body + Unpin, + ResBody: Body + Default, +{ + type Output = Result, S::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + let mut curr_state = this.state; + match curr_state.as_mut().project() { + ValidateGitHubProj::ExtractSignature => { + tracing::trace!( + "[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractSignature" + ); + let req = this.req.take().unwrap(); + let signature = match req.headers().get("x-hub-signature-256") { + Some(sig) => { + let Some(sig) = sig.as_bytes().splitn(2, |x| x == &b'=').nth(1) else { + tracing::debug!("[tower-github-webhook] Invalid header format"); + curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + match hex::decode(sig) { + Ok(sig) => sig, + Err(_) => { + tracing::debug!("[tower-github-webhook] Invalid header format"); + curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + None => { + tracing::debug!( + "[tower-github-webhook] Missing X-HUB-SIGNATURE-256 header" + ); + curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + curr_state.set(ValidateGitHubWebhookFutureState::ExtractBody); + *this.signature = Some(signature); + *this.req = Some(req); + cx.waker().wake_by_ref(); + Poll::Pending + } + ValidateGitHubProj::ExtractBody => { + tracing::trace!( + "[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractBody" + ); + let mut req = this.req.take().unwrap(); + let body = Pin::new(req.body_mut()); + if body.is_end_stream() { + curr_state.set(ValidateGitHubWebhookFutureState::ValidateSignature); + } else { + let frame = ready!(Pin::new(req.body_mut()).poll_frame(cx)); + if let Some(Ok(frame)) = frame { + if let Ok(data) = frame.into_data() { + let mut hmac = this.hmac.take().unwrap(); + hmac.update(data.chunk()); + *this.hmac = Some(hmac); + } + } + } + *this.req = Some(req); + cx.waker().wake_by_ref(); + Poll::Pending + } + ValidateGitHubProj::ValidateSignature => { + tracing::trace!( + "[tower-github-webhook] ValidateGitHubWebhookFutureState::ValidateSignature" + ); + let signature = this.signature.take().unwrap(); + let hmac = this.hmac.take().unwrap(); + if hmac.verify_slice(&signature).is_ok() { + tracing::debug!("[tower-github-webhook] Valid signature"); + curr_state.set(ValidateGitHubWebhookFutureState::InnerBefore); + } else { + tracing::debug!("[tower-github-webhook] Invalid signature"); + curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); + } + cx.waker().wake_by_ref(); + Poll::Pending + } + ValidateGitHubProj::InnerBefore => { + tracing::trace!( + "[tower-github-webhook] ValidateGitHubWebhookFutureState::InnerBefore" + ); + let req = this.req.take().unwrap(); + let fut = this.inner.call(req); + curr_state.set(ValidateGitHubWebhookFutureState::Inner { fut }); + cx.waker().wake_by_ref(); + Poll::Pending + } + ValidateGitHubProj::Inner { fut } => { + tracing::trace!("[tower-github-webhook] ValidateGitHubWebhookFutureState::Inner"); + fut.poll(cx) + } + ValidateGitHubProj::Unauthorized => { + tracing::trace!( + "[tower-github-webhook] ValidateGitHubWebhookFutureState::Unauthorized" + ); + tracing::warn!("[tower-github-webhook] Request not authorized"); + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + Poll::Ready(Ok(res)) + } + } + } +} + +#[pin_project(project = ValidateGitHubProj)] +pub(crate) enum ValidateGitHubWebhookFutureState< + ReqBody, + ResBody, + S: Service, Response = http::Response>, +> { + ExtractSignature, + ExtractBody, + ValidateSignature, + InnerBefore, + Inner { + #[pin] + fut: S::Future, + }, + Unauthorized, +} diff --git a/src/layer.rs b/src/layer.rs new file mode 100644 index 0000000..4fe0bbb --- /dev/null +++ b/src/layer.rs @@ -0,0 +1,24 @@ +use crate::ValidateGitHubWebhook; +use tower::Layer; + +#[derive(Clone)] +pub struct ValidateGitHubWebhookLayer { + webhook_secret: Secret, +} + +impl ValidateGitHubWebhookLayer { + pub fn new(webhook_secret: Secret) -> Self { + Self { webhook_secret } + } +} + +impl Layer for ValidateGitHubWebhookLayer +where + Secret: AsRef<[u8]> + Clone, +{ + type Service = ValidateGitHubWebhook; + + fn layer(&self, inner: S) -> Self::Service { + ValidateGitHubWebhook::new(self.webhook_secret.clone(), inner) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d1367fe --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,13 @@ +//! # Overview +//! +//! `tower-github-webhook` is a crate for verifying signed webhooks received from GitHub. +mod future; +mod layer; +mod service; +#[cfg(test)] +mod test_helpers; +#[cfg(test)] +mod tests; + +pub use layer::ValidateGitHubWebhookLayer; +pub use service::ValidateGitHubWebhook; diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 0000000..02445c6 --- /dev/null +++ b/src/service.rs @@ -0,0 +1,42 @@ +use crate::future::ValidateGitHubWebhookFuture; +use hmac::{Hmac, Mac}; +use http::{Request, Response}; +use http_body::Body; +use sha2::Sha256; +use std::task::{Context, Poll}; +use tower::Service; + +#[derive(Clone)] +pub struct ValidateGitHubWebhook { + inner: S, + hmac: Hmac, +} + +impl ValidateGitHubWebhook { + pub fn new(webhook_secret: impl AsRef<[u8]>, inner: S) -> Self { + let hmac = Hmac::::new_from_slice(webhook_secret.as_ref()) + .expect("Failed to parse webhook_secret"); + Self { inner, hmac } + } +} + +impl Service> for ValidateGitHubWebhook +where + S: Service, Response = Response> + Clone, + ReqBody: Body + Unpin, + ResBody: Body + Default, +{ + type Response = Response; + type Error = S::Error; + type Future = ValidateGitHubWebhookFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let inner = self.inner.clone(); + let hmac = self.hmac.clone(); + ValidateGitHubWebhookFuture::new(req, hmac, inner) + } +} diff --git a/src/test_helpers.rs b/src/test_helpers.rs new file mode 100644 index 0000000..9b989af --- /dev/null +++ b/src/test_helpers.rs @@ -0,0 +1,74 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use http_body::Frame; +use http_body_util::BodyExt; +use tower::BoxError; + +type BoxBody = http_body_util::combinators::UnsyncBoxBody; + +#[derive(Debug)] +pub(crate) struct Body(BoxBody); + +impl Body { + pub(crate) fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into, + { + Self(body.map_err(Into::into).boxed_unsync()) + } + + pub(crate) fn empty() -> Self { + Self::new(http_body_util::Empty::new()) + } +} + +impl Default for Body { + fn default() -> Self { + Self::empty() + } +} + +macro_rules! body_from_impl { + ($ty:ty) => { + impl From<$ty> for Body { + fn from(buf: $ty) -> Self { + Self::new(http_body_util::Full::from(buf)) + } + } + }; +} + +body_from_impl!(&'static [u8]); +body_from_impl!(std::borrow::Cow<'static, [u8]>); +body_from_impl!(Vec); + +body_from_impl!(&'static str); +body_from_impl!(std::borrow::Cow<'static, str>); +body_from_impl!(String); + +body_from_impl!(Bytes); + +impl http_body::Body for Body { + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.0).poll_frame(cx) + } + + fn size_hint(&self) -> http_body::SizeHint { + self.0.size_hint() + } + + fn is_end_stream(&self) -> bool { + self.0.is_end_stream() + } +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..b388786 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,54 @@ +use crate::test_helpers::Body; +use crate::ValidateGitHubWebhookLayer; +use hmac::{Hmac, Mac}; +use http::{Request, Response, StatusCode}; +use sha2::Sha256; +use tower::{service_fn, util::ServiceExt, BoxError, Layer}; + +async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) +} + +#[tokio::test] +async fn gives_unauthorized_error_when_no_header() { + let svc_fun = service_fn(echo); + let svc = ValidateGitHubWebhookLayer::new("123").layer(svc_fun); + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED) +} + +#[tokio::test] +async fn gives_unauthorized_error_when_wrong_signature() { + let svc_fun = service_fn(echo); + let svc = ValidateGitHubWebhookLayer::new("123").layer(svc_fun); + let res = svc + .oneshot( + Request::builder() + .header("x-hub-signature-256", "sha256=fake") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED) +} + +#[tokio::test] +async fn gives_ok_when_correct_signature() { + let svc_fun = service_fn(echo); + let svc = ValidateGitHubWebhookLayer::new("123").layer(svc_fun); + let hmac = + Hmac::::new_from_slice("123".as_bytes()).expect("Failed to parse webhook secret"); + let signature = format!("sha256={}", hex::encode(hmac.finalize().into_bytes())); + + let res = svc + .oneshot( + Request::builder() + .header("x-hub-signature-256", signature) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK) +}