diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..f357307 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["PostgreSQL"] diff --git a/Cargo.toml b/Cargo.toml index 520ee65..8b83023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tokio-postgres-rustls" description = "Rustls integration for tokio-postgres" -version = "0.9.0" +version = "0.13.0" authors = ["Jasper Hugo "] repository = "https://github.com/jbg/tokio-postgres-rustls" edition = "2018" @@ -9,15 +9,23 @@ license = "MIT" readme = "README.md" [dependencies] -futures = "0.3" -ring = "0.16" -rustls = "0.20" -tokio = "1" -tokio-postgres = "0.7" -tokio-rustls = "0.23" +const-oid = { version = "0.9.6", default-features = false, features = ["db"] } +ring = { version = "0.17", default-features = false } +rustls = { version = "0.23", default-features = false } +tokio = { version = "1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false } +tokio-rustls = { version = "0.26", default-features = false } +x509-cert = { version = "0.2.5", default-features = false, features = ["std"] } [dev-dependencies] -env_logger = { version = "0.8", default-features = false } -tokio = { version = "1", features = ["macros", "rt"] } -rustls = { version = "0.20", features = ["dangerous_configuration"] } - +env_logger = { version = "0.11", default-features = false } +tokio = { version = "1", default-features = false, features = ["macros", "rt"] } +tokio-postgres = { version = "0.7", default-features = false, features = [ + "runtime", +] } +rustls = { version = "0.23", default-features = false, features = [ + "std", + "logging", + "tls12", + "ring", +] } diff --git a/README.md b/README.md index b9b5caf..e10efb5 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,6 @@ and the [tokio-postgres asynchronous PostgreSQL client library](https://github.c ``` let config = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(rustls::RootCertStore::empty()) .with_no_client_auth(); let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config); diff --git a/src/lib.rs b/src/lib.rs index a3a4fc2..eccd1ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,151 +1,252 @@ -use std::{ - convert::TryFrom, - future::Future, - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use futures::future::{FutureExt, TryFutureExt}; -use ring::digest; -use rustls::{ClientConfig, ServerName}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect}; -use tokio_rustls::{client::TlsStream, TlsConnector}; +#![doc = include_str!("../README.md")] +#![forbid(rust_2018_idioms)] +#![deny(missing_docs, unsafe_code)] +#![warn(clippy::all, clippy::pedantic)] -#[derive(Clone)] -pub struct MakeRustlsConnect { - config: Arc, -} +use std::{convert::TryFrom, sync::Arc}; -impl MakeRustlsConnect { - pub fn new(config: ClientConfig) -> Self { - Self { - config: Arc::new(config), +use rustls::{pki_types::ServerName, ClientConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, + }; + + use const_oid::db::{ + rfc5912::{ + ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512, + SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, + SHA_512_WITH_RSA_ENCRYPTION, + }, + rfc8410::ID_ED_25519, + }; + use ring::digest; + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::{client::TlsStream, TlsConnector}; + use x509_cert::{der::Decode, TbsCertificate}; + + pub struct TlsConnectFuture { + pub inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: If `self` is pinned, so is `inner`. + #[allow(unsafe_code)] + let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) }; + fut.poll(cx).map_ok(RustlsStream) } } -} -impl MakeTlsConnect for MakeRustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type TlsConnect = RustlsConnect; - type Error = io::Error; - - fn make_tls_connect(&mut self, hostname: &str) -> io::Result { - ServerName::try_from(hostname) - .map(|dns_name| { - RustlsConnect(Some(RustlsConnectData { - hostname: dns_name, - connector: Arc::clone(&self.config).into(), - })) - }) - .or(Ok(RustlsConnect(None))) + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, } -} -pub struct RustlsConnect(Option); + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; -struct RustlsConnectData { - hostname: ServerName, - connector: TlsConnector, -} + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } -impl TlsConnect for RustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type Error = io::Error; - type Future = Pin>> + Send>>; - - fn connect(self, stream: S) -> Self::Future { - match self.0 { - None => Box::pin(core::future::ready(Err(io::ErrorKind::InvalidInput.into()))), - Some(c) => c - .connector - .connect(c.hostname, stream) - .map_ok(|s| RustlsStream(Box::pin(s))) - .boxed(), + pub struct RustlsStream(TlsStream); + + impl RustlsStream { + pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream> { + // SAFETY: When `Self` is pinned, so is the inner `TlsStream`. + #[allow(unsafe_code)] + unsafe { + self.map_unchecked_mut(|this| &mut this.0) + } } } -} -pub struct RustlsStream(Pin>>); + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0]) + .ok() + .and_then(|cert| { + let digest = match cert.signature.oid { + // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1 + ID_SHA_1 + | ID_SHA_256 + | SHA_1_WITH_RSA_ENCRYPTION + | SHA_256_WITH_RSA_ENCRYPTION + | ECDSA_WITH_SHA_256 => &digest::SHA256, + ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => { + &digest::SHA384 + } + ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => { + &digest::SHA512 + } + _ => return None, + }; -impl tokio_postgres::tls::TlsStream for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> ChannelBinding { - let (_, session) = self.0.get_ref(); - match session.peer_certificates() { - Some(certs) if !certs.is_empty() => { - let sha256 = digest::digest(&digest::SHA256, certs[0].as_ref()); - ChannelBinding::tls_server_end_point(sha256.as_ref().into()) + Some(digest) + }) + .map_or_else(ChannelBinding::none, |algorithm| { + let hash = digest::digest(algorithm, certs[0].as_ref()); + ChannelBinding::tls_server_end_point(hash.as_ref().into()) + }), + _ => ChannelBinding::none(), } - _ => ChannelBinding::none(), + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project_stream().poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project_stream().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project_stream().poll_shutdown(cx) } } } -impl AsyncRead for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.0.as_mut().poll_read(cx, buf) +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} + +impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { + config: Arc::new(config), + } } } -impl AsyncWrite for RustlsStream +impl MakeTlsConnect for MakeRustlsConnect where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - self.0.as_mut().poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_flush(cx) - } + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_shutdown(cx) + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + }) + }) } } #[cfg(test)] mod tests { use super::*; - use futures::future::TryFutureExt; - use rustls::{client::ServerCertVerified, client::ServerCertVerifier, Certificate, Error}; - use std::time::SystemTime; + use rustls::pki_types::{CertificateDer, UnixTime}; + use rustls::{ + client::danger::ServerCertVerifier, + client::danger::{HandshakeSignatureValid, ServerCertVerified}, + Error, SignatureScheme, + }; + #[derive(Debug)] struct AcceptAllVerifier {} impl ServerCertVerifier for AcceptAllVerifier { fn verify_server_cert( &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, _ocsp_response: &[u8], - _now: SystemTime, + _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::ED25519, + ] + } } #[tokio::test] @@ -153,7 +254,6 @@ mod tests { env_logger::builder().is_test(true).try_init().unwrap(); let mut config = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(rustls::RootCertStore::empty()) .with_no_client_auth(); config @@ -166,7 +266,7 @@ mod tests { ) .await .expect("connect"); - tokio::spawn(conn.map_err(|e| panic!("{:?}", e))); + tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) }); let stmt = client.prepare("SELECT 1").await.expect("prepare"); let _ = client.query(&stmt, &[]).await.expect("query"); }