Skip to content

Commit 884de76

Browse files
committed
Add query timeout option to interrupt long-running queries
A single background tokio task with a min-heap manages all query deadlines efficiently. When a query starts, a TimeoutGuard is acquired; if the deadline expires before the guard is dropped, the connection is interrupted via sqlite3_interrupt().
1 parent 99ff1f9 commit 884de76

5 files changed

Lines changed: 283 additions & 5 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte
2222
- `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds.
2323
- `authToken`: authentication token for the provider URL (optional).
2424
- `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error
25+
- `queryTimeout`: maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error
2526

2627
The function returns a `Database` object.
2728

integration-tests/tests/async.test.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,35 @@ test.serial("Timeout option", async (t) => {
398398
fs.unlinkSync(path);
399399
});
400400

401+
test.serial("Query timeout option interrupts long-running query", async (t) => {
402+
const queryTimeout = 100;
403+
const path = genDatabaseFilename();
404+
const [db, errorType] = await connect(path, { queryTimeout });
405+
const stmt = await db.prepare(
406+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
407+
);
408+
409+
await t.throwsAsync(async () => {
410+
await stmt.all();
411+
}, {
412+
instanceOf: errorType,
413+
message: "interrupted",
414+
code: "SQLITE_INTERRUPT",
415+
});
416+
417+
db.close();
418+
fs.unlinkSync(path);
419+
});
420+
421+
test.serial("Query timeout option allows short-running query", async (t) => {
422+
const path = genDatabaseFilename();
423+
const [db] = await connect(path, { queryTimeout: 100 });
424+
const stmt = await db.prepare("SELECT 1 AS value");
425+
t.deepEqual(await stmt.get(), { value: 1 });
426+
db.close();
427+
fs.unlinkSync(path);
428+
});
429+
401430
test.serial("Concurrent writes over same connection", async (t) => {
402431
const db = t.context.db;
403432
await db.exec(`

integration-tests/tests/sync.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,44 @@ test.serial("Timeout option", async (t) => {
457457
fs.unlinkSync(path);
458458
});
459459

460+
test.serial("Query timeout option interrupts long-running query", async (t) => {
461+
if (t.context.provider === "sqlite") {
462+
t.assert(true);
463+
return;
464+
}
465+
466+
const path = genDatabaseFilename();
467+
const [db, errorType] = await connect(path, { queryTimeout: 100 });
468+
const stmt = db.prepare(
469+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
470+
);
471+
472+
t.throws(() => {
473+
stmt.all();
474+
}, {
475+
instanceOf: errorType,
476+
message: "interrupted",
477+
code: "SQLITE_INTERRUPT",
478+
});
479+
480+
db.close();
481+
fs.unlinkSync(path);
482+
});
483+
484+
test.serial("Query timeout option allows short-running query", async (t) => {
485+
if (t.context.provider === "sqlite") {
486+
t.assert(true);
487+
return;
488+
}
489+
490+
const path = genDatabaseFilename();
491+
const [db] = await connect(path, { queryTimeout: 100 });
492+
const stmt = db.prepare("SELECT 1 AS value");
493+
t.deepEqual(stmt.get(), { value: 1 });
494+
db.close();
495+
fs.unlinkSync(path);
496+
});
497+
460498
test.serial("Statement.reader [SELECT is true]", async (t) => {
461499
const db = t.context.db;
462500
const stmt = db.prepare("SELECT * FROM users WHERE id = ?");

src/lib.rs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
#![allow(deprecated)]
2222

2323
mod auth;
24+
mod query_timeout;
2425

2526
use napi::{
2627
bindgen_prelude::{Array, FromNapiValue, ToNapiValue},
2728
Env, JsUnknown, Result, ValueType,
2829
};
2930
use napi_derive::napi;
3031
use once_cell::sync::OnceCell;
32+
use query_timeout::{QueryTimeoutManager, TimeoutGuard};
3133
use std::{
3234
str::FromStr,
3335
sync::{
@@ -200,6 +202,8 @@ pub struct Options {
200202
pub encryptionKey: Option<String>,
201203
// Encryption key for remote encryption at rest.
202204
pub remoteEncryptionKey: Option<String>,
205+
// Maximum time in milliseconds that a query is allowed to run.
206+
pub queryTimeout: Option<f64>,
203207
}
204208

205209
/// Access mode.
@@ -224,6 +228,10 @@ pub struct Database {
224228
default_safe_integers: AtomicBool,
225229
// Whether to use memory-only mode.
226230
memory: bool,
231+
// Maximum time in milliseconds that a query is allowed to run.
232+
query_timeout: Option<Duration>,
233+
// Shared timeout manager for efficient query timeout handling.
234+
timeout_manager: Arc<QueryTimeoutManager>,
227235
}
228236

229237
impl Drop for Database {
@@ -321,11 +329,19 @@ pub async fn connect(path: String, opts: Option<Options>) -> Result<Database> {
321329
conn.busy_timeout(Duration::from_millis(timeout as u64))
322330
.map_err(Error::from)?
323331
}
332+
let query_timeout = opts
333+
.as_ref()
334+
.and_then(|o| o.queryTimeout)
335+
.filter(|&t| t > 0.0)
336+
.map(|t| Duration::from_millis(t as u64));
337+
let timeout_manager = Arc::new(QueryTimeoutManager::new());
324338
Ok(Database {
325339
db: Some(db),
326340
conn: Some(Arc::new(conn)),
327341
default_safe_integers,
328342
memory,
343+
query_timeout,
344+
timeout_manager,
329345
})
330346
}
331347

@@ -388,7 +404,13 @@ impl Database {
388404
pluck: false.into(),
389405
timing: false.into(),
390406
};
391-
Ok(Statement::new(conn, stmt, mode))
407+
Ok(Statement::new(
408+
conn,
409+
stmt,
410+
mode,
411+
self.query_timeout,
412+
self.timeout_manager.clone(),
413+
))
392414
}
393415

394416
/// Sets the authorizer for the database.
@@ -520,6 +542,9 @@ impl Database {
520542
));
521543
}
522544
};
545+
let _guard = self
546+
.query_timeout
547+
.map(|t| self.timeout_manager.register(&conn, t));
523548
conn.execute_batch(&sql).await.map_err(Error::from)?;
524549
Ok(())
525550
}
@@ -636,6 +661,10 @@ pub struct Statement {
636661
column_names: Vec<std::ffi::CString>,
637662
// The access mode.
638663
mode: AccessMode,
664+
// Maximum time in milliseconds that a query is allowed to run.
665+
query_timeout: Option<Duration>,
666+
// Shared timeout manager.
667+
timeout_manager: Arc<QueryTimeoutManager>,
639668
}
640669

641670
#[napi]
@@ -651,6 +680,8 @@ impl Statement {
651680
conn: Arc<libsql::Connection>,
652681
stmt: libsql::Statement,
653682
mode: AccessMode,
683+
query_timeout: Option<Duration>,
684+
timeout_manager: Arc<QueryTimeoutManager>,
654685
) -> Self {
655686
let column_names: Vec<std::ffi::CString> = stmt
656687
.columns()
@@ -663,6 +694,8 @@ impl Statement {
663694
stmt,
664695
column_names,
665696
mode,
697+
query_timeout,
698+
timeout_manager,
666699
}
667700
}
668701

@@ -679,8 +712,10 @@ impl Statement {
679712
let start = std::time::Instant::now();
680713
let stmt = self.stmt.clone();
681714
let conn = self.conn.clone();
715+
let guard = self.start_timeout_guard();
682716

683717
let future = async move {
718+
let _guard = guard;
684719
stmt.run(params).await.map_err(Error::from)?;
685720
let changes = if conn.total_changes() == total_changes_before {
686721
0
@@ -723,7 +758,9 @@ impl Statement {
723758
};
724759

725760
let stmt_fut = stmt.clone();
761+
let guard = self.start_timeout_guard();
726762
let future = async move {
763+
let _guard = guard;
727764
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
728765
let row = rows.next().await.map_err(Error::from)?;
729766
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
@@ -787,6 +824,7 @@ impl Statement {
787824
stmt.reset();
788825
let params = map_params(&stmt, params).unwrap();
789826
let stmt = self.stmt.clone();
827+
let guard = self.start_timeout_guard();
790828
let future = async move {
791829
let rows = stmt.query(params).await.map_err(Error::from)?;
792830
Ok::<_, napi::Error>(rows)
@@ -799,6 +837,7 @@ impl Statement {
799837
safe_ints,
800838
raw,
801839
pluck,
840+
guard,
802841
))
803842
})
804843
}
@@ -882,6 +921,13 @@ impl Statement {
882921
}
883922
}
884923

924+
impl Statement {
925+
fn start_timeout_guard(&self) -> Option<TimeoutGuard> {
926+
self.query_timeout
927+
.map(|t| self.timeout_manager.register(&self.conn, t))
928+
}
929+
}
930+
885931
/// Gets first row from statement in blocking mode.
886932
#[napi]
887933
pub fn statement_get_sync(
@@ -901,6 +947,7 @@ pub fn statement_get_sync(
901947
};
902948

903949
let rt = runtime()?;
950+
let _guard = stmt.start_timeout_guard();
904951
rt.block_on(async move {
905952
let params = map_params(&stmt.stmt, params)?;
906953
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
@@ -925,6 +972,7 @@ pub fn statement_get_sync(
925972
pub fn statement_run_sync(stmt: &Statement, params: Option<napi::JsUnknown>) -> Result<RunResult> {
926973
stmt.stmt.reset();
927974
let rt = runtime()?;
975+
let _guard = stmt.start_timeout_guard();
928976
rt.block_on(async move {
929977
let params = map_params(&stmt.stmt, params)?;
930978
let total_changes_before = stmt.conn.total_changes();
@@ -956,11 +1004,12 @@ pub fn statement_iterate_sync(
9561004
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
9571005
let raw = stmt.mode.raw.load(Ordering::SeqCst);
9581006
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
959-
let stmt = stmt.stmt.clone();
1007+
let guard = stmt.start_timeout_guard();
1008+
let inner_stmt = stmt.stmt.clone();
9601009
let (rows, column_names) = rt.block_on(async move {
961-
stmt.reset();
962-
let params = map_params(&stmt, params)?;
963-
let rows = stmt.query(params).await.map_err(Error::from)?;
1010+
inner_stmt.reset();
1011+
let params = map_params(&inner_stmt, params)?;
1012+
let rows = inner_stmt.query(params).await.map_err(Error::from)?;
9641013
let mut column_names = Vec::new();
9651014
for i in 0..rows.column_count() {
9661015
column_names
@@ -974,6 +1023,7 @@ pub fn statement_iterate_sync(
9741023
safe_ints,
9751024
raw,
9761025
pluck,
1026+
guard,
9771027
))
9781028
}
9791029

@@ -1120,6 +1170,7 @@ pub struct RowsIterator {
11201170
safe_ints: bool,
11211171
raw: bool,
11221172
pluck: bool,
1173+
_timeout_guard: Option<TimeoutGuard>,
11231174
}
11241175

11251176
#[napi]
@@ -1130,13 +1181,15 @@ impl RowsIterator {
11301181
safe_ints: bool,
11311182
raw: bool,
11321183
pluck: bool,
1184+
timeout_guard: Option<TimeoutGuard>,
11331185
) -> Self {
11341186
Self {
11351187
rows,
11361188
column_names,
11371189
safe_ints,
11381190
raw,
11391191
pluck,
1192+
_timeout_guard: timeout_guard,
11401193
}
11411194
}
11421195

0 commit comments

Comments
 (0)