diff --git a/libsql/src/database.rs b/libsql/src/database.rs index da0cc7dfac..32d325cbc3 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -99,6 +99,7 @@ enum DbType { url: String, auth_token: String, connector: crate::util::ConnectorService, + _bg_abort: Option>, }, #[cfg(feature = "remote")] Remote { @@ -673,12 +674,11 @@ impl Database { url, auth_token, connector, + .. } => { use crate::{ - hrana::{connection::HttpConnection, hyper::HttpSender}, - local::impls::LibsqlConnection, - replication::connection::State, - sync::connection::SyncedConnection, + hrana::connection::HttpConnection, local::impls::LibsqlConnection, + replication::connection::State, sync::connection::SyncedConnection, }; use tokio::sync::Mutex; @@ -699,10 +699,11 @@ impl Database { if *remote_writes { let synced = SyncedConnection { local, - remote: HttpConnection::new( + remote: HttpConnection::new_with_connector( url.clone(), auth_token.clone(), - HttpSender::new(connector.clone(), None), + connector.clone(), + None, ), read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index dce7a14343..a8be27598e 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -108,6 +108,7 @@ impl Builder<()> { read_your_writes: true, remote_writes: false, push_batch_size: 0, + sync_interval: None, }, } } @@ -401,12 +402,18 @@ cfg_replication! { if res.status().is_success() { tracing::trace!("Using sync protocol v2 for {}", url); - return Builder::new_synced_database(path, url, auth_token) + let builder = Builder::new_synced_database(path, url, auth_token) .connector(connector) .remote_writes(true) - .read_your_writes(read_your_writes) - .build() - .await; + .read_your_writes(read_your_writes); + + let builder = if let Some(sync_interval) = sync_interval { + builder.sync_interval(sync_interval) + } else { + builder + }; + + return builder.build().await; } tracing::trace!("Using sync protocol v1 for {} based on probe results", url); } @@ -542,6 +549,7 @@ cfg_sync! { remote_writes: bool, read_your_writes: bool, push_batch_size: u32, + sync_interval: Option, } impl Builder { @@ -566,6 +574,14 @@ cfg_sync! { self } + /// Set the duration at which the replicator will automatically call `sync` in the + /// background. The sync will continue for the duration that the resulted `Database` + /// type is alive for, once it is dropped the background task will get dropped and stop. + pub fn sync_interval(mut self, duration: std::time::Duration) -> Builder { + self.inner.sync_interval = Some(duration); + self + } + /// Provide a custom http connector that will be used to create http connections. pub fn connector(mut self, connector: C) -> Builder where @@ -580,6 +596,8 @@ cfg_sync! { /// Build a connection to a local database that can be synced to remote server. pub async fn build(self) -> Result { + use tracing::Instrument as _; + let SyncedDatabase { path, flags, @@ -594,6 +612,7 @@ cfg_sync! { remote_writes, read_your_writes, push_batch_size, + sync_interval, } = self.inner; let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); @@ -624,6 +643,35 @@ cfg_sync! { db.sync_ctx.as_ref().unwrap().lock().await.set_push_batch_size(push_batch_size); } + let mut bg_abort: Option> = None; + let conn = db.connect()?; + + let sync_ctx = db.sync_ctx.as_ref().unwrap().clone(); + + if let Some(sync_interval) = sync_interval { + let jh = tokio::spawn( + async move { + loop { + tracing::trace!("trying to sync"); + let mut ctx = sync_ctx.lock().await; + if remote_writes { + if let Err(e) = crate::sync::try_pull(&mut ctx, &conn).await { + tracing::error!("sync error: {}", e); + } + } else { + if let Err(e) = crate::sync::sync_offline(&mut ctx, &conn).await { + tracing::error!("sync error: {}", e); + } + } + tokio::time::sleep(sync_interval).await; + } + } + .instrument(tracing::info_span!("sync_interval")), + ); + + bg_abort.replace(std::sync::Arc::new(crate::sync::DropAbort(jh.abort_handle()))); + } + Ok(Database { db_type: DbType::Offline { db, @@ -632,6 +680,7 @@ cfg_sync! { url, auth_token, connector, + _bg_abort: bg_abort, }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 871a117899..933ed199f7 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use chrono::Utc; use http::{HeaderValue, StatusCode}; use hyper::Body; -use tokio::io::AsyncWriteExt as _; +use tokio::{io::AsyncWriteExt as _, task::AbortHandle}; use uuid::Uuid; #[cfg(test)] @@ -81,6 +81,14 @@ pub struct PushResult { baton: Option, } +pub struct DropAbort(pub AbortHandle); + +impl Drop for DropAbort { + fn drop(&mut self) { + self.0.abort(); + } +} + pub enum PushStatus { Ok, Conflict, @@ -216,7 +224,9 @@ impl SyncContext { match result.status { PushStatus::Conflict => { - return Err(SyncError::InvalidPushFrameConflict(frame_no, result.max_frame_no).into()); + return Err( + SyncError::InvalidPushFrameConflict(frame_no, result.max_frame_no).into(), + ); } _ => {} } @@ -251,7 +261,11 @@ impl SyncContext { tracing::debug!(?durable_frame_num, "frame successfully pushed"); // Update our last known max_frame_no from the server. - tracing::debug!(?generation, ?durable_frame_num, "updating remote generation and durable_frame_num"); + tracing::debug!( + ?generation, + ?durable_frame_num, + "updating remote generation and durable_frame_num" + ); self.durable_generation = generation; self.durable_frame_num = durable_frame_num; @@ -261,7 +275,12 @@ impl SyncContext { }) } - async fn push_with_retry(&self, mut uri: String, body: Bytes, max_retries: usize) -> Result { + async fn push_with_retry( + &self, + mut uri: String, + body: Bytes, + max_retries: usize, + ) -> Result { let mut nr_retries = 0; loop { let mut req = http::Request::post(uri.clone()); @@ -402,7 +421,9 @@ impl SyncContext { } // BUG ALERT: The server returns a 500 error if the remote database is empty. // This is a bug and should be fixed. - if res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::INTERNAL_SERVER_ERROR { + if res.status() == StatusCode::BAD_REQUEST + || res.status() == StatusCode::INTERNAL_SERVER_ERROR + { let res_body = hyper::body::to_bytes(res.into_body()) .await .map_err(SyncError::HttpBody)?; @@ -417,7 +438,9 @@ impl SyncContext { let generation = generation .as_u64() .ok_or_else(|| SyncError::JsonValue(generation.clone()))?; - return Ok(PullResult::EndOfGeneration { max_generation: generation as u32 }); + return Ok(PullResult::EndOfGeneration { + max_generation: generation as u32, + }); } if res.status().is_redirection() { uri = match res.headers().get(hyper::header::LOCATION) { @@ -449,7 +472,6 @@ impl SyncContext { } } - pub(crate) fn next_generation(&mut self) { self.durable_generation += 1; self.durable_frame_num = 0; @@ -741,9 +763,7 @@ pub async fn bootstrap_db(sync_ctx: &mut SyncContext) -> Result<()> { // if we are lagging behind, then we will call the export API and get to the latest // generation directly. let info = sync_ctx.get_remote_info().await?; - sync_ctx - .sync_db_if_needed(info.current_generation) - .await?; + sync_ctx.sync_db_if_needed(info.current_generation).await?; // when sync_ctx is initialised, we set durable_generation to 0. however, once // sync_db is called, it should be > 0. assert!(sync_ctx.durable_generation > 0, "generation should be > 0"); @@ -871,7 +891,7 @@ pub async fn try_pull( let insert_handle = conn.wal_insert_handle()?; let mut err = None; - + loop { let generation = sync_ctx.durable_generation(); let frame_no = sync_ctx.durable_frame_num() + 1;