diff --git a/integration-tests/package.json b/integration-tests/package.json index 8fc0a998..ef87633d 100644 --- a/integration-tests/package.json +++ b/integration-tests/package.json @@ -3,7 +3,7 @@ "type": "module", "private": true, "scripts": { - "test": "cross-env PROVIDER=sqlite ava tests/sync.test.js && cross-env LIBSQL_JS_DEV=1 PROVIDER=libsql ava tests/sync.test.js && cross-env LIBSQL_JS_DEV=1 ava tests/async.test.js && cross-env LIBSQL_JS_DEV=1 ava tests/extensions.test.js" + "test": "cross-env PROVIDER=sqlite ava tests/sync.test.js && cross-env LIBSQL_JS_DEV=1 PROVIDER=libsql ava tests/sync.test.js && cross-env LIBSQL_JS_DEV=1 ava tests/async.test.js && cross-env LIBSQL_JS_DEV=1 ava tests/extensions.test.js ava tests/concurrency.test.js" }, "devDependencies": { "ava": "^5.3.0", diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index 1750b3ab..7eaafd06 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -50,6 +50,11 @@ test.serial("Statement.run() [positional]", async (t) => { const info = stmt.run(["Carol", "carol@example.net"]); t.is(info.changes, 1); t.is(info.lastInsertRowid, 3); + + // Verify that the data is inserted + const stmt2 = await db.prepare("SELECT * FROM users WHERE id = 3"); + t.is(stmt2.get().name, "Carol"); + t.is(stmt2.get().email, "carol@example.net"); }); test.serial("Statement.get() returns no rows", async (t) => { @@ -315,7 +320,7 @@ test.serial("errors", async (t) => { test.serial("Database.prepare() after close()", async (t) => { const db = t.context.db; - await db.close(); + db.close(); await t.throwsAsync(async () => { await db.prepare("SELECT 1"); }, { @@ -326,7 +331,7 @@ test.serial("Database.prepare() after close()", async (t) => { test.serial("Database.exec() after close()", async (t) => { const db = t.context.db; - await db.close(); + db.close(); await t.throwsAsync(async () => { await db.exec("SELECT 1"); }, { diff --git a/integration-tests/tests/concurrency.test.js b/integration-tests/tests/concurrency.test.js new file mode 100644 index 00000000..22816063 --- /dev/null +++ b/integration-tests/tests/concurrency.test.js @@ -0,0 +1,207 @@ +import test from "ava"; +import crypto from 'crypto'; +import fs from 'fs'; + +test.beforeEach(async (t) => { + const [db, errorType, path] = await connect(); + + await db.exec(` + DROP TABLE IF EXISTS users; + CREATE TABLE users (id TEXT PRIMARY KEY, name TEXT, email TEXT) + `); + const aliceId = generateUUID(); + const bobId = generateUUID(); + await db.exec( + `INSERT INTO users (id, name, email) VALUES ('${aliceId}', 'Alice', 'alice@example.org')` + ); + await db.exec( + `INSERT INTO users (id, name, email) VALUES ('${bobId}', 'Bob', 'bob@example.com')` + ); + t.context = { + db, + errorType, + aliceId, + bobId, + path + }; +}); + +test("Concurrent reads", async (t) => { + const db = t.context.db; + const stmt = await db.prepare("SELECT * FROM users WHERE id = ?"); + + const promises = []; + for (let i = 0; i < 100; i++) { + promises.push(stmt.get(t.context.aliceId)); + promises.push(stmt.get(t.context.bobId)); + } + + const results = await Promise.all(promises); + + for (let i = 0; i < results.length; i++) { + const result = results[i]; + t.truthy(result); + t.is(typeof result.name, 'string'); + t.is(typeof result.email, 'string'); + } + cleanup(t.context); +}); + +test("Concurrent writes", async (t) => { + const db = t.context.db; + + await db.exec(` + DROP TABLE IF EXISTS concurrent_users; + CREATE TABLE concurrent_users ( + id TEXT PRIMARY KEY, + name TEXT, + email TEXT + ) + `); + + const stmt = await db.prepare("INSERT INTO concurrent_users(id, name, email) VALUES (:id, :name, :email)"); + + const promises = []; + for (let i = 0; i < 50; i++) { + promises.push(stmt.run({ + id: generateUUID(), + name: `User${i}`, + email: `user${i}@example.com` + })); + } + + await Promise.all(promises); + + const countStmt = await db.prepare("SELECT COUNT(*) as count FROM concurrent_users"); + const result = await countStmt.get(); + t.is(result.count, 50); + + cleanup(t.context); +}); + +test("Concurrent transaction isolation", async (t) => { + const db = t.context.db; + + await db.exec(` + DROP TABLE IF EXISTS transaction_users; + CREATE TABLE transaction_users ( + id TEXT PRIMARY KEY, + name TEXT, + email TEXT + ) + `); + + const aliceId = generateUUID(); + const bobId = generateUUID(); + + await db.exec(` + INSERT INTO transaction_users (id, name, email) VALUES + ('${aliceId}', 'Alice', 'alice@example.org'), + ('${bobId}', 'Bob', 'bob@example.com') + `); + + const updateUser = db.transaction(async (id, name, email) => { + const stmt = await db.prepare("UPDATE transaction_users SET name = :name, email = :email WHERE id = :id"); + await stmt.run({ id, name, email }); + }); + + const promises = []; + for (let i = 0; i < 10; i++) { + promises.push(updateUser(aliceId, `Alice${i}`, `alice${i}@example.org`)); + promises.push(updateUser(bobId, `Bob${i}`, `bob${i}@example.com`)); + } + + await Promise.all(promises); + + const stmt = await db.prepare("SELECT * FROM transaction_users ORDER BY name"); + const results = await stmt.all(); + t.is(results.length, 2); + t.truthy(results[0].name.startsWith('Alice')); + t.truthy(results[1].name.startsWith('Bob')); + + cleanup(t.context); +}); + +test("Concurrent reads and writes", async (t) => { + const db = t.context.db; + + await db.exec(` + DROP TABLE IF EXISTS mixed_users; + CREATE TABLE mixed_users ( + id TEXT PRIMARY KEY, + name TEXT, + email TEXT + ) + `); + + const aliceId = generateUUID(); + await db.exec(` + INSERT INTO mixed_users (id, name, email) VALUES + ('${aliceId}', 'Alice', 'alice@example.org') + `); + + const readStmt = await db.prepare("SELECT * FROM mixed_users WHERE id = ?"); + const writeStmt = await db.prepare("INSERT INTO mixed_users(id, name, email) VALUES (:id, :name, :email)"); + + const promises = []; + for (let i = 0; i < 20; i++) { + promises.push(readStmt.get(aliceId)); + writeStmt.run({ + id: generateUUID(), + name: `User${i}`, + email: `user${i}@example.com` + }); + } + await Promise.all(promises); + + const countStmt = await db.prepare("SELECT COUNT(*) as count FROM mixed_users"); + const result = await countStmt.get(); + t.is(result.count, 21); // 1 initial + 20 new records + + await cleanup(t.context); +}); + +test("Concurrent operations with timeout should handle busy database", async (t) => { + const timeout = 1000; + const path = `test-${crypto.randomBytes(8).toString('hex')}.db`; + const [conn1] = await connect(path); + const [conn2] = await connect(path, { timeout }); + + await conn1.exec("CREATE TABLE t(id TEXT PRIMARY KEY, x INTEGER)"); + await conn1.exec("BEGIN IMMEDIATE"); + await conn1.exec(`INSERT INTO t VALUES ('${generateUUID()}', 1)`); + + const start = Date.now(); + try { + await conn2.exec(`INSERT INTO t VALUES ('${generateUUID()}', 2)`); + t.fail("Should have thrown SQLITE_BUSY error"); + } catch (e) { + t.is(e.code, "SQLITE_BUSY"); + const end = Date.now(); + const elapsed = end - start; + t.true(elapsed > timeout / 2, "Timeout should be respected"); + } + + conn1.close(); + conn2.close(); + // FIXME: Fails on Windows because file is still busy. + // fs.unlinkSync(path); +}); + + +const connect = async (path_opt, options = {}) => { + const path = path_opt ?? `test-${crypto.randomBytes(8).toString('hex')}.db`; + const x = await import("libsql/promise"); + const db = new x.default(process.env.LIBSQL_DATABASE ?? path, options); + return [db, x.SqliteError, path]; +}; + +const cleanup = async (context) => { + context.db.close(); + // FIXME: Fails on Windows because file is still busy. + // fs.unlinkSync(context.path); +}; + +const generateUUID = () => { + return crypto.randomUUID(); +}; diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js index 06a1c79b..a5add279 100644 --- a/integration-tests/tests/sync.test.js +++ b/integration-tests/tests/sync.test.js @@ -58,6 +58,11 @@ test.serial("Statement.run() [positional]", async (t) => { const info = stmt.run(["Carol", "carol@example.net"]); t.is(info.changes, 1); t.is(info.lastInsertRowid, 3); + + // Verify that the data is inserted + const stmt2 = db.prepare("SELECT * FROM users WHERE id = 3"); + t.is(stmt2.get().name, "Carol"); + t.is(stmt2.get().email, "carol@example.net"); }); test.serial("Statement.run() [named]", async (t) => { diff --git a/src/lib.rs b/src/lib.rs index 1b0b98ca..8d831a9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -654,9 +654,8 @@ impl Statement { let start = std::time::Instant::now(); let mut stmt = self.stmt.lock().await; - stmt.reset(); let params = map_params(&stmt, params)?; - stmt.query(params).await.map_err(Error::from)?; + stmt.run(params).await.map_err(Error::from)?; let changes = if conn.total_changes() == total_changes_before { 0 } else { @@ -664,6 +663,7 @@ impl Statement { }; let last_insert_row_id = conn.last_insert_rowid(); let duration = start.elapsed().as_secs_f64(); + stmt.reset(); Ok(RunResult { changes: changes as f64, duration,