-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes the crate by passing through the body
- Loading branch information
Showing
9 changed files
with
171 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "tower-github-webhook" | ||
version = "0.1.2" | ||
version = "0.2.0" | ||
edition = "2021" | ||
authors = ["Sebastian Rollén <[email protected]>"] | ||
license = "MIT" | ||
|
@@ -17,17 +17,18 @@ hex = "0.4.3" | |
hmac = "0.12.1" | ||
http = "1.0.0" | ||
http-body = "1.0.0" | ||
pin-project = "1.1.3" | ||
http-body-util = "0.1.0" | ||
pin-project-lite = "0.2.14" | ||
sha2 = "0.10.8" | ||
tower = { version = "0.4.13", features = ["util"] } | ||
tower-layer = "0.3.2" | ||
tower-service = "0.3.2" | ||
tracing = "0.1.40" | ||
|
||
[dev-dependencies] | ||
axum = { version = "0.7.4", features = ["macros"] } | ||
http-body-util = "0.1.0" | ||
hyper = "1.1.0" | ||
octocrab = "0.33.3" | ||
serde = { version = "1.0.196", features = ["derive"] } | ||
tokio = { version = "1.35.1", features = ["full"] } | ||
tracing = "0.1.40" | ||
tracing-subscriber = "0.3.18" | ||
tower = { version = "0.4.13", features = ["util"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,179 +1,155 @@ | ||
use bytes::Buf; | ||
use bytes::{Buf, Bytes, BytesMut}; | ||
use hmac::{Hmac, Mac}; | ||
use http::{Request, Response, StatusCode}; | ||
use http::{request::Parts, Request, Response, StatusCode}; | ||
use http_body::Body; | ||
use pin_project::pin_project; | ||
use http_body_util::{Either, Empty, Full}; | ||
use pin_project_lite::pin_project; | ||
use sha2::Sha256; | ||
use std::future::Future; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
use tower::Service; | ||
use tower_service::Service; | ||
|
||
#[pin_project] | ||
pub struct ValidateGitHubWebhookFuture< | ||
S: Service<Request<ReqBody>, Response = Response<ResBody>>, | ||
ReqBody, | ||
ResBody, | ||
> { | ||
req: Option<Request<ReqBody>>, | ||
signature: Option<Vec<u8>>, | ||
inner: S, | ||
hmac: Option<Hmac<Sha256>>, | ||
#[pin] | ||
state: ValidateGitHubWebhookFutureState<ReqBody, ResBody, S>, | ||
type FutureResponse<ResBody, Error> = Result<Response<Either<ResBody, Empty<Bytes>>>, Error>; | ||
|
||
pin_project! { | ||
pub struct Future<S: Service<Request<Full<Bytes>>, Response = Response<ResBody>>, ReqBody, ResBody> { | ||
// We use Option<X> here and for `hmac` to make it easy to move these fields out of the future | ||
// later. | ||
parts: Option<Parts>, | ||
buffer: BytesMut, | ||
inner: S, | ||
hmac: Option<Hmac<Sha256>>, | ||
#[pin] | ||
body: ReqBody, | ||
#[pin] | ||
state: State<S::Future>, | ||
} | ||
} | ||
|
||
impl<S, ReqBody, ResBody> ValidateGitHubWebhookFuture<S, ReqBody, ResBody> | ||
impl<S, ReqBody, ResBody> Future<S, ReqBody, ResBody> | ||
where | ||
S: Service<Request<ReqBody>, Response = Response<ResBody>>, | ||
S: Service<Request<Full<Bytes>>, Response = Response<ResBody>>, | ||
ReqBody: Body, | ||
{ | ||
pub fn new(req: Request<ReqBody>, hmac: Hmac<Sha256>, inner: S) -> Self { | ||
let (parts, body) = req.into_parts(); | ||
let body_size = body.size_hint().lower().try_into().unwrap_or(0); | ||
let buffer = BytesMut::with_capacity(body_size); | ||
Self { | ||
req: Some(req), | ||
signature: None, | ||
parts: Some(parts), | ||
body, | ||
buffer, | ||
inner, | ||
hmac: Some(hmac), | ||
state: ValidateGitHubWebhookFutureState::ExtractSignature, | ||
state: State::new(), | ||
} | ||
} | ||
} | ||
|
||
impl<S, F, ReqBody, ResBody> Future for ValidateGitHubWebhookFuture<S, ReqBody, ResBody> | ||
pin_project! { | ||
#[project = StateProj] | ||
enum State<F> { | ||
ExtractSignature, | ||
ExtractBody { | ||
signature: Vec<u8>, | ||
}, | ||
Inner { | ||
#[pin] | ||
fut: F, | ||
}, | ||
} | ||
} | ||
|
||
impl<F> State<F> { | ||
pub fn new() -> Self { | ||
Self::ExtractSignature | ||
} | ||
} | ||
|
||
impl<S, F, ReqBody, ResBody> std::future::Future for Future<S, ReqBody, ResBody> | ||
where | ||
S: Service<Request<ReqBody>, Response = Response<ResBody>, Future = F>, | ||
F: Future<Output = Result<Response<ResBody>, S::Error>>, | ||
ReqBody: Body + Unpin, | ||
ResBody: Body + Default, | ||
S: Service<Request<Full<Bytes>>, Response = Response<ResBody>, Future = F>, | ||
F: std::future::Future<Output = Result<Response<ResBody>, S::Error>>, | ||
ReqBody: Body, | ||
{ | ||
type Output = Result<Response<ResBody>, S::Error>; | ||
type Output = FutureResponse<ResBody, S::Error>; | ||
|
||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
let this = self.as_mut().project(); | ||
let mut curr_state = this.state; | ||
match curr_state.as_mut().project() { | ||
ValidateGitHubProj::ExtractSignature => { | ||
tracing::trace!( | ||
"[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractSignature" | ||
); | ||
let req = this.req.take().unwrap(); | ||
let signature = match req.headers().get("x-hub-signature-256") { | ||
Some(sig) => { | ||
let Some(sig) = sig.as_bytes().splitn(2, |x| x == &b'=').nth(1) else { | ||
tracing::debug!("[tower-github-webhook] Invalid header format"); | ||
curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); | ||
cx.waker().wake_by_ref(); | ||
return Poll::Pending; | ||
}; | ||
match hex::decode(sig) { | ||
Ok(sig) => sig, | ||
Err(_) => { | ||
tracing::debug!("[tower-github-webhook] Invalid header format"); | ||
curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); | ||
cx.waker().wake_by_ref(); | ||
return Poll::Pending; | ||
} | ||
} | ||
} | ||
None => { | ||
tracing::debug!( | ||
"[tower-github-webhook] Missing X-HUB-SIGNATURE-256 header" | ||
); | ||
curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); | ||
cx.waker().wake_by_ref(); | ||
return Poll::Pending; | ||
} | ||
StateProj::ExtractSignature => { | ||
let parts = this | ||
.parts | ||
.take() | ||
.expect("Parts is either reset at the end of this method, or we bail"); | ||
let Some(signature) = parts.headers.get("x-hub-signature-256") else { | ||
return bail("Missing X-HUB-SIGNATURE-256 header"); | ||
}; | ||
let Some(signature) = signature.as_bytes().splitn(2, |x| x == &b'=').nth(1) else { | ||
return bail("Invalid header format"); | ||
}; | ||
curr_state.set(ValidateGitHubWebhookFutureState::ExtractBody); | ||
*this.signature = Some(signature); | ||
*this.req = Some(req); | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
let Ok(signature) = hex::decode(signature) else { | ||
return bail("Invalid header format"); | ||
}; | ||
*this.parts = Some(parts); | ||
curr_state.set(State::ExtractBody { signature }); | ||
rewake(cx) | ||
} | ||
ValidateGitHubProj::ExtractBody => { | ||
tracing::trace!( | ||
"[tower-github-webhook] ValidateGitHubWebhookFutureState::ExtractBody" | ||
); | ||
let mut req = this.req.take().unwrap(); | ||
let body = Pin::new(req.body_mut()); | ||
if body.is_end_stream() { | ||
curr_state.set(ValidateGitHubWebhookFutureState::ValidateSignature); | ||
StateProj::ExtractBody { signature } => { | ||
if this.body.is_end_stream() { | ||
// We're done updating the HMAC, so we can now move it out | ||
let hmac = this | ||
.hmac | ||
.take() | ||
.expect("HMAC is only moved out of the option once, here"); | ||
if hmac.verify_slice(signature).is_ok() { | ||
let parts = this.parts.take().unwrap(); | ||
let body = Full::new(this.buffer.split().freeze()); | ||
let req = Request::from_parts(parts, body); | ||
let fut = this.inner.call(req); | ||
curr_state.set(State::Inner { fut }); | ||
rewake(cx) | ||
} else { | ||
bail("Invalid signature") | ||
} | ||
} else { | ||
let frame = match Pin::new(req.body_mut()).poll_frame(cx) { | ||
Poll::Pending => { | ||
*this.req = Some(req); | ||
return Poll::Pending; | ||
} | ||
Poll::Ready(frame) => frame, | ||
let Poll::Ready(maybe_frame) = this.body.poll_frame(cx) else { | ||
return Poll::Pending; | ||
}; | ||
|
||
if let Some(Ok(frame)) = frame { | ||
if let Some(Ok(frame)) = maybe_frame { | ||
if let Ok(data) = frame.into_data() { | ||
let mut hmac = this.hmac.take().unwrap(); | ||
hmac.update(data.chunk()); | ||
*this.hmac = Some(hmac); | ||
let bytes = data.chunk(); | ||
this.buffer.extend(bytes); | ||
let Some(h) = this.hmac.as_mut() else { | ||
unreachable!() | ||
}; | ||
h.update(bytes); | ||
} | ||
} | ||
rewake(cx) | ||
} | ||
*this.req = Some(req); | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
} | ||
ValidateGitHubProj::ValidateSignature => { | ||
tracing::trace!( | ||
"[tower-github-webhook] ValidateGitHubWebhookFutureState::ValidateSignature" | ||
); | ||
let signature = this.signature.take().unwrap(); | ||
let hmac = this.hmac.take().unwrap(); | ||
if hmac.verify_slice(&signature).is_ok() { | ||
tracing::debug!("[tower-github-webhook] Valid signature"); | ||
curr_state.set(ValidateGitHubWebhookFutureState::InnerBefore); | ||
} else { | ||
tracing::debug!("[tower-github-webhook] Invalid signature"); | ||
curr_state.set(ValidateGitHubWebhookFutureState::Unauthorized); | ||
} | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
} | ||
ValidateGitHubProj::InnerBefore => { | ||
tracing::trace!( | ||
"[tower-github-webhook] ValidateGitHubWebhookFutureState::InnerBefore" | ||
); | ||
let req = this.req.take().unwrap(); | ||
let fut = this.inner.call(req); | ||
curr_state.set(ValidateGitHubWebhookFutureState::Inner { fut }); | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
} | ||
ValidateGitHubProj::Inner { fut } => { | ||
tracing::trace!("[tower-github-webhook] ValidateGitHubWebhookFutureState::Inner"); | ||
fut.poll(cx) | ||
} | ||
ValidateGitHubProj::Unauthorized => { | ||
tracing::trace!( | ||
"[tower-github-webhook] ValidateGitHubWebhookFutureState::Unauthorized" | ||
); | ||
tracing::warn!("[tower-github-webhook] Request not authorized"); | ||
let mut res = Response::new(ResBody::default()); | ||
*res.status_mut() = StatusCode::UNAUTHORIZED; | ||
Poll::Ready(Ok(res)) | ||
StateProj::Inner { fut } => { | ||
let Poll::Ready(response) = fut.poll(cx) else { | ||
return Poll::Pending; | ||
}; | ||
let response = response?; | ||
Poll::Ready(Ok(response.map(|b| Either::Left(b)))) | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[pin_project(project = ValidateGitHubProj)] | ||
pub(crate) enum ValidateGitHubWebhookFutureState< | ||
ReqBody, | ||
ResBody, | ||
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>, | ||
> { | ||
ExtractSignature, | ||
ExtractBody, | ||
ValidateSignature, | ||
InnerBefore, | ||
Inner { | ||
#[pin] | ||
fut: S::Future, | ||
}, | ||
Unauthorized, | ||
fn bail<ResBody, Error>(debug_message: &str) -> Poll<FutureResponse<ResBody, Error>> { | ||
tracing::debug!("[tower-github-webhook] {debug_message}"); | ||
tracing::warn!("[tower-github-webhook] Request not authorized"); | ||
let mut res = Response::new(Either::Right(Empty::new())); | ||
*res.status_mut() = StatusCode::UNAUTHORIZED; | ||
Poll::Ready(Ok(res)) | ||
} | ||
|
||
fn rewake<T>(cx: &mut Context<'_>) -> Poll<T> { | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.