Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ enum DbType {
url: String,
auth_token: String,
connector: crate::util::ConnectorService,
remote_encryption: Option<crate::sync::EncryptionContext>,
},
#[cfg(feature = "remote")]
Remote {
Expand Down Expand Up @@ -212,7 +213,7 @@ cfg_replication! {
endpoint,
auth_token,
https,
encryption_config
encryption_config,
).await
}

Expand Down Expand Up @@ -521,7 +522,7 @@ cfg_remote! {
url: impl Into<String>,
auth_token: impl Into<String>,
connector: C,
version: Option<String>
version: Option<String>,
) -> Result<Self>
where
C: tower::Service<http::Uri> + Send + Clone + Sync + 'static,
Expand Down Expand Up @@ -673,6 +674,7 @@ impl Database {
url,
auth_token,
connector,
remote_encryption,
} => {
use crate::{
hrana::{connection::HttpConnection, hyper::HttpSender},
Expand Down Expand Up @@ -702,7 +704,7 @@ impl Database {
remote: HttpConnection::new(
url.clone(),
auth_token.clone(),
HttpSender::new(connector.clone(), None),
HttpSender::new(connector.clone(), None, remote_encryption.clone()),
),
read_your_writes: *read_your_writes,
context: db.sync_ctx.clone().unwrap(),
Expand Down Expand Up @@ -730,6 +732,7 @@ impl Database {
auth_token,
connector.clone(),
version.as_ref().map(|s| s.as_str()),
None,
),
);

Expand Down
19 changes: 18 additions & 1 deletion libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use crate::{Database, Result};

use super::DbType;

pub use crate::sync::EncryptionContext;

/// A builder for [`Database`]. This struct can be used to build
/// all variants of [`Database`]. These variants include:
///
Expand Down Expand Up @@ -51,6 +53,8 @@ impl Builder<()> {
path: impl AsRef<std::path::Path>,
url: String,
auth_token: String,
#[cfg(feature = "sync")]
remote_encryption: Option<crate::sync::EncryptionContext>,
) -> Builder<RemoteReplica> {
Builder {
inner: RemoteReplica {
Expand All @@ -69,6 +73,8 @@ impl Builder<()> {
skip_safety_assert: false,
#[cfg(feature = "sync")]
sync_protocol: Default::default(),
#[cfg(feature = "sync")]
remote_encryption,
},
}
}
Expand All @@ -93,6 +99,7 @@ impl Builder<()> {
path: impl AsRef<std::path::Path>,
url: String,
auth_token: String,
remote_encryption: Option<EncryptionContext>,
) -> Builder<SyncedDatabase> {
Builder {
inner: SyncedDatabase {
Expand All @@ -108,6 +115,7 @@ impl Builder<()> {
read_your_writes: true,
remote_writes: false,
push_batch_size: 0,
remote_encryption,
},
}
}
Expand Down Expand Up @@ -227,6 +235,8 @@ cfg_replication! {
skip_safety_assert: bool,
#[cfg(feature = "sync")]
sync_protocol: super::SyncProtocol,
#[cfg(feature = "sync")]
remote_encryption: Option<crate::sync::EncryptionContext>,
}

/// Local replica configuration type in [`Builder`].
Expand Down Expand Up @@ -343,6 +353,8 @@ cfg_replication! {
skip_safety_assert,
#[cfg(feature = "sync")]
sync_protocol,
#[cfg(feature = "sync")]
remote_encryption,
} = self.inner;

let connector = if let Some(connector) = connector {
Expand Down Expand Up @@ -401,7 +413,7 @@ cfg_replication! {

if res.status().is_success() {
tracing::trace!("Using sync protocol v2 for {}", url);
return Builder::new_synced_database(path, url, auth_token)
return Builder::new_synced_database(path, url, auth_token, remote_encryption)
.connector(connector)
.remote_writes(true)
.read_your_writes(read_your_writes)
Expand Down Expand Up @@ -542,6 +554,7 @@ cfg_sync! {
remote_writes: bool,
read_your_writes: bool,
push_batch_size: u32,
remote_encryption: Option<crate::sync::EncryptionContext>,
}

impl Builder<SyncedDatabase> {
Expand Down Expand Up @@ -594,6 +607,7 @@ cfg_sync! {
remote_writes,
read_your_writes,
push_batch_size,
remote_encryption,
} = self.inner;

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();
Expand All @@ -617,6 +631,7 @@ cfg_sync! {
flags,
url.clone(),
auth_token.clone(),
remote_encryption.clone(),
)
.await?;

Expand All @@ -632,6 +647,8 @@ cfg_sync! {
url,
auth_token,
connector,
#[cfg(feature = "sync")]
remote_encryption,
},
max_write_replication_index: Default::default(),
})
Expand Down
28 changes: 22 additions & 6 deletions libsql/src/hrana/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@ pub type ByteStream = Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Syn
pub struct HttpSender {
inner: hyper::Client<ConnectorService, hyper::Body>,
version: HeaderValue,
#[cfg(feature = "sync")]
remote_encryption: Option<crate::sync::EncryptionContext>,
}

impl HttpSender {
pub fn new(connector: ConnectorService, version: Option<&str>) -> Self {
pub fn new(connector: ConnectorService, version: Option<&str>, #[cfg(feature = "sync")] remote_encryption: Option<crate::sync::EncryptionContext>) -> Self {
let ver = version.unwrap_or(env!("CARGO_PKG_VERSION"));

let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap();

let inner = hyper::Client::builder().build(connector);

Self { inner, version }
Self { inner, version, #[cfg(feature = "sync")] remote_encryption }
}

async fn send(
Expand All @@ -45,12 +47,24 @@ impl HttpSender {
auth: Arc<str>,
body: String,
) -> Result<super::HttpBody<ByteStream>> {
let req = hyper::Request::post(url.as_ref())
let mut req = hyper::Request::post(url.as_ref())
.header(AUTHORIZATION, auth.as_ref())
.header("x-libsql-client-version", self.version.clone())
.body(hyper::Body::from(body))
.header("x-libsql-client-version", self.version.clone());

if let Some(remote_encryption) = &self.remote_encryption {
if remote_encryption.decrypt_pull {
req = req.header("x-turso-decrypt-response", "true");
}
if remote_encryption.push_is_encrypted {
req = req.header("x-turso-encrypted-request", "true");
}
req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str());
}

let req = req.body(hyper::Body::from(body))
.map_err(|err| HranaError::Http(format!("{:?}", err)))?;


let resp = self.inner.request(req).await.map_err(HranaError::from)?;

let status = resp.status();
Expand Down Expand Up @@ -109,8 +123,10 @@ impl HttpConnection<HttpSender> {
token: impl Into<String>,
connector: ConnectorService,
version: Option<&str>,
#[cfg(feature = "sync")]
remote_encryption: Option<crate::sync::EncryptionContext>,
) -> Self {
let inner = HttpSender::new(connector, version);
let inner = HttpSender::new(connector, version, #[cfg(feature = "sync")] remote_encryption);
Self::new(url.into(), token.into(), inner)
}
}
Expand Down
1 change: 1 addition & 0 deletions libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ pub mod params;
cfg_sync! {
mod sync;
pub use database::SyncProtocol;
pub use sync::EncryptionContext;
}

cfg_replication! {
Expand Down
3 changes: 2 additions & 1 deletion libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ impl Database {
flags: OpenFlags,
endpoint: String,
auth_token: String,
remote_encryption: Option<crate::sync::EncryptionContext>,
) -> Result<Database> {
let db_path = db_path.into();
let endpoint = if endpoint.starts_with("libsql:") {
Expand All @@ -222,7 +223,7 @@ impl Database {
let mut db = Database::open(&db_path, flags)?;

let sync_ctx =
SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token)).await?;
SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token), remote_encryption).await?;
db.sync_ctx = Some(Arc::new(Mutex::new(sync_ctx)));

Ok(db)
Expand Down
54 changes: 54 additions & 0 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ struct PushFramesResult {
baton: Option<String>,
}

#[derive(Debug, Clone)]
pub struct EncryptionContext {
/// The base64-encoded key for the encryption, sent on every request.
pub key_16_bytes_base64_encoded: String,
/// Whether the pushed frames are already encrypted.
pub push_is_encrypted: bool,
/// Whether to request the server to decrypt the pulled frames.
pub decrypt_pull: bool,
}

pub struct SyncContext {
db_path: String,
client: hyper::Client<ConnectorService, Body>,
Expand All @@ -118,6 +128,8 @@ pub struct SyncContext {
/// whenever sync is called very first time, we will call the remote server
/// to get the generation information and sync the db file if needed
initial_server_sync: bool,
/// The encryption context for the sync.
remote_encryption: Option<EncryptionContext>,
}

impl SyncContext {
Expand All @@ -126,6 +138,7 @@ impl SyncContext {
db_path: String,
sync_url: String,
auth_token: Option<String>,
remote_encryption: Option<EncryptionContext>,
) -> Result<Self> {
let client = hyper::client::Client::builder().build::<_, hyper::Body>(connector);

Expand All @@ -147,6 +160,7 @@ impl SyncContext {
durable_generation: 0,
durable_frame_num: 0,
initial_server_sync: false,
remote_encryption,
};

if let Err(e) = me.read_metadata().await {
Expand Down Expand Up @@ -275,6 +289,16 @@ impl SyncContext {
None => {}
}

if let Some(remote_encryption) = &self.remote_encryption {
if remote_encryption.decrypt_pull {
req = req.header("x-turso-decrypt-response", "true");
}
if remote_encryption.push_is_encrypted {
req = req.header("x-turso-encrypted-request", "true");
}
req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str());
}

let req = req.body(body.clone().into()).expect("valid body");

let res = self
Expand Down Expand Up @@ -386,6 +410,16 @@ impl SyncContext {
None => {}
}

if let Some(remote_encryption) = &self.remote_encryption {
if remote_encryption.decrypt_pull {
req = req.header("x-turso-decrypt-response", "true");
}
if remote_encryption.push_is_encrypted {
req = req.header("x-turso-encrypted-request", "true");
}
req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str());
}

let req = req.body(Body::empty()).expect("valid request");

let res = self
Expand Down Expand Up @@ -527,6 +561,16 @@ impl SyncContext {
req = req.header("Authorization", auth_token);
}

if let Some(remote_encryption) = &self.remote_encryption {
if remote_encryption.decrypt_pull {
req = req.header("x-turso-decrypt-response", "true");
}
if remote_encryption.push_is_encrypted {
req = req.header("x-turso-encrypted-request", "true");
}
req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str());
}

let req = req.body(Body::empty()).expect("valid request");

let res = self
Expand Down Expand Up @@ -623,6 +667,16 @@ impl SyncContext {
req = req.header("Authorization", auth_token);
}

if let Some(remote_encryption) = &self.remote_encryption {
if remote_encryption.decrypt_pull {
req = req.header("x-turso-decrypt-response", "true");
}
if remote_encryption.push_is_encrypted {
req = req.header("x-turso-encrypted-request", "true");
}
req = req.header("x-turso-encryption-key", remote_encryption.key_16_bytes_base64_encoded.as_str());
}

let req = req.body(Body::empty()).expect("valid request");

let res = self
Expand Down