Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 130 additions & 2 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ pub enum SyncError {
RedirectHeader(http::header::ToStrError),
#[error("redirect response with no location header")]
NoRedirectLocationHeader,
#[error("failed to pull db export: status={0}, error={1}")]
PullDb(StatusCode, String),
#[error("server returned a lower generation than local: local={0}, remote={1}")]
InvalidLocalGeneration(u32, u32),
}

impl SyncError {
Expand All @@ -86,6 +90,11 @@ pub enum PullResult {
EndOfGeneration { max_generation: u32 },
}

#[derive(serde::Deserialize)]
struct InfoResult {
current_generation: u32,
}

pub struct SyncContext {
db_path: String,
client: hyper::Client<ConnectorService, Body>,
Expand All @@ -97,6 +106,9 @@ pub struct SyncContext {
durable_generation: u32,
/// Represents the max_frame_no from the server.
durable_frame_num: u32,
/// whenever sync is called very first time, we will call the remote server
/// to get the generation information and sync the db file if needed
initial_server_sync: bool,
}

impl SyncContext {
Expand All @@ -123,8 +135,9 @@ impl SyncContext {
max_retries: DEFAULT_MAX_RETRIES,
push_batch_size: DEFAULT_PUSH_BATCH_SIZE,
client,
durable_generation: 1,
durable_generation: 0,
durable_frame_num: 0,
initial_server_sync: false,
};

if let Err(e) = me.read_metadata().await {
Expand Down Expand Up @@ -173,7 +186,7 @@ impl SyncContext {
frame_no,
frame_no + frames_count
);
tracing::debug!("pushing frame");
tracing::debug!("pushing frame(frame_no={}, count={}, generation={})", frame_no, frames_count, generation);

let result = self.push_with_retry(uri, frames, self.max_retries).await?;

Expand Down Expand Up @@ -458,6 +471,105 @@ impl SyncContext {

Ok(())
}

/// get_remote_info calls the remote server to get the current generation information.
async fn get_remote_info(&self) -> Result<InfoResult> {
let uri = format!("{}/info", self.sync_url);
let mut req = http::Request::builder().method("GET").uri(&uri);

if let Some(auth_token) = &self.auth_token {
req = req.header("Authorization", auth_token);
}

let req = req.body(Body::empty()).expect("valid request");

let res = self
.client
.request(req)
.await
.map_err(SyncError::HttpDispatch)?;

if !res.status().is_success() {
let status = res.status();
let body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;
return Err(
SyncError::PullDb(status, String::from_utf8_lossy(&body).to_string()).into(),
);
}

let body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;

let info = serde_json::from_slice(&body).map_err(SyncError::JsonDecode)?;

Ok(info)
}

async fn sync_db_if_needed(&mut self, generation: u32) -> Result<()> {
// we will get the export file only if the remote generation is different from the one we have
if generation == self.durable_generation {
return Ok(());
}
// somehow we are ahead of the remote in generations. following should not happen because
// we checkpoint only if the remote server tells us to do so.
if self.durable_generation > generation {
tracing::error!(
"server returned a lower generation than what we have: sent={}, got={}",
self.durable_generation,
generation
);
return Err(
SyncError::InvalidLocalGeneration(self.durable_generation, generation).into(),
);
}
tracing::debug!(
"syncing db file from remote server, generation={}",
generation
);
self.sync_db(generation).await
}

/// sync_db will download the db file from the remote server and replace the local file.
async fn sync_db(&mut self, generation: u32) -> Result<()> {
let uri = format!("{}/export/{}", self.sync_url, generation);
let mut req = http::Request::builder().method("GET").uri(&uri);

if let Some(auth_token) = &self.auth_token {
req = req.header("Authorization", auth_token);
}

let req = req.body(Body::empty()).expect("valid request");

let res = self
.client
.request(req)
.await
.map_err(SyncError::HttpDispatch)?;

if !res.status().is_success() {
let status = res.status();
let body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;
return Err(
SyncError::PullFrame(status, String::from_utf8_lossy(&body).to_string()).into(),
);
}

// todo: do streaming write to the disk
let bytes = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;

atomic_write(&self.db_path, &bytes).await?;
self.durable_generation = generation;
self.durable_frame_num = 0;
self.write_metadata().await?;
Ok(())
}
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
Expand Down Expand Up @@ -555,6 +667,22 @@ pub async fn sync_offline(
Err(e) => Err(e),
}
} else {
// todo: we are checking with the remote server only during initialisation. ideally,
// we should check everytime we try to sync with the remote server. However, we need to close
// all the ongoing connections since we replace `.db` file and remove the `.db-wal` file
if !sync_ctx.initial_server_sync {
// sync is being called first time. so we will call remote, get the generation information
// if we are lagging behind, then we will call the export API and get to the latest
// generation directly.
let info = sync_ctx.get_remote_info().await?;
sync_ctx
.sync_db_if_needed(info.current_generation)
.await?;
// when sync_ctx is initialised, we set durable_generation to 0. however, once
// sync_db is called, it should be > 0.
assert!(sync_ctx.durable_generation > 0, "generation should be > 0");
sync_ctx.initial_server_sync = true;
}
try_pull(sync_ctx, conn).await
}
.or_else(|err| {
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async fn test_sync_context_corrupted_metadata() {

// Verify that the context was reset to default values
assert_eq!(sync_ctx.durable_frame_num(), 0);
assert_eq!(sync_ctx.durable_generation(), 1);
assert_eq!(sync_ctx.durable_generation(), 0);
}

#[tokio::test]
Expand Down
Loading