diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..f357307 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["PostgreSQL"] diff --git a/Cargo.toml b/Cargo.toml index 93cc952..8b83023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tokio-postgres-rustls" description = "Rustls integration for tokio-postgres" -version = "0.12.0" +version = "0.13.0" authors = ["Jasper Hugo "] repository = "https://github.com/jbg/tokio-postgres-rustls" edition = "2018" @@ -9,15 +9,23 @@ license = "MIT" readme = "README.md" [dependencies] +const-oid = { version = "0.9.6", default-features = false, features = ["db"] } ring = { version = "0.17", default-features = false } rustls = { version = "0.23", default-features = false } tokio = { version = "1", default-features = false } tokio-postgres = { version = "0.7", default-features = false } tokio-rustls = { version = "0.26", default-features = false } -x509-certificate = {version = "0.23", default-features = false } +x509-cert = { version = "0.2.5", default-features = false, features = ["std"] } [dev-dependencies] env_logger = { version = "0.11", default-features = false } tokio = { version = "1", default-features = false, features = ["macros", "rt"] } -tokio-postgres = { version = "0.7", default-features = false, features = ["runtime"] } -rustls = { version = "0.23", default-features = false, features = ["std", "logging", "tls12", "ring"] } +tokio-postgres = { version = "0.7", default-features = false, features = [ + "runtime", +] } +rustls = { version = "0.23", default-features = false, features = [ + "std", + "logging", + "tls12", + "ring", +] } diff --git a/src/lib.rs b/src/lib.rs index 5fd02ce..eccd1ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,30 +1,175 @@ -use std::{ - convert::TryFrom, - future::Future, - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use DigestAlgorithm::{Sha1, Sha256, Sha384, Sha512}; - -use ring::digest; -use rustls::pki_types::ServerName; -use rustls::ClientConfig; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect}; -use tokio_rustls::{client::TlsStream, TlsConnector}; -use x509_certificate::{DigestAlgorithm, SignatureAlgorithm, X509Certificate}; -use SignatureAlgorithm::{ - EcdsaSha256, EcdsaSha384, Ed25519, NoSignature, RsaSha1, RsaSha256, RsaSha384, RsaSha512, -}; +#![doc = include_str!("../README.md")] +#![forbid(rust_2018_idioms)] +#![deny(missing_docs, unsafe_code)] +#![warn(clippy::all, clippy::pedantic)] +use std::{convert::TryFrom, sync::Arc}; + +use rustls::{pki_types::ServerName, ClientConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, + }; + + use const_oid::db::{ + rfc5912::{ + ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512, + SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, + SHA_512_WITH_RSA_ENCRYPTION, + }, + rfc8410::ID_ED_25519, + }; + use ring::digest; + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::{client::TlsStream, TlsConnector}; + use x509_cert::{der::Decode, TbsCertificate}; + + pub struct TlsConnectFuture { + pub inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: If `self` is pinned, so is `inner`. + #[allow(unsafe_code)] + let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) }; + fut.poll(cx).map_ok(RustlsStream) + } + } + + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, + } + + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; + + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } + + pub struct RustlsStream(TlsStream); + + impl RustlsStream { + pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream> { + // SAFETY: When `Self` is pinned, so is the inner `TlsStream`. + #[allow(unsafe_code)] + unsafe { + self.map_unchecked_mut(|this| &mut this.0) + } + } + } + + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0]) + .ok() + .and_then(|cert| { + let digest = match cert.signature.oid { + // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1 + ID_SHA_1 + | ID_SHA_256 + | SHA_1_WITH_RSA_ENCRYPTION + | SHA_256_WITH_RSA_ENCRYPTION + | ECDSA_WITH_SHA_256 => &digest::SHA256, + ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => { + &digest::SHA384 + } + ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => { + &digest::SHA512 + } + _ => return None, + }; + + Some(digest) + }) + .map_or_else(ChannelBinding::none, |algorithm| { + let hash = digest::digest(algorithm, certs[0].as_ref()); + ChannelBinding::tls_server_end_point(hash.as_ref().into()) + }), + _ => ChannelBinding::none(), + } + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project_stream().poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project_stream().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project_stream().poll_shutdown(cx) + } + } +} + +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. #[derive(Clone)] pub struct MakeRustlsConnect { config: Arc, } impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] pub fn new(config: ClientConfig) -> Self { Self { config: Arc::new(config), @@ -36,13 +181,13 @@ impl MakeTlsConnect for MakeRustlsConnect where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Stream = RustlsStream; - type TlsConnect = RustlsConnect; + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; type Error = rustls::pki_types::InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> Result { + fn make_tls_connect(&mut self, hostname: &str) -> Result { ServerName::try_from(hostname).map(|dns_name| { - RustlsConnect(RustlsConnectData { + private::RustlsConnect(private::RustlsConnectData { hostname: dns_name.to_owned(), connector: Arc::clone(&self.config).into(), }) @@ -50,100 +195,6 @@ where } } -pub struct RustlsConnect(RustlsConnectData); - -struct RustlsConnectData { - hostname: ServerName<'static>, - connector: TlsConnector, -} - -impl TlsConnect for RustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type Error = io::Error; - type Future = Pin>> + Send>>; - - fn connect(self, stream: S) -> Self::Future { - Box::pin(async move { - self.0 - .connector - .connect(self.0.hostname, stream) - .await - .map(|s| RustlsStream(Box::pin(s))) - }) - } -} - -pub struct RustlsStream(Pin>>); - -impl tokio_postgres::tls::TlsStream for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> ChannelBinding { - let (_, session) = self.0.get_ref(); - match session.peer_certificates() { - Some(certs) if !certs.is_empty() => X509Certificate::from_der(&certs[0]) - .ok() - .and_then(|cert| cert.signature_algorithm()) - .map(|algorithm| match algorithm { - // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1 - RsaSha1 | RsaSha256 | EcdsaSha256 => &digest::SHA256, - RsaSha384 | EcdsaSha384 => &digest::SHA384, - RsaSha512 => &digest::SHA512, - Ed25519 => &digest::SHA512, - NoSignature(algo) => match algo { - Sha1 | Sha256 => &digest::SHA256, - Sha384 => &digest::SHA384, - Sha512 => &digest::SHA512, - }, - }) - .map(|algorithm| { - let hash = digest::digest(algorithm, certs[0].as_ref()); - ChannelBinding::tls_server_end_point(hash.as_ref().into()) - }) - .unwrap_or(ChannelBinding::none()), - _ => ChannelBinding::none(), - } - } -} - -impl AsyncRead for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.0.as_mut().poll_read(cx, buf) - } -} - -impl AsyncWrite for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - self.0.as_mut().poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_shutdown(cx) - } -} - #[cfg(test)] mod tests { use super::*;