diff --git a/Cargo.toml b/Cargo.toml index 9d00049..29b7a8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -libsql = { version = "0.9.18", features = ["encryption"] } +libsql = { version = "0.9.20", features = ["encryption"] } napi = { version = "2", default-features = false, features = ["napi6", "tokio_rt", "async"] } napi-derive = "2" once_cell = "1.18.0" diff --git a/index.d.ts b/index.d.ts index c4834ac..6dc7609 100644 --- a/index.d.ts +++ b/index.d.ts @@ -25,6 +25,12 @@ export interface SyncResult { export declare function databasePrepareSync(db: Database, sql: string): Statement /** Syncs the database in blocking mode. */ export declare function databaseSyncSync(db: Database): SyncResult +/** Executes SQL in blocking mode. */ +export declare function databaseExecSync(db: Database, sql: string): void +/** Gets first row from statement in blocking mode. */ +export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null): unknown +/** Runs a statement in blocking mode. */ +export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null): RunResult export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null): RowsIterator /** SQLite `run()` result object */ export interface RunResult { @@ -109,7 +115,7 @@ export declare class Database { * * `env` - The environment. * * `sql` - The SQL statement to execute. */ - exec(sql: string): void + exec(sql: string): Promise /** * Syncs the database. * @@ -146,7 +152,7 @@ export declare class Statement { * * * `params` - The parameters to bind to the statement. */ - run(params?: unknown | undefined | null): RunResult + run(params?: unknown | undefined | null): object /** * Executes a SQL statement and returns the first row. * @@ -155,7 +161,7 @@ export declare class Statement { * * `env` - The environment. * * `params` - The parameters to bind to the statement. */ - get(params?: unknown | undefined | null): unknown + get(params?: unknown | undefined | null): object /** * Create an iterator over the rows of a statement. * diff --git a/index.js b/index.js index 4af5a23..95b3e00 100644 --- a/index.js +++ b/index.js @@ -310,12 +310,15 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { Database, databasePrepareSync, databaseSyncSync, Statement, statementIterateSync, RowsIterator, iteratorNextSync, Record } = nativeBinding +const { Database, databasePrepareSync, databaseSyncSync, databaseExecSync, Statement, statementGetSync, statementRunSync, statementIterateSync, RowsIterator, iteratorNextSync, Record } = nativeBinding module.exports.Database = Database module.exports.databasePrepareSync = databasePrepareSync module.exports.databaseSyncSync = databaseSyncSync +module.exports.databaseExecSync = databaseExecSync module.exports.Statement = Statement +module.exports.statementGetSync = statementGetSync +module.exports.statementRunSync = statementRunSync module.exports.statementIterateSync = statementIterateSync module.exports.RowsIterator = RowsIterator module.exports.iteratorNextSync = iteratorNextSync diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index 7eaafd0..84f77a2 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -47,21 +47,21 @@ test.serial("Statement.run() [positional]", async (t) => { const db = t.context.db; const stmt = await db.prepare("INSERT INTO users(name, email) VALUES (?, ?)"); - const info = stmt.run(["Carol", "carol@example.net"]); + const info = await stmt.run(["Carol", "carol@example.net"]); t.is(info.changes, 1); t.is(info.lastInsertRowid, 3); // Verify that the data is inserted const stmt2 = await db.prepare("SELECT * FROM users WHERE id = 3"); - t.is(stmt2.get().name, "Carol"); - t.is(stmt2.get().email, "carol@example.net"); + t.is((await stmt2.get()).name, "Carol"); + t.is((await stmt2.get()).email, "carol@example.net"); }); test.serial("Statement.get() returns no rows", async (t) => { const db = t.context.db; const stmt = await db.prepare("SELECT * FROM users WHERE id = 0"); - t.is(stmt.get(), undefined); + t.is((await stmt.get()), undefined); }); test.serial("Statement.get() [no parameters]", async (t) => { @@ -70,7 +70,7 @@ test.serial("Statement.get() [no parameters]", async (t) => { var stmt = 0; stmt = await db.prepare("SELECT * FROM users"); - t.is(stmt.get().name, "Alice"); + t.is((await stmt.get()).name, "Alice"); t.deepEqual(await stmt.raw().get(), [1, 'Alice', 'alice@example.org']); }); @@ -80,15 +80,15 @@ test.serial("Statement.get() [positional]", async (t) => { var stmt = 0; stmt = await db.prepare("SELECT * FROM users WHERE id = ?"); - t.is(stmt.get(0), undefined); - t.is(stmt.get([0]), undefined); - t.is(stmt.get(1).name, "Alice"); - t.is(stmt.get(2).name, "Bob"); + t.is((await stmt.get(0)), undefined); + t.is((await stmt.get([0])), undefined); + t.is((await stmt.get(1)).name, "Alice"); + t.is((await stmt.get(2)).name, "Bob"); stmt = await db.prepare("SELECT * FROM users WHERE id = ?1"); - t.is(stmt.get({1: 0}), undefined); - t.is(stmt.get({1: 1}).name, "Alice"); - t.is(stmt.get({1: 2}).name, "Bob"); + t.is((await stmt.get({1: 0})), undefined); + t.is((await stmt.get({1: 1})).name, "Alice"); + t.is((await stmt.get({1: 2})).name, "Bob"); }); test.serial("Statement.get() [named]", async (t) => { @@ -97,19 +97,19 @@ test.serial("Statement.get() [named]", async (t) => { var stmt = undefined; stmt = await db.prepare("SELECT * FROM users WHERE id = :id"); - t.is(stmt.get({ id: 0 }), undefined); - t.is(stmt.get({ id: 1 }).name, "Alice"); - t.is(stmt.get({ id: 2 }).name, "Bob"); + t.is((await stmt.get({ id: 0 })), undefined); + t.is((await stmt.get({ id: 1 })).name, "Alice"); + t.is((await stmt.get({ id: 2 })).name, "Bob"); stmt = await db.prepare("SELECT * FROM users WHERE id = @id"); - t.is(stmt.get({ id: 0 }), undefined); - t.is(stmt.get({ id: 1 }).name, "Alice"); - t.is(stmt.get({ id: 2 }).name, "Bob"); + t.is((await stmt.get({ id: 0 })), undefined); + t.is((await stmt.get({ id: 1 })).name, "Alice"); + t.is((await stmt.get({ id: 2 })).name, "Bob"); stmt = await db.prepare("SELECT * FROM users WHERE id = $id"); - t.is(stmt.get({ id: 0 }), undefined); - t.is(stmt.get({ id: 1 }).name, "Alice"); - t.is(stmt.get({ id: 2 }).name, "Bob"); + t.is((await stmt.get({ id: 0 })), undefined); + t.is((await stmt.get({ id: 1 })).name, "Alice"); + t.is((await stmt.get({ id: 2 })).name, "Bob"); }); @@ -117,7 +117,7 @@ test.serial("Statement.get() [raw]", async (t) => { const db = t.context.db; const stmt = await db.prepare("SELECT * FROM users WHERE id = ?"); - t.deepEqual(stmt.raw().get(1), [1, "Alice", "alice@example.org"]); + t.deepEqual(await stmt.raw().get(1), [1, "Alice", "alice@example.org"]); }); test.serial("Statement.iterate() [empty]", async (t) => { @@ -253,9 +253,9 @@ test.serial("Database.transaction()", async (t) => { "INSERT INTO users(name, email) VALUES (:name, :email)" ); - const insertMany = db.transaction((users) => { + const insertMany = db.transaction(async (users) => { t.is(db.inTransaction, true); - for (const user of users) insert.run(user); + for (const user of users) await insert.run(user); }); t.is(db.inTransaction, false); @@ -267,9 +267,9 @@ test.serial("Database.transaction()", async (t) => { t.is(db.inTransaction, false); const stmt = await db.prepare("SELECT * FROM users WHERE id = ?"); - t.is(stmt.get(3).name, "Joey"); - t.is(stmt.get(4).name, "Sally"); - t.is(stmt.get(5).name, "Junior"); + t.is((await stmt.get(3)).name, "Joey"); + t.is((await stmt.get(4)).name, "Sally"); + t.is((await stmt.get(5)).name, "Junior"); }); test.serial("Database.transaction().immediate()", async (t) => { @@ -277,9 +277,9 @@ test.serial("Database.transaction().immediate()", async (t) => { const insert = await db.prepare( "INSERT INTO users(name, email) VALUES (:name, :email)" ); - const insertMany = db.transaction((users) => { + const insertMany = db.transaction(async (users) => { t.is(db.inTransaction, true); - for (const user of users) insert.run(user); + for (const user of users) await insert.run(user); }); t.is(db.inTransaction, false); await insertMany.immediate([ diff --git a/integration-tests/tests/concurrency.test.js b/integration-tests/tests/concurrency.test.js index 2281606..d4d1bac 100644 --- a/integration-tests/tests/concurrency.test.js +++ b/integration-tests/tests/concurrency.test.js @@ -32,8 +32,8 @@ test("Concurrent reads", async (t) => { const promises = []; for (let i = 0; i < 100; i++) { - promises.push(stmt.get(t.context.aliceId)); - promises.push(stmt.get(t.context.bobId)); + promises.push(await stmt.get(t.context.aliceId)); + promises.push(await stmt.get(t.context.bobId)); } const results = await Promise.all(promises); @@ -63,7 +63,7 @@ test("Concurrent writes", async (t) => { const promises = []; for (let i = 0; i < 50; i++) { - promises.push(stmt.run({ + promises.push(await stmt.run({ id: generateUUID(), name: `User${i}`, email: `user${i}@example.com` @@ -79,49 +79,6 @@ test("Concurrent writes", async (t) => { cleanup(t.context); }); -test("Concurrent transaction isolation", async (t) => { - const db = t.context.db; - - await db.exec(` - DROP TABLE IF EXISTS transaction_users; - CREATE TABLE transaction_users ( - id TEXT PRIMARY KEY, - name TEXT, - email TEXT - ) - `); - - const aliceId = generateUUID(); - const bobId = generateUUID(); - - await db.exec(` - INSERT INTO transaction_users (id, name, email) VALUES - ('${aliceId}', 'Alice', 'alice@example.org'), - ('${bobId}', 'Bob', 'bob@example.com') - `); - - const updateUser = db.transaction(async (id, name, email) => { - const stmt = await db.prepare("UPDATE transaction_users SET name = :name, email = :email WHERE id = :id"); - await stmt.run({ id, name, email }); - }); - - const promises = []; - for (let i = 0; i < 10; i++) { - promises.push(updateUser(aliceId, `Alice${i}`, `alice${i}@example.org`)); - promises.push(updateUser(bobId, `Bob${i}`, `bob${i}@example.com`)); - } - - await Promise.all(promises); - - const stmt = await db.prepare("SELECT * FROM transaction_users ORDER BY name"); - const results = await stmt.all(); - t.is(results.length, 2); - t.truthy(results[0].name.startsWith('Alice')); - t.truthy(results[1].name.startsWith('Bob')); - - cleanup(t.context); -}); - test("Concurrent reads and writes", async (t) => { const db = t.context.db; @@ -146,7 +103,7 @@ test("Concurrent reads and writes", async (t) => { const promises = []; for (let i = 0; i < 20; i++) { promises.push(readStmt.get(aliceId)); - writeStmt.run({ + await writeStmt.run({ id: generateUUID(), name: `User${i}`, email: `user${i}@example.com` diff --git a/promise.js b/promise.js index 80b44d8..37b12e0 100644 --- a/promise.js +++ b/promise.js @@ -88,14 +88,14 @@ class Database { const db = this; const wrapTxn = (mode) => { - return (...bindParameters) => { - db.exec("BEGIN " + mode); + return async (...bindParameters) => { + await db.exec("BEGIN " + mode); try { - const result = fn(...bindParameters); - db.exec("COMMIT"); + const result = await fn(...bindParameters); + await db.exec("COMMIT"); return result; } catch (err) { - db.exec("ROLLBACK"); + await db.exec("ROLLBACK"); throw err; } }; @@ -172,9 +172,9 @@ class Database { * * @param {string} sql - The SQL statement string to execute. */ - exec(sql) { + async exec(sql) { try { - this.db.exec(sql); + await this.db.exec(sql); } catch (err) { throw convertError(err); } @@ -257,9 +257,9 @@ class Statement { /** * Executes the SQL statement and returns an info object. */ - run(...bindParameters) { + async run(...bindParameters) { try { - return this.stmt.run(...bindParameters); + return await this.stmt.run(...bindParameters); } catch (err) { throw convertError(err); } @@ -270,9 +270,9 @@ class Statement { * * @param bindParameters - The bind parameters for executing the statement. */ - get(...bindParameters) { + async get(...bindParameters) { try { - return this.stmt.get(...bindParameters); + return await this.stmt.get(...bindParameters); } catch (err) { throw convertError(err); } diff --git a/src/lib.rs b/src/lib.rs index 20db556..332b055 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,7 +36,7 @@ use std::{ }, time::Duration, }; -use tokio::{runtime::Runtime, sync::Mutex}; +use tokio::runtime::Runtime; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; struct Error(libsql::Error); @@ -219,7 +219,7 @@ pub struct Database { // The libSQL database instance. db: libsql::Database, // The libSQL connection instance. - conn: Option>>, + conn: Option>, // Whether to use safe integers by default. default_safe_integers: AtomicBool, // Whether to use memory-only mode. @@ -334,7 +334,7 @@ impl Database { } Ok(Database { db, - conn: Some(Arc::new(Mutex::new(conn))), + conn: Some(Arc::new(conn)), default_safe_integers, memory, }) @@ -349,15 +349,11 @@ impl Database { /// Returns whether the database is in a transaction. #[napi] pub fn in_transaction(&self) -> Result { - let rt = runtime()?; let conn = match &self.conn { Some(conn) => conn.clone(), None => return Ok(false), }; - Ok(rt.block_on(async move { - let conn = conn.lock().await; - !conn.is_autocommit() - })) + Ok(!conn.is_autocommit()) } /// Prepares a statement for execution. @@ -381,10 +377,7 @@ impl Database { )); } }; - let stmt = { - let conn = conn.lock().await; - conn.prepare(&sql).await.map_err(Error::from)? - }; + let stmt = { conn.prepare(&sql).await.map_err(Error::from)? }; let mode = AccessMode { safe_ints: self.default_safe_integers.load(Ordering::SeqCst).into(), raw: false.into(), @@ -453,10 +446,7 @@ impl Database { let auth_arc = auth_arc.clone(); move |ctx: &libsql::AuthContext| auth_arc.authorize(ctx) }; - let rt = runtime()?; - let guard_conn = rt.block_on(async { conn.lock().await }); - guard_conn - .authorizer(Some(std::sync::Arc::new(closure))) + conn.authorizer(Some(std::sync::Arc::new(closure))) .map_err(Error::from)?; Ok(()) } @@ -482,7 +472,6 @@ impl Database { } }; rt.block_on(async move { - let conn = conn.lock().await; conn.load_extension_enable().map_err(Error::from)?; if let Err(err) = conn.load_extension(&path, entry_point.as_deref()) { let _ = conn.load_extension_disable(); @@ -512,17 +501,18 @@ impl Database { /// * `env` - The environment. /// * `sql` - The SQL statement to execute. #[napi] - pub fn exec(&self, env: Env, sql: String) -> Result<()> { - let rt = runtime()?; + pub async fn exec(&self, sql: String) -> Result<()> { let conn = match &self.conn { Some(conn) => conn.clone(), - None => return Err(throw_database_closed_error(&env).into()), + None => { + return Err(throw_sqlite_error( + "The database connection is not open".to_string(), + "SQLITE_NOTOPEN".to_string(), + 0, + )); + } }; - rt.block_on(async move { - let conn = conn.lock().await; - conn.execute_batch(&sql).await - }) - .map_err(Error::from)?; + conn.execute_batch(&sql).await.map_err(Error::from)?; Ok(()) } @@ -547,16 +537,11 @@ impl Database { /// * `env` - The environment. #[napi] pub fn interrupt(&self, env: Env) -> Result<()> { - let rt = runtime()?; let conn = match &self.conn { Some(conn) => conn.clone(), None => return Err(throw_database_closed_error(&env).into()), }; - rt.block_on(async move { - let conn = conn.lock().await; - conn.interrupt() - }) - .map_err(Error::from)?; + conn.interrupt().map_err(Error::from)?; Ok(()) } @@ -603,6 +588,13 @@ pub fn database_sync_sync(db: &Database) -> Result { rt.block_on(async move { db.sync().await }) } +/// Executes SQL in blocking mode. +#[napi] +pub fn database_exec_sync(db: &Database, sql: String) -> Result<()> { + let rt = runtime()?; + rt.block_on(async move { db.exec(sql).await }) +} + fn is_remote_path(path: &str) -> bool { path.starts_with("libsql://") || path.starts_with("http://") || path.starts_with("https://") } @@ -618,9 +610,9 @@ fn throw_database_closed_error(env: &Env) -> napi::Error { #[napi] pub struct Statement { // The libSQL connection instance. - conn: Arc>, + conn: Arc, // The libSQL statement instance. - stmt: Arc>, + stmt: Arc, // The column names. column_names: Vec, // The access mode. @@ -637,7 +629,7 @@ impl Statement { /// * `stmt` - The libSQL statement instance. /// * `mode` - The access mode. pub(crate) fn new( - conn: Arc>, + conn: Arc, stmt: libsql::Statement, mode: AccessMode, ) -> Self { @@ -646,7 +638,7 @@ impl Statement { .iter() .map(|c| std::ffi::CString::new(c.name().to_string()).unwrap()) .collect(); - let stmt = Arc::new(tokio::sync::Mutex::new(stmt)); + let stmt = Arc::new(stmt); Self { conn, stmt, @@ -661,15 +653,15 @@ impl Statement { /// /// * `params` - The parameters to bind to the statement. #[napi] - pub fn run(&self, params: Option) -> Result { - let rt = runtime()?; - rt.block_on(async move { - let conn = self.conn.lock().await; - let total_changes_before = conn.total_changes(); - let start = std::time::Instant::now(); + pub fn run(&self, env: Env, params: Option) -> Result { + self.stmt.reset(); + let params = map_params(&self.stmt, params)?; + let total_changes_before = self.conn.total_changes(); + let start = std::time::Instant::now(); + let stmt = self.stmt.clone(); + let conn = self.conn.clone(); - let mut stmt = self.stmt.lock().await; - let params = map_params(&stmt, params)?; + let future = async move { stmt.run(params).await.map_err(Error::from)?; let changes = if conn.total_changes() == total_changes_before { 0 @@ -678,13 +670,14 @@ impl Statement { }; let last_insert_row_id = conn.last_insert_rowid(); let duration = start.elapsed().as_secs_f64(); - stmt.reset(); Ok(RunResult { changes: changes as f64, duration, lastInsertRowid: last_insert_row_id, }) - }) + }; + + env.execute_tokio_future(future, move |&mut _env, result| Ok(result)) } /// Executes a SQL statement and returns the first row. @@ -694,36 +687,35 @@ impl Statement { /// * `env` - The environment. /// * `params` - The parameters to bind to the statement. #[napi] - pub fn get(&self, env: Env, params: Option) -> Result { - let rt = runtime()?; - + pub fn get(&self, env: Env, params: Option) -> Result { let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); let timed = self.mode.timing.load(Ordering::SeqCst); + let params = map_params(&self.stmt, params)?; + let stmt = self.stmt.clone(); + let column_names = self.column_names.clone(); + let start = if timed { Some(std::time::Instant::now()) } else { None }; - rt.block_on(async move { - let mut stmt = self.stmt.lock().await; - let params = map_params(&stmt, params)?; - let mut rows = stmt.query(params).await.map_err(Error::from)?; + + let stmt_fut = stmt.clone(); + let future = async move { + let mut rows = stmt_fut.query(params).await.map_err(Error::from)?; let row = rows.next().await.map_err(Error::from)?; let duration: Option = start.map(|start| start.elapsed().as_secs_f64()); - let result = Self::get_internal( - &env, - &row, - &self.column_names, - safe_ints, - raw, - pluck, - duration, - ); + Ok((row, duration)) + }; + + env.execute_tokio_future(future, move |&mut env, (row, duration)| { + let result = + Self::get_internal(&env, &row, &column_names, safe_ints, raw, pluck, duration); stmt.reset(); - result + Ok(result) }) } @@ -769,22 +761,15 @@ impl Statement { /// * `params` - The parameters to bind to the statement. #[napi] pub fn iterate(&self, env: Env, params: Option) -> Result { - let rt = runtime()?; let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); let stmt = self.stmt.clone(); - let params = { - let stmt = stmt.clone(); - rt.block_on(async move { - let mut stmt = stmt.lock().await; - stmt.reset(); - map_params(&stmt, params).unwrap() - }) - }; + stmt.reset(); + let params = map_params(&stmt, params).unwrap(); let stmt = self.stmt.clone(); let future = async move { - let rows = stmt.lock().await.query(params).await.map_err(Error::from)?; + let rows = stmt.query(params).await.map_err(Error::from)?; Ok::<_, napi::Error>(rows) }; let column_names = self.column_names.clone(); @@ -801,11 +786,7 @@ impl Statement { #[napi] pub fn raw(&self, raw: Option) -> Result<&Self> { - let rt = runtime()?; - let returns_data = rt.block_on(async move { - let stmt = self.stmt.lock().await; - !stmt.columns().is_empty() - }); + let returns_data = !self.stmt.columns().is_empty(); if !returns_data { return Err(napi::Error::from_reason( "The raw() method is only for statements that return data", @@ -833,42 +814,38 @@ impl Statement { #[napi] pub fn columns(&self, env: Env) -> Result { - let rt = runtime()?; - rt.block_on(async move { - let stmt = self.stmt.lock().await; - let columns = stmt.columns(); - let mut js_array = env.create_array(columns.len() as u32)?; - for (i, col) in columns.iter().enumerate() { - let mut js_obj = env.create_object()?; - js_obj.set_named_property("name", env.create_string(col.name())?)?; - // origin_name -> column - if let Some(origin_name) = col.origin_name() { - js_obj.set_named_property("column", env.create_string(origin_name)?)?; - } else { - js_obj.set_named_property("column", env.get_null()?)?; - } - // table_name -> table - if let Some(table_name) = col.table_name() { - js_obj.set_named_property("table", env.create_string(table_name)?)?; - } else { - js_obj.set_named_property("table", env.get_null()?)?; - } - // database_name -> database - if let Some(database_name) = col.database_name() { - js_obj.set_named_property("database", env.create_string(database_name)?)?; - } else { - js_obj.set_named_property("database", env.get_null()?)?; - } - // decl_type -> type - if let Some(decl_type) = col.decl_type() { - js_obj.set_named_property("type", env.create_string(decl_type)?)?; - } else { - js_obj.set_named_property("type", env.get_null()?)?; - } - js_array.set(i as u32, js_obj)?; + let columns = self.stmt.columns(); + let mut js_array = env.create_array(columns.len() as u32)?; + for (i, col) in columns.iter().enumerate() { + let mut js_obj = env.create_object()?; + js_obj.set_named_property("name", env.create_string(col.name())?)?; + // origin_name -> column + if let Some(origin_name) = col.origin_name() { + js_obj.set_named_property("column", env.create_string(origin_name)?)?; + } else { + js_obj.set_named_property("column", env.get_null()?)?; } - Ok(js_array) - }) + // table_name -> table + if let Some(table_name) = col.table_name() { + js_obj.set_named_property("table", env.create_string(table_name)?)?; + } else { + js_obj.set_named_property("table", env.get_null()?)?; + } + // database_name -> database + if let Some(database_name) = col.database_name() { + js_obj.set_named_property("database", env.create_string(database_name)?)?; + } else { + js_obj.set_named_property("database", env.get_null()?)?; + } + // decl_type -> type + if let Some(decl_type) = col.decl_type() { + js_obj.set_named_property("type", env.create_string(decl_type)?)?; + } else { + js_obj.set_named_property("type", env.get_null()?)?; + } + js_array.set(i as u32, js_obj)?; + } + Ok(js_array) } #[napi] @@ -881,16 +858,75 @@ impl Statement { #[napi] pub fn interrupt(&self) -> Result<()> { - let rt = runtime()?; - rt.block_on(async move { - let mut stmt = self.stmt.lock().await; - stmt.interrupt() - }) - .map_err(Error::from)?; + self.stmt.interrupt().map_err(Error::from)?; Ok(()) } } +/// Gets first row from statement in blocking mode. +#[napi] +pub fn statement_get_sync( + stmt: &Statement, + env: Env, + params: Option, +) -> Result { + let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); + let raw = stmt.mode.raw.load(Ordering::SeqCst); + let pluck = stmt.mode.pluck.load(Ordering::SeqCst); + let timed = stmt.mode.timing.load(Ordering::SeqCst); + + let start = if timed { + Some(std::time::Instant::now()) + } else { + None + }; + + let rt = runtime()?; + rt.block_on(async move { + let params = map_params(&stmt.stmt, params)?; + let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?; + let row = rows.next().await.map_err(Error::from)?; + let duration: Option = start.map(|start| start.elapsed().as_secs_f64()); + let result = Statement::get_internal( + &env, + &row, + &stmt.column_names, + safe_ints, + raw, + pluck, + duration, + ); + stmt.stmt.reset(); + result + }) +} + +/// Runs a statement in blocking mode. +#[napi] +pub fn statement_run_sync(stmt: &Statement, params: Option) -> Result { + stmt.stmt.reset(); + let rt = runtime()?; + rt.block_on(async move { + let params = map_params(&stmt.stmt, params)?; + let total_changes_before = stmt.conn.total_changes(); + let start = std::time::Instant::now(); + + stmt.stmt.run(params).await.map_err(Error::from)?; + let changes = if stmt.conn.total_changes() == total_changes_before { + 0 + } else { + stmt.conn.changes() + }; + let last_insert_row_id = stmt.conn.last_insert_rowid(); + let duration = start.elapsed().as_secs_f64(); + Ok(RunResult { + changes: changes as f64, + duration, + lastInsertRowid: last_insert_row_id, + }) + }) +} + #[napi] pub fn statement_iterate_sync( stmt: &Statement, @@ -903,7 +939,6 @@ pub fn statement_iterate_sync( let pluck = stmt.mode.pluck.load(Ordering::SeqCst); let stmt = stmt.stmt.clone(); let (rows, column_names) = rt.block_on(async move { - let mut stmt = stmt.lock().await; stmt.reset(); let params = map_params(&stmt, params)?; let rows = stmt.query(params).await.map_err(Error::from)?; diff --git a/wrapper.js b/wrapper.js index 732b19f..b220b19 100644 --- a/wrapper.js +++ b/wrapper.js @@ -1,6 +1,6 @@ "use strict"; -const { Database: NativeDb, databasePrepareSync, databaseSyncSync, statementIterateSync, iteratorNextSync } = require("./index.js"); +const { Database: NativeDb, databasePrepareSync, databaseSyncSync, databaseExecSync, statementRunSync, statementGetSync, statementIterateSync, iteratorNextSync } = require("./index.js"); const SqliteError = require("./sqlite-error.js"); const Authorization = require("./auth"); @@ -178,7 +178,7 @@ class Database { */ exec(sql) { try { - this.db.exec(sql); + databaseExecSync(this.db, sql); } catch (err) { throw convertError(err); } @@ -263,7 +263,7 @@ class Statement { */ run(...bindParameters) { try { - return this.stmt.run(...bindParameters); + return statementRunSync(this.stmt, ...bindParameters); } catch (err) { throw convertError(err); } @@ -276,7 +276,7 @@ class Statement { */ get(...bindParameters) { try { - return this.stmt.get(...bindParameters); + return statementGetSync(this.stmt, ...bindParameters); } catch (err) { throw convertError(err); }