From 45b40c34e0b475af5b84afd4c8aecac2e7ca1114 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Wed, 9 Apr 2025 08:41:23 +0300 Subject: [PATCH] libsql: Add authorizer support This adds support for `sqlite3_set_authorizer()`. The API follows rusqlite and the enumerations are lifted straight from rusqlite for compatibility. --- libsql/src/auth.rs | 263 ++++++++++++++++++++++++++++++ libsql/src/connection.rs | 11 ++ libsql/src/errors.rs | 2 + libsql/src/lib.rs | 4 +- libsql/src/local/connection.rs | 93 +++++++++++ libsql/src/local/impls.rs | 6 +- libsql/tests/integration_tests.rs | 73 ++++++++- 7 files changed, 449 insertions(+), 3 deletions(-) create mode 100644 libsql/src/auth.rs diff --git a/libsql/src/auth.rs b/libsql/src/auth.rs new file mode 100644 index 0000000000..927cfad9bb --- /dev/null +++ b/libsql/src/auth.rs @@ -0,0 +1,263 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct AuthContext<'a> { + pub action: AuthAction<'a>, + + pub database_name: Option<&'a str>, + + pub accessor: Option<&'a str>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum AuthAction<'a> { + Unknown { + code: i32, + arg1: Option<&'a str>, + arg2: Option<&'a str>, + }, + CreateIndex { + index_name: &'a str, + table_name: &'a str, + }, + CreateTable { + table_name: &'a str, + }, + CreateTempIndex { + index_name: &'a str, + table_name: &'a str, + }, + CreateTempTable { + table_name: &'a str, + }, + CreateTempTrigger { + trigger_name: &'a str, + table_name: &'a str, + }, + CreateTempView { + view_name: &'a str, + }, + CreateTrigger { + trigger_name: &'a str, + table_name: &'a str, + }, + CreateView { + view_name: &'a str, + }, + Delete { + table_name: &'a str, + }, + DropIndex { + index_name: &'a str, + table_name: &'a str, + }, + DropTable { + table_name: &'a str, + }, + DropTempIndex { + index_name: &'a str, + table_name: &'a str, + }, + DropTempTable { + table_name: &'a str, + }, + DropTempTrigger { + trigger_name: &'a str, + table_name: &'a str, + }, + DropTempView { + view_name: &'a str, + }, + DropTrigger { + trigger_name: &'a str, + table_name: &'a str, + }, + DropView { + view_name: &'a str, + }, + Insert { + table_name: &'a str, + }, + Pragma { + pragma_name: &'a str, + pragma_value: Option<&'a str>, + }, + Read { + table_name: &'a str, + column_name: &'a str, + }, + Select, + Transaction { + operation: TransactionOperation, + }, + Update { + table_name: &'a str, + column_name: &'a str, + }, + Attach { + filename: &'a str, + }, + Detach { + database_name: &'a str, + }, + AlterTable { + database_name: &'a str, + table_name: &'a str, + }, + Reindex { + index_name: &'a str, + }, + Analyze { + table_name: &'a str, + }, + CreateVtable { + table_name: &'a str, + module_name: &'a str, + }, + DropVtable { + table_name: &'a str, + module_name: &'a str, + }, + Function { + function_name: &'a str, + }, + Savepoint { + operation: TransactionOperation, + savepoint_name: &'a str, + }, + Recursive, +} + +#[cfg(feature = "core")] +impl<'a> AuthAction<'a> { + pub(crate) fn from_raw(code: i32, arg1: Option<&'a str>, arg2: Option<&'a str>) -> Self { + use libsql_sys::ffi; + + match (code, arg1, arg2) { + (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex { + index_name, + table_name, + }, + (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name }, + (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::CreateTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => { + Self::CreateTempTable { table_name } + } + (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => { + Self::CreateTempView { view_name } + } + (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name }, + (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name }, + (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex { + index_name, + table_name, + }, + (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name }, + (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::DropTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => { + Self::DropTempTable { table_name } + } + (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::DropTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name }, + (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger { + trigger_name, + table_name, + }, + (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name }, + (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name }, + (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma { + pragma_name, + pragma_value, + }, + (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read { + table_name, + column_name, + }, + (ffi::SQLITE_SELECT, ..) => Self::Select, + (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction { + operation: TransactionOperation::from_str(operation_str), + }, + (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update { + table_name, + column_name, + }, + (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename }, + (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name }, + (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable { + database_name, + table_name, + }, + (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name }, + (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name }, + (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => { + Self::CreateVtable { + table_name, + module_name, + } + } + (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable { + table_name, + module_name, + }, + (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name }, + (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint { + operation: TransactionOperation::from_str(operation_str), + savepoint_name, + }, + (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive, + (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 }, + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TransactionOperation { + Unknown, + Begin, + Release, + Rollback, +} + +#[cfg(feature = "core")] +impl TransactionOperation { + fn from_str(op_str: &str) -> Self { + match op_str { + "BEGIN" => Self::Begin, + "RELEASE" => Self::Release, + "ROLLBACK" => Self::Rollback, + _ => Self::Unknown, + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Authorization { + Allow, + Ignore, + Deny, +} diff --git a/libsql/src/connection.rs b/libsql/src/connection.rs index 6dcc5869fa..c61348c8a6 100644 --- a/libsql/src/connection.rs +++ b/libsql/src/connection.rs @@ -4,12 +4,15 @@ use std::path::Path; use std::sync::Arc; use std::time::Duration; +use crate::auth::{AuthContext, Authorization}; use crate::params::{IntoParams, Params}; use crate::rows::Rows; use crate::statement::Statement; use crate::transaction::Transaction; use crate::{Result, TransactionBehavior}; +pub type AuthHook = Arc Authorization>; + #[async_trait::async_trait] pub(crate) trait Conn { async fn execute(&self, sql: &str, params: Params) -> Result; @@ -43,6 +46,10 @@ pub(crate) trait Conn { fn load_extension(&self, _dylib_path: &Path, _entry_point: Option<&str>) -> Result<()> { Err(crate::Error::LoadExtensionNotSupported) } + + fn authorizer(&self, _hook: Option) -> Result<()> { + Err(crate::Error::AuthorizerNotSupported) + } } /// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially @@ -258,6 +265,10 @@ impl Connection { ) -> Result<()> { self.conn.load_extension(dylib_path.as_ref(), entry_point) } + + pub fn authorizer(&self, hook: Option) -> Result<()> { + self.conn.authorizer(hook) + } } impl fmt::Debug for Connection { diff --git a/libsql/src/errors.rs b/libsql/src/errors.rs index 8d3ed0e581..ad230039c5 100644 --- a/libsql/src/errors.rs +++ b/libsql/src/errors.rs @@ -21,6 +21,8 @@ pub enum Error { SyncNotSupported(String), // Not in rusqlite #[error("Loading extension is only supported in local databases.")] LoadExtensionNotSupported, // Not in rusqlite + #[error("Authorizer is only supported in local databases.")] + AuthorizerNotSupported, // Not in rusqlite #[error("Column not found: {0}")] ColumnNotFound(i32), // Not in rusqlite #[error("Hrana: `{0}`")] diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index 823ab89c84..15f98d8869 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -153,6 +153,7 @@ pub use errors::Error; pub use params::params_from_iter; +mod auth; mod connection; mod database; mod load_extension_guard; @@ -176,7 +177,8 @@ cfg_hrana! { } pub use self::{ - connection::{BatchRows, Connection}, + auth::{AuthAction, AuthContext, Authorization}, + connection::{AuthHook, BatchRows, Connection}, database::{Builder, Database}, load_extension_guard::LoadExtensionGuard, rows::{Column, Row, Rows}, diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index b9a2e20150..8e0bfec579 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +use crate::auth::{AuthAction, AuthContext, Authorization}; +use crate::connection::AuthHook; use crate::local::rows::BatchedRows; use crate::params::Params; use crate::{connection::BatchRows, errors}; @@ -22,6 +24,8 @@ pub struct Connection { #[cfg(feature = "replication")] pub(crate) writer: Option, + + authorizer: RefCell>, } impl Drop for Connection { @@ -64,6 +68,7 @@ impl Connection { drop_ref: Arc::new(()), #[cfg(feature = "replication")] writer: db.writer()?, + authorizer: RefCell::new(None), }; #[cfg(feature = "sync")] if let Some(_) = db.sync_ctx { @@ -90,11 +95,19 @@ impl Connection { drop_ref: Arc::new(()), #[cfg(feature = "replication")] writer: None, + authorizer: RefCell::new(None), } } /// Disconnect from the database. pub fn disconnect(&mut self) { + // Clean up the authorizer before closing + unsafe { + let rc = libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut()); + if rc != ffi::SQLITE_OK { + tracing::error!("Failed to clear authorizer during disconnect"); + } + } if Arc::get_mut(&mut self.drop_ref).is_some() { unsafe { libsql_sys::ffi::sqlite3_close_v2(self.raw) }; } @@ -458,6 +471,38 @@ impl Connection { } } + pub fn authorizer(&self, hook: Option) -> Result<()> { + unsafe { + let rc = libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut()); + if rc != ffi::SQLITE_OK { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + "Failed to clear authorizer".to_string(), + )); + } + } + + *self.authorizer.borrow_mut() = hook.clone(); + + let (callback, user_data) = match hook { + Some(_) => { + let callback = authorizer_callback as unsafe extern "C" fn(_, _, _, _, _, _) -> _; + let user_data = self as *const Connection as *mut ::std::os::raw::c_void; + (Some(callback), user_data) + }, + None => (None, std::ptr::null_mut()), + }; + + let rc = unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) }; + if rc != ffi::SQLITE_OK { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + "Failed to set authorizer".to_string(), + )); + } + Ok(()) + } + pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> { let rc = unsafe { libsql_sys::ffi::sqlite3_wal_checkpoint_v2(self.handle(), std::ptr::null(), truncate as i32, std::ptr::null_mut(), std::ptr::null_mut()) }; if rc != 0 { @@ -570,6 +615,54 @@ impl Connection { } } +unsafe extern "C" fn authorizer_callback( + user_data: *mut ::std::os::raw::c_void, + code: ::std::os::raw::c_int, + arg1: *const ::std::os::raw::c_char, + arg2: *const ::std::os::raw::c_char, + database_name: *const ::std::os::raw::c_char, + accessor: *const ::std::os::raw::c_char, +) -> ::std::os::raw::c_int { + let conn = user_data as *const Connection; + let hook = unsafe { (*conn).authorizer.borrow() }; + let hook = match &*hook { + Some(hook) => hook, + None => return ffi::SQLITE_OK, + }; + let arg1 = if arg1.is_null() { + None + } else { + unsafe { std::ffi::CStr::from_ptr(arg1).to_str().ok() } + }; + + let arg2 = if arg2.is_null() { + None + } else { + unsafe { std::ffi::CStr::from_ptr(arg2).to_str().ok() } + }; + let database_name = if database_name.is_null() { + None + } else { + unsafe { std::ffi::CStr::from_ptr(database_name).to_str().ok() } + }; + let accessor = if accessor.is_null() { + None + } else { + unsafe { std::ffi::CStr::from_ptr(accessor).to_str().ok() } + }; + let action = AuthAction::from_raw(code, arg1, arg2); + let auth_context = AuthContext { + action, + database_name, + accessor, + }; + match hook(&auth_context) { + Authorization::Allow => ffi::SQLITE_OK, + Authorization::Deny => ffi::SQLITE_DENY, + Authorization::Ignore => ffi::SQLITE_IGNORE, + } +} + pub(crate) struct WalInsertHandle<'a> { conn: &'a Connection, in_session: RefCell diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 34852ab196..30d219fb6a 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -4,7 +4,7 @@ use std::time::Duration; use crate::connection::BatchRows; use crate::{ - connection::Conn, + connection::{AuthHook, Conn}, params::Params, rows::{ColumnsInner, RowInner, RowsInner}, statement::Stmt, @@ -88,6 +88,10 @@ impl Conn for LibsqlConnection { fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { self.conn.load_extension(dylib_path, entry_point) } + + fn authorizer(&self, hook: Option) -> Result<()> { + self.conn.authorizer(hook) + } } impl Drop for LibsqlConnection { diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 92d8d358d8..2101a2e1ea 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -4,11 +4,12 @@ use futures::{StreamExt, TryStreamExt}; use libsql::{ named_params, params, params::{IntoParams, IntoValue}, - Connection, Database, Value, + AuthAction, Authorization, Connection, Database, Result, Value, }; use rand::distributions::Uniform; use rand::prelude::*; use std::collections::HashSet; +use std::sync::Arc; async fn setup() -> Connection { let db = Database::open(":memory:").unwrap(); @@ -783,3 +784,73 @@ async fn vector_fuzz_test() { let _ = conn.execute("REINDEX users;", ()).await.unwrap(); } } + +#[tokio::test] +async fn test_deny_authorizer() { + let db = Database::open(":memory:").unwrap(); + let conn = db.connect().unwrap(); + conn.execute("CREATE TABLE users (id INTEGER, name TEXT)", ()) + .await + .unwrap(); + conn.authorizer(Some(Arc::new(|ctx| { + assert_eq!( + ctx.action, + AuthAction::Insert { + table_name: "users" + } + ); + assert_eq!(ctx.database_name, Some("main")); + assert_eq!(ctx.accessor, None); + Authorization::Deny + }))) + .unwrap(); + let res = conn + .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')", ()) + .await; + assert_sqlite_error(res, libsql::ffi::SQLITE_AUTH); + conn.authorizer(None).unwrap(); + conn.execute("INSERT INTO users (id, name) VALUES (1, 'Alice')", ()) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_ignore_authorizer() { + let db = Database::open(":memory:").unwrap(); + let conn = db.connect().unwrap(); + conn.execute("CREATE TABLE users (id INTEGER, name TEXT)", ()) + .await + .unwrap(); + conn.authorizer(Some(Arc::new(|ctx| { + assert_eq!( + ctx.action, + AuthAction::Insert { + table_name: "users" + } + ); + assert_eq!(ctx.database_name, Some("main")); + assert_eq!(ctx.accessor, None); + Authorization::Ignore + }))) + .unwrap(); + conn.execute("INSERT INTO users (id, name) VALUES (1, 'Alice')", ()) + .await + .unwrap(); + conn.authorizer(None).unwrap(); + let rows = conn.query("SELECT * FROM users", ()).await.unwrap(); + // There should be no rows + assert_eq!(rows.into_stream().count().await, 0); +} + +fn assert_sqlite_error(res: Result, code: i32) { + match res { + Ok(_) => panic!("Expected error, got Ok"), + Err(e) => { + if let libsql::Error::SqliteFailure(c, _) = e { + assert!(c == code, "Expected error code {}, got {}", code, c); + } else { + panic!("Expected SqliteFailure, got {:?}", e); + } + } + } +}