Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot Reloading TLS Certificates #2683

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 183 additions & 11 deletions core/http/src/tls/listener.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use std::io;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use std::future::Future;
use std::net::SocketAddr;

use rustls::sign::CertifiedKey;
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};

use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};


use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates, CertificateDer};
use crate::listener::{Connection, Listener, Certificates};

/// A TLS listener over TCP.
pub struct TlsListener {
Expand Down Expand Up @@ -62,19 +66,189 @@ pub enum TlsState {
Streaming(BareTlsStream<TcpStream>),
}

#[derive(Clone)]
pub enum FileOrBytes {
File(PathBuf),
Bytes(Vec<u8>),
}

/// TLS as ~configured by `TlsConfig` in `rocket` core.
pub struct Config<R> {
pub cert_chain: R,
pub private_key: R,
//pub cert_chain: R,
//pub private_key: R,
pub cert_chain: FileOrBytes,
pub private_key: FileOrBytes,
pub ciphersuites: Vec<rustls::SupportedCipherSuite>,
pub prefer_server_order: bool,
pub ca_certs: Option<R>,
pub mandatory_mtls: bool,
pub tls_updater: Option<std::sync::Arc<std::sync::RwLock<DynamicConfig>>>,
}

#[derive(Clone, Debug, Default, PartialEq)]
pub struct DynamicConfig {
pub certs: Vec<u8>,
pub key: Vec<u8>,
}

type Reader = Box<dyn std::io::BufRead + Sync + Send>;

fn to_reader(value: &FileOrBytes) -> io::Result<Reader> {
match value {
FileOrBytes::File(path) => {
let file = std::fs::File::open(&path).map_err(move |e| {
let msg = format!("error reading TLS file `{}`", e);
std::io::Error::new(e.kind(), msg)
})?;

Ok(Box::new(io::BufReader::new(file)))
}
FileOrBytes::Bytes(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))),
}
}

#[derive(Debug)]
pub struct CertResolver {
pub certified_key: Arc<std::sync::RwLock<Option<Arc<CertifiedKey>>>>,
_handle: tokio::task::JoinHandle<()>,
}

impl CertResolver {
pub async fn new<R>(config: &Config<R>) -> crate::tls::Result<Arc<Self>>
where R: io::BufRead
{

let certified_key = Arc::new(std::sync::RwLock::new(None));

let private_key = config.private_key.to_owned();
let cert_chain = config.cert_chain.to_owned();

let loop_certified_key = certified_key.clone();
let loop_updater = config.tls_updater.as_ref().map(|i| i.clone());

let handle = tokio::spawn(async move {

let mut dynamic_certs = None;
let mut first_load = true;
let mut do_load = false;

let mut last_loaded = std::time::SystemTime::now();

loop {

let mut reload_pair = None;

// Have to be careful for file system errors here, if the user swapping file as part of the hot-swap
// process we could find files missing, or incorrect file pairs as files are copied in
match (&cert_chain, &private_key) {
(FileOrBytes::File(cert_chain_path), FileOrBytes::File(private_key_path)) => {

let last_modified_certs_chain =
if let Ok(metadata) = std::fs::metadata(cert_chain_path) {
metadata.modified().ok()
} else {
None
};

let last_modified_private_key =
if let Ok(metadata) = std::fs::metadata(private_key_path) {
metadata.modified().ok()
} else {
None
};

if first_load || (last_modified_certs_chain.is_some() && last_modified_private_key.is_some()) {
// Duration since will return Err(...) if the time given is in the future, i.e. if either
// of the file time are more recent than the last loaded time then this will triffer the
// conditional
if first_load || (last_loaded.duration_since(last_modified_certs_chain.unwrap()).is_err()
|| last_loaded.duration_since(last_modified_private_key.unwrap()).is_err()) {

// Attempt to open and load cert_chain and private_key from files
let loaded_cert_chain = if let Ok(file) = std::fs::File::open(&cert_chain_path) {
load_cert_chain(&mut io::BufReader::new(file)).ok()
} else {
None
};

let loaded_private_key = if let Ok(file) = std::fs::File::open(&private_key_path) {
load_key(&mut io::BufReader::new(file)).ok()
} else {
None
};

if loaded_cert_chain.is_some() && loaded_private_key.is_some() {
reload_pair = Some((loaded_cert_chain.unwrap(), loaded_private_key.unwrap()));
}
}
}

}
_ => (), // File/Vec, Vec/File and Vec/Vec options will not reload
}

if let Some(loop_updater) = loop_updater.as_ref() {
if let Ok(certs) = loop_updater.read() {
if dynamic_certs.is_none() || *certs != *dynamic_certs.as_ref().unwrap() {
dynamic_certs = Some(certs.clone());

let loaded_cert_chain = load_cert_chain(&mut io::Cursor::new(certs.certs.clone()));
let loaded_private_key = load_key(&mut io::Cursor::new(certs.key.clone()));

if loaded_cert_chain.is_ok() && loaded_private_key.is_ok() {
reload_pair = Some((
loaded_cert_chain.unwrap(),
loaded_private_key.unwrap(),
));
do_load = true; // Load immediately
}
}
}
}

if reload_pair.is_some() {
if (first_load || do_load) {
let reload_pair = reload_pair.unwrap();
if let Ok(mut certified_key) = loop_certified_key.write() {
dbg!("loading a new key");
*certified_key = Some(Arc::new(CertifiedKey::new(
reload_pair.0,
rustls::crypto::ring::sign::any_supported_type(&reload_pair.1).unwrap()
)));
last_loaded = std::time::SystemTime::now();
}
do_load = false;
first_load = false;
} else {
do_load = true; // Do the load this time round
dbg!("Defer load");
}
}

tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
});


Ok(Arc::new(Self {
certified_key: certified_key,
_handle: handle,

}))
}
}

impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(&self, _client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let cert = self.certified_key.read().unwrap();
if cert.is_none() { return None; }
Some(cert.as_ref().unwrap().clone())
}
}

impl TlsListener {
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
where R: io::BufRead
pub async fn bind<R, T>(addr: SocketAddr, mut c: Config<R>, cert_resolver: Arc<T>) -> crate::tls::Result<TlsListener>
where R: io::BufRead, T: rustls::server::ResolvesServerCert + 'static
{
let provider = rustls::crypto::CryptoProvider {
cipher_suites: c.ciphersuites,
Expand All @@ -93,12 +267,10 @@ impl TlsListener {
None => WebPkiClientVerifier::no_client_auth(),
};

let key = load_key(&mut c.private_key)?;
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;
.with_cert_resolver(cert_resolver);

config.ignore_client_order = c.prefer_server_order;
config.session_storage = ServerSessionMemoryCache::new(1024);
Expand Down Expand Up @@ -176,7 +348,7 @@ impl TlsStream {
Ok(stream) => {
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
self.certs.set(peer_certs.into_iter()
.map(|v| CertificateDer(v.clone().into_owned()))
.map(|v| crate::listener::CertificateDer(v.clone().into_owned()))
.collect());
}

Expand Down
2 changes: 1 addition & 1 deletion core/http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod listener;
pub mod mtls;

pub use rustls;
pub use listener::{TlsListener, Config};
pub use listener::{TlsListener, Config, CertResolver, FileOrBytes, DynamicConfig};
pub mod util;
pub mod error;

Expand Down
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ version_check = "0.9.1"

[dev-dependencies]
figment = { version = "0.10", features = ["test"] }
reqwest = { version = "0.11", features = ["blocking"] }
pretty_assertions = "1"
11 changes: 11 additions & 0 deletions core/lib/src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,17 @@ impl Config {
#[cfg(not(feature = "mtls"))] { false }
}

pub fn with_tls_loader(&mut self, loader: &std::sync::Arc<std::sync::RwLock<crate::http::tls::DynamicConfig>>) {
// Cannot add tls loader if tls is not set
if self.tls.is_none() {
return;
}

let mut config = self.tls.take().unwrap();
config.with_tls_loader(loader);
self.tls = Some(config);
}

#[cfg(feature = "secrets")]
pub(crate) fn known_secret_key_used(&self) -> bool {
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
Expand Down
41 changes: 39 additions & 2 deletions core/lib/src/config/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,32 @@ pub struct TlsConfig {
/// Whether to prefer the server's cipher suite order over the client's.
#[serde(default)]
pub(crate) prefer_server_cipher_order: bool,
#[serde(skip)]
pub(crate) tls_updater: Option<TlsUpdater>,
/// Configuration for mutual TLS, if any.
#[serde(default)]
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub(crate) mutual: Option<MutualTls>,
}

#[derive(Default, Debug)]
pub (crate) struct TlsUpdater {
pub(crate) tls_updater: std::sync::Arc<std::sync::RwLock<crate::http::tls::DynamicConfig>>,
}

impl PartialEq for TlsUpdater {
fn eq(&self, other: &Self) -> bool {
true
}
}

impl Clone for TlsUpdater {
fn clone(&self) -> Self {
Self { tls_updater: self.tls_updater.clone() }
}
}

/// Mutual TLS configuration.
///
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which
Expand Down Expand Up @@ -254,6 +273,7 @@ impl TlsConfig {
key: Either::Right(vec![]),
ciphers: CipherSuite::default_set(),
prefer_server_cipher_order: false,
tls_updater: None,
#[cfg(feature = "mtls")]
mutual: None,
}
Expand Down Expand Up @@ -302,6 +322,11 @@ impl TlsConfig {
}
}


pub fn with_tls_loader(&mut self, loader: &std::sync::Arc<std::sync::RwLock<crate::http::tls::DynamicConfig>>) {
self.tls_updater = Some(TlsUpdater{tls_updater: loader.clone()});
}

/// Sets the cipher suites supported by the server and their order of
/// preference from most to least preferred.
///
Expand Down Expand Up @@ -659,10 +684,19 @@ mod with_tls_feature {
/// This is only called when TLS is enabled.
pub(crate) fn to_native_config(&self) -> io::Result<Config<Reader>> {
Ok(Config {
cert_chain: to_reader(&self.certs)?,
private_key: to_reader(&self.key)?,
//cert_chain: to_reader(&self.certs)?,
//private_key: to_reader(&self.key)?,
cert_chain: match &self.certs {
Either::Left(file) => crate::http::tls::FileOrBytes::File(file.relative()),
Either::Right(bytes) => crate::http::tls::FileOrBytes::Bytes(bytes.clone()),
},
private_key: match &self.key {
Either::Left(file) => crate::http::tls::FileOrBytes::File(file.relative()),
Either::Right(bytes) => crate::http::tls::FileOrBytes::Bytes(bytes.clone()),
},
ciphersuites: self.rustls_ciphers().collect(),
prefer_server_order: self.prefer_server_cipher_order,
tls_updater: self.tls_updater.as_ref().map(|i| i.tls_updater.clone()),
#[cfg(not(feature = "mtls"))]
mandatory_mtls: false,
#[cfg(not(feature = "mtls"))]
Expand Down Expand Up @@ -700,4 +734,7 @@ mod with_tls_feature {
})
}
}



}
5 changes: 5 additions & 0 deletions core/lib/src/rocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,11 @@ impl Rocket<Build> {
let mut config = Config::try_from(&self.figment).map_err(ErrorKind::Config)?;
crate::log::init(&config);

let mut dynamic_tls_config = self.state.try_get::<std::sync::Arc<std::sync::RwLock<crate::http::tls::DynamicConfig>>>();
if let Some(dynamic_tls_config) = dynamic_tls_config.take() {
config.with_tls_loader(dynamic_tls_config);
}

// Check for safely configured secrets.
#[cfg(feature = "secrets")]
if !config.secret_key.is_provided() {
Expand Down
4 changes: 3 additions & 1 deletion core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ impl Rocket<Orbit> {
use crate::http::tls::TlsListener;

let conf = config.to_native_config().map_err(ErrorKind::Io)?;
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?;
let cert_resolver = crate::http::tls::CertResolver::new(&conf).await.unwrap();

let l = TlsListener::bind(addr, conf, cert_resolver).await.map_err(ErrorKind::TlsBind)?;
addr = l.local_addr().unwrap_or(addr);
self.config.address = addr.ip();
self.config.port = addr.port();
Expand Down
Loading