diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index e810e82d30..466d158622 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -76,9 +76,7 @@ impl Connection { // disabled so that we can sync our changes back to a remote // server. conn.query("PRAGMA journal_mode = WAL", Params::None)?; - unsafe { - ffi::libsql_wal_disable_checkpoint(conn.raw); - } + conn.wal_disable_checkpoint()?; } Ok(conn) } @@ -554,6 +552,16 @@ impl Connection { Ok(buf) } + fn wal_disable_checkpoint(&self) -> Result<()> { + let rc = unsafe { libsql_sys::ffi::libsql_wal_disable_checkpoint(self.handle()) }; + if rc != 0 { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + format!("wal_disable_checkpoint failed"), + )); + } + Ok(()) + } fn wal_insert_begin(&self) -> Result<()> { let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_begin(self.handle()) }; if rc != 0 { @@ -576,14 +584,14 @@ impl Connection { Ok(()) } - fn wal_insert_frame(&self, frame: &[u8]) -> Result<()> { + fn wal_insert_frame(&self, frame_no: u32, frame: &[u8]) -> Result<()> { let mut conflict = 0i32; let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_frame( self.handle(), - frame.len() as u32, + frame_no, frame.as_ptr() as *mut std::ffi::c_void, - 0, + frame.len() as u32, &mut conflict, ) }; @@ -658,13 +666,13 @@ unsafe extern "C" fn authorizer_callback( pub(crate) struct WalInsertHandle<'a> { conn: &'a Connection, - in_session: RefCell + in_session: RefCell, } impl WalInsertHandle<'_> { - pub fn insert(&self, frame: &[u8]) -> Result<()> { + pub fn insert_at(&self, frame_no: u32, frame: &[u8]) -> Result<()> { assert!(*self.in_session.borrow()); - self.conn.wal_insert_frame(frame) + self.conn.wal_insert_frame(frame_no, frame) } pub fn begin(&self) -> Result<()> { @@ -698,3 +706,54 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +#[cfg(test)] +mod tests { + use crate::{ + local::{Connection, Database}, + params::Params, + OpenFlags, + }; + + #[tokio::test] + pub async fn test_kek() { + let temp_dir = tempfile::tempdir().unwrap(); + let path1 = temp_dir.path().join("local1.db"); + let db1 = Database::new(path1.to_str().unwrap().to_string(), OpenFlags::default()); + let conn1 = Connection::connect(&db1).unwrap(); + conn1 + .query("PRAGMA journal_mode = WAL", Params::None) + .unwrap(); + conn1.wal_disable_checkpoint().unwrap(); + + let path2 = temp_dir.path().join("local2.db"); + let db2 = Database::new(path2.to_str().unwrap().to_string(), OpenFlags::default()); + let conn2 = Connection::connect(&db2).unwrap(); + conn2 + .query("PRAGMA journal_mode = WAL", Params::None) + .unwrap(); + conn2.wal_disable_checkpoint().unwrap(); + + conn1.execute("CREATE TABLE t(x)", Params::None).unwrap(); + const CNT: usize = 32; + for _ in 0..CNT { + conn1 + .execute( + "INSERT INTO t VALUES (randomblob(1024 * 1024))", + Params::None, + ) + .unwrap(); + } + let handle = conn2.wal_insert_handle().unwrap(); + let frame_count = conn1.wal_frame_count(); + for frame_no in 0..frame_count { + let frame = conn1.wal_get_frame(frame_no + 1, 4096).unwrap(); + handle.insert_at(frame_no as u32 + 1, &frame).unwrap(); + } + let result = conn2.query("SELECT COUNT(*) FROM t", Params::None).unwrap(); + let row = result.unwrap().next().unwrap().unwrap(); + let column = row.get_value(0).unwrap(); + let cnt = *column.as_integer().unwrap(); + assert_eq!(cnt, 32 as i64); + } +} diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 1071646ea2..b713b88bbb 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -974,8 +974,8 @@ pub async fn try_pull( ); return Err(SyncError::InvalidPullFrameBytes(frames.len()).into()); } - for chunk in frames.chunks(FRAME_SIZE) { - let r = insert_handle.insert(&chunk); + for (i, chunk) in frames.chunks(FRAME_SIZE).enumerate() { + let r = insert_handle.insert_at(frame_no + i as u32, &chunk); if let Err(e) = r { tracing::error!( "insert error (frame= {}) : {:?}",