22
33using System ;
44using System . Collections . Generic ;
5- using System . Data . Common ;
65using System . Diagnostics ;
76using System . Diagnostics . CodeAnalysis ;
87using System . Linq ;
@@ -263,11 +262,9 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
263262 translator . Translate ( appendWhere : false ) ;
264263
265264 using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
266- DbCommand ? command = null ;
267265
268- if ( options . IncludeVectors )
269- {
270- command = SqliteCommandBuilder . BuildSelectInnerJoinCommand (
266+ using var command = options . IncludeVectors
267+ ? SqliteCommandBuilder . BuildSelectInnerJoinCommand (
271268 connection ,
272269 this . _vectorTableName ,
273270 this . _dataTableName ,
@@ -279,11 +276,8 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
279276 translator . Clause . ToString ( ) ,
280277 translator . Parameters ,
281278 top : top ,
282- skip : options . Skip ) ;
283- }
284- else
285- {
286- command = SqliteCommandBuilder . BuildSelectDataCommand (
279+ skip : options . Skip )
280+ : SqliteCommandBuilder . BuildSelectDataCommand (
287281 connection ,
288282 this . _dataTableName ,
289283 this . _model ,
@@ -293,28 +287,21 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
293287 translator . Parameters ,
294288 top : top ,
295289 skip : options . Skip ) ;
296- }
297290
298- using ( command )
299- {
300- const string OperationName = "Get" ;
291+ const string OperationName = "Get" ;
301292
302- using var reader = await connection . ExecuteWithErrorHandlingAsync (
303- this . _collectionMetadata ,
304- OperationName ,
305- ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
306- cancellationToken ) . ConfigureAwait ( false ) ;
293+ using var reader = await connection . ExecuteWithErrorHandlingAsync (
294+ this . _collectionMetadata ,
295+ OperationName ,
296+ ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
297+ cancellationToken ) . ConfigureAwait ( false ) ;
307298
308- while ( await reader . ReadWithErrorHandlingAsync (
309- this . _collectionMetadata ,
310- OperationName ,
311- cancellationToken ) . ConfigureAwait ( false ) )
312- {
313- yield return this . GetAndMapRecord (
314- reader ,
315- this . _model . Properties ,
316- options . IncludeVectors ) ;
317- }
299+ while ( await reader . ReadWithErrorHandlingAsync (
300+ this . _collectionMetadata ,
301+ OperationName ,
302+ cancellationToken ) . ConfigureAwait ( false ) )
303+ {
304+ yield return this . _mapper . MapFromStorageToDataModel ( reader , options . IncludeVectors ) ;
318305 }
319306 }
320307
@@ -363,7 +350,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
363350 {
364351 Verify . NotNull ( record ) ;
365352
366- IReadOnlyList < Embedding > ? [ ] ? generatedEmbeddings = null ;
353+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings = null ;
367354
368355 var vectorPropertyCount = this . _model . VectorProperties . Count ;
369356 for ( var i = 0 ; i < vectorPropertyCount ; i ++ )
@@ -382,8 +369,8 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
382369 // and generate embeddings for them in a single batch. That's some more complexity though.
383370 if ( vectorProperty . TryGenerateEmbedding < TRecord , Embedding < float > > ( record , cancellationToken , out var floatTask ) )
384371 {
385- generatedEmbeddings ??= new IReadOnlyList < Embedding > ? [ vectorPropertyCount ] ;
386- generatedEmbeddings [ i ] = [ await floatTask . ConfigureAwait ( false ) ] ;
372+ generatedEmbeddings ??= new Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ( vectorPropertyCount ) ;
373+ generatedEmbeddings [ vectorProperty ] = [ await floatTask . ConfigureAwait ( false ) ] ;
387374 }
388375 else
389376 {
@@ -394,16 +381,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
394381
395382 using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
396383
397- var storageModel = this . _mapper . MapFromDataToStorageModel ( record , recordIndex : 0 , generatedEmbeddings ) ;
398-
399- var key = storageModel [ this . _keyStorageName ] ;
400-
401- Verify . NotNull ( key ) ;
402-
403- var condition = new SqliteWhereEqualsCondition ( this . _keyStorageName , key ) ;
404-
405- await this . InternalUpsertBatchAsync ( connection , [ storageModel ] , condition , cancellationToken )
406- . ConfigureAwait ( false ) ;
384+ await this . InternalUpsertBatchAsync ( connection , [ record ] , generatedEmbeddings , cancellationToken ) . ConfigureAwait ( false ) ;
407385 }
408386
409387 /// <inheritdoc />
@@ -414,7 +392,7 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
414392 IReadOnlyList < TRecord > ? recordsList = null ;
415393
416394 // If an embedding generator is defined, invoke it once per property for all records.
417- IReadOnlyList < Embedding > ? [ ] ? generatedEmbeddings = null ;
395+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings = null ;
418396
419397 var vectorPropertyCount = this . _model . VectorProperties . Count ;
420398 for ( var i = 0 ; i < vectorPropertyCount ; i ++ )
@@ -447,8 +425,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
447425 // and generate embeddings for them in a single batch. That's some more complexity though.
448426 if ( vectorProperty . TryGenerateEmbeddings < TRecord , Embedding < float > > ( records , cancellationToken , out var floatTask ) )
449427 {
450- generatedEmbeddings ??= new IReadOnlyList < Embedding > ? [ vectorPropertyCount ] ;
451- generatedEmbeddings [ i ] = ( IReadOnlyList < Embedding < float > > ) await floatTask . ConfigureAwait ( false ) ;
428+ generatedEmbeddings ??= new Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ( vectorPropertyCount ) ;
429+ generatedEmbeddings [ vectorProperty ] = await floatTask . ConfigureAwait ( false ) ;
452430 }
453431 else
454432 {
@@ -457,19 +435,9 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
457435 }
458436 }
459437
460- var storageModels = records . Select ( ( r , i ) => this . _mapper . MapFromDataToStorageModel ( r , i , generatedEmbeddings ) ) . ToList ( ) ;
461-
462- if ( storageModels . Count == 0 )
463- {
464- return ;
465- }
466-
467- var keys = storageModels . Select ( model => model [ this . _keyStorageName ] ! ) . ToList ( ) ;
468-
469438 using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
470- var condition = new SqliteWhereInCondition ( this . _keyStorageName , keys ) ;
471439
472- await this . InternalUpsertBatchAsync ( connection , storageModels , condition , cancellationToken ) . ConfigureAwait ( false ) ;
440+ await this . InternalUpsertBatchAsync ( connection , records , generatedEmbeddings , cancellationToken ) . ConfigureAwait ( false ) ;
473441 }
474442
475443 /// <inheritdoc />
@@ -557,11 +525,7 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> EnumerateAndMapSearc
557525 if ( recordCounter >= searchOptions . Skip )
558526 {
559527 var score = SqlitePropertyMapping . GetPropertyValue < double > ( reader , SqliteCommandBuilder . DistancePropertyName ) ;
560-
561- var record = this . GetAndMapRecord (
562- reader ,
563- this . _model . Properties ,
564- searchOptions . IncludeVectors ) ;
528+ var record = this . _mapper . MapFromStorageToDataModel ( reader , searchOptions . IncludeVectors ) ;
565529
566530 yield return new VectorSearchResult < TRecord > ( record , score ) ;
567531 }
@@ -632,69 +596,67 @@ private async IAsyncEnumerable<TRecord> InternalGetBatchAsync(
632596 const string OperationName = "Select" ;
633597
634598 bool includeVectors = options ? . IncludeVectors is true && this . _vectorPropertiesExist ;
635-
636- DbCommand command ;
637-
638- if ( includeVectors )
599+ if ( includeVectors && this . _model . EmbeddingGenerationRequired )
639600 {
640- if ( this . _model . EmbeddingGenerationRequired )
641- {
642- throw new NotSupportedException ( VectorDataStrings . IncludeVectorsNotSupportedWithEmbeddingGeneration ) ;
643- }
601+ throw new NotSupportedException ( VectorDataStrings . IncludeVectorsNotSupportedWithEmbeddingGeneration ) ;
602+ }
644603
645- command = SqliteCommandBuilder . BuildSelectInnerJoinCommand < TRecord > (
604+ var command = includeVectors
605+ ? SqliteCommandBuilder . BuildSelectInnerJoinCommand < TRecord > (
646606 connection ,
647607 this . _vectorTableName ,
648608 this . _dataTableName ,
649609 this . _keyStorageName ,
650610 this . _model ,
651611 [ condition ] ,
652- includeDistance : false ) ;
653- }
654- else
655- {
656- command = SqliteCommandBuilder . BuildSelectDataCommand < TRecord > (
612+ includeDistance : false )
613+ : SqliteCommandBuilder . BuildSelectDataCommand < TRecord > (
657614 connection ,
658615 this . _dataTableName ,
659616 this . _model ,
660617 [ condition ] ) ;
661- }
662618
663- using ( command )
664- {
665- using var reader = await connection . ExecuteWithErrorHandlingAsync (
666- this . _collectionMetadata ,
667- OperationName ,
668- ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
669- cancellationToken ) . ConfigureAwait ( false ) ;
619+ using var reader = await connection . ExecuteWithErrorHandlingAsync (
620+ this . _collectionMetadata ,
621+ OperationName ,
622+ ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
623+ cancellationToken ) . ConfigureAwait ( false ) ;
670624
671- while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
672- {
673- yield return this . GetAndMapRecord (
674- reader ,
675- this . _model . Properties ,
676- includeVectors ) ;
677- }
625+ while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
626+ {
627+ yield return this . _mapper . MapFromStorageToDataModel ( reader , includeVectors ) ;
678628 }
679629 }
680630
681- private async Task < IReadOnlyList < TKey > > InternalUpsertBatchAsync (
631+ private async Task InternalUpsertBatchAsync (
682632 SqliteConnection connection ,
683- List < Dictionary < string , object ? > > storageModels ,
684- SqliteWhereCondition condition ,
633+ IEnumerable < TRecord > records ,
634+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings ,
685635 CancellationToken cancellationToken )
686636 {
687- Verify . NotNull ( storageModels ) ;
688- Verify . True ( storageModels . Count > 0 , "Number of provided records should be greater than zero." ) ;
637+ Verify . NotNull ( records ) ;
689638
690639 if ( this . _vectorPropertiesExist )
691640 {
641+ // We're going to have to traverse the records multiple times, so materialize the enumerable if needed.
642+ var recordsList = records is IReadOnlyList < TRecord > r ? r : records . ToList ( ) ;
643+
644+ if ( recordsList . Count == 0 )
645+ {
646+ return ;
647+ }
648+
649+ records = recordsList ;
650+
651+ var keyProperty = this . _model . KeyProperty ;
652+ var keys = recordsList . Select ( r => keyProperty . GetValueAsObject ( r ) ! ) . ToList ( ) ;
653+
692654 // Deleting vector records first since current version of vector search extension
693655 // doesn't support Upsert operation, only Delete/Insert.
694656 using var vectorDeleteCommand = SqliteCommandBuilder . BuildDeleteCommand (
695657 connection ,
696658 this . _vectorTableName ,
697- [ condition ] ) ;
659+ [ new SqliteWhereInCondition ( this . _keyStorageName , keys ) ] ) ;
698660
699661 await connection . ExecuteWithErrorHandlingAsync (
700662 this . _collectionMetadata ,
@@ -706,8 +668,9 @@ await connection.ExecuteWithErrorHandlingAsync(
706668 connection ,
707669 this . _vectorTableName ,
708670 this . _keyStorageName ,
709- this . _model . Properties ,
710- storageModels ,
671+ this . _model ,
672+ records ,
673+ generatedEmbeddings ,
711674 data : false ) ;
712675
713676 await connection . ExecuteWithErrorHandlingAsync (
@@ -721,8 +684,9 @@ await connection.ExecuteWithErrorHandlingAsync(
721684 connection ,
722685 this . _dataTableName ,
723686 this . _keyStorageName ,
724- this . _model . Properties ,
725- storageModels ,
687+ this . _model ,
688+ records ,
689+ generatedEmbeddings ,
726690 data : true ,
727691 replaceIfExists : true ) ;
728692
@@ -732,18 +696,14 @@ await connection.ExecuteWithErrorHandlingAsync(
732696 ( ) => dataCommand . ExecuteReaderAsync ( cancellationToken ) ,
733697 cancellationToken ) . ConfigureAwait ( false ) ;
734698
735- var keys = new List < TKey > ( ) ;
736-
737699 while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
738700 {
739701 var key = reader . GetFieldValue < TKey > ( 0 ) ;
740702
741- keys . Add ( key ) ;
703+ // TODO: Inject the generated keys into the record for autogenerated keys.
742704
743705 await reader . NextResultAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
744706 }
745-
746- return keys ;
747707 }
748708
749709 private Task InternalDeleteBatchAsync ( SqliteConnection connection , SqliteWhereCondition condition , CancellationToken cancellationToken )
@@ -778,25 +738,6 @@ private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCo
778738 return Task . WhenAll ( tasks ) ;
779739 }
780740
781- private TRecord GetAndMapRecord (
782- DbDataReader reader ,
783- IReadOnlyList < PropertyModel > properties ,
784- bool includeVectors )
785- {
786- var storageModel = new Dictionary < string , object ? > ( ) ;
787-
788- foreach ( var property in properties )
789- {
790- if ( includeVectors || property is not VectorPropertyModel )
791- {
792- var propertyValue = SqlitePropertyMapping . GetPropertyValue ( reader , property . StorageName , property . Type ) ;
793- storageModel . Add ( property . StorageName , propertyValue ) ;
794- }
795- }
796-
797- return this . _mapper . MapFromStorageToDataModel ( storageModel , includeVectors ) ;
798- }
799-
800741#pragma warning disable CS0618 // VectorSearchFilter is obsolete
801742 private List < SqliteWhereCondition > ? GetFilterConditions ( VectorSearchFilter ? filter , string ? tableName = null )
802743 {
0 commit comments