diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 32177cf803..a1d6b742f9 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -52,6 +52,8 @@ pub enum SyncError { InvalidPushFrameNoLow(u32, u32), #[error("server returned a higher frame_no: sent={0}, got={1}")] InvalidPushFrameNoHigh(u32, u32), + #[error("server returned a conflict: sent={0}, got={1}")] + InvalidPushFrameConflict(u32, u32), #[error("failed to pull frame: status={0}, error={1}")] PullFrame(StatusCode, String), #[error("failed to get location header for redirect: {0}")] @@ -66,6 +68,17 @@ impl SyncError { } } +pub struct PushResult { + status: PushStatus, + generation: u32, + max_frame_no: u32, +} + +pub enum PushStatus { + Ok, + Conflict, +} + pub enum PullResult { /// A frame was successfully pulled. Frame(Bytes), @@ -162,7 +175,16 @@ impl SyncContext { ); tracing::debug!("pushing frame"); - let (generation, durable_frame_num) = self.push_with_retry(uri, frames, self.max_retries).await?; + let result = self.push_with_retry(uri, frames, self.max_retries).await?; + + match result.status { + PushStatus::Conflict => { + return Err(SyncError::InvalidPushFrameConflict(frame_no, result.max_frame_no).into()); + } + _ => {} + } + let generation = result.generation; + let durable_frame_num = result.max_frame_no; if durable_frame_num > frame_no + frames_count - 1 { tracing::error!( @@ -198,7 +220,7 @@ impl SyncContext { Ok(durable_frame_num) } - async fn push_with_retry(&self, mut uri: String, body: Bytes, max_retries: usize) -> Result<(u32, u32)> { + async fn push_with_retry(&self, mut uri: String, body: Bytes, max_retries: usize) -> Result { let mut nr_retries = 0; loop { let mut req = http::Request::post(uri.clone()); @@ -228,6 +250,14 @@ impl SyncContext { let resp = serde_json::from_slice::(&res_body[..]) .map_err(SyncError::JsonDecode)?; + let status = resp + .get("status") + .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; + + let status = status + .as_str() + .ok_or_else(|| SyncError::JsonValue(status.clone()))?; + let generation = resp .get("generation") .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; @@ -244,7 +274,14 @@ impl SyncContext { .as_u64() .ok_or_else(|| SyncError::JsonValue(max_frame_no.clone()))?; - return Ok((generation as u32, max_frame_no as u32)); + let status = match status { + "ok" => PushStatus::Ok, + "conflict" => PushStatus::Conflict, + _ => return Err(SyncError::JsonValue(resp.clone()).into()), + }; + let generation = generation as u32; + let max_frame_no = max_frame_no as u32; + return Ok(PushResult { status, generation, max_frame_no }); } if res.status().is_redirection() { diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index 84f7c4980c..38af3902fa 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -376,6 +376,7 @@ impl MockServer { if req.uri().path().contains("/sync/") { // Return the max_frame_no that has been accepted let response = serde_json::json!({ + "status": "ok", "generation": 1, "max_frame_no": current_count });