diff --git a/libsql/src/sync/connection.rs b/libsql/src/sync/connection.rs index 815226657c..dbd85d0cdd 100644 --- a/libsql/src/sync/connection.rs +++ b/libsql/src/sync/connection.rs @@ -3,6 +3,7 @@ use crate::{ hrana::{connection::HttpConnection, hyper::HttpSender}, local::{self, impls::LibsqlStmt}, params::Params, + parser, replication::connection::State, sync::SyncContext, BatchRows, Error, Result, Statement, Transaction, TransactionBehavior, @@ -33,7 +34,66 @@ impl SyncedConnection { let mut state = self.state.lock().await; - crate::replication::connection::should_execute_local(&mut state, stmts.as_slice()) + if !self.remote.is_autocommit() { + *state = State::Txn; + } + + { + let predicted_end_state = { + let mut state = state.clone(); + + stmts.iter().for_each(|parser::Statement { kind, .. }| { + state = state.step(*kind); + }); + + state + }; + + let should_execute_local = match (*state, predicted_end_state) { + (State::Init, State::Init) => stmts.iter().all(parser::Statement::is_read_only), + + (State::Init, State::TxnReadOnly) | (State::TxnReadOnly, State::TxnReadOnly) => { + let is_read_only = stmts.iter().all(parser::Statement::is_read_only); + + if !is_read_only { + return Err(Error::Misuse( + "Invalid write in a readonly transaction".into(), + )); + } + + *state = State::TxnReadOnly; + true + } + + (State::TxnReadOnly, State::Init) => { + let is_read_only = stmts.iter().all(parser::Statement::is_read_only); + + if !is_read_only { + return Err(Error::Misuse( + "Invalid write in a readonly transaction".into(), + )); + } + + *state = State::Init; + true + } + + (init, State::Invalid) => { + let err = Err(Error::InvalidParserState(format!("{:?}", init))); + + // Reset state always back to init so the user can start over + *state = State::Init; + + return err; + } + _ => { + *state = predicted_end_state; + false + }, + }; + + Ok(should_execute_local) + } } }