diff --git a/libsql/src/database.rs b/libsql/src/database.rs index da0cc7dfac..57212d4fb7 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -99,6 +99,7 @@ enum DbType { url: String, auth_token: String, connector: crate::util::ConnectorService, + remote_encryption: Option, }, #[cfg(feature = "remote")] Remote { @@ -212,7 +213,7 @@ cfg_replication! { endpoint, auth_token, https, - encryption_config + encryption_config, ).await } @@ -521,7 +522,7 @@ cfg_remote! { url: impl Into, auth_token: impl Into, connector: C, - version: Option + version: Option, ) -> Result where C: tower::Service + Send + Clone + Sync + 'static, @@ -673,6 +674,7 @@ impl Database { url, auth_token, connector, + remote_encryption, } => { use crate::{ hrana::{connection::HttpConnection, hyper::HttpSender}, @@ -702,7 +704,7 @@ impl Database { remote: HttpConnection::new( url.clone(), auth_token.clone(), - HttpSender::new(connector.clone(), None), + HttpSender::new(connector.clone(), None, remote_encryption.clone()), ), read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), @@ -730,6 +732,7 @@ impl Database { auth_token, connector.clone(), version.as_ref().map(|s| s.as_str()), + None, ), ); diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index dce7a14343..7ce371974b 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -6,6 +6,8 @@ use crate::{Database, Result}; use super::DbType; +pub use crate::sync::EncryptionContext; + /// A builder for [`Database`]. This struct can be used to build /// all variants of [`Database`]. These variants include: /// @@ -51,6 +53,8 @@ impl Builder<()> { path: impl AsRef, url: String, auth_token: String, + #[cfg(feature = "sync")] + remote_encryption: Option, ) -> Builder { Builder { inner: RemoteReplica { @@ -69,6 +73,8 @@ impl Builder<()> { skip_safety_assert: false, #[cfg(feature = "sync")] sync_protocol: Default::default(), + #[cfg(feature = "sync")] + remote_encryption, }, } } @@ -93,6 +99,7 @@ impl Builder<()> { path: impl AsRef, url: String, auth_token: String, + remote_encryption: Option, ) -> Builder { Builder { inner: SyncedDatabase { @@ -108,6 +115,7 @@ impl Builder<()> { read_your_writes: true, remote_writes: false, push_batch_size: 0, + remote_encryption, }, } } @@ -227,6 +235,8 @@ cfg_replication! { skip_safety_assert: bool, #[cfg(feature = "sync")] sync_protocol: super::SyncProtocol, + #[cfg(feature = "sync")] + remote_encryption: Option, } /// Local replica configuration type in [`Builder`]. @@ -343,6 +353,8 @@ cfg_replication! { skip_safety_assert, #[cfg(feature = "sync")] sync_protocol, + #[cfg(feature = "sync")] + remote_encryption, } = self.inner; let connector = if let Some(connector) = connector { @@ -401,7 +413,7 @@ cfg_replication! { if res.status().is_success() { tracing::trace!("Using sync protocol v2 for {}", url); - return Builder::new_synced_database(path, url, auth_token) + return Builder::new_synced_database(path, url, auth_token, remote_encryption) .connector(connector) .remote_writes(true) .read_your_writes(read_your_writes) @@ -542,6 +554,7 @@ cfg_sync! { remote_writes: bool, read_your_writes: bool, push_batch_size: u32, + remote_encryption: Option, } impl Builder { @@ -594,6 +607,7 @@ cfg_sync! { remote_writes, read_your_writes, push_batch_size, + remote_encryption, } = self.inner; let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); @@ -617,6 +631,7 @@ cfg_sync! { flags, url.clone(), auth_token.clone(), + remote_encryption.clone(), ) .await?; @@ -632,6 +647,8 @@ cfg_sync! { url, auth_token, connector, + #[cfg(feature = "sync")] + remote_encryption, }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index 9ed0880af8..e586092568 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -26,17 +26,19 @@ pub type ByteStream = Box> + Send + Syn pub struct HttpSender { inner: hyper::Client, version: HeaderValue, + #[cfg(feature = "sync")] + remote_encryption: Option, } impl HttpSender { - pub fn new(connector: ConnectorService, version: Option<&str>) -> Self { + pub fn new(connector: ConnectorService, version: Option<&str>, #[cfg(feature = "sync")] remote_encryption: Option) -> Self { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap(); let inner = hyper::Client::builder().build(connector); - Self { inner, version } + Self { inner, version, #[cfg(feature = "sync")] remote_encryption } } async fn send( @@ -45,12 +47,24 @@ impl HttpSender { auth: Arc, body: String, ) -> Result> { - let req = hyper::Request::post(url.as_ref()) + let mut req = hyper::Request::post(url.as_ref()) .header(AUTHORIZATION, auth.as_ref()) - .header("x-libsql-client-version", self.version.clone()) - .body(hyper::Body::from(body)) + .header("x-libsql-client-version", self.version.clone()); + + if let Some(remote_encryption) = &self.remote_encryption { + if remote_encryption.decrypt_pull { + req = req.header("x-turso-decrypt-response", "true"); + } + if remote_encryption.push_is_encrypted { + req = req.header("x-turso-encrypted-request", "true"); + } + req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str()); + } + + let req = req.body(hyper::Body::from(body)) .map_err(|err| HranaError::Http(format!("{:?}", err)))?; + let resp = self.inner.request(req).await.map_err(HranaError::from)?; let status = resp.status(); @@ -109,8 +123,10 @@ impl HttpConnection { token: impl Into, connector: ConnectorService, version: Option<&str>, + #[cfg(feature = "sync")] + remote_encryption: Option, ) -> Self { - let inner = HttpSender::new(connector, version); + let inner = HttpSender::new(connector, version, #[cfg(feature = "sync")] remote_encryption); Self::new(url.into(), token.into(), inner) } } diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index 15f98d8869..fe47e56d16 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -132,6 +132,7 @@ pub mod params; cfg_sync! { mod sync; pub use database::SyncProtocol; + pub use sync::EncryptionContext; } cfg_replication! { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 7391870f7a..13dac3605c 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -212,6 +212,7 @@ impl Database { flags: OpenFlags, endpoint: String, auth_token: String, + remote_encryption: Option, ) -> Result { let db_path = db_path.into(); let endpoint = if endpoint.starts_with("libsql:") { @@ -222,7 +223,7 @@ impl Database { let mut db = Database::open(&db_path, flags)?; let sync_ctx = - SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token)).await?; + SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token), remote_encryption).await?; db.sync_ctx = Some(Arc::new(Mutex::new(sync_ctx))); Ok(db) diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 871a117899..e94b34c8ed 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -104,6 +104,16 @@ struct PushFramesResult { baton: Option, } +#[derive(Debug, Clone)] +pub struct EncryptionContext { + /// The base64-encoded key for the encryption, sent on every request. + pub key_16_bytes_base64_encoded: String, + /// Whether the pushed frames are already encrypted. + pub push_is_encrypted: bool, + /// Whether to request the server to decrypt the pulled frames. + pub decrypt_pull: bool, +} + pub struct SyncContext { db_path: String, client: hyper::Client, @@ -118,6 +128,8 @@ pub struct SyncContext { /// whenever sync is called very first time, we will call the remote server /// to get the generation information and sync the db file if needed initial_server_sync: bool, + /// The encryption context for the sync. + remote_encryption: Option, } impl SyncContext { @@ -126,6 +138,7 @@ impl SyncContext { db_path: String, sync_url: String, auth_token: Option, + remote_encryption: Option, ) -> Result { let client = hyper::client::Client::builder().build::<_, hyper::Body>(connector); @@ -147,6 +160,7 @@ impl SyncContext { durable_generation: 0, durable_frame_num: 0, initial_server_sync: false, + remote_encryption, }; if let Err(e) = me.read_metadata().await { @@ -275,6 +289,16 @@ impl SyncContext { None => {} } + if let Some(remote_encryption) = &self.remote_encryption { + if remote_encryption.decrypt_pull { + req = req.header("x-turso-decrypt-response", "true"); + } + if remote_encryption.push_is_encrypted { + req = req.header("x-turso-encrypted-request", "true"); + } + req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str()); + } + let req = req.body(body.clone().into()).expect("valid body"); let res = self @@ -386,6 +410,16 @@ impl SyncContext { None => {} } + if let Some(remote_encryption) = &self.remote_encryption { + if remote_encryption.decrypt_pull { + req = req.header("x-turso-decrypt-response", "true"); + } + if remote_encryption.push_is_encrypted { + req = req.header("x-turso-encrypted-request", "true"); + } + req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str()); + } + let req = req.body(Body::empty()).expect("valid request"); let res = self @@ -527,6 +561,16 @@ impl SyncContext { req = req.header("Authorization", auth_token); } + if let Some(remote_encryption) = &self.remote_encryption { + if remote_encryption.decrypt_pull { + req = req.header("x-turso-decrypt-response", "true"); + } + if remote_encryption.push_is_encrypted { + req = req.header("x-turso-encrypted-request", "true"); + } + req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str()); + } + let req = req.body(Body::empty()).expect("valid request"); let res = self @@ -623,6 +667,16 @@ impl SyncContext { req = req.header("Authorization", auth_token); } + if let Some(remote_encryption) = &self.remote_encryption { + if remote_encryption.decrypt_pull { + req = req.header("x-turso-decrypt-response", "true"); + } + if remote_encryption.push_is_encrypted { + req = req.header("x-turso-encrypted-request", "true"); + } + req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str()); + } + let req = req.body(Body::empty()).expect("valid request"); let res = self