From 96e57c2e99c70ebdb4ae8cc1845fd827cd26f130 Mon Sep 17 00:00:00 2001 From: James Newton Date: Wed, 4 Dec 2024 11:53:42 -0800 Subject: [PATCH] Add support for namespace to the remote connection builder --- libsql-server/src/http/user/db_factory.rs | 7 ++ libsql-server/tests/embedded_replica/mod.rs | 100 ++++++++++++++++++++ libsql/src/database.rs | 6 ++ libsql/src/database/builder.rs | 22 ++++- libsql/src/hrana/hyper.rs | 27 +++++- libsql/src/local/database.rs | 3 +- 6 files changed, 155 insertions(+), 10 deletions(-) diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 2a36024d5c..2a7c4c5752 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -50,6 +50,8 @@ pub fn namespace_from_headers( if let Some(from_metadata) = headers.get(NAMESPACE_METADATA_KEY) { try_namespace_from_metadata(from_metadata) + } else if let Some(from_ns_header) = headers.get("x-namespace") { + try_namespace_from_header(from_ns_header) } else if let Some(from_host) = headers.get("host") { try_namespace_from_host(from_host, disable_default_namespace) } else if !disable_default_namespace { @@ -59,6 +61,11 @@ pub fn namespace_from_headers( } } +fn try_namespace_from_header(header: &axum::http::HeaderValue) -> Result { + NamespaceName::from_bytes(header.as_bytes().to_vec().into()) + .map_err(|_| Error::InvalidNamespace) +} + fn try_namespace_from_host( from_host: &axum::http::HeaderValue, disable_default_namespace: bool, diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index e7b4b9f7f0..9c40ea4a42 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -1696,3 +1696,103 @@ fn schema_db() { sim.run().unwrap(); } + +#[test] +fn remote_namespace_header_support() { + let tmp_host = tempdir().unwrap(); + let tmp_host_path = tmp_host.path().to_owned(); + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + make_primary(&mut sim, tmp_host_path.clone()); + + sim.client("client", async move { + let client = Client::new(); + + client + .post("http://primary:9090/v1/namespaces/foo/create", json!({})) + .await?; + + let db_url = "http://primary:8080"; + + let remote = libsql::Builder::new_remote(db_url.to_string(), String::new()) + .namespace("foo") + .connector(TurmoilConnector) + .build() + .await + .unwrap(); + + let conn = remote.connect().unwrap(); + + conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) + .await?; + + conn.execute("INSERT into user(id) values (1);", ()).await?; + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn remote_replica_namespace_header_support() { + let tmp_host = tempdir().unwrap(); + let tmp_host_path = tmp_host.path().to_owned(); + + let tmp_embedded = tempdir().unwrap(); + let tmp_embedded_path = tmp_embedded.path().to_owned(); + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + make_primary(&mut sim, tmp_host_path.clone()); + + sim.client("client", async move { + let client = Client::new(); + + client + .post("http://primary:9090/v1/namespaces/foo/create", json!({})) + .await?; + + let db_url = "http://primary:8080"; + + let remote = libsql::Builder::new_remote(db_url.to_string(), String::new()) + .namespace("foo") + .connector(TurmoilConnector) + .build() + .await + .unwrap(); + + let conn = remote.connect().unwrap(); + + conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) + .await?; + + conn.execute("INSERT into user(id) values (1);", ()).await?; + + let path = tmp_embedded_path.join("embedded"); + + let remote_replica = libsql::Builder::new_remote_replica( + path.to_str().unwrap(), + db_url.to_string(), + String::new(), + ) + .namespace("foo") + .connector(TurmoilConnector) + .build() + .await + .unwrap(); + + let rep = remote_replica.sync().await.unwrap(); + assert_eq!(rep.frame_no(), Some(2)); + assert_eq!(rep.frames_synced(), 3); + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 32d325cbc3..1362ec217a 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -107,6 +107,7 @@ enum DbType { auth_token: String, connector: crate::util::ConnectorService, version: Option, + namespace: Option, }, } @@ -250,6 +251,7 @@ cfg_replication! { OpenFlags::default(), encryption_config.clone(), None, + None, ).await?; Ok(Database { @@ -541,6 +543,7 @@ cfg_remote! { auth_token: auth_token.into(), connector: crate::util::ConnectorService::new(svc), version, + namespace: None, }, max_write_replication_index: Default::default(), }) @@ -704,6 +707,7 @@ impl Database { auth_token.clone(), connector.clone(), None, + None, ), read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), @@ -724,6 +728,7 @@ impl Database { auth_token, connector, version, + namespace, } => { let conn = std::sync::Arc::new( crate::hrana::connection::HttpConnection::new_with_connector( @@ -731,6 +736,7 @@ impl Database { auth_token, connector.clone(), version.as_ref().map(|s| s.as_str()), + namespace.as_ref().map(|s| s.as_str()), ), ); diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 3150648fa8..c740e3b80d 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -59,12 +59,12 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, encryption_config: None, read_your_writes: true, sync_interval: None, http_request_callback: None, - namespace: None, skip_safety_assert: false, #[cfg(feature = "sync")] sync_protocol: Default::default(), @@ -102,6 +102,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, connector: None, read_your_writes: true, @@ -122,6 +123,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, } } @@ -135,6 +137,7 @@ cfg_replication_or_remote_or_sync! { auth_token: String, connector: Option, version: Option, + namespace: Option, } } @@ -223,7 +226,6 @@ cfg_replication! { read_your_writes: bool, sync_interval: Option, http_request_callback: Option, - namespace: Option, skip_safety_assert: bool, #[cfg(feature = "sync")] sync_protocol: super::SyncProtocol, @@ -300,7 +302,7 @@ cfg_replication! { /// Set the namespace that will be communicated to remote replica in the http header. pub fn namespace(mut self, namespace: impl Into) -> Builder { - self.inner.namespace = Some(namespace.into()); + self.inner.remote.namespace = Some(namespace.into()); self } @@ -334,12 +336,12 @@ cfg_replication! { auth_token, connector, version, + namespace, }, encryption_config, read_your_writes, sync_interval, http_request_callback, - namespace, skip_safety_assert, #[cfg(feature = "sync")] sync_protocol, @@ -500,6 +502,7 @@ cfg_replication! { auth_token, connector, version, + namespace, }) = remote { let connector = if let Some(connector) = connector { @@ -524,6 +527,7 @@ cfg_replication! { flags, encryption_config.clone(), http_request_callback, + namespace, ) .await? } else { @@ -606,6 +610,7 @@ cfg_sync! { auth_token, connector: _, version: _, + namespace: _, }, connector, remote_writes, @@ -730,6 +735,13 @@ cfg_remote! { self } + /// Set the namespace that will be communicated to the remote in the http header. + pub fn namespace(mut self, namespace: impl Into) -> Builder + { + self.inner.namespace = Some(namespace.into()); + self + } + /// Build the remote database client. pub async fn build(self) -> Result { let Remote { @@ -737,6 +749,7 @@ cfg_remote! { auth_token, connector, version, + namespace, } = self.inner; let connector = if let Some(connector) = connector { @@ -758,6 +771,7 @@ cfg_remote! { auth_token, connector, version, + namespace, }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index 675865ea24..b78838d11a 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -26,17 +26,27 @@ pub type ByteStream = Box> + Send + Syn pub struct HttpSender { inner: hyper::Client, version: HeaderValue, + namespace: Option, } impl HttpSender { - pub fn new(connector: ConnectorService, version: Option<&str>) -> Self { + pub fn new( + connector: ConnectorService, + version: Option<&str>, + namespace: Option<&str>, + ) -> Self { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap(); + let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap()); let inner = hyper::Client::builder().build(connector); - Self { inner, version } + Self { + inner, + version, + namespace, + } } async fn send( @@ -45,9 +55,15 @@ impl HttpSender { auth: Arc, body: String, ) -> Result> { - let req = hyper::Request::post(url.as_ref()) + let mut req_builder = hyper::Request::post(url.as_ref()) .header(AUTHORIZATION, auth.as_ref()) - .header("x-libsql-client-version", self.version.clone()) + .header("x-libsql-client-version", self.version.clone()); + + if let Some(namespace) = self.namespace { + req_builder = req_builder.header("x-namespace", namespace); + } + + let req = req_builder .body(hyper::Body::from(body)) .map_err(|err| HranaError::Http(format!("{:?}", err)))?; @@ -109,8 +125,9 @@ impl HttpConnection { token: impl Into, connector: ConnectorService, version: Option<&str>, + namespace: Option<&str>, ) -> Self { - let inner = HttpSender::new(connector, version); + let inner = HttpSender::new(connector, version, namespace); Self::new(url.into(), token.into(), inner) } } diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 7391870f7a..7cfcf330fe 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -266,6 +266,7 @@ impl Database { flags: OpenFlags, encryption_config: Option, http_request_callback: Option, + namespace: Option, ) -> Result { use std::path::PathBuf; @@ -284,7 +285,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, - None, + namespace, ) .map_err(|e| crate::Error::Replication(e.into()))?;