Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

router: reduce unnecessary String allocations #1165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions dropshot/src/api_description.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,15 +670,15 @@ impl<Context: ServerContext> ApiDescription<Context> {
_ => panic!("reference not expected"),
};

let method_ref = match &method[..] {
"GET" => &mut pathitem.get,
"PUT" => &mut pathitem.put,
"POST" => &mut pathitem.post,
"DELETE" => &mut pathitem.delete,
"OPTIONS" => &mut pathitem.options,
"HEAD" => &mut pathitem.head,
"PATCH" => &mut pathitem.patch,
"TRACE" => &mut pathitem.trace,
let method_ref = match method {
http::Method::GET => &mut pathitem.get,
http::Method::PUT => &mut pathitem.put,
http::Method::POST => &mut pathitem.post,
http::Method::DELETE => &mut pathitem.delete,
http::Method::OPTIONS => &mut pathitem.options,
http::Method::HEAD => &mut pathitem.head,
http::Method::PATCH => &mut pathitem.patch,
http::Method::TRACE => &mut pathitem.trace,
other => panic!("unexpected method `{}`", other),
};
let mut operation = openapiv3::Operation::default();
Expand Down
97 changes: 55 additions & 42 deletions dropshot/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ use crate::ApiEndpointBodyContentType;
use http::Method;
use http::StatusCode;
use percent_encoding::percent_decode_str;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::sync::Arc;

/// `HttpRouter` is a simple data structure for routing incoming HTTP requests to
Expand Down Expand Up @@ -81,7 +83,7 @@ pub struct HttpRouter<Context: ServerContext> {
#[derive(Debug)]
struct HttpRouterNode<Context: ServerContext> {
/// Handlers, etc. for each of the HTTP methods defined for this node.
methods: BTreeMap<String, ApiEndpoint<Context>>,
methods: HashMap<Method, ApiEndpoint<Context>>,
/// Edges linking to child nodes.
edges: Option<HttpRouterEdges<Context>>,
}
Expand Down Expand Up @@ -217,7 +219,7 @@ pub struct RouterLookupResult<Context: ServerContext> {

impl<Context: ServerContext> HttpRouterNode<Context> {
pub fn new() -> Self {
HttpRouterNode { methods: BTreeMap::new(), edges: None }
HttpRouterNode { methods: HashMap::new(), edges: None }
}
}

Expand Down Expand Up @@ -385,16 +387,15 @@ impl<Context: ServerContext> HttpRouter<Context> {
};
}

let methodname = method.as_str().to_uppercase();
if node.methods.contains_key(&methodname) {
if node.methods.contains_key(&method) {
panic!(
"URI path \"{}\": attempted to create duplicate route for \
method \"{}\"",
path, method,
);
}

node.methods.insert(methodname, endpoint);
node.methods.insert(method, endpoint);
}

/// Look up the route handler for an HTTP request having method `method` and
Expand All @@ -408,36 +409,41 @@ impl<Context: ServerContext> HttpRouter<Context> {
method: &Method,
path: InputPath<'_>,
) -> Result<RouterLookupResult<Context>, HttpError> {
let all_segments = input_path_to_segments(&path).map_err(|_| {
HttpError::for_bad_request(
None,
String::from("invalid path encoding"),
)
})?;
let mut all_segments = all_segments.into_iter();
let mut all_segments = input_path_to_segments(&path);
let mut node = &self.root;
let mut variables = VariableSet::new();

while let Some(segment) = all_segments.next() {
let segment_string = segment.to_string();
while let Some(maybe_segment) = all_segments.next() {
let segment = maybe_segment.map_err(|e| {
HttpError::for_bad_request(
None,
format!("invalid path encoding: {e}"),
)
})?;

node = match &node.edges {
None => None,

Some(HttpRouterEdges::Literals(edges)) => {
edges.get(&segment_string)
edges.get(segment.as_ref())
}
Some(HttpRouterEdges::VariableSingle(varname, ref node)) => {
variables.insert(
varname.clone(),
VariableValue::String(segment_string),
VariableValue::String(segment.into_owned()),
);
Some(node)
}
Some(HttpRouterEdges::VariableRest(varname, node)) => {
let mut rest = vec![segment];
while let Some(segment) = all_segments.next() {
rest.push(segment);
let mut rest = vec![segment.into_owned()];
while let Some(maybe_segment) = all_segments.next() {
let segment = maybe_segment.map_err(|e| {
HttpError::for_bad_request(
None,
format!("invalid path encoding: {e}"),
)
})?;
rest.push(segment.into_owned());
}
variables.insert(
varname.clone(),
Expand Down Expand Up @@ -478,9 +484,8 @@ impl<Context: ServerContext> HttpRouter<Context> {
));
}

let methodname = method.as_str().to_uppercase();
node.methods
.get(&methodname)
.get(&method)
.map(|handler| RouterLookupResult {
handler: Arc::clone(&handler.handler),
operation_id: handler.operation_id.clone(),
Expand Down Expand Up @@ -512,7 +517,7 @@ fn insert_var(
}

impl<'a, Context: ServerContext> IntoIterator for &'a HttpRouter<Context> {
type Item = (String, String, &'a ApiEndpoint<Context>);
type Item = (String, Method, &'a ApiEndpoint<Context>);
type IntoIter = HttpRouterIter<'a, Context>;
fn into_iter(self) -> Self::IntoIter {
HttpRouterIter::new(self)
Expand All @@ -529,7 +534,7 @@ impl<'a, Context: ServerContext> IntoIterator for &'a HttpRouter<Context> {
/// blank string and an iterator over the root node's children.
pub struct HttpRouterIter<'a, Context: ServerContext> {
method:
Box<dyn Iterator<Item = (&'a String, &'a ApiEndpoint<Context>)> + 'a>,
Box<dyn Iterator<Item = (&'a Method, &'a ApiEndpoint<Context>)> + 'a>,
path: Vec<(PathSegment, Box<PathIter<'a, Context>>)>,
}
type PathIter<'a, Context> =
Expand Down Expand Up @@ -592,7 +597,7 @@ impl<'a, Context: ServerContext> HttpRouterIter<'a, Context> {
}

impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> {
type Item = (String, String, &'a ApiEndpoint<Context>);
type Item = (String, Method, &'a ApiEndpoint<Context>);

fn next(&mut self) -> Option<Self::Item> {
// If there are no path components left then we've reached the end of
Expand Down Expand Up @@ -630,6 +635,14 @@ impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> {
}
}

#[derive(Debug, thiserror::Error)]
enum InputPathError {
#[error(transparent)]
PercentDecode(#[from] std::str::Utf8Error),
#[error("dot-segments are not permitted")]
DotSegment,
}

/// Helper function for taking a Uri path and producing a `Vec<String>` of
/// URL-decoded strings, each representing one segment of the path. The input is
/// percent-encoded. Empty segments i.e. due to consecutive "/" characters or a
Expand All @@ -653,7 +666,9 @@ impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> {
/// that consumers may be susceptible to other information leaks, for example
/// if a client were able to follow a symlink to the root of the filesystem. As
/// always, it is incumbent on the consumer and *critical* to validate input.
fn input_path_to_segments(path: &InputPath) -> Result<Vec<String>, String> {
fn input_path_to_segments<'path>(
path: &'path InputPath,
) -> impl Iterator<Item = Result<Cow<'path, str>, InputPathError>> + 'path {
// We're given the "path" portion of a URI and we want to construct an
// array of the segments of the path. Relevant references:
//
Expand Down Expand Up @@ -682,17 +697,12 @@ fn input_path_to_segments(path: &InputPath) -> Result<Vec<String>, String> {
// should be ignored). The net result is that that crate doesn't buy us
// much here, but it does create more work, so we'll just split it
// ourselves.
path.0
.split('/')
.filter(|segment| !segment.is_empty())
.map(|segment| match segment {
"." | ".." => Err("dot-segments are not permitted".to_string()),
_ => Ok(percent_decode_str(segment)
.decode_utf8()
.map_err(|e| e.to_string())?
.to_string()),
})
.collect()
path.0.split('/').filter(|segment| !segment.is_empty()).map(|segment| {
match segment {
"." | ".." => Err(InputPathError::DotSegment),
_ => Ok(percent_decode_str(segment).decode_utf8()?),
}
})
}

/// Whereas in `input_path_to_segments()` we must accommodate any user input, when
Expand Down Expand Up @@ -729,6 +739,7 @@ mod test {
use super::super::handler::RouteHandler;
use super::input_path_to_segments;
use super::HttpRouter;
use super::InputPathError;
use super::PathSegment;
use crate::api_description::ApiEndpointBodyContentType;
use crate::from_map::from_map;
Expand Down Expand Up @@ -1309,10 +1320,10 @@ mod test {
assert_eq!(
ret,
vec![
("/".to_string(), "GET".to_string(),),
("/".to_string(), http::Method::GET,),
(
"/projects/{project_id}/instances".to_string(),
"GET".to_string(),
http::Method::GET,
),
]
);
Expand All @@ -1335,16 +1346,18 @@ mod test {
assert_eq!(
ret,
vec![
("/".to_string(), "GET".to_string(),),
("/".to_string(), "POST".to_string(),),
("/".to_string(), http::Method::GET,),
("/".to_string(), http::Method::POST),
]
);
}

#[test]
fn test_segments() {
let segs =
input_path_to_segments(&"//foo/bar/baz%2fbuzz".into()).unwrap();
let segs = input_path_to_segments(&"//foo/bar/baz%2fbuzz".into())
.map(|seg| Ok::<String, InputPathError>(seg?.into_owned()))
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(segs, vec!["foo", "bar", "baz/buzz"]);
}

Expand Down
2 changes: 1 addition & 1 deletion dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl<C: ServerContext> HttpServerStarter<C> {

for (path, method, _) in &app_state.router {
debug!(&log, "registered endpoint";
"method" => &method,
"method" => %method,
"path" => &path
);
}
Expand Down
Loading