diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index c02bbe229f..9df2c3e537 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -102,6 +102,7 @@ sync = [ "stream", "remote", "replication", + "dep:base64", "dep:tower", "dep:hyper", "dep:http", @@ -131,6 +132,7 @@ hrana = [ serde = ["dep:serde"] remote = [ "hrana", + "dep:base64", "dep:tower", "dep:hyper", "dep:hyper", diff --git a/libsql/examples/encryption_sync.rs b/libsql/examples/encryption_sync.rs new file mode 100644 index 0000000000..15a3b22eaf --- /dev/null +++ b/libsql/examples/encryption_sync.rs @@ -0,0 +1,83 @@ +// Example of using offline writes with encryption + +use libsql::{params, Builder}; +use libsql::{EncryptionContext, EncryptionKey}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // The local database path where the data will be stored. + let db_path = std::env::var("LIBSQL_DB_PATH").unwrap(); + + // The remote sync URL to use. + let sync_url = std::env::var("LIBSQL_SYNC_URL").unwrap(); + + // The authentication token for the remote sync server. + let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); + + // Optional encryption key for the database, if provided. + let encryption = if let Ok(key) = std::env::var("LIBSQL_ENCRYPTION_KEY") { + Some(EncryptionContext { + key: EncryptionKey::Base64Encoded(key), + }) + } else { + None + }; + + let db_builder = + Builder::new_synced_database(db_path, sync_url, auth_token).remote_encryption(encryption); + + let db = match db_builder.build().await { + Ok(db) => db, + Err(error) => { + eprintln!("Error connecting to remote sync server: {}", error); + return; + } + }; + + let conn = db.connect().unwrap(); + + print!("Syncing with remote database..."); + db.sync().await.unwrap(); + println!(" done"); + + let mut results = conn.query("SELECT count(*) FROM dummy", ()).await.unwrap(); + let count: u32 = results.next().await.unwrap().unwrap().get(0).unwrap(); + println!("dummy table has {} entries", count); + + conn.execute( + r#" + CREATE TABLE IF NOT EXISTS guest_book_entries ( + text TEXT + )"#, + (), + ) + .await + .unwrap(); + + let mut input = String::new(); + println!("Please write your entry to the guestbook:"); + match std::io::stdin().read_line(&mut input) { + Ok(_) => { + println!("You entered: {}", input); + let params = params![input.as_str()]; + conn.execute("INSERT INTO guest_book_entries (text) VALUES (?)", params) + .await + .unwrap(); + } + Err(error) => { + eprintln!("Error reading input: {}", error); + } + } + db.sync().await.unwrap(); + let mut results = conn + .query("SELECT * FROM guest_book_entries", ()) + .await + .unwrap(); + println!("Guest book entries:"); + while let Some(row) = results.next().await.unwrap() { + let text: String = row.get(0).unwrap(); + println!(" {}", text); + } +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 7069799caa..322913eefc 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -8,6 +8,8 @@ pub use builder::Builder; pub use libsql_sys::{Cipher, EncryptionConfig}; use crate::{Connection, Result}; +#[cfg(any(feature = "remote", feature = "sync"))] +use base64::{engine::general_purpose, Engine}; use std::fmt; use std::sync::atomic::AtomicU64; @@ -100,6 +102,7 @@ enum DbType { auth_token: String, connector: crate::util::ConnectorService, _bg_abort: Option>, + remote_encryption: Option, }, #[cfg(feature = "remote")] Remote { @@ -108,6 +111,7 @@ enum DbType { connector: crate::util::ConnectorService, version: Option, namespace: Option, + remote_encryption: Option, }, } @@ -214,7 +218,7 @@ cfg_replication! { endpoint, auth_token, https, - encryption_config + encryption_config, ).await } @@ -524,7 +528,7 @@ cfg_remote! { url: impl Into, auth_token: impl Into, connector: C, - version: Option + version: Option, ) -> Result where C: tower::Service + Send + Clone + Sync + 'static, @@ -544,6 +548,7 @@ cfg_remote! { connector: crate::util::ConnectorService::new(svc), version, namespace: None, + remote_encryption: None }, max_write_replication_index: Default::default(), }) @@ -677,6 +682,7 @@ impl Database { url, auth_token, connector, + remote_encryption, .. } => { use crate::{ @@ -708,6 +714,7 @@ impl Database { connector.clone(), None, None, + remote_encryption.clone(), ), read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), @@ -730,6 +737,7 @@ impl Database { connector, version, namespace, + remote_encryption, } => { let conn = std::sync::Arc::new( crate::hrana::connection::HttpConnection::new_with_connector( @@ -738,6 +746,7 @@ impl Database { connector.clone(), version.as_ref().map(|s| s.as_str()), namespace.as_ref().map(|s| s.as_str()), + remote_encryption.clone(), ), ); @@ -781,3 +790,29 @@ impl std::fmt::Debug for Database { f.debug_struct("Database").finish() } } + +#[cfg(any(feature = "remote", feature = "sync"))] +#[derive(Debug, Clone)] +pub enum EncryptionKey { + /// The key is a base64-encoded string. + Base64Encoded(String), + /// The key is a byte array. + Bytes(Vec), +} + +#[cfg(any(feature = "remote", feature = "sync"))] +impl EncryptionKey { + pub fn as_string(&self) -> String { + match self { + EncryptionKey::Base64Encoded(s) => s.clone(), + EncryptionKey::Bytes(b) => general_purpose::STANDARD.encode(b), + } + } +} + +#[cfg(any(feature = "remote", feature = "sync"))] +#[derive(Debug, Clone)] +pub struct EncryptionContext { + /// The base64-encoded key for the encryption, sent on every request. + pub key: EncryptionKey, +} diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index ef70479430..623ec24a3e 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -5,6 +5,9 @@ cfg_core! { use super::DbType; use crate::{Database, Result}; +#[cfg(any(feature = "remote", feature = "sync"))] +pub use crate::database::EncryptionContext; + /// A builder for [`Database`]. This struct can be used to build /// all variants of [`Database`]. These variants include: /// @@ -60,6 +63,8 @@ impl Builder<()> { connector: None, version: None, namespace: None, + #[cfg(any(feature = "remote", feature = "sync"))] + remote_encryption: None, }, encryption_config: None, read_your_writes: true, @@ -68,6 +73,8 @@ impl Builder<()> { skip_safety_assert: false, #[cfg(feature = "sync")] sync_protocol: Default::default(), + #[cfg(feature = "sync")] + remote_encryption: None }, } } @@ -103,12 +110,14 @@ impl Builder<()> { connector: None, version: None, namespace: None, + remote_encryption: None, }, connector: None, read_your_writes: true, remote_writes: false, push_batch_size: 0, sync_interval: None, + remote_encryption: None, }, } } @@ -124,6 +133,7 @@ impl Builder<()> { connector: None, version: None, namespace: None, + remote_encryption: None, }, } } @@ -138,6 +148,8 @@ cfg_replication_or_remote_or_sync! { connector: Option, version: Option, namespace: Option, + #[cfg(any(feature = "remote", feature = "sync"))] + remote_encryption: Option, } } @@ -229,6 +241,8 @@ cfg_replication! { skip_safety_assert: bool, #[cfg(feature = "sync")] sync_protocol: super::SyncProtocol, + #[cfg(feature = "sync")] + remote_encryption: Option, } /// Local replica configuration type in [`Builder`]. @@ -290,6 +304,13 @@ cfg_replication! { self } + /// Set the encryption context if the database is encrypted in remote server. + #[cfg(feature = "sync")] + pub fn remote_encryption(mut self, encryption_context: Option) -> Builder { + self.inner.remote_encryption = encryption_context; + self + } + pub fn http_request_callback(mut self, f: F) -> Builder where F: Fn(&mut http::Request<()>) + Send + Sync + 'static @@ -337,6 +358,7 @@ cfg_replication! { connector, version, namespace, + .. }, encryption_config, read_your_writes, @@ -345,6 +367,8 @@ cfg_replication! { skip_safety_assert, #[cfg(feature = "sync")] sync_protocol, + #[cfg(feature = "sync")] + remote_encryption, } = self.inner; let connector = if let Some(connector) = connector { @@ -406,7 +430,8 @@ cfg_replication! { let builder = Builder::new_synced_database(path, url, auth_token) .connector(connector) .remote_writes(true) - .read_your_writes(read_your_writes); + .read_your_writes(read_your_writes) + .remote_encryption(remote_encryption); let builder = if let Some(sync_interval) = sync_interval { builder.sync_interval(sync_interval) @@ -463,7 +488,10 @@ cfg_replication! { Ok(Database { - db_type: DbType::Sync { db, encryption_config }, + db_type: DbType::Sync { + db, + encryption_config, + }, max_write_replication_index: Default::default(), }) } @@ -503,6 +531,7 @@ cfg_replication! { connector, version, namespace, + .. }) = remote { let connector = if let Some(connector) = connector { @@ -553,6 +582,7 @@ cfg_sync! { read_your_writes: bool, push_batch_size: u32, sync_interval: Option, + remote_encryption: Option, } impl Builder { @@ -585,6 +615,12 @@ cfg_sync! { self } + /// Set the encryption context if the database is encrypted in remote server. + pub fn remote_encryption(mut self, encryption_context: Option) -> Builder { + self.inner.remote_encryption = encryption_context; + self + } + /// Provide a custom http connector that will be used to create http connections. pub fn connector(mut self, connector: C) -> Builder where @@ -611,12 +647,14 @@ cfg_sync! { connector: _, version: _, namespace: _, + .. }, connector, remote_writes, read_your_writes, push_batch_size, sync_interval, + remote_encryption, } = self.inner; let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); @@ -640,6 +678,7 @@ cfg_sync! { flags, url.clone(), auth_token.clone(), + remote_encryption.clone(), ) .await?; @@ -708,6 +747,8 @@ cfg_sync! { auth_token, connector, _bg_abort: bg_abort, + #[cfg(feature = "sync")] + remote_encryption, }, max_write_replication_index: Default::default(), }) @@ -742,6 +783,12 @@ cfg_remote! { self } + /// Set the encryption context if the database is encrypted in remote server. + pub fn remote_encryption(mut self, encryption_context: Option) -> Builder { + self.inner.remote_encryption = encryption_context; + self + } + /// Build the remote database client. pub async fn build(self) -> Result { let Remote { @@ -750,6 +797,7 @@ cfg_remote! { connector, version, namespace, + remote_encryption, } = self.inner; let connector = if let Some(connector) = connector { @@ -772,6 +820,7 @@ cfg_remote! { connector, version, namespace, + remote_encryption }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index b78838d11a..d32341796b 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -27,6 +27,8 @@ pub struct HttpSender { inner: hyper::Client, version: HeaderValue, namespace: Option, + #[cfg(any(feature = "remote", feature = "sync"))] + remote_encryption: Option, } impl HttpSender { @@ -34,6 +36,9 @@ impl HttpSender { connector: ConnectorService, version: Option<&str>, namespace: Option<&str>, + #[cfg(any(feature = "remote", feature = "sync"))] remote_encryption: Option< + crate::database::EncryptionContext, + >, ) -> Self { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); @@ -41,11 +46,12 @@ impl HttpSender { let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap()); let inner = hyper::Client::builder().build(connector); - Self { inner, version, namespace, + #[cfg(any(feature = "remote", feature = "sync"))] + remote_encryption, } } @@ -63,6 +69,12 @@ impl HttpSender { req_builder = req_builder.header("x-namespace", namespace); } + #[cfg(any(feature = "remote", feature = "sync"))] + if let Some(remote_encryption) = &self.remote_encryption { + req_builder = + req_builder.header("x-turso-encryption-key", remote_encryption.key.as_string()); + } + let req = req_builder .body(hyper::Body::from(body)) .map_err(|err| HranaError::Http(format!("{:?}", err)))?; @@ -126,8 +138,17 @@ impl HttpConnection { connector: ConnectorService, version: Option<&str>, namespace: Option<&str>, + #[cfg(any(feature = "remote", feature = "sync"))] remote_encryption: Option< + crate::database::EncryptionContext, + >, ) -> Self { - let inner = HttpSender::new(connector, version, namespace); + let inner = HttpSender::new( + connector, + version, + namespace, + #[cfg(any(feature = "remote", feature = "sync"))] + remote_encryption, + ); Self::new(url.into(), token.into(), inner) } } diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index 15f98d8869..a42b0a4940 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -132,6 +132,8 @@ pub mod params; cfg_sync! { mod sync; pub use database::SyncProtocol; + pub use database::EncryptionContext; + pub use database::EncryptionKey; } cfg_replication! { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 7cfcf330fe..5d5eda0b06 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -212,6 +212,7 @@ impl Database { flags: OpenFlags, endpoint: String, auth_token: String, + remote_encryption: Option, ) -> Result { let db_path = db_path.into(); let endpoint = if endpoint.starts_with("libsql:") { @@ -221,8 +222,14 @@ impl Database { }; let mut db = Database::open(&db_path, flags)?; - let sync_ctx = - SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token)).await?; + let sync_ctx = SyncContext::new( + connector, + db_path.into(), + endpoint, + Some(auth_token), + remote_encryption, + ) + .await?; db.sync_ctx = Some(Arc::new(Mutex::new(sync_ctx))); Ok(db) diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 9a74b118b0..1071646ea2 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -1,11 +1,11 @@ use crate::{local::Connection, util::ConnectorService, Error, Result}; -use std::path::Path; - +use crate::database::EncryptionContext; use bytes::Bytes; use chrono::Utc; use http::{HeaderValue, StatusCode}; use hyper::Body; +use std::path::Path; use tokio::io::AsyncWriteExt as _; use uuid::Uuid; @@ -133,6 +133,8 @@ pub struct SyncContext { /// whenever sync is called very first time, we will call the remote server /// to get the generation information and sync the db file if needed initial_server_sync: bool, + /// The encryption context for the sync. + remote_encryption: Option, } impl SyncContext { @@ -141,6 +143,7 @@ impl SyncContext { db_path: String, sync_url: String, auth_token: Option, + remote_encryption: Option, ) -> Result { let client = hyper::client::Client::builder().build::<_, hyper::Body>(connector); @@ -163,6 +166,7 @@ impl SyncContext { durable_generation: 0, durable_frame_num: 0, initial_server_sync: false, + remote_encryption, }; if let Err(e) = me.read_metadata().await { @@ -303,6 +307,10 @@ impl SyncContext { None => {} } + if let Some(remote_encryption) = &self.remote_encryption { + req = req.header("x-turso-encryption-key", remote_encryption.key.as_string()); + } + let req = req.body(body.clone().into()).expect("valid body"); let res = self @@ -414,6 +422,10 @@ impl SyncContext { None => {} } + if let Some(remote_encryption) = &self.remote_encryption { + req = req.header("x-turso-encryption-key", remote_encryption.key.as_string()); + } + let req = req.body(Body::empty()).expect("valid request"); let res = self @@ -577,6 +589,10 @@ impl SyncContext { req = req.header("Authorization", auth_token); } + if let Some(remote_encryption) = &self.remote_encryption { + req = req.header("x-turso-encryption-key", remote_encryption.key.as_string()); + } + let req = req.body(Body::empty()).expect("valid request"); let res = self @@ -673,6 +689,10 @@ impl SyncContext { req = req.header("Authorization", auth_token); } + if let Some(remote_encryption) = &self.remote_encryption { + req = req.header("x-turso-encryption-key", remote_encryption.key.as_string()); + } + let req = req.body(Body::empty()).expect("valid request"); let (res, http_duration) = diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index edd6d33eee..2232aecad5 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -20,6 +20,7 @@ async fn test_sync_context_push_frame() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -49,6 +50,7 @@ async fn test_sync_context_with_auth() { db_path.to_str().unwrap().to_string(), server.url(), Some("test_token".to_string()), + None, ) .await .unwrap(); @@ -73,6 +75,7 @@ async fn test_sync_context_multiple_frames() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -102,6 +105,7 @@ async fn test_sync_context_corrupted_metadata() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -123,6 +127,7 @@ async fn test_sync_context_corrupted_metadata() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -146,6 +151,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -171,6 +177,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap(); @@ -210,6 +217,7 @@ async fn test_sync_context_retry_on_error() { db_path.to_str().unwrap().to_string(), server.url(), None, + None, ) .await .unwrap();