diff --git a/core/http/src/tls/certificate_resolver.rs b/core/http/src/tls/certificate_resolver.rs new file mode 100644 index 0000000000..02fa0bd473 --- /dev/null +++ b/core/http/src/tls/certificate_resolver.rs @@ -0,0 +1,27 @@ +use std::{io, sync::Arc}; + +use rustls::{server::ClientHello, sign::{any_supported_type, CertifiedKey}}; + +use crate::tls::Config; +use crate::tls::util::{load_certs, load_private_key}; + +pub(crate) struct CertResolver(Arc); +impl CertResolver { + pub fn new(config: &mut Config) -> Result + where R: io::BufRead, + { + let certs = load_certs(&mut config.cert_chain)?; + let private_key = load_private_key(&mut config.private_key)?; + let key = any_supported_type(&private_key) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + + Ok( + Self(Arc::new(CertifiedKey::new(certs, key)))) + } +} + +impl rustls::server::ResolvesServerCert for CertResolver { + fn resolve(&self, _client_hello: ClientHello<'_>) -> Option> { + Some(self.0.clone()) + } +} \ No newline at end of file diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs index 8c675d6110..92fa92af8e 100644 --- a/core/http/src/tls/listener.rs +++ b/core/http/src/tls/listener.rs @@ -5,12 +5,14 @@ use std::task::{Context, Poll}; use std::future::Future; use std::net::SocketAddr; +use rustls::server::ResolvesServerCert; use tokio::net::{TcpListener, TcpStream}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream}; -use crate::tls::util::{load_certs, load_private_key, load_ca_certs}; +use crate::tls::util::load_ca_certs; use crate::listener::{Connection, Listener, Certificates}; +use crate::tls::CertResolver; /// A TLS listener over TCP. pub struct TlsListener { @@ -72,18 +74,12 @@ pub struct Config { } impl TlsListener { - pub async fn bind(addr: SocketAddr, mut c: Config) -> io::Result - where R: io::BufRead + pub async fn bind(addr: SocketAddr, mut c: Config, cert_resolver: Option<&Arc>) -> io::Result + where R: io::BufRead, { use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient}; use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig}; - let cert_chain = load_certs(&mut c.cert_chain) - .map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?; - - let key = load_private_key(&mut c.private_key) - .map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?; - let client_auth = match c.ca_certs { Some(ref mut ca_certs) => match load_ca_certs(ca_certs) { Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(), @@ -93,14 +89,18 @@ impl TlsListener { None => NoClientAuth::boxed(), }; + let cert_resolver = match cert_resolver { + Some(c) => c.clone(), + None => Arc::new(CertResolver::new(&mut c)?), + }; + let mut tls_config = ServerConfig::builder() .with_cipher_suites(&c.ciphersuites) .with_safe_default_kx_groups() .with_safe_default_protocol_versions() .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))? .with_client_cert_verifier(client_auth) - .with_single_cert(cert_chain, key) - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + .with_cert_resolver(cert_resolver); tls_config.ignore_client_order = c.prefer_server_order; diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs index 04959ba23e..5495d47b80 100644 --- a/core/http/src/tls/mod.rs +++ b/core/http/src/tls/mod.rs @@ -3,6 +3,9 @@ mod listener; #[cfg(feature = "mtls")] pub mod mtls; +pub(crate) mod certificate_resolver; + pub use rustls; pub use listener::{TlsListener, Config}; +pub(crate) use certificate_resolver::*; pub mod util; diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 791154c15a..cecd8b774d 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -86,4 +86,5 @@ version_check = "0.9.1" [dev-dependencies] figment = { version = "0.10", features = ["test"] } +reqwest = { version = "0.11", features = ["blocking"] } pretty_assertions = "1" diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 800c9dbe63..1b67d8e78b 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use std::time::Duration; use std::pin::Pin; -use yansi::Paint; use tokio::sync::oneshot; +use yansi::Paint; use tokio::time::sleep; use futures::stream::StreamExt; use futures::future::{FutureExt, Future, BoxFuture}; @@ -421,9 +421,12 @@ impl Rocket { if self.config.tls_enabled() { if let Some(ref config) = self.config.tls { use crate::http::tls::TlsListener; + use crate::http::tls::rustls::server::ResolvesServerCert; let conf = config.to_native_config().map_err(ErrorKind::Io)?; - let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?; + let resolver = self.state::>(); + + let l = TlsListener::bind(addr, conf, resolver).await.map_err(ErrorKind::Bind)?; addr = l.local_addr().unwrap_or(addr); self.config.address = addr.ip(); self.config.port = addr.port(); diff --git a/core/lib/tests/tls-config-from-source-1503.rs b/core/lib/tests/tls-config-from-source-1503.rs index 92085f6861..9023a51863 100644 --- a/core/lib/tests/tls-config-from-source-1503.rs +++ b/core/lib/tests/tls-config-from-source-1503.rs @@ -11,8 +11,8 @@ fn tls_config_from_source() { use rocket::config::{Config, TlsConfig}; use rocket::figment::Figment; - let cert_path = relative!("examples/tls/private/cert.pem"); - let key_path = relative!("examples/tls/private/key.pem"); + let cert_path = relative!("../../examples/tls/private/cert.pem"); + let key_path = relative!("../../examples/tls/private/key.pem"); let rocket_config = Config { tls: Some(TlsConfig::from_paths(cert_path, key_path)), @@ -24,3 +24,71 @@ fn tls_config_from_source() { assert_eq!(tls.certs().unwrap_left(), cert_path); assert_eq!(tls.key().unwrap_left(), key_path); } + +#[test] +fn tls_server_operation() { + use std::io::Read; + + use rocket::{get, routes}; + use rocket::config::{Config, TlsConfig}; + use rocket::figment::Figment; + + let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem"); + let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem"); + let ca_cert_path = relative!("../../examples/tls/private/ca_cert.pem"); + + println!("{cert_path:?}"); + + let port = { + let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).expect("creating listener"); + listener.local_addr().expect("getting listener's port").port() + }; + + let rocket_config = Config { + port, + tls: Some(TlsConfig::from_paths(cert_path, key_path)), + ..Default::default() + }; + let config: Config = Figment::from(rocket_config).extract().expect("creating config"); + let (shutdown_signal_sender, mut shutdown_signal_receiver) = tokio::sync::mpsc::channel::<()>(1); + + // Create a runtime in a separate thread for the server being tested + let join_handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + #[get("/hello")] + fn tls_test_get() -> &'static str { + "world" + } + + rt.block_on(async { + let task_handle = tokio::spawn( async { + rocket::custom(config) + .mount("/", routes![tls_test_get]) + .launch().await.unwrap(); + }); + shutdown_signal_receiver.recv().await; + task_handle.abort(); + }); + }); + + let request_url = format!("https://localhost:{}/hello", port); + + // CA certificate is not loaded, so request should fail + assert!(reqwest::blocking::get(&request_url).is_err()); + + // Load the CA certicate for use with test client + let cert = { + let mut buf = Vec::new(); + std::fs::File::open(ca_cert_path).expect("open ca_certs") + .read_to_end(&mut buf).expect("read ca_certs"); + reqwest::Certificate::from_pem(&buf).expect("create certificate") + }; + let client = reqwest::blocking::Client::builder().add_root_certificate(cert).build().expect("build client"); + + let response = client.get(&request_url).send().expect("https request"); + assert_eq!(&response.text().unwrap(), "world"); + + shutdown_signal_sender.blocking_send(()).expect("signal shutdown"); + join_handle.join().expect("join thread"); +} \ No newline at end of file