From 4b2e95332ffd4dd01a6e3c37e99cbb47a59d5a8a Mon Sep 17 00:00:00 2001 From: Ty Overby Date: Fri, 19 Jun 2015 18:26:18 -0700 Subject: [PATCH] Cap calls to with_capacity with sane default For instances where the length of a collection must be obeyed, but is untrusted, calls to with_capacity() could result in OOM errors. This change makes it so that collections don't pre-allocate more than 1MB of memory. --- src/collection_impls.rs | 6 +-- src/lib.rs | 14 +++++ src/serialize.rs | 115 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 130 insertions(+), 5 deletions(-) diff --git a/src/collection_impls.rs b/src/collection_impls.rs index b68e068..f8074c9 100644 --- a/src/collection_impls.rs +++ b/src/collection_impls.rs @@ -12,7 +12,7 @@ use std::hash::Hash; -use {Decodable, Encodable, Decoder, Encoder}; +use {Decodable, Encodable, Decoder, Encoder, cap_capacity}; use std::collections::{LinkedList, VecDeque, BTreeMap, BTreeSet, HashMap, HashSet}; impl< @@ -149,7 +149,7 @@ impl Decodable for HashMap { fn decode(d: &mut D) -> Result, D::Error> { d.read_map(|d, len| { - let mut map = HashMap::with_capacity(len); + let mut map = HashMap::with_capacity(cap_capacity::<(K, V)>(len)); for i in 0..len { let key = try!(d.read_map_elt_key(i, |d| Decodable::decode(d))); let val = try!(d.read_map_elt_val(i, |d| Decodable::decode(d))); @@ -176,7 +176,7 @@ impl Encodable for HashSet where T: Encodable + Hash + Eq { impl Decodable for HashSet where T: Decodable + Hash + Eq, { fn decode(d: &mut D) -> Result, D::Error> { d.read_seq(|d, len| { - let mut set = HashSet::with_capacity(len); + let mut set = HashSet::with_capacity(cap_capacity::(len)); for i in 0..len { set.insert(try!(d.read_seq_elt(i, |d| Decodable::decode(d)))); } diff --git a/src/lib.rs b/src/lib.rs index 4e90529..8a4acf2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,20 @@ pub use self::serialize::{Decoder, Encoder, Decodable, Encodable, DecoderHelpers, EncoderHelpers}; + +// Limit collections from allocating more than +// 1 MB for calls to `with_capacity`. +fn cap_capacity(given_len: usize) -> usize { + use std::cmp::min; + use std::mem::size_of; + const PRE_ALLOCATE_CAP: usize = 0x100000; + + match size_of::() { + 0 => min(given_len, PRE_ALLOCATE_CAP), + n => min(given_len, PRE_ALLOCATE_CAP / n) + } +} + mod serialize; mod collection_impls; diff --git a/src/serialize.rs b/src/serialize.rs index fd664fa..02052f6 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -22,6 +22,8 @@ use std::sync::Arc; use std::marker::PhantomData; use std::borrow::Cow; +use cap_capacity; + pub trait Encoder { type Error; @@ -479,7 +481,7 @@ impl Encodable for Vec { impl Decodable for Vec { fn decode(d: &mut D) -> Result, D::Error> { d.read_seq(|d, len| { - let mut v = Vec::with_capacity(len); + let mut v = Vec::with_capacity(cap_capacity::(len)); for i in 0..len { v.push(try!(d.read_seq_elt(i, |d| Decodable::decode(d)))); } @@ -722,7 +724,7 @@ impl DecoderHelpers for D { FnMut(&mut D) -> Result, { self.read_seq(|this, len| { - let mut v = Vec::with_capacity(len); + let mut v = Vec::with_capacity(cap_capacity::(len)); for i in 0..len { v.push(try!(this.read_seq_elt(i, |this| f(this)))); } @@ -730,3 +732,112 @@ impl DecoderHelpers for D { }) } } + +#[test] +#[allow(unused_variables)] +fn capacity_rules() { + use std::usize::MAX; + use std::collections::{HashMap, HashSet}; + + struct MyDecoder; + impl Decoder for MyDecoder { + type Error = (); + + // Primitive types: + fn read_nil(&mut self) -> Result<(), Self::Error> { Err(()) } + fn read_usize(&mut self) -> Result { Err(()) } + fn read_u64(&mut self) -> Result { Err(()) } + fn read_u32(&mut self) -> Result { Err(()) } + fn read_u16(&mut self) -> Result { Err(()) } + fn read_u8(&mut self) -> Result { Err(()) } + fn read_isize(&mut self) -> Result { Err(()) } + fn read_i64(&mut self) -> Result { Err(()) } + fn read_i32(&mut self) -> Result { Err(()) } + fn read_i16(&mut self) -> Result { Err(()) } + fn read_i8(&mut self) -> Result { Err(()) } + fn read_bool(&mut self) -> Result { Err(()) } + fn read_f64(&mut self) -> Result { Err(()) } + fn read_f32(&mut self) -> Result { Err(()) } + fn read_char(&mut self) -> Result { Err(()) } + fn read_str(&mut self) -> Result { Err(()) } + + // Compound types: + fn read_enum(&mut self, name: &str, f: F) -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_enum_variant(&mut self, names: &[&str], f: F) + -> Result + where F: FnMut(&mut Self, usize) -> Result { Err(()) } + fn read_enum_variant_arg(&mut self, a_idx: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_enum_struct_variant(&mut self, names: &[&str], f: F) + -> Result + where F: FnMut(&mut Self, usize) -> Result { Err(()) } + fn read_enum_struct_variant_field(&mut self, + f_name: &str, + f_idx: usize, + f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_struct(&mut self, s_name: &str, len: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + fn read_struct_field(&mut self, + f_name: &str, + f_idx: usize, + f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_tuple(&mut self, len: usize, f: F) -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + fn read_tuple_arg(&mut self, a_idx: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_tuple_struct(&mut self, s_name: &str, len: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + fn read_tuple_struct_arg(&mut self, a_idx: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + // Specialized types: + fn read_option(&mut self, f: F) -> Result + where F: FnMut(&mut Self, bool) -> Result { Err(()) } + + fn read_seq(&mut self, f: F) -> Result + where F: FnOnce(&mut Self, usize) -> Result { + f(self, MAX) + } + fn read_seq_elt(&mut self, idx: usize, f: F) -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + fn read_map(&mut self, f: F) -> Result + where F: FnOnce(&mut Self, usize) -> Result { + f(self, MAX) + } + fn read_map_elt_key(&mut self, idx: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + fn read_map_elt_val(&mut self, idx: usize, f: F) + -> Result + where F: FnOnce(&mut Self) -> Result { Err(()) } + + // Failure + fn error(&mut self, err: &str) -> Self::Error { () } + } + + let mut dummy = MyDecoder; + let vec_result: Result, ()> = Decodable::decode(&mut dummy); + assert!(vec_result.is_err()); + + let map_result: Result, ()> = Decodable::decode(&mut dummy); + assert!(map_result.is_err()); + + let set_result: Result, ()> = Decodable::decode(&mut dummy); + assert!(set_result.is_err()); +}