@@ -60,6 +60,10 @@ pub enum SyncError {
6060 RedirectHeader ( http:: header:: ToStrError ) ,
6161 #[ error( "redirect response with no location header" ) ]
6262 NoRedirectLocationHeader ,
63+ #[ error( "failed to pull db export: status={0}, error={1}" ) ]
64+ PullDb ( StatusCode , String ) ,
65+ #[ error( "server returned a lower generation than local: local={0}, remote={1}" ) ]
66+ InvalidLocalGeneration ( u32 , u32 ) ,
6367}
6468
6569impl SyncError {
@@ -86,6 +90,11 @@ pub enum PullResult {
8690 EndOfGeneration { max_generation : u32 } ,
8791}
8892
93+ #[ derive( serde:: Deserialize ) ]
94+ struct InfoResult {
95+ current_generation : u32 ,
96+ }
97+
8998pub struct SyncContext {
9099 db_path : String ,
91100 client : hyper:: Client < ConnectorService , Body > ,
@@ -97,6 +106,9 @@ pub struct SyncContext {
97106 durable_generation : u32 ,
98107 /// Represents the max_frame_no from the server.
99108 durable_frame_num : u32 ,
109+ /// whenever sync is called very first time, we will call the remote server
110+ /// to get the generation information and sync the db file if needed
111+ initial_server_sync : bool ,
100112}
101113
102114impl SyncContext {
@@ -125,6 +137,7 @@ impl SyncContext {
125137 client,
126138 durable_generation : 1 ,
127139 durable_frame_num : 0 ,
140+ initial_server_sync : false ,
128141 } ;
129142
130143 if let Err ( e) = me. read_metadata ( ) . await {
@@ -458,6 +471,114 @@ impl SyncContext {
458471
459472 Ok ( ( ) )
460473 }
474+
475+ /// get_remote_info calls the remote server to get the current generation information.
476+ async fn get_remote_info ( & self ) -> Result < InfoResult > {
477+ let uri = format ! ( "{}/info" , self . sync_url) ;
478+ let mut req = http:: Request :: builder ( ) . method ( "GET" ) . uri ( & uri) ;
479+
480+ if let Some ( auth_token) = & self . auth_token {
481+ req = req. header ( "Authorization" , auth_token) ;
482+ }
483+
484+ let req = req. body ( Body :: empty ( ) ) . expect ( "valid request" ) ;
485+
486+ let res = self
487+ . client
488+ . request ( req)
489+ . await
490+ . map_err ( SyncError :: HttpDispatch ) ?;
491+
492+ if !res. status ( ) . is_success ( ) {
493+ let status = res. status ( ) ;
494+ let body = hyper:: body:: to_bytes ( res. into_body ( ) )
495+ . await
496+ . map_err ( SyncError :: HttpBody ) ?;
497+ return Err (
498+ SyncError :: PullDb ( status, String :: from_utf8_lossy ( & body) . to_string ( ) ) . into ( ) ,
499+ ) ;
500+ }
501+
502+ let body = hyper:: body:: to_bytes ( res. into_body ( ) )
503+ . await
504+ . map_err ( SyncError :: HttpBody ) ?;
505+
506+ let info = serde_json:: from_slice ( & body) . map_err ( SyncError :: JsonDecode ) ?;
507+
508+ Ok ( info)
509+ }
510+
511+ async fn sync_db_if_needed ( & mut self , generation : u32 ) -> Result < ( ) > {
512+ // we will get the export file only if the remote generation is different from the one we have
513+ if generation == self . durable_generation {
514+ return Ok ( ( ) ) ;
515+ }
516+ // somehow we are ahead of the remote in generations. following should not happen because
517+ // we checkpoint only if the remote server tells us to do so.
518+ if self . durable_generation > generation {
519+ tracing:: error!(
520+ "server returned a lower generation than what we have: sent={}, got={}" ,
521+ self . durable_generation,
522+ generation
523+ ) ;
524+ return Err (
525+ SyncError :: InvalidLocalGeneration ( self . durable_generation , generation) . into ( ) ,
526+ ) ;
527+ }
528+ self . sync_db ( generation) . await
529+ }
530+
531+ /// sync_db will download the db file from the remote server and replace the local file.
532+ async fn sync_db ( & mut self , generation : u32 ) -> Result < ( ) > {
533+ let uri = format ! ( "{}/export/{}" , self . sync_url, generation) ;
534+ let mut req = http:: Request :: builder ( ) . method ( "GET" ) . uri ( & uri) ;
535+
536+ if let Some ( auth_token) = & self . auth_token {
537+ req = req. header ( "Authorization" , auth_token) ;
538+ }
539+
540+ let req = req. body ( Body :: empty ( ) ) . expect ( "valid request" ) ;
541+
542+ let res = self
543+ . client
544+ . request ( req)
545+ . await
546+ . map_err ( SyncError :: HttpDispatch ) ?;
547+
548+ if !res. status ( ) . is_success ( ) {
549+ let status = res. status ( ) ;
550+ let body = hyper:: body:: to_bytes ( res. into_body ( ) )
551+ . await
552+ . map_err ( SyncError :: HttpBody ) ?;
553+ return Err (
554+ SyncError :: PullFrame ( status, String :: from_utf8_lossy ( & body) . to_string ( ) ) . into ( ) ,
555+ ) ;
556+ }
557+
558+ // todo: do streaming write to the disk
559+ let bytes = hyper:: body:: to_bytes ( res. into_body ( ) )
560+ . await
561+ . map_err ( SyncError :: HttpBody ) ?;
562+
563+
564+ // since we are starting from new generation, we need to clean up and remove the old files
565+ // we will remove `.db-wal` and `.db-shm` manually. The `.db` file will be replaced by the
566+ // atomic write.
567+ let db_path_wal = format ! ( "{}-wal" , self . db_path) ;
568+ let db_path_shm = format ! ( "{}-shm" , self . db_path) ;
569+ tokio:: fs:: remove_file ( db_path_wal)
570+ . await
571+ . map_err ( SyncError :: io ( "remove wal file" ) ) ?;
572+ tokio:: fs:: remove_file ( db_path_shm)
573+ . await
574+ . map_err ( SyncError :: io ( "remove shm file" ) ) ?;
575+
576+ atomic_write ( & self . db_path , & bytes) . await ?;
577+ self . durable_generation = generation;
578+ self . durable_frame_num = 0 ;
579+ self . write_metadata ( ) . await ?;
580+ Ok ( ( ) )
581+ }
461582}
462583
463584#[ derive( serde:: Serialize , serde:: Deserialize , Debug ) ]
@@ -555,6 +676,19 @@ pub async fn sync_offline(
555676 Err ( e) => Err ( e) ,
556677 }
557678 } else {
679+ // todo: we are checking with the remote server only during initialisation. ideally,
680+ // we should check everytime we try to sync with the remote server. However, we need to close
681+ // all the ongoing connections since we replace `.db` file and remove the `.db-wal` file
682+ if !sync_ctx. initial_server_sync {
683+ // sync is being called first time. so we will call remote, get the generation information
684+ // if we are lagging behind, then we will call the export API and get to the latest
685+ // generation directly.
686+ let info = sync_ctx. get_remote_info ( ) . await ?;
687+ sync_ctx
688+ . sync_db_if_needed ( info. current_generation )
689+ . await ?;
690+ sync_ctx. initial_server_sync = true ;
691+ }
558692 try_pull ( sync_ctx, conn) . await
559693 }
560694 . or_else ( |err| {
0 commit comments