Skip to content

Commit a8d8b02

Browse files
committed
add simple vector_add / vector_sub functions
1 parent 7fbef83 commit a8d8b02

6 files changed

Lines changed: 449 additions & 15 deletions

File tree

libsql-ffi/bundled/bindings/bindgen.rs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ extern "C" {
2323
) -> ::std::os::raw::c_int;
2424
}
2525

26-
pub const __GNUC_VA_LIST: i32 = 1;
2726
pub const SQLITE_VERSION: &[u8; 7] = b"3.45.1\0";
2827
pub const SQLITE_VERSION_NUMBER: i32 = 3045001;
2928
pub const SQLITE_SOURCE_ID: &[u8; 85] =
@@ -502,8 +501,8 @@ pub const FTS5_TOKENIZE_DOCUMENT: i32 = 4;
502501
pub const FTS5_TOKENIZE_AUX: i32 = 8;
503502
pub const FTS5_TOKEN_COLOCATED: i32 = 1;
504503
pub const WAL_SAVEPOINT_NDATA: i32 = 4;
505-
pub type va_list = __builtin_va_list;
506-
pub type __gnuc_va_list = __builtin_va_list;
504+
pub type __gnuc_va_list = [u64; 4usize];
505+
pub type va_list = [u64; 4usize];
507506
extern "C" {
508507
pub static sqlite3_version: [::std::os::raw::c_char; 0usize];
509508
}
@@ -940,7 +939,7 @@ extern "C" {
940939
extern "C" {
941940
pub fn sqlite3_vmprintf(
942941
arg1: *const ::std::os::raw::c_char,
943-
arg2: *mut __va_list_tag,
942+
arg2: va_list,
944943
) -> *mut ::std::os::raw::c_char;
945944
}
946945
extern "C" {
@@ -956,7 +955,7 @@ extern "C" {
956955
arg1: ::std::os::raw::c_int,
957956
arg2: *mut ::std::os::raw::c_char,
958957
arg3: *const ::std::os::raw::c_char,
959-
arg4: *mut __va_list_tag,
958+
arg4: va_list,
960959
) -> *mut ::std::os::raw::c_char;
961960
}
962961
extern "C" {
@@ -2506,7 +2505,7 @@ extern "C" {
25062505
pub fn sqlite3_str_vappendf(
25072506
arg1: *mut sqlite3_str,
25082507
zFormat: *const ::std::os::raw::c_char,
2509-
arg2: *mut __va_list_tag,
2508+
arg2: va_list,
25102509
);
25112510
}
25122511
extern "C" {
@@ -3574,12 +3573,3 @@ extern "C" {
35743573
extern "C" {
35753574
pub static sqlite3_wal_manager: libsql_wal_manager;
35763575
}
3577-
pub type __builtin_va_list = [__va_list_tag; 1usize];
3578-
#[repr(C)]
3579-
#[derive(Debug, Copy, Clone)]
3580-
pub struct __va_list_tag {
3581-
pub gp_offset: ::std::os::raw::c_uint,
3582-
pub fp_offset: ::std::os::raw::c_uint,
3583-
pub overflow_arg_area: *mut ::std::os::raw::c_void,
3584-
pub reg_save_area: *mut ::std::os::raw::c_void,
3585-
}

libsql-ffi/bundled/src/sqlite3.c

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85618,6 +85618,16 @@ float vectorFB16DistanceL2(const Vector *, const Vector *);
8561885618
float vectorF32DistanceL2 (const Vector *, const Vector *);
8561985619
double vectorF64DistanceL2(const Vector *, const Vector *);
8562085620

85621+
/*
85622+
* Add/sub vectors
85623+
*/
85624+
void vectorAdd (Vector *, const Vector *);
85625+
void vectorF32Add(Vector *, const Vector *);
85626+
void vectorF64Add(Vector *, const Vector *);
85627+
void vectorSub (Vector *, const Vector *);
85628+
void vectorF32Sub(Vector *, const Vector *);
85629+
void vectorF64Sub(Vector *, const Vector *);
85630+
8562185631
/*
8562285632
* Serializes vector to the sqlite_blob in little-endian format according to the IEEE-754 standard
8562385633
* LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob
@@ -211619,6 +211629,32 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
211619211629
return 0;
211620211630
}
211621211631

211632+
void vectorAdd(Vector *pVector1, const Vector *pVector2){
211633+
assert( pVector1->type == pVector2->type );
211634+
switch (pVector1->type) {
211635+
case VECTOR_TYPE_FLOAT32:
211636+
return vectorF32Add(pVector1, pVector2);
211637+
case VECTOR_TYPE_FLOAT64:
211638+
return vectorF64Add(pVector1, pVector2);
211639+
default:
211640+
assert(0);
211641+
}
211642+
return;
211643+
}
211644+
211645+
void vectorSub(Vector *pVector1, const Vector *pVector2){
211646+
assert( pVector1->type == pVector2->type );
211647+
switch (pVector1->type) {
211648+
case VECTOR_TYPE_FLOAT32:
211649+
return vectorF32Sub(pVector1, pVector2);
211650+
case VECTOR_TYPE_FLOAT64:
211651+
return vectorF64Sub(pVector1, pVector2);
211652+
default:
211653+
assert(0);
211654+
}
211655+
return;
211656+
}
211657+
211622211658
SQLITE_API const char *sqlite3_type_repr(int type){
211623211659
switch( type ){
211624211660
case SQLITE_NULL:
@@ -212709,6 +212745,130 @@ static void vectorDistanceL2Func(sqlite3_context *context, int argc, sqlite3_val
212709212745
vectorDistanceFunc(context, argc, argv, vectorDistanceL2);
212710212746
}
212711212747

212748+
/*
212749+
** Implementation of vector_add(X, Y) function.
212750+
*/
212751+
static void vectorAddFn(sqlite3_context *context, int argc, sqlite3_value **argv){
212752+
char *pzErrMsg = NULL;
212753+
Vector *pVector1 = NULL, *pVector2 = NULL;
212754+
int type1, type2;
212755+
int dims1, dims2;
212756+
if( argc < 2 ) {
212757+
return;
212758+
}
212759+
if( detectVectorParameters(argv[0], 0, &type1, &dims1, &pzErrMsg) != 0 ){
212760+
sqlite3_result_error(context, pzErrMsg, -1);
212761+
sqlite3_free(pzErrMsg);
212762+
goto out_free;
212763+
}
212764+
if( detectVectorParameters(argv[1], 0, &type2, &dims2, &pzErrMsg) != 0 ){
212765+
sqlite3_result_error(context, pzErrMsg, -1);
212766+
sqlite3_free(pzErrMsg);
212767+
goto out_free;
212768+
}
212769+
if( type1 != type2 ){
212770+
pzErrMsg = sqlite3_mprintf("vector_add: vectors must have the same type: %d != %d", type1, type2);
212771+
sqlite3_result_error(context, pzErrMsg, -1);
212772+
sqlite3_free(pzErrMsg);
212773+
goto out_free;
212774+
}
212775+
if( dims1 != dims2 ){
212776+
pzErrMsg = sqlite3_mprintf("vector_add: vectors must have the same length: %d != %d", dims1, dims2);
212777+
sqlite3_result_error(context, pzErrMsg, -1);
212778+
sqlite3_free(pzErrMsg);
212779+
goto out_free;
212780+
}
212781+
pVector1 = vectorContextAlloc(context, type1, dims1);
212782+
if( pVector1==NULL ){
212783+
goto out_free;
212784+
}
212785+
pVector2 = vectorContextAlloc(context, type2, dims2);
212786+
if( pVector2==NULL ){
212787+
goto out_free;
212788+
}
212789+
if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){
212790+
sqlite3_result_error(context, pzErrMsg, -1);
212791+
sqlite3_free(pzErrMsg);
212792+
goto out_free;
212793+
}
212794+
if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){
212795+
sqlite3_result_error(context, pzErrMsg, -1);
212796+
sqlite3_free(pzErrMsg);
212797+
goto out_free;
212798+
}
212799+
vectorAdd(pVector1, pVector2);
212800+
vectorSerializeWithMeta(context, pVector1);
212801+
out_free:
212802+
if( pVector2 ){
212803+
vectorFree(pVector2);
212804+
}
212805+
if( pVector1 ){
212806+
vectorFree(pVector1);
212807+
}
212808+
}
212809+
212810+
/*
212811+
** Implementation of vector_sub(X, Y) function.
212812+
*/
212813+
static void vectorSubFn(sqlite3_context *context, int argc, sqlite3_value **argv){
212814+
char *pzErrMsg = NULL;
212815+
Vector *pVector1 = NULL, *pVector2 = NULL;
212816+
int type1, type2;
212817+
int dims1, dims2;
212818+
if( argc < 2 ) {
212819+
return;
212820+
}
212821+
if( detectVectorParameters(argv[0], 0, &type1, &dims1, &pzErrMsg) != 0 ){
212822+
sqlite3_result_error(context, pzErrMsg, -1);
212823+
sqlite3_free(pzErrMsg);
212824+
goto out_free;
212825+
}
212826+
if( detectVectorParameters(argv[1], 0, &type2, &dims2, &pzErrMsg) != 0 ){
212827+
sqlite3_result_error(context, pzErrMsg, -1);
212828+
sqlite3_free(pzErrMsg);
212829+
goto out_free;
212830+
}
212831+
if( type1 != type2 ){
212832+
pzErrMsg = sqlite3_mprintf("vector_add: vectors must have the same type: %d != %d", type1, type2);
212833+
sqlite3_result_error(context, pzErrMsg, -1);
212834+
sqlite3_free(pzErrMsg);
212835+
goto out_free;
212836+
}
212837+
if( dims1 != dims2 ){
212838+
pzErrMsg = sqlite3_mprintf("vector_add: vectors must have the same length: %d != %d", dims1, dims2);
212839+
sqlite3_result_error(context, pzErrMsg, -1);
212840+
sqlite3_free(pzErrMsg);
212841+
goto out_free;
212842+
}
212843+
pVector1 = vectorContextAlloc(context, type1, dims1);
212844+
if( pVector1==NULL ){
212845+
goto out_free;
212846+
}
212847+
pVector2 = vectorContextAlloc(context, type2, dims2);
212848+
if( pVector2==NULL ){
212849+
goto out_free;
212850+
}
212851+
if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){
212852+
sqlite3_result_error(context, pzErrMsg, -1);
212853+
sqlite3_free(pzErrMsg);
212854+
goto out_free;
212855+
}
212856+
if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){
212857+
sqlite3_result_error(context, pzErrMsg, -1);
212858+
sqlite3_free(pzErrMsg);
212859+
goto out_free;
212860+
}
212861+
vectorSub(pVector1, pVector2);
212862+
vectorSerializeWithMeta(context, pVector1);
212863+
out_free:
212864+
if( pVector2 ){
212865+
vectorFree(pVector2);
212866+
}
212867+
if( pVector1 ){
212868+
vectorFree(pVector1);
212869+
}
212870+
}
212871+
212712212872
/*
212713212873
* Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb));
212714212874
*/
@@ -212732,6 +212892,8 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){
212732212892
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
212733212893
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
212734212894
FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func),
212895+
FUNCTION(vector_add, 2, 0, 0, vectorAddFn),
212896+
FUNCTION(vector_sub, 2, 0, 0, vectorSubFn),
212735212897

212736212898
FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx),
212737212899
};
@@ -214828,6 +214990,36 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){
214828214990
return sqrt(sum);
214829214991
}
214830214992

214993+
void vectorF32Add(Vector *v1, const Vector *v2){
214994+
float sum = 0;
214995+
float *e1 = v1->data;
214996+
float *e2 = v2->data;
214997+
int i;
214998+
214999+
assert( v1->dims == v2->dims );
215000+
assert( v1->type == VECTOR_TYPE_FLOAT32 );
215001+
assert( v2->type == VECTOR_TYPE_FLOAT32 );
215002+
215003+
for(i = 0; i < v1->dims; i++){
215004+
e1[i] += e2[i];
215005+
}
215006+
}
215007+
215008+
void vectorF32Sub(Vector *v1, const Vector *v2){
215009+
float sum = 0;
215010+
float *e1 = v1->data;
215011+
float *e2 = v2->data;
215012+
int i;
215013+
215014+
assert( v1->dims == v2->dims );
215015+
assert( v1->type == VECTOR_TYPE_FLOAT32 );
215016+
assert( v2->type == VECTOR_TYPE_FLOAT32 );
215017+
215018+
for(i = 0; i < v1->dims; i++){
215019+
e1[i] -= e2[i];
215020+
}
215021+
}
215022+
214831215023
void vectorF32DeserializeFromBlob(
214832215024
Vector *pVector,
214833215025
const unsigned char *pBlob,
@@ -215026,6 +215218,36 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){
215026215218
return sqrt(sum);
215027215219
}
215028215220

215221+
void vectorF64Add(Vector *v1, const Vector *v2){
215222+
double sum = 0;
215223+
double *e1 = v1->data;
215224+
double *e2 = v2->data;
215225+
int i;
215226+
215227+
assert( v1->dims == v2->dims );
215228+
assert( v1->type == VECTOR_TYPE_FLOAT64 );
215229+
assert( v2->type == VECTOR_TYPE_FLOAT64 );
215230+
215231+
for(i = 0; i < v1->dims; i++){
215232+
e1[i] += e2[i];
215233+
}
215234+
}
215235+
215236+
void vectorF64Sub(Vector *v1, const Vector *v2){
215237+
double sum = 0;
215238+
double *e1 = v1->data;
215239+
double *e2 = v2->data;
215240+
int i;
215241+
215242+
assert( v1->dims == v2->dims );
215243+
assert( v1->type == VECTOR_TYPE_FLOAT64 );
215244+
assert( v2->type == VECTOR_TYPE_FLOAT64 );
215245+
215246+
for(i = 0; i < v1->dims; i++){
215247+
e1[i] -= e2[i];
215248+
}
215249+
}
215250+
215029215251
void vectorF64DeserializeFromBlob(
215030215252
Vector *pVector,
215031215253
const unsigned char *pBlob,

0 commit comments

Comments
 (0)