diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..f357307 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["PostgreSQL"] diff --git a/Cargo.toml b/Cargo.toml index 986c44d..8b83023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tokio-postgres-rustls" description = "Rustls integration for tokio-postgres" -version = "0.7.0" +version = "0.13.0" authors = ["Jasper Hugo "] repository = "https://github.com/jbg/tokio-postgres-rustls" edition = "2018" @@ -9,14 +9,23 @@ license = "MIT" readme = "README.md" [dependencies] -futures = "0.3" -ring = "0.16" -rustls = "0.19" -tokio = "1" -tokio-postgres = "0.7" -tokio-rustls = "0.22" -webpki = "0.21" +const-oid = { version = "0.9.6", default-features = false, features = ["db"] } +ring = { version = "0.17", default-features = false } +rustls = { version = "0.23", default-features = false } +tokio = { version = "1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false } +tokio-rustls = { version = "0.26", default-features = false } +x509-cert = { version = "0.2.5", default-features = false, features = ["std"] } [dev-dependencies] -env_logger = { version = "0.8", default-features = false } -tokio = { version = "1", features = ["macros", "rt"] } +env_logger = { version = "0.11", default-features = false } +tokio = { version = "1", default-features = false, features = ["macros", "rt"] } +tokio-postgres = { version = "0.7", default-features = false, features = [ + "runtime", +] } +rustls = { version = "0.23", default-features = false, features = [ + "std", + "logging", + "tls12", + "ring", +] } diff --git a/README.md b/README.md index c08a8bc..e10efb5 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ and the [tokio-postgres asynchronous PostgreSQL client library](https://github.c # Example ``` -let config = rustls::ClientConfig::new(); +let config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config); let connect_fut = tokio_postgres::connect("sslmode=require host=localhost user=postgres", tls); // ... diff --git a/src/lib.rs b/src/lib.rs index 6b39861..eccd1ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,25 +1,175 @@ -use std::{ - future::Future, - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use futures::future::{FutureExt, TryFutureExt}; -use ring::digest; -use rustls::{ClientConfig, Session}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect}; -use tokio_rustls::{client::TlsStream, TlsConnector}; -use webpki::{DNSName, DNSNameRef}; +#![doc = include_str!("../README.md")] +#![forbid(rust_2018_idioms)] +#![deny(missing_docs, unsafe_code)] +#![warn(clippy::all, clippy::pedantic)] +use std::{convert::TryFrom, sync::Arc}; + +use rustls::{pki_types::ServerName, ClientConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, + }; + + use const_oid::db::{ + rfc5912::{ + ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512, + SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, + SHA_512_WITH_RSA_ENCRYPTION, + }, + rfc8410::ID_ED_25519, + }; + use ring::digest; + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::{client::TlsStream, TlsConnector}; + use x509_cert::{der::Decode, TbsCertificate}; + + pub struct TlsConnectFuture { + pub inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: If `self` is pinned, so is `inner`. + #[allow(unsafe_code)] + let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) }; + fut.poll(cx).map_ok(RustlsStream) + } + } + + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, + } + + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; + + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } + + pub struct RustlsStream(TlsStream); + + impl RustlsStream { + pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream> { + // SAFETY: When `Self` is pinned, so is the inner `TlsStream`. + #[allow(unsafe_code)] + unsafe { + self.map_unchecked_mut(|this| &mut this.0) + } + } + } + + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0]) + .ok() + .and_then(|cert| { + let digest = match cert.signature.oid { + // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1 + ID_SHA_1 + | ID_SHA_256 + | SHA_1_WITH_RSA_ENCRYPTION + | SHA_256_WITH_RSA_ENCRYPTION + | ECDSA_WITH_SHA_256 => &digest::SHA256, + ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => { + &digest::SHA384 + } + ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => { + &digest::SHA512 + } + _ => return None, + }; + + Some(digest) + }) + .map_or_else(ChannelBinding::none, |algorithm| { + let hash = digest::digest(algorithm, certs[0].as_ref()); + ChannelBinding::tls_server_end_point(hash.as_ref().into()) + }), + _ => ChannelBinding::none(), + } + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project_stream().poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project_stream().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project_stream().poll_shutdown(cx) + } + } +} + +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. #[derive(Clone)] pub struct MakeRustlsConnect { config: Arc, } impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] pub fn new(config: ClientConfig) -> Self { Self { config: Arc::new(config), @@ -31,110 +181,92 @@ impl MakeTlsConnect for MakeRustlsConnect where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Stream = RustlsStream; - type TlsConnect = RustlsConnect; - type Error = io::Error; + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> io::Result { - DNSNameRef::try_from_ascii_str(hostname) - .map(|dns_name| RustlsConnect { + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { hostname: dns_name.to_owned(), connector: Arc::clone(&self.config).into(), }) - .map_err(|_| io::ErrorKind::InvalidInput.into()) - } -} - -pub struct RustlsConnect { - hostname: DNSName, - connector: TlsConnector, -} - -impl TlsConnect for RustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type Error = io::Error; - type Future = Pin>> + Send>>; - - fn connect(self, stream: S) -> Self::Future { - self.connector - .connect(self.hostname.as_ref(), stream) - .map_ok(|s| RustlsStream(Box::pin(s))) - .boxed() + }) } } -pub struct RustlsStream(Pin>>); +#[cfg(test)] +mod tests { + use super::*; + use rustls::pki_types::{CertificateDer, UnixTime}; + use rustls::{ + client::danger::ServerCertVerifier, + client::danger::{HandshakeSignatureValid, ServerCertVerified}, + Error, SignatureScheme, + }; -impl tokio_postgres::tls::TlsStream for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> ChannelBinding { - let (_, session) = self.0.get_ref(); - match session.get_peer_certificates() { - Some(certs) if certs.len() > 0 => { - let sha256 = digest::digest(&digest::SHA256, certs[0].as_ref()); - ChannelBinding::tls_server_end_point(sha256.as_ref().into()) - } - _ => ChannelBinding::none(), + #[derive(Debug)] + struct AcceptAllVerifier {} + impl ServerCertVerifier for AcceptAllVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) } - } -} -impl AsyncRead for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.0.as_mut().poll_read(cx, buf) - } - -} - -impl AsyncWrite for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - self.0.as_mut().poll_write(cx, buf) - } + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_flush(cx) - } + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_shutdown(cx) + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::ED25519, + ] + } } -} - -#[cfg(test)] -mod tests { - use futures::future::TryFutureExt; - #[tokio::test] async fn it_works() { env_logger::builder().is_test(true).try_init().unwrap(); - let config = rustls::ClientConfig::new(); + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + config + .dangerous() + .set_certificate_verifier(Arc::new(AcceptAllVerifier {})); let tls = super::MakeRustlsConnect::new(config); - let (client, conn) = - tokio_postgres::connect("sslmode=require host=localhost port=5432 user=postgres", tls) - .await - .expect("connect"); - tokio::spawn(conn.map_err(|e| panic!("{:?}", e))); + let (client, conn) = tokio_postgres::connect( + "sslmode=require host=localhost port=5432 user=postgres", + tls, + ) + .await + .expect("connect"); + tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) }); let stmt = client.prepare("SELECT 1").await.expect("prepare"); let _ = client.query(&stmt, &[]).await.expect("query"); }