Skip to content

Commit 96e57c2

Browse files
Add support for namespace to the remote connection builder
1 parent 5d57c82 commit 96e57c2

File tree

6 files changed

+155
-10
lines changed

6 files changed

+155
-10
lines changed

libsql-server/src/http/user/db_factory.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ pub fn namespace_from_headers(
5050

5151
if let Some(from_metadata) = headers.get(NAMESPACE_METADATA_KEY) {
5252
try_namespace_from_metadata(from_metadata)
53+
} else if let Some(from_ns_header) = headers.get("x-namespace") {
54+
try_namespace_from_header(from_ns_header)
5355
} else if let Some(from_host) = headers.get("host") {
5456
try_namespace_from_host(from_host, disable_default_namespace)
5557
} else if !disable_default_namespace {
@@ -59,6 +61,11 @@ pub fn namespace_from_headers(
5961
}
6062
}
6163

64+
fn try_namespace_from_header(header: &axum::http::HeaderValue) -> Result<NamespaceName, Error> {
65+
NamespaceName::from_bytes(header.as_bytes().to_vec().into())
66+
.map_err(|_| Error::InvalidNamespace)
67+
}
68+
6269
fn try_namespace_from_host(
6370
from_host: &axum::http::HeaderValue,
6471
disable_default_namespace: bool,

libsql-server/tests/embedded_replica/mod.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,3 +1696,103 @@ fn schema_db() {
16961696

16971697
sim.run().unwrap();
16981698
}
1699+
1700+
#[test]
1701+
fn remote_namespace_header_support() {
1702+
let tmp_host = tempdir().unwrap();
1703+
let tmp_host_path = tmp_host.path().to_owned();
1704+
1705+
let mut sim = Builder::new()
1706+
.simulation_duration(Duration::from_secs(1000))
1707+
.build();
1708+
1709+
make_primary(&mut sim, tmp_host_path.clone());
1710+
1711+
sim.client("client", async move {
1712+
let client = Client::new();
1713+
1714+
client
1715+
.post("http://primary:9090/v1/namespaces/foo/create", json!({}))
1716+
.await?;
1717+
1718+
let db_url = "http://primary:8080";
1719+
1720+
let remote = libsql::Builder::new_remote(db_url.to_string(), String::new())
1721+
.namespace("foo")
1722+
.connector(TurmoilConnector)
1723+
.build()
1724+
.await
1725+
.unwrap();
1726+
1727+
let conn = remote.connect().unwrap();
1728+
1729+
conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ())
1730+
.await?;
1731+
1732+
conn.execute("INSERT into user(id) values (1);", ()).await?;
1733+
1734+
Ok(())
1735+
});
1736+
1737+
sim.run().unwrap();
1738+
}
1739+
1740+
#[test]
1741+
fn remote_replica_namespace_header_support() {
1742+
let tmp_host = tempdir().unwrap();
1743+
let tmp_host_path = tmp_host.path().to_owned();
1744+
1745+
let tmp_embedded = tempdir().unwrap();
1746+
let tmp_embedded_path = tmp_embedded.path().to_owned();
1747+
1748+
let mut sim = Builder::new()
1749+
.simulation_duration(Duration::from_secs(1000))
1750+
.build();
1751+
1752+
make_primary(&mut sim, tmp_host_path.clone());
1753+
1754+
sim.client("client", async move {
1755+
let client = Client::new();
1756+
1757+
client
1758+
.post("http://primary:9090/v1/namespaces/foo/create", json!({}))
1759+
.await?;
1760+
1761+
let db_url = "http://primary:8080";
1762+
1763+
let remote = libsql::Builder::new_remote(db_url.to_string(), String::new())
1764+
.namespace("foo")
1765+
.connector(TurmoilConnector)
1766+
.build()
1767+
.await
1768+
.unwrap();
1769+
1770+
let conn = remote.connect().unwrap();
1771+
1772+
conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ())
1773+
.await?;
1774+
1775+
conn.execute("INSERT into user(id) values (1);", ()).await?;
1776+
1777+
let path = tmp_embedded_path.join("embedded");
1778+
1779+
let remote_replica = libsql::Builder::new_remote_replica(
1780+
path.to_str().unwrap(),
1781+
db_url.to_string(),
1782+
String::new(),
1783+
)
1784+
.namespace("foo")
1785+
.connector(TurmoilConnector)
1786+
.build()
1787+
.await
1788+
.unwrap();
1789+
1790+
let rep = remote_replica.sync().await.unwrap();
1791+
assert_eq!(rep.frame_no(), Some(2));
1792+
assert_eq!(rep.frames_synced(), 3);
1793+
1794+
Ok(())
1795+
});
1796+
1797+
sim.run().unwrap();
1798+
}

libsql/src/database.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum DbType {
107107
auth_token: String,
108108
connector: crate::util::ConnectorService,
109109
version: Option<String>,
110+
namespace: Option<String>,
110111
},
111112
}
112113

@@ -250,6 +251,7 @@ cfg_replication! {
250251
OpenFlags::default(),
251252
encryption_config.clone(),
252253
None,
254+
None,
253255
).await?;
254256

255257
Ok(Database {
@@ -541,6 +543,7 @@ cfg_remote! {
541543
auth_token: auth_token.into(),
542544
connector: crate::util::ConnectorService::new(svc),
543545
version,
546+
namespace: None,
544547
},
545548
max_write_replication_index: Default::default(),
546549
})
@@ -704,6 +707,7 @@ impl Database {
704707
auth_token.clone(),
705708
connector.clone(),
706709
None,
710+
None,
707711
),
708712
read_your_writes: *read_your_writes,
709713
context: db.sync_ctx.clone().unwrap(),
@@ -724,13 +728,15 @@ impl Database {
724728
auth_token,
725729
connector,
726730
version,
731+
namespace,
727732
} => {
728733
let conn = std::sync::Arc::new(
729734
crate::hrana::connection::HttpConnection::new_with_connector(
730735
url,
731736
auth_token,
732737
connector.clone(),
733738
version.as_ref().map(|s| s.as_str()),
739+
namespace.as_ref().map(|s| s.as_str()),
734740
),
735741
);
736742

libsql/src/database/builder.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ impl Builder<()> {
5959
auth_token,
6060
connector: None,
6161
version: None,
62+
namespace: None,
6263
},
6364
encryption_config: None,
6465
read_your_writes: true,
6566
sync_interval: None,
6667
http_request_callback: None,
67-
namespace: None,
6868
skip_safety_assert: false,
6969
#[cfg(feature = "sync")]
7070
sync_protocol: Default::default(),
@@ -102,6 +102,7 @@ impl Builder<()> {
102102
auth_token,
103103
connector: None,
104104
version: None,
105+
namespace: None,
105106
},
106107
connector: None,
107108
read_your_writes: true,
@@ -122,6 +123,7 @@ impl Builder<()> {
122123
auth_token,
123124
connector: None,
124125
version: None,
126+
namespace: None,
125127
},
126128
}
127129
}
@@ -135,6 +137,7 @@ cfg_replication_or_remote_or_sync! {
135137
auth_token: String,
136138
connector: Option<crate::util::ConnectorService>,
137139
version: Option<String>,
140+
namespace: Option<String>,
138141
}
139142
}
140143

@@ -223,7 +226,6 @@ cfg_replication! {
223226
read_your_writes: bool,
224227
sync_interval: Option<std::time::Duration>,
225228
http_request_callback: Option<crate::util::HttpRequestCallback>,
226-
namespace: Option<String>,
227229
skip_safety_assert: bool,
228230
#[cfg(feature = "sync")]
229231
sync_protocol: super::SyncProtocol,
@@ -300,7 +302,7 @@ cfg_replication! {
300302
/// Set the namespace that will be communicated to remote replica in the http header.
301303
pub fn namespace(mut self, namespace: impl Into<String>) -> Builder<RemoteReplica>
302304
{
303-
self.inner.namespace = Some(namespace.into());
305+
self.inner.remote.namespace = Some(namespace.into());
304306
self
305307
}
306308

@@ -334,12 +336,12 @@ cfg_replication! {
334336
auth_token,
335337
connector,
336338
version,
339+
namespace,
337340
},
338341
encryption_config,
339342
read_your_writes,
340343
sync_interval,
341344
http_request_callback,
342-
namespace,
343345
skip_safety_assert,
344346
#[cfg(feature = "sync")]
345347
sync_protocol,
@@ -500,6 +502,7 @@ cfg_replication! {
500502
auth_token,
501503
connector,
502504
version,
505+
namespace,
503506
}) = remote
504507
{
505508
let connector = if let Some(connector) = connector {
@@ -524,6 +527,7 @@ cfg_replication! {
524527
flags,
525528
encryption_config.clone(),
526529
http_request_callback,
530+
namespace,
527531
)
528532
.await?
529533
} else {
@@ -606,6 +610,7 @@ cfg_sync! {
606610
auth_token,
607611
connector: _,
608612
version: _,
613+
namespace: _,
609614
},
610615
connector,
611616
remote_writes,
@@ -730,13 +735,21 @@ cfg_remote! {
730735
self
731736
}
732737

738+
/// Set the namespace that will be communicated to the remote in the http header.
739+
pub fn namespace(mut self, namespace: impl Into<String>) -> Builder<Remote>
740+
{
741+
self.inner.namespace = Some(namespace.into());
742+
self
743+
}
744+
733745
/// Build the remote database client.
734746
pub async fn build(self) -> Result<Database> {
735747
let Remote {
736748
url,
737749
auth_token,
738750
connector,
739751
version,
752+
namespace,
740753
} = self.inner;
741754

742755
let connector = if let Some(connector) = connector {
@@ -758,6 +771,7 @@ cfg_remote! {
758771
auth_token,
759772
connector,
760773
version,
774+
namespace,
761775
},
762776
max_write_replication_index: Default::default(),
763777
})

libsql/src/hrana/hyper.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,27 @@ pub type ByteStream = Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Syn
2626
pub struct HttpSender {
2727
inner: hyper::Client<ConnectorService, hyper::Body>,
2828
version: HeaderValue,
29+
namespace: Option<HeaderValue>,
2930
}
3031

3132
impl HttpSender {
32-
pub fn new(connector: ConnectorService, version: Option<&str>) -> Self {
33+
pub fn new(
34+
connector: ConnectorService,
35+
version: Option<&str>,
36+
namespace: Option<&str>,
37+
) -> Self {
3338
let ver = version.unwrap_or(env!("CARGO_PKG_VERSION"));
3439

3540
let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap();
41+
let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap());
3642

3743
let inner = hyper::Client::builder().build(connector);
3844

39-
Self { inner, version }
45+
Self {
46+
inner,
47+
version,
48+
namespace,
49+
}
4050
}
4151

4252
async fn send(
@@ -45,9 +55,15 @@ impl HttpSender {
4555
auth: Arc<str>,
4656
body: String,
4757
) -> Result<super::HttpBody<ByteStream>> {
48-
let req = hyper::Request::post(url.as_ref())
58+
let mut req_builder = hyper::Request::post(url.as_ref())
4959
.header(AUTHORIZATION, auth.as_ref())
50-
.header("x-libsql-client-version", self.version.clone())
60+
.header("x-libsql-client-version", self.version.clone());
61+
62+
if let Some(namespace) = self.namespace {
63+
req_builder = req_builder.header("x-namespace", namespace);
64+
}
65+
66+
let req = req_builder
5167
.body(hyper::Body::from(body))
5268
.map_err(|err| HranaError::Http(format!("{:?}", err)))?;
5369

@@ -109,8 +125,9 @@ impl HttpConnection<HttpSender> {
109125
token: impl Into<String>,
110126
connector: ConnectorService,
111127
version: Option<&str>,
128+
namespace: Option<&str>,
112129
) -> Self {
113-
let inner = HttpSender::new(connector, version);
130+
let inner = HttpSender::new(connector, version, namespace);
114131
Self::new(url.into(), token.into(), inner)
115132
}
116133
}

libsql/src/local/database.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ impl Database {
266266
flags: OpenFlags,
267267
encryption_config: Option<EncryptionConfig>,
268268
http_request_callback: Option<crate::util::HttpRequestCallback>,
269+
namespace: Option<String>,
269270
) -> Result<Database> {
270271
use std::path::PathBuf;
271272

@@ -284,7 +285,7 @@ impl Database {
284285
auth_token,
285286
version.as_deref(),
286287
http_request_callback,
287-
None,
288+
namespace,
288289
)
289290
.map_err(|e| crate::Error::Replication(e.into()))?;
290291

0 commit comments

Comments
 (0)