diff options
author | Jiayu Hu <hujiayu@google.com> | 2024-03-18 14:23:44 -0700 |
---|---|---|
committer | Jiayu Hu <hujiayu@google.com> | 2024-03-18 15:47:26 -0700 |
commit | 555cb6e3295cf525baf46358235389ff52c9dcc2 (patch) | |
tree | 22dea3a10ec58bc8a54e07fb8ecb08720e4f8b13 | |
parent | 29d4712b67d1ade154739e5fa9a9a7970afe6c0a (diff) | |
download | icing-555cb6e3295cf525baf46358235389ff52c9dcc2.tar.gz |
Update Icing from upstream.
Descriptions:
========================================================================
Integration test for embedding search
========================================================================
Support embedding search in the advanced scoring language
========================================================================
Support embedding search in the advanced query language
========================================================================
Add missing header inclusions.
========================================================================
Assign default values to MemoryMappedFile members to avoid temporary object error during std::move
========================================================================
Use GetEmbeddingVector in EmbeddingIndex::TransferIndex
========================================================================
Refactor SectionRestrictData
========================================================================
Add EmbeddingIndex to IcingSearchEngine and create EmbeddingIndexingHandler
========================================================================
Support optimize for embedding index
========================================================================
Implement embedding search index
========================================================================
Branch posting list accessor and serializer to store embedding hits
========================================================================
Introduce embedding hit
========================================================================
Update schema and document definition
========================================================================
Add tests verifying that Icing will correctly save updated schema description fields.
========================================================================
BUG: 326987971
BUG: 326656531
BUG: 329747255
Change-Id: I84f0a5a3f7fe133ece16c567a1f5f44e7866fd77
88 files changed, 9665 insertions, 875 deletions
diff --git a/icing/document-builder.h b/icing/document-builder.h index 44500f9..5d6f14d 100644 --- a/icing/document-builder.h +++ b/icing/document-builder.h @@ -126,6 +126,13 @@ class DocumentBuilder { return AddDocumentProperty(std::move(property_name), {document_values...}); } + // Takes a property name and any number of vector values. + template <typename... V> + DocumentBuilder& AddVectorProperty(std::string property_name, + V... vector_values) { + return AddVectorProperty(std::move(property_name), {vector_values...}); + } + DocumentProto Build() const { return document_; } private: @@ -183,6 +190,17 @@ class DocumentBuilder { } return *this; } + + DocumentBuilder& AddVectorProperty( + std::string property_name, + std::initializer_list<PropertyProto::VectorProto> vector_values) { + auto property = document_.add_properties(); + property->set_name(std::move(property_name)); + for (PropertyProto::VectorProto vector_value : vector_values) { + property->mutable_vector_values()->Add(std::move(vector_value)); + } + return *this; + } }; } // namespace lib diff --git a/icing/file/memory-mapped-file.h b/icing/file/memory-mapped-file.h index 54507af..185d940 100644 --- a/icing/file/memory-mapped-file.h +++ b/icing/file/memory-mapped-file.h @@ -79,7 +79,6 @@ #include <algorithm> #include <cstdint> -#include <memory> #include <string> #include <string_view> @@ -314,7 +313,7 @@ class MemoryMappedFile { // Cached constructor params. const Filesystem* filesystem_; std::string file_path_; - Strategy strategy_; + Strategy strategy_ = Strategy::READ_WRITE_AUTO_SYNC; // Raw file related fields: // - max_file_size_ @@ -327,7 +326,7 @@ class MemoryMappedFile { // // Note: max_file_size_ will be specified in runtime and the caller should // make sure its value is correct and reasonable. - int64_t max_file_size_; + int64_t max_file_size_ = 0; // Cached file size to avoid calling system call too frequently. It is only // used in GrowAndRemapIfNecessary(), the new API that handles underlying file @@ -336,7 +335,7 @@ class MemoryMappedFile { // Note: it is guaranteed that file_size_ is smaller or equal to the actual // file size as long as the underlying file hasn't been truncated or deleted // externally. See GrowFileSize() for more details. - int64_t file_size_; + int64_t file_size_ = 0; // Memory mapped related fields: // - mmap_result_ @@ -345,22 +344,22 @@ class MemoryMappedFile { // - mmap_size_ // Raw pointer (or error) returned by calls to mmap(). - void* mmap_result_; + void* mmap_result_ = nullptr; // Offset within the file at which the current memory-mapped region starts. - int64_t file_offset_; + int64_t file_offset_ = 0; // Size that is currently memory-mapped. // Note that the mmapped size can be larger than the underlying file size. We // can reduce remapping by pre-mmapping a large memory and grow the file size // later. See GrowAndRemapIfNecessary(). - int64_t mmap_size_; + int64_t mmap_size_ = 0; // The difference between file_offset_ and the actual adjusted (aligned) // offset. // Since mmap requires the offset to be a multiple of system page size, we // have to align file_offset_ to the last multiple of system page size. - int64_t alignment_adjustment_; + int64_t alignment_adjustment_ = 0; // E.g. system_page_size = 5, RemapImpl(/*new_file_offset=*/8, mmap_size) // diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 72e744b..89caaf1 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -37,6 +37,8 @@ #include "icing/file/filesystem.h" #include "icing/file/version-util.h" #include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embedding-indexing-handler.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/index-processor.h" #include "icing/index/index.h" @@ -108,6 +110,7 @@ constexpr std::string_view kIndexSubfolderName = "index_dir"; constexpr std::string_view kIntegerIndexSubfolderName = "integer_index_dir"; constexpr std::string_view kQualifiedIdJoinIndexSubfolderName = "qualified_id_join_index_dir"; +constexpr std::string_view kEmbeddingIndexSubfolderName = "embedding_index_dir"; constexpr std::string_view kSchemaSubfolderName = "schema_dir"; constexpr std::string_view kSetSchemaMarkerFilename = "set_schema_marker"; constexpr std::string_view kInitMarkerFilename = "init_marker"; @@ -292,6 +295,11 @@ std::string MakeQualifiedIdJoinIndexWorkingPath(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kQualifiedIdJoinIndexSubfolderName); } +// Working path for embedding index. +std::string MakeEmbeddingIndexWorkingPath(const std::string& base_dir) { + return absl_ports::StrCat(base_dir, "/", kEmbeddingIndexSubfolderName); +} + // SchemaStore files are in a standalone subfolder for easier file management. // We can delete and recreate the subfolder and not touch/affect anything // else. @@ -478,6 +486,7 @@ void IcingSearchEngine::ResetMembers() { index_.reset(); integer_index_.reset(); qualified_id_join_index_.reset(); + embedding_index_.reset(); } libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile( @@ -668,15 +677,19 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( MakeIntegerIndexWorkingPath(options_.base_dir()); const std::string qualified_id_join_index_dir = MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()); + const std::string embedding_index_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); if (!filesystem_->DeleteDirectoryRecursively(doc_store_dir.c_str()) || !filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) || !IntegerIndex::Discard(*filesystem_, integer_index_dir).ok() || !QualifiedIdJoinIndex::Discard(*filesystem_, qualified_id_join_index_dir) - .ok()) { + .ok() || + !EmbeddingIndex::Discard(*filesystem_, embedding_index_dir).ok()) { return absl_ports::InternalError(absl_ports::StrCat( "Could not delete directories: ", index_dir, ", ", integer_index_dir, - ", ", qualified_id_join_index_dir, " and ", doc_store_dir)); + ", ", qualified_id_join_index_dir, ", ", embedding_index_dir, " and ", + doc_store_dir)); } ICING_ASSIGN_OR_RETURN( bool document_store_derived_files_regenerated, @@ -734,6 +747,15 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( CreateQualifiedIdJoinIndex( *filesystem_, std::move(qualified_id_join_index_dir), options_)); + // Discard embedding index directory and instantiate a new one. + std::string embedding_index_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); + ICING_RETURN_IF_ERROR( + EmbeddingIndex::Discard(*filesystem_, embedding_index_dir)); + ICING_ASSIGN_OR_RETURN( + embedding_index_, + EmbeddingIndex::Create(filesystem_.get(), embedding_index_dir)); + std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer(); IndexRestorationResult restore_result = RestoreIndexIfNeeded(); index_init_status = std::move(restore_result.status); @@ -756,6 +778,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); initialize_stats->set_qualified_id_join_index_restoration_cause( InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); + initialize_stats->set_embedding_index_restoration_cause( + InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); } else if (version_state_change != version_util::StateChange::kCompatible) { ICING_ASSIGN_OR_RETURN(bool document_store_derived_files_regenerated, InitializeDocumentStore( @@ -777,6 +801,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( InitializeStatsProto::VERSION_CHANGED); initialize_stats->set_qualified_id_join_index_restoration_cause( InitializeStatsProto::VERSION_CHANGED); + initialize_stats->set_embedding_index_restoration_cause( + InitializeStatsProto::VERSION_CHANGED); } else { ICING_ASSIGN_OR_RETURN( bool document_store_derived_files_regenerated, @@ -813,6 +839,7 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( initialize_stats->set_qualified_id_join_index_restoration_cause( InitializeStatsProto::FEATURE_FLAG_CHANGED); } + // TODO(b/326656531): Update version-util to consider embedding index. } if (status.ok()) { @@ -981,6 +1008,30 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex( } } + // Embedding index + const std::string embedding_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); + InitializeStatsProto::RecoveryCause embedding_index_recovery_cause; + auto embedding_index_or = + EmbeddingIndex::Create(filesystem_.get(), embedding_dir); + if (!embedding_index_or.ok()) { + ICING_RETURN_IF_ERROR(EmbeddingIndex::Discard(*filesystem_, embedding_dir)); + + embedding_index_recovery_cause = InitializeStatsProto::IO_ERROR; + + // Try recreating it from scratch and re-indexing everything. + ICING_ASSIGN_OR_RETURN( + embedding_index_, + EmbeddingIndex::Create(filesystem_.get(), embedding_dir)); + } else { + // Embedding index was created fine. + embedding_index_ = std::move(embedding_index_or).ValueOrDie(); + // If a recover does have to happen, then it must be because the index is + // out of sync with the document store. + embedding_index_recovery_cause = + InitializeStatsProto::INCONSISTENT_WITH_GROUND_TRUTH; + } + std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer(); IndexRestorationResult restore_result = RestoreIndexIfNeeded(); if (restore_result.index_needed_restoration || @@ -1000,6 +1051,10 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex( initialize_stats->set_qualified_id_join_index_restoration_cause( qualified_id_join_index_recovery_cause); } + if (restore_result.embedding_index_needed_restoration) { + initialize_stats->set_embedding_index_restoration_cause( + embedding_index_recovery_cause); + } } return restore_result.status; } @@ -1515,9 +1570,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); // Gets unordered results from query processor auto query_processor_or = QueryProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get(), - clock_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); delete_stats->set_parse_query_latency_ms( @@ -1713,6 +1768,15 @@ OptimizeResultProto IcingSearchEngine::Optimize() { << qualified_id_join_index_optimize_status.error_message(); should_rebuild_index = true; } + + libtextclassifier3::Status embedding_index_optimize_status = + embedding_index_->Optimize(optimize_result.document_id_old_to_new, + document_store_->last_added_document_id()); + if (!embedding_index_optimize_status.ok()) { + ICING_LOG(WARNING) << "Failed to optimize embedding index. Error: " + << embedding_index_optimize_status.error_message(); + should_rebuild_index = true; + } } // If we received a DATA_LOSS error from OptimizeDocumentStore, we have a // valid document store, but it might be the old one or the new one. So throw @@ -1939,6 +2003,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk( ICING_RETURN_IF_ERROR(index_->PersistToDisk()); ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk()); ICING_RETURN_IF_ERROR(qualified_id_join_index_->PersistToDisk()); + ICING_RETURN_IF_ERROR(embedding_index_->PersistToDisk()); return libtextclassifier3::Status::OK; } @@ -2220,9 +2285,9 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( // Gets unordered results from query processor auto query_processor_or = QueryProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get(), - clock_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!query_processor_or.ok()) { search_stats->set_parse_query_latency_ms( component_timer->GetElapsedMilliseconds()); @@ -2266,8 +2331,10 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( // Scores but does not rank the results. libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> scoring_processor_or = ScoringProcessor::Create( - scoring_spec, document_store_.get(), schema_store_.get(), - current_time_ms, join_children_fetcher); + scoring_spec, /*default_semantic_metric_type=*/ + search_spec.embedding_query_metric_type(), document_store_.get(), + schema_store_.get(), current_time_ms, join_children_fetcher, + &query_results.embedding_query_results); if (!scoring_processor_or.ok()) { return QueryScoringResults(std::move(scoring_processor_or).status(), std::move(query_results.query_terms), @@ -2504,27 +2571,28 @@ IcingSearchEngine::RestoreIndexIfNeeded() { if (last_stored_document_id == index_->last_added_document_id() && last_stored_document_id == integer_index_->last_added_document_id() && last_stored_document_id == - qualified_id_join_index_->last_added_document_id()) { + qualified_id_join_index_->last_added_document_id() && + last_stored_document_id == embedding_index_->last_added_document_id()) { // No need to recover. - return {libtextclassifier3::Status::OK, false, false, false}; + return {libtextclassifier3::Status::OK, false, false, false, false}; } if (last_stored_document_id == kInvalidDocumentId) { // Document store is empty but index is not. Clear the index. - return {ClearAllIndices(), false, false, false}; + return {ClearAllIndices(), false, false, false, false}; } // Truncate indices first. auto truncate_result_or = TruncateIndicesTo(last_stored_document_id); if (!truncate_result_or.ok()) { - return {std::move(truncate_result_or).status(), false, false, false}; + return {std::move(truncate_result_or).status(), false, false, false, false}; } TruncateIndexResult truncate_result = std::move(truncate_result_or).ValueOrDie(); if (truncate_result.first_document_to_reindex > last_stored_document_id) { // Nothing to restore. Just return. - return {libtextclassifier3::Status::OK, false, false, false}; + return {libtextclassifier3::Status::OK, false, false, false, false}; } auto data_indexing_handlers_or = CreateDataIndexingHandlers(); @@ -2532,7 +2600,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {data_indexing_handlers_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } // By using recovery_mode for IndexProcessor, we're able to replay documents // from smaller document id and it will skip documents that are already been @@ -2559,7 +2628,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { // Returns other errors return {document_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } } DocumentProto document(std::move(document_or).ValueOrDie()); @@ -2572,7 +2642,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {tokenized_document_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } TokenizedDocument tokenized_document( std::move(tokenized_document_or).ValueOrDie()); @@ -2584,7 +2655,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { // Real error. Stop recovering and pass it up. return {status, truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } // FIXME: why can we skip data loss error here? // Just a data loss. Keep trying to add the remaining docs, but report the @@ -2595,7 +2667,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {overall_status, truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() { @@ -2648,6 +2721,11 @@ IcingSearchEngine::CreateDataIndexingHandlers() { clock_.get(), document_store_.get(), qualified_id_join_index_.get())); handlers.push_back(std::move(qualified_id_join_indexing_handler)); + // Embedding index handler + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<EmbeddingIndexingHandler> embedding_indexing_handler, + EmbeddingIndexingHandler::Create(clock_.get(), embedding_index_.get())); + handlers.push_back(std::move(embedding_indexing_handler)); return handlers; } @@ -2736,10 +2814,43 @@ IcingSearchEngine::TruncateIndicesTo(DocumentId last_stored_document_id) { first_document_to_reindex = kMinDocumentId; } + // Attempt to truncate embedding index + bool embedding_index_needed_restoration = false; + DocumentId embedding_index_last_added_document_id = + embedding_index_->last_added_document_id(); + if (embedding_index_last_added_document_id == kInvalidDocumentId || + last_stored_document_id > embedding_index_last_added_document_id) { + // If last_stored_document_id is greater than + // embedding_index_last_added_document_id, then we only have to replay docs + // starting from (embedding_index_last_added_document_id + 1). Also use + // std::min since we might need to replay even smaller doc ids for other + // components. + embedding_index_needed_restoration = true; + if (embedding_index_last_added_document_id != kInvalidDocumentId) { + first_document_to_reindex = + std::min(first_document_to_reindex, + embedding_index_last_added_document_id + 1); + } else { + first_document_to_reindex = kMinDocumentId; + } + } else if (last_stored_document_id < embedding_index_last_added_document_id) { + // Clear the entire embedding index if last_stored_document_id is + // smaller than embedding_index_last_added_document_id, because + // there is no way to remove data with doc_id > last_stored_document_id from + // embedding index efficiently and we have to rebuild. + ICING_RETURN_IF_ERROR(embedding_index_->Clear()); + + // Since the entire embedding index is discarded, we start to + // rebuild it by setting first_document_to_reindex to kMinDocumentId. + embedding_index_needed_restoration = true; + first_document_to_reindex = kMinDocumentId; + } + return TruncateIndexResult(first_document_to_reindex, index_needed_restoration, integer_index_needed_restoration, - qualified_id_join_index_needed_restoration); + qualified_id_join_index_needed_restoration, + embedding_index_needed_restoration); } libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles( @@ -2750,7 +2861,7 @@ libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles( if (schema_store_ != nullptr || document_store_ != nullptr || index_ != nullptr || integer_index_ != nullptr || - qualified_id_join_index_ != nullptr) { + qualified_id_join_index_ != nullptr || embedding_index_ != nullptr) { return absl_ports::FailedPreconditionError( "Cannot discard derived files while having valid instances"); } @@ -2792,12 +2903,15 @@ libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles( } } + // TODO(b/326656531): Update version-util to consider embedding index. + return libtextclassifier3::Status::OK; } libtextclassifier3::Status IcingSearchEngine::ClearSearchIndices() { ICING_RETURN_IF_ERROR(index_->Reset()); ICING_RETURN_IF_ERROR(integer_index_->Clear()); + ICING_RETURN_IF_ERROR(embedding_index_->Clear()); return libtextclassifier3::Status::OK; } @@ -2870,9 +2984,9 @@ SuggestionResponse IcingSearchEngine::SearchSuggestions( // Create the suggestion processor. auto suggestion_processor_or = SuggestionProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get(), - clock_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!suggestion_processor_or.ok()) { TransformStatus(suggestion_processor_or.status(), response_status); return response; diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index b9df95d..57f0f28 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -28,6 +28,7 @@ #include "icing/file/filesystem.h" #include "icing/file/version-util.h" #include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/numeric-index.h" #include "icing/jni/jni-cache.h" @@ -483,6 +484,9 @@ class IcingSearchEngine { std::unique_ptr<QualifiedIdJoinIndex> qualified_id_join_index_ ICING_GUARDED_BY(mutex_); + // Storage for all hits of embedding contents from the document store. + std::unique_ptr<EmbeddingIndex> embedding_index_ ICING_GUARDED_BY(mutex_); + // Pointer to JNI class references const std::unique_ptr<const JniCache> jni_cache_; @@ -701,6 +705,7 @@ class IcingSearchEngine { bool index_needed_restoration; bool integer_index_needed_restoration; bool qualified_id_join_index_needed_restoration; + bool embedding_index_needed_restoration; }; IndexRestorationResult RestoreIndexIfNeeded() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -741,17 +746,21 @@ class IcingSearchEngine { bool index_needed_restoration; bool integer_index_needed_restoration; bool qualified_id_join_index_needed_restoration; + bool embedding_index_needed_restoration; explicit TruncateIndexResult( DocumentId first_document_to_reindex_in, bool index_needed_restoration_in, bool integer_index_needed_restoration_in, - bool qualified_id_join_index_needed_restoration_in) + bool qualified_id_join_index_needed_restoration_in, + bool embedding_index_needed_restoration_in) : first_document_to_reindex(first_document_to_reindex_in), index_needed_restoration(index_needed_restoration_in), integer_index_needed_restoration(integer_index_needed_restoration_in), qualified_id_join_index_needed_restoration( - qualified_id_join_index_needed_restoration_in) {} + qualified_id_join_index_needed_restoration_in), + embedding_index_needed_restoration( + embedding_index_needed_restoration_in) {} }; libtextclassifier3::StatusOr<TruncateIndexResult> TruncateIndicesTo( DocumentId last_stored_document_id) diff --git a/icing/icing-search-engine_schema_test.cc b/icing/icing-search-engine_schema_test.cc index 49c024e..34081e9 100644 --- a/icing/icing-search-engine_schema_test.cc +++ b/icing/icing-search-engine_schema_test.cc @@ -36,7 +36,6 @@ #include "icing/proto/persist.pb.h" #include "icing/proto/reset.pb.h" #include "icing/proto/schema.pb.h" -#include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/status.pb.h" #include "icing/proto/storage.pb.h" @@ -60,6 +59,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::Eq; using ::testing::HasSubstr; +using ::testing::Not; using ::testing::Return; // For mocking purpose, we allow tests to provide a custom Filesystem. @@ -3154,6 +3154,86 @@ TEST_F(IcingSearchEngineSchemaTest, IcingShouldReturnErrorForExtraSections) { HasSubstr("Too many properties to be indexed")); } +TEST_F(IcingSearchEngineSchemaTest, UpdatedTypeDescriptionIsSaved) { + // Create a schema with more sections than allowed. + PropertyConfigProto old_property = + PropertyConfigBuilder() + .SetName("prop0") + .SetDescription("old property description") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); + SchemaTypeConfigProto old_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(old_property) + .Build(); + SchemaProto old_schema = + SchemaBuilder().AddType(old_schema_type_config).Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk()); + + // Update the type description + SchemaTypeConfigProto new_schema_type_config = + SchemaTypeConfigBuilder(old_schema_type_config) + .SetDescription("new description") + .Build(); + SchemaProto new_schema = + SchemaBuilder().AddType(new_schema_type_config).Build(); + ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk()); + + GetSchemaResultProto get_result = icing.GetSchema(); + ASSERT_THAT(get_result.status(), ProtoIsOk()); + ASSERT_THAT(get_result.schema(), EqualsProto(new_schema)); + ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema))); +} + +TEST_F(IcingSearchEngineSchemaTest, UpdatedPropertyDescriptionIsSaved) { + // Create a schema with more sections than allowed. + PropertyConfigProto old_property = + PropertyConfigBuilder() + .SetName("prop0") + .SetDescription("old property description") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); + SchemaTypeConfigProto old_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(old_property) + .Build(); + SchemaProto old_schema = + SchemaBuilder().AddType(old_schema_type_config).Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk()); + + // Update the property description + PropertyConfigProto new_property = + PropertyConfigBuilder(old_property) + .SetDescription("new property description") + .Build(); + SchemaTypeConfigProto new_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(new_property) + .Build(); + SchemaProto new_schema = + SchemaBuilder().AddType(new_schema_type_config).Build(); + ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk()); + + GetSchemaResultProto get_result = icing.GetSchema(); + ASSERT_THAT(get_result.status(), ProtoIsOk()); + ASSERT_THAT(get_result.schema(), EqualsProto(new_schema)); + ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema))); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/icing-search-engine_search_test.cc b/icing/icing-search-engine_search_test.cc index d815f61..a58dbc8 100644 --- a/icing/icing-search-engine_search_test.cc +++ b/icing/icing-search-engine_search_test.cc @@ -13,12 +13,14 @@ // limitations under the License. #include <cstdint> +#include <initializer_list> #include <limits> #include <memory> #include <string> +#include <string_view> #include <utility> +#include <vector> -#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -27,7 +29,7 @@ #include "icing/index/lite/term-id-hit-pair.h" #include "icing/jni/jni-cache.h" #include "icing/join/join-processor.h" -#include "icing/portable/endian.h" +#include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/equals-proto.h" #include "icing/portable/platform.h" #include "icing/proto/debug.pb.h" @@ -49,11 +51,13 @@ #include "icing/result/result-state-manager.h" #include "icing/schema-builder.h" #include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" #include "icing/testing/fake-clock.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/clock.h" #include "icing/util/snippet-helpers.h" namespace icing { @@ -63,6 +67,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::DoubleEq; +using ::testing::DoubleNear; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Gt; @@ -120,6 +125,8 @@ class IcingSearchEngineSearchTest // Non-zero value so we don't override it to be the current time constexpr int64_t kDefaultCreationTimestampMs = 1575492852000; +constexpr double kEps = 0.000001; + IcingSearchEngineOptions GetDefaultIcingOptions() { IcingSearchEngineOptions icing_options; icing_options.set_base_dir(GetTestBaseDir()); @@ -7306,6 +7313,208 @@ TEST_P(IcingSearchEngineSearchTest, HasPropertyQueryNestedDocument) { EXPECT_THAT(results.results(), IsEmpty()); } +TEST_P(IcingSearchEngineSearchTest, EmbeddingSearch) { + if (GetParam() != + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + GTEST_SKIP() << "Embedding search is only supported in advanced query."; + } + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName("embedding1") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName("embedding2") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED))) + .Build(); + DocumentProto document0 = + DocumentBuilder() + .SetKey("icing", "uri0") + .SetSchema("Email") + .SetCreationTimestampMs(1) + .AddStringProperty("body", "foo") + .AddVectorProperty( + "embedding1", + CreateVector("my_model_v1", {0.1, 0.2, 0.3, 0.4, 0.5})) + .AddVectorProperty( + "embedding2", + CreateVector("my_model_v1", {-0.1, -0.2, -0.3, 0.4, 0.5}), + CreateVector("my_model_v2", {0.6, 0.7, 0.8})) + .Build(); + DocumentProto document1 = + DocumentBuilder() + .SetKey("icing", "uri1") + .SetSchema("Email") + .SetCreationTimestampMs(1) + .AddVectorProperty( + "embedding1", + CreateVector("my_model_v1", {-0.1, 0.2, -0.3, -0.4, 0.5})) + .AddVectorProperty("embedding2", + CreateVector("my_model_v2", {0.6, 0.7, -0.8})) + .Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document0).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_embedding_query_metric_type( + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT); + search_spec.add_enabled_features( + std::string(kListFilterQueryLanguageFeature)); + search_spec.add_enabled_features(std::string(kEmbeddingSearchFeature)); + search_spec.set_search_type(GetParam()); + // Add an embedding query with semantic scores: + // - document 0: -0.5 (embedding1), 0.3 (embedding2) + // - document 1: -0.9 (embedding1) + *search_spec.add_embedding_query_vectors() = + CreateVector("my_model_v1", {1, -1, -1, 1, -1}); + // Add an embedding query with semantic scores: + // - document 0: -0.5 (embedding2) + // - document 1: -2.1 (embedding2) + *search_spec.add_embedding_query_vectors() = + CreateVector("my_model_v2", {-1, -1, 1}); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + + // Match documents that have embeddings with a similarity closer to 0 that is + // greater than -1. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1), 0.3 (embedding2) + // - document 1: -0.9 (embedding1) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5, 0.3}) + sum({}) = -0.2 + // - document 1: sum({-0.9}) + sum({}) = -0.9 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(0), -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + SearchResultProto results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps)); + + // Create a query the same as above but with a section restriction, which + // still matches document 0 and document 1 but the semantic score 0.3 should + // be removed from document 0. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1) + // - document 1: -0.9 (embedding1) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5}) = -0.5 + // - document 1: sum({-0.9}) = -0.9 + search_spec.set_query( + "embedding1:semanticSearch(getSearchSpecEmbedding(0), -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps)); + + // Create a query that only matches document 0. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5}) = -0.5 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -1.5)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps)); + + // Create a query that only matches document 1. + // + // The matched embeddings for each doc are: + // - document 1: -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 1: sum({-2.1}) = -2.1 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -10, -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-2.1, kEps)); + + // Create a complex query that matches all hits from all documents. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1), 0.3 (embedding2), -0.5 (embedding2) + // - document 1: -0.9 (embedding1), -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5, 0.3}) + sum({-0.5}) = -0.7 + // - document 1: sum({-0.9}) + sum({-2.1}) = -3 + search_spec.set_query( + "semanticSearch(getSearchSpecEmbedding(0)) OR " + "semanticSearch(getSearchSpecEmbedding(1))"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3 - 0.5, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9 - 2.1, kEps)); + + // Create a hybrid query that matches document 0 because of term-based search + // and document 1 because of embedding-based search. + // + // The matched embeddings for each doc are: + // - document 1: -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({}) = 0 + // - document 1: sum({-2.1}) = -2.1 + search_spec.set_query( + "foo OR semanticSearch(getSearchSpecEmbedding(1), -10, -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + // Document 0 has no matched embedding hit, so its score is 0. + EXPECT_THAT(results.results(0).score(), DoubleNear(0, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-2.1, kEps)); +} + INSTANTIATE_TEST_SUITE_P( IcingSearchEngineSearchTest, IcingSearchEngineSearchTest, testing::Values( diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.cc b/icing/index/embed/doc-hit-info-iterator-embedding.cc new file mode 100644 index 0000000..8f2bc7d --- /dev/null +++ b/icing/index/embed/doc-hit-info-iterator-embedding.cc @@ -0,0 +1,164 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/doc-hit-info-iterator-embedding.h" + +#include <memory> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/embed/embedding-scorer.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" +#include "icing/index/iterator/section-restrict-data.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>> +DocHitInfoIteratorEmbedding::Create( + const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + double score_low, double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index) { + ICING_RETURN_ERROR_IF_NULL(query); + ICING_RETURN_ERROR_IF_NULL(embedding_index); + ICING_RETURN_ERROR_IF_NULL(score_map); + + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = embedding_index->GetAccessorForVector(*query); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + // A not-found error should be fine, since that means there is no matching + // embedding hits in the index. + pl_accessor = nullptr; + } else { + // Otherwise, return the error as is. + return pl_accessor_or.status(); + } + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer, + EmbeddingScorer::Create(metric_type)); + + return std::unique_ptr<DocHitInfoIteratorEmbedding>( + new DocHitInfoIteratorEmbedding(query, std::move(section_restrict_data), + metric_type, std::move(embedding_scorer), + score_low, score_high, score_map, + embedding_index, std::move(pl_accessor))); +} + +libtextclassifier3::StatusOr<const EmbeddingHit*> +DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() { + if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) { + ICING_ASSIGN_OR_RETURN(cached_embedding_hits_, + posting_list_accessor_->GetNextHitsBatch()); + cached_embedding_hits_idx_ = 0; + if (cached_embedding_hits_.empty()) { + no_more_hit_ = true; + return nullptr; + } + } + const EmbeddingHit& embedding_hit = + cached_embedding_hits_[cached_embedding_hits_idx_]; + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id()); + if (section_restrict_data_ != nullptr) { + current_allowed_sections_mask_ = + section_restrict_data_->ComputeAllowedSectionsMask( + doc_hit_info_.document_id()); + } + } else if (doc_hit_info_.document_id() != + embedding_hit.basic_hit().document_id()) { + return nullptr; + } + ++cached_embedding_hits_idx_; + return &embedding_hit; +} + +libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() { + if (no_more_hit_ || posting_list_accessor_ == nullptr) { + return absl_ports::ResourceExhaustedError( + "No more DocHitInfos in iterator"); + } + + doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone); + std::vector<double>* matched_scores = nullptr; + current_allowed_sections_mask_ = kSectionIdMaskAll; + while (true) { + ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit, + AdvanceToNextEmbeddingHit()); + if (embedding_hit == nullptr) { + // No more hits for the current document. + break; + } + + // Filter out the embedding hit according to the section restriction. + if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) & + current_allowed_sections_mask_) == 0) { + continue; + } + + // Calculate the semantic score. + int dimension = query_.values_size(); + ICING_ASSIGN_OR_RETURN( + const float* vector, + embedding_index_.GetEmbeddingVector(*embedding_hit, dimension)); + double semantic_score = + embedding_scorer_->Score(dimension, + /*v1=*/query_.values().data(), + /*v2=*/vector); + + // If the semantic score is within the desired score range, update + // doc_hit_info_ and score_map_. + if (score_low_ <= semantic_score && semantic_score <= score_high_) { + doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id()); + if (matched_scores == nullptr) { + matched_scores = &(score_map_[doc_hit_info_.document_id()]); + } + matched_scores->push_back(semantic_score); + } + } + + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + return absl_ports::ResourceExhaustedError( + "No more DocHitInfos in iterator"); + } + + ++num_advance_calls_; + + // Skip the current document if it has no vector matched. + if (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone) { + return Advance(); + } + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.h b/icing/index/embed/doc-hit-info-iterator-embedding.h new file mode 100644 index 0000000..21180db --- /dev/null +++ b/icing/index/embed/doc-hit-info-iterator-embedding.h @@ -0,0 +1,161 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ +#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ + +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/embed/embedding-scorer.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/iterator/section-restrict-data.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" + +namespace icing { +namespace lib { + +class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator { + public: + // Create a DocHitInfoIterator for iterating through all docs which have an + // embedding matched with the provided query with a score in the range of + // [score_low, score_high], using the provided metric_type. + // + // The iterator will store the matched embedding scores in score_map to + // prepare for scoring. + // + // The iterator will handle the section restriction logic internally by the + // provided section_restrict_data. + // + // Returns: + // - a DocHitInfoIteratorEmbedding instance on success. + // - Any error from posting lists. + static libtextclassifier3::StatusOr< + std::unique_ptr<DocHitInfoIteratorEmbedding>> + Create(const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + double score_low, double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index); + + libtextclassifier3::Status Advance() override; + + // The iterator will internally handle the section restriction logic by itself + // to have better control, so that it is able to filter out embedding hits + // from unwanted sections to avoid retrieving unnecessary vectors and + // calculate scores for them. + bool full_section_restriction_applied() const override { return true; } + + libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override { + return absl_ports::InvalidArgumentError( + "Query suggestions for the semanticSearch function are not supported"); + } + + CallStats GetCallStats() const override { + return CallStats( + /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_, + /*num_leaf_advance_calls_main_index_in=*/0, + /*num_leaf_advance_calls_integer_index_in=*/0, + /*num_leaf_advance_calls_no_index_in=*/0, + /*num_blocks_inspected_in=*/0); + } + + std::string ToString() const override { return "embedding_iterator"; } + + // PopulateMatchedTermsStats is not applicable to embedding search. + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats, + SectionIdMask filtering_section_mask) const override {} + + private: + explicit DocHitInfoIteratorEmbedding( + const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low, + double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index, + std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor) + : query_(*query), + section_restrict_data_(std::move(section_restrict_data)), + metric_type_(metric_type), + embedding_scorer_(std::move(embedding_scorer)), + score_low_(score_low), + score_high_(score_high), + score_map_(*score_map), + embedding_index_(*embedding_index), + posting_list_accessor_(std::move(posting_list_accessor)), + cached_embedding_hits_idx_(0), + current_allowed_sections_mask_(kSectionIdMaskAll), + no_more_hit_(false), + num_advance_calls_(0) {} + + // Advance to the next embedding hit of the current document. If the current + // document id is kInvalidDocumentId, the method will advance to the first + // embedding hit of the next document and update doc_hit_info_. + // + // This method also properly updates cached_embedding_hits_, + // cached_embedding_hits_idx_, current_allowed_sections_mask_, and + // no_more_hit_ to reflect the current state. + // + // Returns: + // - a const pointer to the next embedding hit on success. + // - nullptr, if there is no more hit for the current document, or no more + // hit in general if the current document id is kInvalidDocumentId. + // - Any error from posting lists. + libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit(); + + // Query information + const PropertyProto::VectorProto& query_; // Does not own + std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable. + + // Scoring arguments + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; + std::unique_ptr<EmbeddingScorer> embedding_scorer_; + double score_low_; + double score_high_; + + // Score map + EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own + + // Access to embeddings index data + const EmbeddingIndex& embedding_index_; + std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_; + + // Cached data from the embeddings index + std::vector<EmbeddingHit> cached_embedding_hits_; + int cached_embedding_hits_idx_; + SectionIdMask current_allowed_sections_mask_; + bool no_more_hit_; + + int num_advance_calls_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ diff --git a/icing/index/embed/embedding-hit.h b/icing/index/embed/embedding-hit.h new file mode 100644 index 0000000..165a123 --- /dev/null +++ b/icing/index/embed/embedding-hit.h @@ -0,0 +1,67 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_HIT_H_ +#define ICING_INDEX_EMBED_EMBEDDING_HIT_H_ + +#include <cstdint> + +#include "icing/index/hit/hit.h" + +namespace icing { +namespace lib { + +class EmbeddingHit { + public: + // Order: basic_hit, location + // Value bits layout: 32 basic_hit + 32 location + using Value = uint64_t; + + // WARNING: Changing this value will invalidate any pre-existing posting lists + // on user devices. + // + // kInvalidValue contains: + // - BasicHit of value 0, which is invalid. + // - location of 0 (valid), which is ok because BasicHit is already invalid. + static_assert(BasicHit::kInvalidValue == 0); + static constexpr Value kInvalidValue = 0; + + explicit EmbeddingHit(BasicHit basic_hit, uint32_t location) { + value_ = (static_cast<uint64_t>(basic_hit.value()) << 32) | location; + } + + explicit EmbeddingHit(Value value) : value_(value) {} + + // BasicHit contains the document id and the section id. + BasicHit basic_hit() const { return BasicHit(value_ >> 32); } + // The location of the referred embedding vector in the vector storage. + uint32_t location() const { return value_ & 0xFFFFFFFF; }; + + bool is_valid() const { return basic_hit().is_valid(); } + Value value() const { return value_; } + + bool operator<(const EmbeddingHit& h2) const { return value_ < h2.value_; } + bool operator==(const EmbeddingHit& h2) const { return value_ == h2.value_; } + + private: + Value value_; +}; + +static_assert(sizeof(BasicHit) == 4); +static_assert(sizeof(EmbeddingHit) == 8); + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_HIT_H_ diff --git a/icing/index/embed/embedding-hit_test.cc b/icing/index/embed/embedding-hit_test.cc new file mode 100644 index 0000000..31a0b6c --- /dev/null +++ b/icing/index/embed/embedding-hit_test.cc @@ -0,0 +1,80 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-hit.h" + +#include <algorithm> +#include <cstdint> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/index/hit/hit.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsFalse; + +static constexpr DocumentId kSomeDocumentId = 24; +static constexpr SectionId kSomeSectionid = 5; +static constexpr uint32_t kSomeLocation = 123; + +TEST(EmbeddingHitTest, Accessors) { + BasicHit basic_hit(kSomeSectionid, kSomeDocumentId); + EmbeddingHit embedding_hit(basic_hit, kSomeLocation); + EXPECT_THAT(embedding_hit.basic_hit(), Eq(basic_hit)); + EXPECT_THAT(embedding_hit.location(), Eq(kSomeLocation)); +} + +TEST(EmbeddingHitTest, Invalid) { + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EXPECT_THAT(invalid_hit.is_valid(), IsFalse()); + + // Also make sure the invalid EmbeddingHit contains an invalid document id. + EXPECT_THAT(invalid_hit.basic_hit().document_id(), Eq(kInvalidDocumentId)); + EXPECT_THAT(invalid_hit.basic_hit().section_id(), Eq(kMinSectionId)); + EXPECT_THAT(invalid_hit.location(), Eq(0)); +} + +TEST(EmbeddingHitTest, Comparison) { + // Create basic hits with basic_hit1 < basic_hit2 < basic_hit3. + BasicHit basic_hit1(/*section_id=*/1, /*document_id=*/2409); + BasicHit basic_hit2(/*section_id=*/1, /*document_id=*/243); + BasicHit basic_hit3(/*section_id=*/15, /*document_id=*/243); + + // Embedding hits are sorted by BasicHit first, and then by location. + // So embedding_hit3 < embedding_hit4 < embedding_hit2 < embedding_hit1. + EmbeddingHit embedding_hit1(basic_hit3, /*location=*/10); + EmbeddingHit embedding_hit2(basic_hit3, /*location=*/0); + EmbeddingHit embedding_hit3(basic_hit1, /*location=*/100); + EmbeddingHit embedding_hit4(basic_hit2, /*location=*/0); + + std::vector<EmbeddingHit> embedding_hits{embedding_hit1, embedding_hit2, + embedding_hit3, embedding_hit4}; + std::sort(embedding_hits.begin(), embedding_hits.end()); + EXPECT_THAT(embedding_hits, ElementsAre(embedding_hit3, embedding_hit4, + embedding_hit2, embedding_hit1)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-index.cc b/icing/index/embed/embedding-index.cc new file mode 100644 index 0000000..2e70f0e --- /dev/null +++ b/icing/index/embed/embedding-index.cc @@ -0,0 +1,440 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-index.h" + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/store/document-id.h" +#include "icing/store/dynamic-trie-key-mapper.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" +#include "icing/util/encode-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +constexpr uint32_t kEmbeddingHitListMapperMaxSize = + 128 * 1024 * 1024; // 128 MiB; + +// The maximum length returned by encode_util::EncodeIntToCString is 5 for +// uint32_t. +constexpr uint32_t kEncodedDimensionLength = 5; + +std::string GetMetadataFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/metadata"); +} + +std::string GetFlashIndexStorageFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/flash_index_storage"); +} + +std::string GetEmbeddingHitListMapperPath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper"); +} + +std::string GetEmbeddingVectorsFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/embedding_vectors"); +} + +// An injective function that maps the ordered pair (dimension, model_signature) +// to a string, which is used to form a key for embedding_posting_list_mapper_. +std::string GetPostingListKey(uint32_t dimension, + std::string_view model_signature) { + std::string encoded_dimension_str = + encode_util::EncodeIntToCString(dimension); + // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes. + while (encoded_dimension_str.size() < kEncodedDimensionLength) { + // C string cannot contain 0 bytes, so we append it using 1, just like what + // we do in encode_util::EncodeIntToCString. + // + // The reason that this works is because DecodeIntToString decodes a byte + // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded + // dimension that is less than 5 bytes, it means that the dimension contains + // unencoded leading 0x00. So here we're explicitly encoding those bytes as + // 0x01. + encoded_dimension_str.push_back(1); + } + return absl_ports::StrCat(encoded_dimension_str, model_signature); +} + +std::string GetPostingListKey(const PropertyProto::VectorProto& vector) { + return GetPostingListKey(vector.values_size(), vector.model_signature()); +} + +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> +EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) { + ICING_RETURN_ERROR_IF_NULL(filesystem); + + std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>( + new EmbeddingIndex(*filesystem, std::move(working_path))); + ICING_RETURN_IF_ERROR(index->Initialize()); + return index; +} + +libtextclassifier3::Status EmbeddingIndex::Initialize() { + bool is_new = false; + if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) { + // Create working directory. + if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) { + return absl_ports::InternalError( + absl_ports::StrCat("Failed to create directory: ", working_path_)); + } + is_new = true; + } + + ICING_ASSIGN_OR_RETURN( + MemoryMappedFile metadata_mmapped_file, + MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/kMetadataFileSize, + /*pre_mapping_file_offset=*/0, + /*pre_mapping_mmap_size=*/kMetadataFileSize)); + metadata_mmapped_file_ = + std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file)); + + ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage, + FlashIndexStorage::Create( + GetFlashIndexStorageFilePath(working_path_), + &filesystem_, posting_list_hit_serializer_.get())); + flash_index_storage_ = + std::make_unique<FlashIndexStorage>(std::move(flash_index_storage)); + + ICING_ASSIGN_OR_RETURN( + embedding_posting_list_mapper_, + DynamicTrieKeyMapper<PostingListIdentifier>::Create( + filesystem_, GetEmbeddingHitListMapperPath(working_path_), + kEmbeddingHitListMapperMaxSize)); + + ICING_ASSIGN_OR_RETURN( + embedding_vectors_, + FileBackedVector<float>::Create( + filesystem_, GetEmbeddingVectorsFilePath(working_path_), + MemoryMappedFile::READ_WRITE_AUTO_SYNC)); + + if (is_new) { + ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary( + /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize)); + info().magic = Info::kMagic; + info().last_added_document_id = kInvalidDocumentId; + ICING_RETURN_IF_ERROR(InitializeNewStorage()); + } else { + if (metadata_mmapped_file_->available_size() != kMetadataFileSize) { + return absl_ports::FailedPreconditionError( + "Incorrect metadata file size"); + } + if (info().magic != Info::kMagic) { + return absl_ports::FailedPreconditionError("Incorrect magic value"); + } + ICING_RETURN_IF_ERROR(InitializeExistingStorage()); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::Clear() { + pending_embedding_hits_.clear(); + metadata_mmapped_file_.reset(); + flash_index_storage_.reset(); + embedding_posting_list_mapper_.reset(); + embedding_vectors_.reset(); + if (filesystem_.DirectoryExists(working_path_.c_str())) { + ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_)); + } + is_initialized_ = false; + return Initialize(); +} + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +EmbeddingIndex::GetAccessor(uint32_t dimension, + std::string_view model_signature) const { + if (dimension == 0) { + return absl_ports::InvalidArgumentError("Dimension is 0"); + } + std::string key = GetPostingListKey(dimension, model_signature); + ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id, + embedding_posting_list_mapper_->Get(key)); + return PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + posting_list_id); +} + +libtextclassifier3::Status EmbeddingIndex::BufferEmbedding( + const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) { + if (vector.values_size() == 0) { + return absl_ports::InvalidArgumentError("Vector dimension is 0"); + } + + uint32_t location = embedding_vectors_->num_elements(); + uint32_t dimension = vector.values_size(); + std::string key = GetPostingListKey(vector); + + // Buffer the embedding hit. + pending_embedding_hits_.push_back( + {std::move(key), EmbeddingHit(basic_hit, location)}); + + // Put vector + ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr, + embedding_vectors_->Allocate(dimension)); + mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() { + std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end()); + auto iter_curr_key = pending_embedding_hits_.rbegin(); + while (iter_curr_key != pending_embedding_hits_.rend()) { + // In order to batch putting embedding hits with the same key (dimension, + // model_signature) to the same posting list, we find the range + // [iter_curr_key, iter_next_key) of embedding hits with the same key and + // put them into their corresponding posting list together. + auto iter_next_key = iter_curr_key; + while (iter_next_key != pending_embedding_hits_.rend() && + iter_next_key->first == iter_curr_key->first) { + iter_next_key++; + } + + const std::string& key = iter_curr_key->first; + libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or = + embedding_posting_list_mapper_->Get(key); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (posting_list_id_or.ok()) { + // Existing posting list. + ICING_ASSIGN_OR_RETURN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + posting_list_id_or.ValueOrDie())); + } else if (absl_ports::IsNotFound(posting_list_id_or.status())) { + // New posting list. + ICING_ASSIGN_OR_RETURN( + pl_accessor, + PostingListEmbeddingHitAccessor::Create( + flash_index_storage_.get(), posting_list_hit_serializer_.get())); + } else { + // Errors + return std::move(posting_list_id_or).status(); + } + + // Adding the embedding hits. + for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) { + ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second)); + } + + // Finalize this posting list and add the posting list id in + // embedding_posting_list_mapper_. + PostingListEmbeddingHitAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + if (!result.id.is_valid()) { + return absl_ports::InternalError("Failed to finalize posting list"); + } + ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id)); + + // Advance to the next key. + iter_curr_key = iter_next_key; + } + pending_embedding_hits_.clear(); + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + EmbeddingIndex* new_index) const { + std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr = + embedding_posting_list_mapper_->GetIterator(); + while (itr->Advance()) { + std::string_view key = itr->GetKey(); + // This should never happen unless there is an inconsistency, or the index + // is corrupted. + if (key.size() < kEncodedDimensionLength) { + return absl_ports::InternalError( + "Got invalid key from embedding posting list mapper."); + } + uint32_t dimension = encode_util::DecodeIntFromCString( + std::string_view(key.begin(), kEncodedDimensionLength)); + + // Transfer hits + std::vector<EmbeddingHit> new_hits; + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + /*existing_posting_list_id=*/itr->GetValue())); + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + old_pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + break; + } + for (EmbeddingHit& old_hit : batch) { + // Safety checks to add robustness to the codebase, so to make sure + // that we never access invalid memory, in case that hit from the + // posting list is corrupted. + ICING_ASSIGN_OR_RETURN(const float* old_vector, + GetEmbeddingVector(old_hit, dimension)); + if (old_hit.basic_hit().document_id() < 0 || + old_hit.basic_hit().document_id() >= + document_id_old_to_new.size()) { + return absl_ports::InternalError( + "Embedding hit document id is out of bound. The provided map is " + "too small, or the index may have been corrupted."); + } + + // Construct transferred hit + DocumentId new_document_id = + document_id_old_to_new[old_hit.basic_hit().document_id()]; + if (new_document_id == kInvalidDocumentId) { + continue; + } + uint32_t new_location = new_index->embedding_vectors_->num_elements(); + new_hits.push_back(EmbeddingHit( + BasicHit(old_hit.basic_hit().section_id(), new_document_id), + new_location)); + + // Copy the embedding vector of the hit to the new index. + ICING_ASSIGN_OR_RETURN( + FileBackedVector<float>::MutableArrayView mutable_arr, + new_index->embedding_vectors_->Allocate(dimension)); + mutable_arr.SetArray(/*idx=*/0, old_vector, dimension); + } + } + // No hit needs to be added to the new index. + if (new_hits.empty()) { + return libtextclassifier3::Status::OK; + } + // Add transferred hits to the new index. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum, + PostingListEmbeddingHitAccessor::Create( + new_index->flash_index_storage_.get(), + new_index->posting_list_hit_serializer_.get())); + for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend(); + ++new_hit_itr) { + ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr)); + } + PostingListEmbeddingHitAccessor::FinalizeResult result = + std::move(*hit_accum).Finalize(); + if (!result.id.is_valid()) { + return absl_ports::InternalError("Failed to finalize posting list"); + } + ICING_RETURN_IF_ERROR( + new_index->embedding_posting_list_mapper_->Put(key, result.id)); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id) { + // This is just for completeness, but this should never be necessary, since we + // should never have pending hits at the time when Optimize is run. + ICING_RETURN_IF_ERROR(CommitBufferToIndex()); + + std::string temporary_index_working_path = working_path_ + "_temp"; + if (!filesystem_.DeleteDirectoryRecursively( + temporary_index_working_path.c_str())) { + ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path; + return absl_ports::InternalError( + "Unable to delete temp directory to prepare to build new index."); + } + + DestructibleDirectory temporary_index_dir( + &filesystem_, std::move(temporary_index_working_path)); + if (!temporary_index_dir.is_valid()) { + return absl_ports::InternalError( + "Unable to create temp directory to build new index."); + } + + { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<EmbeddingIndex> new_index, + EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir())); + ICING_RETURN_IF_ERROR( + TransferIndex(document_id_old_to_new, new_index.get())); + new_index->set_last_added_document_id(new_last_added_document_id); + ICING_RETURN_IF_ERROR(new_index->PersistToDisk()); + } + + // Destruct current storage instances to safely swap directories. + metadata_mmapped_file_.reset(); + flash_index_storage_.reset(); + embedding_posting_list_mapper_.reset(); + embedding_vectors_.reset(); + + if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(), + working_path_.c_str())) { + return absl_ports::InternalError( + "Unable to apply new index due to failed swap!"); + } + + // Reinitialize the index. + is_initialized_ = false; + return Initialize(); +} + +libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) { + return metadata_mmapped_file_->PersistToDisk(); +} + +libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) { + if (!flash_index_storage_->PersistToDisk()) { + return absl_ports::InternalError("Fail to persist flash index to disk"); + } + ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk()); + ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk()); + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum( + bool force) { + return info().ComputeChecksum(); +} + +libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum( + bool force) { + ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc, + embedding_posting_list_mapper_->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc, + embedding_vectors_->ComputeChecksum()); + return Crc32(embedding_posting_list_mapper_crc.Get() ^ + embedding_vectors_crc.Get()); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-index.h b/icing/index/embed/embedding-index.h new file mode 100644 index 0000000..7318871 --- /dev/null +++ b/icing/index/embed/embedding-index.h @@ -0,0 +1,274 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ +#define ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/persistent-storage.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/index/hit/hit.h" +#include "icing/store/document-id.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +class EmbeddingIndex : public PersistentStorage { + public: + struct Info { + static constexpr int32_t kMagic = 0xfbe13cbb; + + int32_t magic; + DocumentId last_added_document_id; + + Crc32 ComputeChecksum() const { + return Crc32( + std::string_view(reinterpret_cast<const char*>(this), sizeof(Info))); + } + } __attribute__((packed)); + static_assert(sizeof(Info) == 8, ""); + + // Metadata file layout: <Crcs><Info> + static constexpr int32_t kCrcsMetadataBufferOffset = 0; + static constexpr int32_t kInfoMetadataBufferOffset = + static_cast<int32_t>(sizeof(Crcs)); + static constexpr int32_t kMetadataFileSize = sizeof(Crcs) + sizeof(Info); + static_assert(kMetadataFileSize == 20, ""); + + static constexpr WorkingPathType kWorkingPathType = + WorkingPathType::kDirectory; + + EmbeddingIndex(const EmbeddingIndex&) = delete; + EmbeddingIndex& operator=(const EmbeddingIndex&) = delete; + + // Creates a new EmbeddingIndex instance to index embeddings. + // + // Returns: + // - FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored + // checksum. + // - INTERNAL_ERROR on I/O errors. + // - Any error from MemoryMappedFile, FlashIndexStorage, + // DynamicTrieKeyMapper, or FileBackedVector. + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> Create( + const Filesystem* filesystem, std::string working_path); + + static libtextclassifier3::Status Discard(const Filesystem& filesystem, + const std::string& working_path) { + return PersistentStorage::Discard(filesystem, working_path, + kWorkingPathType); + } + + libtextclassifier3::Status Clear(); + + // Buffer an embedding pending to be added to the index. This is required + // since EmbeddingHits added in posting lists must be decreasing, which means + // that section ids and location indexes for a single document must be added + // decreasingly. + // + // Returns: + // - OK on success + // - INVALID_ARGUMENT error if the dimension is 0. + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status BufferEmbedding( + const BasicHit& basic_hit, const PropertyProto::VectorProto& vector); + + // Commit the embedding hits in the buffer to the index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + // - Any error from posting lists + libtextclassifier3::Status CommitBufferToIndex(); + + // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match + // with the provided dimension and signature. + // + // Returns: + // - a PostingListEmbeddingHitAccessor instance on success. + // - INVALID_ARGUMENT error if the dimension is 0. + // - NOT_FOUND error if there is no matching embedding hit. + // - Any error from posting lists. + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + GetAccessor(uint32_t dimension, std::string_view model_signature) const; + + // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match + // with the provided vector's dimension and signature. + // + // Returns: + // - a PostingListEmbeddingHitAccessor instance on success. + // - INVALID_ARGUMENT error if the dimension is 0. + // - NOT_FOUND error if there is no matching embedding hit. + // - Any error from posting lists. + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + GetAccessorForVector(const PropertyProto::VectorProto& vector) const { + return GetAccessor(vector.values_size(), vector.model_signature()); + } + + // Reduces internal file sizes by reclaiming space of deleted documents. + // new_last_added_document_id will be used to update the last added document + // id in the lite index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id); + + libtextclassifier3::StatusOr<const float*> GetEmbeddingVector( + const EmbeddingHit& hit, uint32_t dimension) const { + if (static_cast<int64_t>(hit.location()) + dimension > + GetTotalVectorSize()) { + return absl_ports::InternalError( + "Got an embedding hit that refers to a vector out of range."); + } + return embedding_vectors_->array() + hit.location(); + } + + const float* GetRawEmbeddingData() const { + return embedding_vectors_->array(); + } + + int32_t GetTotalVectorSize() const { + return embedding_vectors_->num_elements(); + } + + DocumentId last_added_document_id() const { + return info().last_added_document_id; + } + + void set_last_added_document_id(DocumentId document_id) { + Info& info_ref = info(); + if (info_ref.last_added_document_id == kInvalidDocumentId || + document_id > info_ref.last_added_document_id) { + info_ref.last_added_document_id = document_id; + } + } + + private: + explicit EmbeddingIndex(const Filesystem& filesystem, + std::string working_path) + : PersistentStorage(filesystem, std::move(working_path), + kWorkingPathType) {} + + libtextclassifier3::Status Initialize(); + + // Transfers embedding data and hits from the current index to new_index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error. This could potentially leave the storages + // in an invalid state and the caller should handle it properly (e.g. + // discard and rebuild) + libtextclassifier3::Status TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + EmbeddingIndex* new_index) const; + + // Flushes contents of metadata file. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistMetadataToDisk(bool force) override; + + // Flushes contents of all storages to underlying files. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistStoragesToDisk(bool force) override; + + // Computes and returns Info checksum. + // + // Returns: + // - Crc of the Info on success + libtextclassifier3::StatusOr<Crc32> ComputeInfoChecksum(bool force) override; + + // Computes and returns all storages checksum. + // + // Returns: + // - Crc of all storages on success + // - INTERNAL_ERROR if any data inconsistency + libtextclassifier3::StatusOr<Crc32> ComputeStoragesChecksum( + bool force) override; + + Crcs& crcs() override { + return *reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() + + kCrcsMetadataBufferOffset); + } + + const Crcs& crcs() const override { + return *reinterpret_cast<const Crcs*>(metadata_mmapped_file_->region() + + kCrcsMetadataBufferOffset); + } + + Info& info() { + return *reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() + + kInfoMetadataBufferOffset); + } + + const Info& info() const { + return *reinterpret_cast<const Info*>(metadata_mmapped_file_->region() + + kInfoMetadataBufferOffset); + } + + // In memory data: + // Pending embedding hits with their embedding keys used for + // embedding_posting_list_mapper_. + std::vector<std::pair<std::string, EmbeddingHit>> pending_embedding_hits_; + + // Metadata + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; + + // Posting list storage + std::unique_ptr<PostingListEmbeddingHitSerializer> + posting_list_hit_serializer_ = + std::make_unique<PostingListEmbeddingHitSerializer>(); + std::unique_ptr<FlashIndexStorage> flash_index_storage_; + + // The mapper from embedding keys to the corresponding posting list identifier + // that stores all embedding hits with the same key. + // + // The key for an embedding hit is a one-to-one encoded string of the ordered + // pair (dimension, model_signature) corresponding to the embedding. + std::unique_ptr<KeyMapper<PostingListIdentifier>> + embedding_posting_list_mapper_; + + // A single FileBackedVector that holds all embedding vectors. + std::unique_ptr<FileBackedVector<float>> embedding_vectors_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ diff --git a/icing/index/embed/embedding-index_test.cc b/icing/index/embed/embedding-index_test.cc new file mode 100644 index 0000000..5980e82 --- /dev/null +++ b/icing/index/embed/embedding-index_test.cc @@ -0,0 +1,582 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-index.h" + +#include <unistd.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/filesystem.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/proto/document.pb.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Test; + +class EmbeddingIndexTest : public Test { + protected: + void SetUp() override { + embedding_index_dir_ = GetTestTempDir() + "/embedding_index_test"; + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); + } + + void TearDown() override { + embedding_index_.reset(); + filesystem_.DeleteDirectoryRecursively(embedding_index_dir_.c_str()); + } + + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + uint32_t dimension, std::string_view model_signature) { + std::vector<EmbeddingHit> hits; + + libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = + embedding_index_->GetAccessor(dimension, model_signature); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + return hits; + } else { + return std::move(pl_accessor_or).status(); + } + + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + return hits; + } + hits.insert(hits.end(), batch.begin(), batch.end()); + } + } + + std::vector<float> GetRawEmbeddingData() { + return std::vector<float>(embedding_index_->GetRawEmbeddingData(), + embedding_index_->GetRawEmbeddingData() + + embedding_index_->GetTotalVectorSize()); + } + + Filesystem filesystem_; + std::string embedding_index_dir_; + std::unique_ptr<EmbeddingIndex> embedding_index_; +}; + +TEST_F(EmbeddingIndexTest, AddSingleEmbedding) { + PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, AddMultipleEmbeddingsInTheSameSection) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, HitsWithLowerSectionIdReturnedFirst) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/5, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/2, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/5, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, HitsWithHigherDocumentIdReturnedFirst) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); +} + +TEST_F(EmbeddingIndexTest, AddEmbeddingsFromDifferentModels) { + PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = + CreateVector("model2", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT( + GetHits(/*dimension=*/5, /*model_signature=*/"non-existent-model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, + AddEmbeddingsWithSameSignatureButDifferentDimension) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, ClearIndex) { + // Loop the same logic twice to make sure that clear works as expected, and + // the index is still valid after clearing. + for (int i = 0; i < 2; i++) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/2, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/1), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Check that clear works as expected. + ICING_ASSERT_OK(embedding_index_->Clear()); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), kInvalidDocumentId); + } +} + +TEST_F(EmbeddingIndexTest, EmptyCommitIsOk) { + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexTest, MultipleCommits) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); +} + +TEST_F(EmbeddingIndexTest, + InvalidCommit_SectionIdCanOnlyDecreaseForSingleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector2)); + // Posting list with delta encoding can only allow decreasing values. + EXPECT_THAT(embedding_index_->CommitBufferToIndex(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(EmbeddingIndexTest, InvalidCommit_DocumentIdCanOnlyIncrease) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + // Posting list with delta encoding can only allow decreasing values, which + // means document ids must be committed increasingly, since document ids are + // inverted in hit values. + EXPECT_THAT(embedding_index_->CommitBufferToIndex(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(EmbeddingIndexTest, EmptyOptimizeIsOk) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{}, + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexTest, OptimizeSingleEmbeddingSingleDocument) { + PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(2); + + // Before optimize + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize without deleting any documents, and check that the index is + // not changed. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize to map document id 2 to 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/1), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize to delete the document. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsSingleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(2); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize without deleting any documents, and check that the index is + // not changed. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize to map document id 2 to 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize to delete the document. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsMultipleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = CreateVector("model", {1, 2, 3}); + PropertyProto::VectorProto vector3 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector3)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/6), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, 1, 2, 3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize without deleting any documents. It is expected to see that the + // raw embedding data is rearranged, since during index transfer, embedding + // vectors from higher document ids are added first. + // + // Also keep in mind that once the raw data is rearranged, calling another + // Optimize subsequently will not change the raw data again. + for (int i = 0; i < 2; i++) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(-0.1, -0.2, -0.3, 0.1, 0.2, 0.3, 1, 2, 3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + } + + // Run optimize to delete document 0, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, 0}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(-0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeEmbeddingsFromDifferentModels) { + PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = CreateVector("model1", {1, 2}); + PropertyProto::VectorProto vector3 = + CreateVector("model2", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/1), vector3)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/2), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1), + /*location=*/4)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 1, 2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize without deleting any documents. It is expected to see that the + // raw embedding data is rearranged, since during index transfer: + // - Embedding vectors with lower keys, which are the string encoded ordered + // pairs (dimension, model_signature), are iterated first. + // - Embedding vectors from higher document ids are added first. + // + // Also keep in mind that once the raw data is rearranged, calling another + // Optimize subsequently will not change the raw data again. + for (int i = 0; i < 2; i++) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1), + /*location=*/4)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(1, 2, 0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + } + + // Run optimize to delete document 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-query-results.h b/icing/index/embed/embedding-query-results.h new file mode 100644 index 0000000..de85489 --- /dev/null +++ b/icing/index/embed/embedding-query-results.h @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ +#define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ + +#include <unordered_map> +#include <vector> + +#include "icing/proto/search.pb.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +// A class to store results generated from embedding queries. +struct EmbeddingQueryResults { + // Maps from DocumentId to the list of matched embedding scores for that + // document, which will be used in the advanced scoring language to + // determine the results for the "this.matchedSemanticScores(...)" function. + using EmbeddingQueryScoreMap = + std::unordered_map<DocumentId, std::vector<double>>; + + // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap. + std::unordered_map< + int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code, + EmbeddingQueryScoreMap>> + result_scores; + + // Returns the matched scores for the given query_vector_index, metric_type, + // and doc_id. Returns nullptr if (query_vector_index, metric_type) does not + // exist in the result_scores map. + const std::vector<double>* GetMatchedScoresForDocument( + int query_vector_index, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + DocumentId doc_id) const { + // Check if a mapping exists for the query_vector_index + auto outer_it = result_scores.find(query_vector_index); + if (outer_it == result_scores.end()) { + return nullptr; + } + // Check if a mapping exists for the metric_type + auto inner_it = outer_it->second.find(metric_type); + if (inner_it == outer_it->second.end()) { + return nullptr; + } + const EmbeddingQueryScoreMap& score_map = inner_it->second; + + // Check if the doc_id exists in the score_map + auto scores_it = score_map.find(doc_id); + if (scores_it == score_map.end()) { + return nullptr; + } + return &scores_it->second; + } +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ diff --git a/icing/index/embed/embedding-scorer.cc b/icing/index/embed/embedding-scorer.cc new file mode 100644 index 0000000..0d84e01 --- /dev/null +++ b/icing/index/embed/embedding-scorer.cc @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-scorer.h" + +#include <cmath> +#include <memory> +#include <string> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +namespace { + +float CalculateDotProduct(int dimension, const float* v1, const float* v2) { + float dot_product = 0.0; + for (int i = 0; i < dimension; ++i) { + dot_product += v1[i] * v2[i]; + } + return dot_product; +} + +float CalculateNorm2(int dimension, const float* v) { + return std::sqrt(CalculateDotProduct(dimension, v, v)); +} + +float CalculateCosine(int dimension, const float* v1, const float* v2) { + float divisor = CalculateNorm2(dimension, v1) * CalculateNorm2(dimension, v2); + if (divisor == 0.0) { + return 0.0; + } + return CalculateDotProduct(dimension, v1, v2) / divisor; +} + +float CalculateEuclideanDistance(int dimension, const float* v1, + const float* v2) { + float result = 0.0; + for (int i = 0; i < dimension; ++i) { + float diff = v1[i] - v2[i]; + result += diff * diff; + } + return std::sqrt(result); +} + +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> +EmbeddingScorer::Create( + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) { + switch (metric_type) { + case SearchSpecProto::EmbeddingQueryMetricType::COSINE: + return std::make_unique<CosineEmbeddingScorer>(); + case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT: + return std::make_unique<DotProductEmbeddingScorer>(); + case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN: + return std::make_unique<EuclideanDistanceEmbeddingScorer>(); + default: + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type))); + } +} + +float CosineEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateCosine(dimension, v1, v2); +} + +float DotProductEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateDotProduct(dimension, v1, v2); +} + +float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateEuclideanDistance(dimension, v1, v2); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-scorer.h b/icing/index/embed/embedding-scorer.h new file mode 100644 index 0000000..8caf0bc --- /dev/null +++ b/icing/index/embed/embedding-scorer.h @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ +#define ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +class EmbeddingScorer { + public: + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> Create( + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type); + virtual float Score(int dimension, const float* v1, + const float* v2) const = 0; + + virtual ~EmbeddingScorer() = default; +}; + +class CosineEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +class DotProductEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +class EuclideanDistanceEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.cc b/icing/index/embed/posting-list-embedding-hit-accessor.cc new file mode 100644 index 0000000..e154165 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor.cc @@ -0,0 +1,132 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" + +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +PostingListEmbeddingHitAccessor::Create( + FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer) { + uint32_t max_posting_list_bytes = storage->max_posting_list_bytes(); + ICING_ASSIGN_OR_RETURN(PostingListUsed in_memory_posting_list, + PostingListUsed::CreateFromUnitializedRegion( + serializer, max_posting_list_bytes)); + return std::unique_ptr<PostingListEmbeddingHitAccessor>( + new PostingListEmbeddingHitAccessor(storage, serializer, + std::move(in_memory_posting_list))); +} + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +PostingListEmbeddingHitAccessor::CreateFromExisting( + FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer, + PostingListIdentifier existing_posting_list_id) { + // Our in_memory_posting_list_ will start as empty. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + Create(storage, serializer)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage->GetPostingList(existing_posting_list_id)); + pl_accessor->preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + return pl_accessor; +} + +// Returns the next batch of hits for the provided posting list. +libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> +PostingListEmbeddingHitAccessor::GetNextHitsBatch() { + if (preexisting_posting_list_ == nullptr) { + if (has_reached_posting_list_chain_end_) { + return std::vector<EmbeddingHit>(); + } + return absl_ports::FailedPreconditionError( + "Cannot retrieve hits from a PostingListEmbeddingHitAccessor that was " + "not created from a preexisting posting list."); + } + ICING_ASSIGN_OR_RETURN( + std::vector<EmbeddingHit> batch, + serializer_->GetHits(&preexisting_posting_list_->posting_list)); + uint32_t next_block_index = kInvalidBlockIndex; + // Posting lists will only be chained when they are max-sized, in which case + // next_block_index will point to the next block for the next posting list. + // Otherwise, next_block_index can be kInvalidBlockIndex or be used to point + // to the next free list block, which is not relevant here. + if (preexisting_posting_list_->posting_list.size_in_bytes() == + storage_->max_posting_list_bytes()) { + next_block_index = preexisting_posting_list_->next_block_index; + } + + if (next_block_index != kInvalidBlockIndex) { + // Since we only have to deal with next block for max-sized posting list + // block, max_num_posting_lists is 1 and posting_list_index_bits is + // BitsToStore(1). + PostingListIdentifier next_posting_list_id( + next_block_index, /*posting_list_index=*/0, + /*posting_list_index_bits=*/BitsToStore(1)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage_->GetPostingList(next_posting_list_id)); + preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + } else { + has_reached_posting_list_chain_end_ = true; + preexisting_posting_list_.reset(); + } + return batch; +} + +libtextclassifier3::Status PostingListEmbeddingHitAccessor::PrependHit( + const EmbeddingHit &hit) { + PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr) + ? preexisting_posting_list_->posting_list + : in_memory_posting_list_; + libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit); + if (!absl_ports::IsResourceExhausted(status)) { + return status; + } + // There is no more room to add hits to this current posting list! Therefore, + // we need to either move those hits to a larger posting list or flush this + // posting list and create another max-sized posting list in the chain. + if (preexisting_posting_list_ != nullptr) { + ICING_RETURN_IF_ERROR(FlushPreexistingPostingList()); + } else { + ICING_RETURN_IF_ERROR(FlushInMemoryPostingList()); + } + + // Re-add hit. Should always fit since we just cleared + // in_memory_posting_list_. It's fine to explicitly reference + // in_memory_posting_list_ here because there's no way of reaching this line + // while preexisting_posting_list_ is still in use. + return serializer_->PrependHit(&in_memory_posting_list_, hit); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.h b/icing/index/embed/posting-list-embedding-hit-accessor.h new file mode 100644 index 0000000..4acb9a3 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor.h @@ -0,0 +1,106 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ +#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ + +#include <memory> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +namespace icing { +namespace lib { + +// This class is used to provide a simple abstraction for adding hits to posting +// lists. PostingListEmbeddingHitAccessor handles 1) selection of properly-sized +// posting lists for the accumulated hits during Finalize() and 2) chaining of +// max-sized posting lists. +class PostingListEmbeddingHitAccessor : public PostingListAccessor { + public: + // Creates an empty PostingListEmbeddingHitAccessor. + // + // RETURNS: + // - On success, a valid unique_ptr instance of + // PostingListEmbeddingHitAccessor + // - INVALID_ARGUMENT error if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + Create(FlashIndexStorage* storage, + PostingListEmbeddingHitSerializer* serializer); + + // Create a PostingListEmbeddingHitAccessor with an existing posting list + // identified by existing_posting_list_id. + // + // The PostingListEmbeddingHitAccessor will add hits to this posting list + // until it is necessary either to 1) chain the posting list (if it is + // max-sized) or 2) move its hits to a larger posting list. + // + // RETURNS: + // - On success, a valid unique_ptr instance of + // PostingListEmbeddingHitAccessor + // - INVALID_ARGUMENT if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + CreateFromExisting(FlashIndexStorage* storage, + PostingListEmbeddingHitSerializer* serializer, + PostingListIdentifier existing_posting_list_id); + + PostingListSerializer* GetSerializer() override { return serializer_; } + + // Retrieve the next batch of hits for the posting list chain + // + // RETURNS: + // - On success, a vector of hits in the posting list chain + // - INTERNAL if called on an instance of PostingListEmbeddingHitAccessor + // that was created via PostingListEmbeddingHitAccessor::Create, if unable + // to read the next posting list in the chain or if the posting list has + // been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetNextHitsBatch(); + + // Prepend one hit. This may result in flushing the posting list to disk (if + // the PostingListEmbeddingHitAccessor holds a max-sized posting list that is + // full) or freeing a pre-existing posting list if it is too small to fit all + // hits necessary. + // + // RETURNS: + // - OK, on success + // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the + // previously added hit. + // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new + // posting list. + libtextclassifier3::Status PrependHit(const EmbeddingHit& hit); + + private: + explicit PostingListEmbeddingHitAccessor( + FlashIndexStorage* storage, PostingListEmbeddingHitSerializer* serializer, + PostingListUsed in_memory_posting_list) + : PostingListAccessor(storage, std::move(in_memory_posting_list)), + serializer_(serializer) {} + + PostingListEmbeddingHitSerializer* serializer_; // Does not own. +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-accessor_test.cc b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc new file mode 100644 index 0000000..b9ebe87 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc @@ -0,0 +1,387 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/index/hit/hit.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/hit-test-utils.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Lt; +using ::testing::SizeIs; + +class PostingListEmbeddingHitAccessorTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/test_dir"; + file_name_ = test_dir_ + "/test_file.idx.index"; + + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str())); + + serializer_ = std::make_unique<PostingListEmbeddingHitSerializer>(); + + ICING_ASSERT_OK_AND_ASSIGN( + FlashIndexStorage flash_index_storage, + FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get())); + flash_index_storage_ = + std::make_unique<FlashIndexStorage>(std::move(flash_index_storage)); + } + + void TearDown() override { + flash_index_storage_.reset(); + serializer_.reset(); + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + } + + Filesystem filesystem_; + std::string test_dir_; + std::string file_name_; + std::unique_ptr<PostingListEmbeddingHitSerializer> serializer_; + std::unique_ptr<FlashIndexStorage> flash_index_storage_; +}; + +TEST_F(PostingListEmbeddingHitAccessorTest, HitsAddAndRetrieveProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit& hit : hits1) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); + } + PostingListAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result.status); + EXPECT_THAT(result.id.block_index(), Eq(1)); + EXPECT_THAT(result.id.posting_list_index(), Eq(0)); + + // Retrieve some hits. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(pl_holder.next_block_index, Eq(kInvalidBlockIndex)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLKeepOnSameBlock) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add a single hit. This will fit in a min-sized posting list. + EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/1); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + // Should have been allocated to the first block. + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more hit. The minimum size for a posting list must be able to fit + // at least two hits, so this should NOT cause the previous pl to be + // reallocated. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + EmbeddingHit hit2(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit2)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + // Should have been allocated to the same posting list as the first hit. + EXPECT_THAT(result2.id, Eq(result1.id)); + + // The posting list at result2.id should hold all of the hits that have been + // added. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result2.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAre(hit2, hit1))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLReallocateToLargerPL) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/11, /*desired_byte_length=*/1); + + // Add 8 hits with a small posting list of 24 bytes. The first 7 hits will + // be compressed to one byte each and will be able to fit in the 8 byte + // padded region. The last hit will fit in one of the special hits. The + // posting list will be ALMOST_FULL and can fit at most 2 more hits. + for (int i = 0; i < 8; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + // Should have been allocated to the first block. + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Now let's add some more hits! + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + // The current posting list can fit at most 2 more hits. + for (int i = 8; i < 10; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + // The 2 hits should still fit on the first block + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more hit + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result2.id)); + // The current posting list should be FULL. Adding more hits should result in + // these hits being moved to a larger posting list. + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[10])); + PostingListAccessor::FinalizeResult result3 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result3.status); + // Should have been allocated to the second (new) block because the posting + // list should have grown beyond the size that the first block maintains. + EXPECT_THAT(result3.id.block_index(), Eq(2)); + EXPECT_THAT(result3.id.posting_list_index(), Eq(0)); + + // The posting list at result3.id should hold all of the hits that have been + // added. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result3.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, MultiBlockChainsBlocksProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5000, /*desired_byte_length=*/1); + for (const EmbeddingHit& hit : hits1) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + PostingListIdentifier second_block_id = result1.id; + // Should have been allocated to the second block, which holds a max-sized + // posting list. + EXPECT_THAT(second_block_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // Now let's retrieve them! + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_block_id)); + // This pl_holder will only hold a posting list with the hits that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits, + serializer_->GetHits(&pl_holder.posting_list)); + ASSERT_THAT(second_block_hits, SizeIs(Lt(hits1.size()))); + auto first_block_hits_start = hits1.rbegin() + second_block_hits.size(); + EXPECT_THAT(second_block_hits, + ElementsAreArray(hits1.rbegin(), first_block_hits_start)); + + // Now retrieve all of the hits that were on the first block. + uint32_t first_block_id = pl_holder.next_block_index; + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT( + serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, + PreexistingMultiBlockReusesBlocksProperly) { + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/5050, /*desired_byte_length=*/1); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + for (int i = 0; i < 5000; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + PostingListIdentifier first_add_id = result1.id; + EXPECT_THAT(first_add_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // Now add a couple more hits. These should fit on the existing, not full + // second block. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), first_add_id)); + for (int i = 5000; i < hits.size(); ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + PostingListIdentifier second_add_id = result2.id; + EXPECT_THAT(second_add_id, Eq(first_add_id)); + + // We should be able to retrieve all 5050 hits. + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_add_id)); + // This pl_holder will only hold a posting list with the hits that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits, + serializer_->GetHits(&pl_holder.posting_list)); + ASSERT_THAT(second_block_hits, SizeIs(Lt(hits.size()))); + auto first_block_hits_start = hits.rbegin() + second_block_hits.size(); + EXPECT_THAT(second_block_hits, + ElementsAreArray(hits.rbegin(), first_block_hits_start)); + + // Now retrieve all of the hits that were on the first block. + uint32_t first_block_id = pl_holder.next_block_index; + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT( + serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, InvalidHitReturnsInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EXPECT_THAT(pl_accessor->PrependHit(invalid_hit), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, + HitsNotDecreasingReturnsInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/5); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + + EmbeddingHit hit2(BasicHit(/*section_id=*/6, /*document_id=*/1), + /*location=*/5); + EXPECT_THAT(pl_accessor->PrependHit(hit2), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EmbeddingHit hit3(BasicHit(/*section_id=*/2, /*document_id=*/0), + /*location=*/5); + EXPECT_THAT(pl_accessor->PrependHit(hit3), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EmbeddingHit hit4(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/6); + EXPECT_THAT(pl_accessor->PrependHit(hit4), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, NewPostingListNoHitsAdded) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + EXPECT_THAT(result1.status, + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPostingListNoHitsAdded) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/5); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor2, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor2).Finalize(); + ICING_ASSERT_OK(result2.status); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.cc b/icing/index/embed/posting-list-embedding-hit-serializer.cc new file mode 100644 index 0000000..9247bcb --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer.cc @@ -0,0 +1,647 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +#include <cinttypes> +#include <cstdint> +#include <cstring> +#include <limits> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +uint32_t PostingListEmbeddingHitSerializer::GetBytesUsed( + const PostingListUsed* posting_list_used) const { + // The special hits will be included if they represent actual hits. If they + // represent the hit offset or the invalid hit sentinel, they are not + // included. + return posting_list_used->size_in_bytes() - + GetStartByteOffset(posting_list_used); +} + +uint32_t PostingListEmbeddingHitSerializer::GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) { + // If in either the FULL state or ALMOST_FULL state, this posting list *is* + // the minimum size posting list that can fit these hits. So just return the + // size of the posting list. + return posting_list_used->size_in_bytes(); + } + + // - In NOT_FULL status, BytesUsed contains no special hits. For a posting + // list in the NOT_FULL state with n hits, we would have n-1 compressed hits + // and 1 uncompressed hit. + // - The minimum sized posting list that would be guaranteed to fit these hits + // would be FULL, but calculating the size required for the FULL posting + // list would require deserializing the last two added hits, so instead we + // will calculate the size of an ALMOST_FULL posting list to fit. + // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the + // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in + // the compressed region. + // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed + // hits. + // - Therefore, fitting these hits into a posting list would require + // BytesUsed + one extra full hit. + return GetBytesUsed(posting_list_used) + sizeof(EmbeddingHit); +} + +void PostingListEmbeddingHitSerializer::Clear( + PostingListUsed* posting_list_used) const { + // Safe to ignore return value because posting_list_used->size_in_bytes() is + // a valid argument. + SetStartByteOffset(posting_list_used, + /*offset=*/posting_list_used->size_in_bytes()); +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::MoveFrom( + PostingListUsed* dst, PostingListUsed* src) const { + ICING_RETURN_ERROR_IF_NULL(dst); + ICING_RETURN_ERROR_IF_NULL(src); + if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "src MinPostingListSizeToFit %d must be larger than size %d.", + GetMinPostingListSizeToFit(src), dst->size_in_bytes())); + } + + if (!IsPostingListValid(dst)) { + return absl_ports::FailedPreconditionError( + "Dst posting list is in an invalid state and can't be used!"); + } + if (!IsPostingListValid(src)) { + return absl_ports::InvalidArgumentError( + "Cannot MoveFrom an invalid src posting list!"); + } + + // Pop just enough hits that all of src's compressed hits fit in + // dst posting_list's compressed area. Then we can memcpy that area. + std::vector<EmbeddingHit> hits; + while (IsFull(src) || IsAlmostFull(src) || + (dst->size_in_bytes() - kSpecialHitsSize < GetBytesUsed(src))) { + if (!GetHitsInternal(src, /*limit=*/1, /*pop=*/true, &hits).ok()) { + return absl_ports::AbortedError( + "Unable to retrieve hits from src posting list."); + } + } + + // memcpy the area and set up start byte offset. + Clear(dst); + memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src), + src->posting_list_buffer() + GetStartByteOffset(src), + GetBytesUsed(src)); + // Because we popped all hits from src outside of the compressed area and we + // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() - + // kSpecialHitSize. This is guaranteed to be a valid byte offset for the + // NOT_FULL state, so ignoring the value is safe. + SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src)); + + // Put back remaining hits. + for (size_t i = 0; i < hits.size(); i++) { + const EmbeddingHit& hit = hits[hits.size() - i - 1]; + // PrependHit can return either INVALID_ARGUMENT - if hit is invalid or not + // less than the previous hit - or RESOURCE_EXHAUSTED. RESOURCE_EXHAUSTED + // should be impossible because we've already assured that there is enough + // room above. + ICING_RETURN_IF_ERROR(PrependHit(dst, hit)); + } + + Clear(src); + return libtextclassifier3::Status::OK; +} + +uint32_t PostingListEmbeddingHitSerializer::GetPadEnd( + const PostingListUsed* posting_list_used, uint32_t offset) const { + EmbeddingHit::Value pad; + uint32_t pad_end = offset; + while (pad_end < posting_list_used->size_in_bytes()) { + size_t pad_len = VarInt::Decode( + posting_list_used->posting_list_buffer() + pad_end, &pad); + if (pad != 0) { + // No longer a pad. + break; + } + pad_end += pad_len; + } + return pad_end; +} + +bool PostingListEmbeddingHitSerializer::PadToEnd( + PostingListUsed* posting_list_used, uint32_t start, uint32_t end) const { + if (end > posting_list_used->size_in_bytes()) { + ICING_LOG(ERROR) << "Cannot pad a region that ends after size!"; + return false; + } + // In VarInt a value of 0 encodes to 0. + memset(posting_list_used->posting_list_buffer() + start, 0, end - start); + return true; +} + +libtextclassifier3::Status +PostingListEmbeddingHitSerializer::PrependHitToAlmostFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + // Get delta between first hit and the new hit. Try to fit delta + // in the padded area and put new hit at the special position 1. + // Calling ValueOrDie is safe here because 1 < kNumSpecialData. + EmbeddingHit cur = GetSpecialHit(posting_list_used, /*index=*/1); + if (cur.value() <= hit.value()) { + return absl_ports::InvalidArgumentError( + "Hit being prepended must be strictly less than the most recent Hit"); + } + uint64_t delta = cur.value() - hit.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + + uint32_t pad_end = GetPadEnd(posting_list_used, + /*offset=*/kSpecialHitsSize); + + if (pad_end >= kSpecialHitsSize + delta_len) { + // Pad area has enough space for delta of existing hit (cur). Write delta at + // pad_end - delta_len. + uint8_t* delta_offset = + posting_list_used->posting_list_buffer() + pad_end - delta_len; + memcpy(delta_offset, delta_buf, delta_len); + + // Now first hit is the new hit, at special position 1. Safe to ignore the + // return value because 1 < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, hit); + // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid + // argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // No space for delta. We put the new hit at special position 0 + // and go to the full state. Safe to ignore the return value because 1 < + // kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, hit); + } + return libtextclassifier3::Status::OK; +} + +void PostingListEmbeddingHitSerializer::PrependHitToEmpty( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + // First hit to be added. Just add verbatim, no compression. + if (posting_list_used->size_in_bytes() == kSpecialHitsSize) { + // Safe to ignore the return value because 1 < kNumSpecialData + SetSpecialHit(posting_list_used, /*index=*/1, hit); + // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid + // argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // Since this is the first hit, size != kSpecialHitsSize and + // size % sizeof(EmbeddingHit) == 0, we know that there is room to fit 'hit' + // into the compressed region, so ValueOrDie is safe. + uint32_t offset = + PrependHitUncompressed(posting_list_used, hit, + /*offset=*/posting_list_used->size_in_bytes()) + .ValueOrDie(); + // Safe to ignore the return value because PrependHitUncompressed is + // guaranteed to return a valid offset. + SetStartByteOffset(posting_list_used, offset); + } +} + +libtextclassifier3::Status +PostingListEmbeddingHitSerializer::PrependHitToNotFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const { + // First hit in compressed area. It is uncompressed. See if delta + // between the first hit and new hit will still fit in the + // compressed area. + if (offset + sizeof(EmbeddingHit::Value) > + posting_list_used->size_in_bytes()) { + // The first hit in the compressed region *should* be uncompressed, but + // somehow there isn't enough room between offset and the end of the + // compressed area to fit an uncompressed hit. This should NEVER happen. + return absl_ports::FailedPreconditionError( + "Posting list is in an invalid state."); + } + EmbeddingHit::Value cur_value; + memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset, + sizeof(EmbeddingHit::Value)); + if (cur_value <= hit.value()) { + return absl_ports::InvalidArgumentError( + IcingStringUtil::StringPrintf("EmbeddingHit %" PRId64 + " being prepended must be " + "strictly less than the most recent " + "EmbeddingHit %" PRId64, + hit.value(), cur_value)); + } + uint64_t delta = cur_value - hit.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + + // offset now points to one past the end of the first hit. + offset += sizeof(EmbeddingHit::Value); + if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) + delta_len <= offset) { + // Enough space for delta in compressed area. + + // Prepend delta. + offset -= delta_len; + memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf, + delta_len); + + // Prepend new hit. We know that there is room for 'hit' because of the if + // statement above, so calling ValueOrDie is safe. + offset = + PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie(); + // offset is guaranteed to be valid here. So it's safe to ignore the return + // value. The if above will guarantee that offset >= kSpecialHitSize and < + // posting_list_used->size_in_bytes() because the if ensures that there is + // enough room between offset and kSpecialHitSize to fit the delta of the + // previous hit and the uncompressed hit. + SetStartByteOffset(posting_list_used, offset); + } else if (kSpecialHitsSize + delta_len <= offset) { + // Only have space for delta. The new hit must be put in special + // position 1. + + // Prepend delta. + offset -= delta_len; + memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf, + delta_len); + + // Prepend pad. Safe to ignore the return value of PadToEnd because offset + // must be less than posting_list_used->size_in_bytes(). Otherwise, this + // function already would have returned FAILED_PRECONDITION. + PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize, + /*end=*/offset); + + // Put new hit in special position 1. Safe to ignore return value because 1 + // < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, hit); + + // State almost_full. Safe to ignore the return value because + // sizeof(EmbeddingHit) is a valid argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // Very rare case where delta is larger than sizeof(EmbeddingHit::Value) + // (i.e. varint delta encoding expanded required storage). We + // move first hit to special position 1 and put new hit in + // special position 0. + EmbeddingHit cur(cur_value); + // Safe to ignore the return value of PadToEnd because offset must be less + // than posting_list_used->size_in_bytes(). Otherwise, this function + // already would have returned FAILED_PRECONDITION. + PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize, + /*end=*/offset); + // Safe to ignore the return value here because 0 and 1 < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, cur); + SetSpecialHit(posting_list_used, /*index=*/0, hit); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::PrependHit( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + static_assert( + sizeof(EmbeddingHit::Value) <= sizeof(uint64_t), + "EmbeddingHit::Value cannot be larger than 8 bytes because the delta " + "must be able to fit in 8 bytes."); + if (!hit.is_valid()) { + return absl_ports::InvalidArgumentError("Cannot prepend an invalid hit!"); + } + if (!IsPostingListValid(posting_list_used)) { + return absl_ports::FailedPreconditionError( + "This PostingListUsed is in an invalid state and can't add any hits!"); + } + + if (IsFull(posting_list_used)) { + // State full: no space left. + return absl_ports::ResourceExhaustedError("No more room for hits"); + } else if (IsAlmostFull(posting_list_used)) { + return PrependHitToAlmostFull(posting_list_used, hit); + } else if (IsEmpty(posting_list_used)) { + PrependHitToEmpty(posting_list_used, hit); + return libtextclassifier3::Status::OK; + } else { + uint32_t offset = GetStartByteOffset(posting_list_used); + return PrependHitToNotFull(posting_list_used, hit, offset); + } +} + +libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> +PostingListEmbeddingHitSerializer::GetHits( + const PostingListUsed* posting_list_used) const { + std::vector<EmbeddingHit> hits_out; + ICING_RETURN_IF_ERROR(GetHits(posting_list_used, &hits_out)); + return hits_out; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHits( + const PostingListUsed* posting_list_used, + std::vector<EmbeddingHit>* hits_out) const { + return GetHitsInternal(posting_list_used, + /*limit=*/std::numeric_limits<uint32_t>::max(), + /*pop=*/false, hits_out); +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::PopFrontHits( + PostingListUsed* posting_list_used, uint32_t num_hits) const { + if (num_hits == 1 && IsFull(posting_list_used)) { + // The PL is in full status which means that we save 2 uncompressed hits in + // the 2 special postions. But full status may be reached by 2 different + // statuses. + // (1) In "almost full" status + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // When we prepend another hit, we can only put it at the special + // position 0. And we get a full PL + // +-----------------+----------------+-------+-----------------+ + // |new 1st hit |original 1st hit|(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // (2) In "not full" status + // +-----------------+----------------+------+-------+------------------+ + // |hits-start-offset|Hit::kInvalidVal|(pad) |1st hit|(compressed) hits | + // +-----------------+----------------+------+-------+------------------+ + // When we prepend another hit, we can reach any of the 3 following + // scenarios: + // (2.1) not full + // if the space of pad and original 1st hit can accommodate the new 1st hit + // and the encoded delta value. + // +-----------------+----------------+------+-----------+-----------------+ + // |hits-start-offset|Hit::kInvalidVal|(pad) |new 1st hit|(compressed) hits| + // +-----------------+----------------+------+-----------+-----------------+ + // (2.2) almost full + // If the space of pad and original 1st hit cannot accommodate the new 1st + // hit and the encoded delta value but can accommodate the encoded delta + // value only. We can put the new 1st hit at special position 1. + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |new 1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // (2.3) full + // In very rare case, it cannot even accommodate only the encoded delta + // value. we can move the original 1st hit into special position 1 and the + // new 1st hit into special position 0. This may happen because we use + // VarInt encoding method which may make the encoded value longer (about + // 4/3 times of original) + // +-----------------+----------------+-------+-----------------+ + // |new 1st hit |original 1st hit|(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // Suppose now the PL is full. But we don't know whether it arrived to + // this status from "not full" like (2.3) or from "almost full" like (1). + // We'll return to "almost full" status like (1) if we simply pop the new + // 1st hit but we want to make the prepending operation "reversible". So + // there should be some way to return to "not full" if possible. A simple + // way to do it is to pop 2 hits out of the PL to status "almost full" or + // "not full". And add the original 1st hit back. We can return to the + // correct original statuses of (2.1) or (1). This makes our prepending + // operation reversible. + std::vector<EmbeddingHit> out; + + // Popping 2 hits should never fail because we've just ensured that the + // posting list is in the FULL state. + ICING_RETURN_IF_ERROR( + GetHitsInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out)); + + // PrependHit should never fail because out[1] is a valid hit less than + // previous hits in the posting list and because there's no way that the + // posting list could run out of room because it previously stored this hit + // AND another hit. + ICING_RETURN_IF_ERROR(PrependHit(posting_list_used, out[1])); + } else if (num_hits > 0) { + return GetHitsInternal(posting_list_used, /*limit=*/num_hits, /*pop=*/true, + nullptr); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHitsInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<EmbeddingHit>* out) const { + // Put current uncompressed val here. + EmbeddingHit::Value val = EmbeddingHit::kInvalidValue; + uint32_t offset = GetStartByteOffset(posting_list_used); + uint32_t count = 0; + + // First traverse the first two special positions. + while (count < limit && offset < kSpecialHitsSize) { + // Calling ValueOrDie is safe here because offset / sizeof(EmbeddingHit) < + // kNumSpecialData because of the check above. + EmbeddingHit hit = GetSpecialHit(posting_list_used, + /*index=*/offset / sizeof(EmbeddingHit)); + val = hit.value(); + if (out != nullptr) { + out->push_back(hit); + } + offset += sizeof(EmbeddingHit); + count++; + } + + // If special position 1 was set then we need to skip padding. + if (val != EmbeddingHit::kInvalidValue && offset == kSpecialHitsSize) { + offset = GetPadEnd(posting_list_used, offset); + } + + while (count < limit && offset < posting_list_used->size_in_bytes()) { + if (val == EmbeddingHit::kInvalidValue) { + // First hit is in compressed area. Put that in val. + memcpy(&val, posting_list_used->posting_list_buffer() + offset, + sizeof(EmbeddingHit::Value)); + offset += sizeof(EmbeddingHit::Value); + } else { + // Now we have delta encoded subsequent hits. Decode and push. + uint64_t delta; + offset += VarInt::Decode( + posting_list_used->posting_list_buffer() + offset, &delta); + val += delta; + } + EmbeddingHit hit(val); + if (out != nullptr) { + out->push_back(hit); + } + count++; + } + + if (pop) { + PostingListUsed* mutable_posting_list_used = + const_cast<PostingListUsed*>(posting_list_used); + // Modify the posting list so that we pop all hits actually + // traversed. + if (offset >= kSpecialHitsSize && + offset < posting_list_used->size_in_bytes()) { + // In the compressed area. Pop and reconstruct. offset/val is + // the last traversed hit, which we must discard. So move one + // more forward. + uint64_t delta; + offset += VarInt::Decode( + posting_list_used->posting_list_buffer() + offset, &delta); + val += delta; + + // Now val is the first hit of the new posting list. + if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) <= offset) { + // val fits in compressed area. Simply copy. + offset -= sizeof(EmbeddingHit::Value); + memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val, + sizeof(EmbeddingHit::Value)); + } else { + // val won't fit in compressed area. + EmbeddingHit hit(val); + // Okay to ignore the return value here because 1 < kNumSpecialData. + SetSpecialHit(mutable_posting_list_used, /*index=*/1, hit); + + // Prepend pad. Safe to ignore the return value of PadToEnd because + // offset must be less than posting_list_used->size_in_bytes() thanks to + // the if above. + PadToEnd(mutable_posting_list_used, + /*start=*/kSpecialHitsSize, + /*end=*/offset); + offset = sizeof(EmbeddingHit); + } + } + // offset is guaranteed to be valid so ignoring the return value of + // set_start_byte_offset is safe. It falls into one of four scenarios: + // Scenario 1: the above if was false because offset is not < + // posting_list_used->size_in_bytes() + // In this case, offset must be == posting_list_used->size_in_bytes() + // because we reached offset by unwinding hits on the posting list. + // Scenario 2: offset is < kSpecialHitSize + // In this case, offset is guaranteed to be either 0 or + // sizeof(EmbeddingHit) because offset is incremented by + // sizeof(EmbeddingHit) within the first while loop. + // Scenario 3: offset is within the compressed region and the new first hit + // in the posting list (the value that 'val' holds) will fit as an + // uncompressed hit in the compressed region. The resulting offset from + // decompressing val must be >= kSpecialHitSize because otherwise we'd be + // in Scenario 4 + // Scenario 4: offset is within the compressed region, but the new first hit + // in the posting list is too large to fit as an uncompressed hit in the + // in the compressed region. Therefore, it must be stored in a special hit + // and offset will be sizeof(EmbeddingHit). + SetStartByteOffset(mutable_posting_list_used, offset); + } + + return libtextclassifier3::Status::OK; +} + +EmbeddingHit PostingListEmbeddingHitSerializer::GetSpecialHit( + const PostingListUsed* posting_list_used, uint32_t index) const { + static_assert(sizeof(EmbeddingHit::Value) >= sizeof(uint32_t), "HitTooSmall"); + EmbeddingHit val(EmbeddingHit::kInvalidValue); + memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val), + sizeof(val)); + return val; +} + +void PostingListEmbeddingHitSerializer::SetSpecialHit( + PostingListUsed* posting_list_used, uint32_t index, + const EmbeddingHit& val) const { + memcpy(posting_list_used->posting_list_buffer() + index * sizeof(val), &val, + sizeof(val)); +} + +bool PostingListEmbeddingHitSerializer::IsPostingListValid( + const PostingListUsed* posting_list_used) const { + if (IsAlmostFull(posting_list_used)) { + // Special Hit 1 should hold a Hit. Calling ValueOrDie is safe because we + // know that 1 < kNumSpecialData. + if (!GetSpecialHit(posting_list_used, /*index=*/1).is_valid()) { + ICING_LOG(ERROR) + << "Both special hits cannot be invalid at the same time."; + return false; + } + } else if (!IsFull(posting_list_used)) { + // NOT_FULL. Special Hit 0 should hold a valid offset. Calling ValueOrDie is + // safe because we know that 0 < kNumSpecialData. + if (GetSpecialHit(posting_list_used, /*index=*/0).value() > + posting_list_used->size_in_bytes() || + GetSpecialHit(posting_list_used, /*index=*/0).value() < + kSpecialHitsSize) { + ICING_LOG(ERROR) << "EmbeddingHit: " + << GetSpecialHit(posting_list_used, /*index=*/0).value() + << " size: " << posting_list_used->size_in_bytes() + << " sp size: " << kSpecialHitsSize; + return false; + } + } + return true; +} + +uint32_t PostingListEmbeddingHitSerializer::GetStartByteOffset( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used)) { + return 0; + } else if (IsAlmostFull(posting_list_used)) { + return sizeof(EmbeddingHit); + } else { + // NOT_FULL, calling ValueOrDie is safe because we know that 0 < + // kNumSpecialData. + return GetSpecialHit(posting_list_used, /*index=*/0).value(); + } +} + +bool PostingListEmbeddingHitSerializer::SetStartByteOffset( + PostingListUsed* posting_list_used, uint32_t offset) const { + if (offset > posting_list_used->size_in_bytes()) { + ICING_LOG(ERROR) << "offset cannot be a value greater than size " + << posting_list_used->size_in_bytes() << ". offset is " + << offset << "."; + return false; + } + if (offset < kSpecialHitsSize && offset > sizeof(EmbeddingHit)) { + ICING_LOG(ERROR) << "offset cannot be a value between (" + << sizeof(EmbeddingHit) << ", " << kSpecialHitsSize + << "). offset is " << offset << "."; + return false; + } + if (offset < sizeof(EmbeddingHit) && offset != 0) { + ICING_LOG(ERROR) << "offset cannot be a value between (0, " + << sizeof(EmbeddingHit) << "). offset is " << offset + << "."; + return false; + } + if (offset >= kSpecialHitsSize) { + // not_full state. Safe to ignore the return value because 0 and 1 are both + // < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, EmbeddingHit(offset)); + SetSpecialHit(posting_list_used, /*index=*/1, + EmbeddingHit(EmbeddingHit::kInvalidValue)); + } else if (offset == sizeof(EmbeddingHit)) { + // almost_full state. Safe to ignore the return value because 1 is both < + // kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, + EmbeddingHit(EmbeddingHit::kInvalidValue)); + } + // Nothing to do for the FULL state - the offset isn't actually stored + // anywhere and both special hits hold valid hits. + return true; +} + +libtextclassifier3::StatusOr<uint32_t> +PostingListEmbeddingHitSerializer::PrependHitUncompressed( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const { + if (offset < kSpecialHitsSize + sizeof(EmbeddingHit::Value)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Not enough room to prepend EmbeddingHit::Value at offset %d.", + offset)); + } + offset -= sizeof(EmbeddingHit::Value); + EmbeddingHit::Value val = hit.value(); + memcpy(posting_list_used->posting_list_buffer() + offset, &val, + sizeof(EmbeddingHit::Value)); + return offset; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.h b/icing/index/embed/posting-list-embedding-hit-serializer.h new file mode 100644 index 0000000..76198c2 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer.h @@ -0,0 +1,284 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ +#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ + +#include <cstdint> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +// A serializer class to serialize hits to PostingListUsed. Layout described in +// comments in posting-list-embedding-hit-serializer.cc. +class PostingListEmbeddingHitSerializer : public PostingListSerializer { + public: + static constexpr uint32_t kSpecialHitsSize = + kNumSpecialData * sizeof(EmbeddingHit); + + uint32_t GetDataTypeBytes() const override { return sizeof(EmbeddingHit); } + + uint32_t GetMinPostingListSize() const override { + static constexpr uint32_t kMinPostingListSize = kSpecialHitsSize; + static_assert(sizeof(PostingListIndex) <= kMinPostingListSize, + "PostingListIndex must be small enough to fit in a " + "minimum-sized Posting List."); + + return kMinPostingListSize; + } + + uint32_t GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const override; + + uint32_t GetBytesUsed( + const PostingListUsed* posting_list_used) const override; + + void Clear(PostingListUsed* posting_list_used) const override; + + libtextclassifier3::Status MoveFrom(PostingListUsed* dst, + PostingListUsed* src) const override; + + // Prepend a hit to the posting list. + // + // RETURNS: + // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the + // previously added hit. + // - RESOURCE_EXHAUSTED if there is no more room to add hit to the posting + // list. + libtextclassifier3::Status PrependHit(PostingListUsed* posting_list_used, + const EmbeddingHit& hit) const; + + // Prepend hits to the posting list. Hits should be sorted in descending order + // (as defined by the less than operator for Hit) + // + // Returns the number of hits that could be prepended to the posting list. If + // keep_prepended is true, whatever could be prepended is kept, otherwise the + // posting list is left in its original state. + template <class T, EmbeddingHit (*GetHit)(const T&)> + libtextclassifier3::StatusOr<uint32_t> PrependHitArray( + PostingListUsed* posting_list_used, const T* array, uint32_t num_hits, + bool keep_prepended) const; + + // Retrieves the hits stored in the posting list. + // + // RETURNS: + // - On success, a vector of hits sorted by the reverse order of prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + const PostingListUsed* posting_list_used) const; + + // Same as GetHits but appends hits to hits_out. + // + // RETURNS: + // - On success, a vector of hits sorted by the reverse order of prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetHits(const PostingListUsed* posting_list_used, + std::vector<EmbeddingHit>* hits_out) const; + + // Undo the last num_hits hits prepended. If num_hits > number of + // hits we clear all hits. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status PopFrontHits(PostingListUsed* posting_list_used, + uint32_t num_hits) const; + + private: + // Posting list layout formats: + // + // not_full + // + // +-----------------+----------------+-------+-----------------+ + // |hits-start-offset|Hit::kInvalidVal|xxxxxxx|(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // almost_full + // + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // full() + // + // +-----------------+----------------+-------+-----------------+ + // |1st hit |2nd hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // The first two uncompressed hits also implicitly encode information about + // the size of the compressed hits region. + // + // 1. If the posting list is NOT_FULL, then + // posting_list_buffer_[0] contains the byte offset of the start of the + // compressed hits - and, thus, the size of the compressed hits region is + // size_in_bytes - posting_list_buffer_[0]. + // + // 2. If posting list is ALMOST_FULL or FULL, then the compressed hits region + // starts somewhere between [kSpecialHitsSize, kSpecialHitsSize + + // sizeof(EmbeddingHit) - 1] and ends at size_in_bytes - 1. + + // Helpers to determine what state the posting list is in. + bool IsFull(const PostingListUsed* posting_list_used) const { + return GetSpecialHit(posting_list_used, /*index=*/0).is_valid() && + GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + bool IsAlmostFull(const PostingListUsed* posting_list_used) const { + return !GetSpecialHit(posting_list_used, /*index=*/0).is_valid() && + GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + bool IsEmpty(const PostingListUsed* posting_list_used) const { + return GetSpecialHit(posting_list_used, /*index=*/0).value() == + posting_list_used->size_in_bytes() && + !GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + // Returns false if both special hits are invalid or if the offset value + // stored in the special hit is less than kSpecialHitsSize or greater than + // posting_list_used->size_in_bytes(). Returns true, otherwise. + bool IsPostingListValid(const PostingListUsed* posting_list_used) const; + + // Prepend hit to a posting list that is in the ALMOST_FULL state. + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if hit is not less than the previously added hit. + libtextclassifier3::Status PrependHitToAlmostFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const; + + // Prepend hit to a posting list that is in the EMPTY state. This will always + // succeed because there are no pre-existing hits and no validly constructed + // posting list could fail to fit one hit. + void PrependHitToEmpty(PostingListUsed* posting_list_used, + const EmbeddingHit& hit) const; + + // Prepend hit to a posting list that is in the NOT_FULL state. + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if hit is not less than the previously added hit. + libtextclassifier3::Status PrependHitToNotFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const; + + // Returns either 0 (full state), sizeof(EmbeddingHit) (almost_full state) or + // a byte offset between kSpecialHitsSize and + // posting_list_used->size_in_bytes() (inclusive) (not_full state). + uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const; + + // Sets the special hits to properly reflect what offset is (see layout + // comment for further details). + // + // Returns false if offset > posting_list_used->size_in_bytes() or offset is + // (kSpecialHitsSize, sizeof(EmbeddingHit)) or offset is + // (sizeof(EmbeddingHit), 0). True, otherwise. + bool SetStartByteOffset(PostingListUsed* posting_list_used, + uint32_t offset) const; + + // Manipulate padded areas. We never store the same hit value twice + // so a delta of 0 is a pad byte. + + // Returns offset of first non-pad byte. + uint32_t GetPadEnd(const PostingListUsed* posting_list_used, + uint32_t offset) const; + + // Fill padding between offset start and offset end with 0s. + // Returns false if end > posting_list_used->size_in_bytes(). True, + // otherwise. + bool PadToEnd(PostingListUsed* posting_list_used, uint32_t start, + uint32_t end) const; + + // Helper for AppendHits/PopFrontHits. Adds limit number of hits to out or all + // hits in the posting list if the posting list contains less than limit + // number of hits. out can be NULL. + // + // NOTE: If called with limit=1, pop=true on a posting list that transitioned + // from NOT_FULL directly to FULL, GetHitsInternal will not return the posting + // list to NOT_FULL. Instead it will leave it in a valid state, but it will be + // ALMOST_FULL. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetHitsInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<EmbeddingHit>* out) const; + + // Retrieves the value stored in the index-th special hit. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + // + // RETURNS: + // - A valid SpecialData<EmbeddingHit>. + EmbeddingHit GetSpecialHit(const PostingListUsed* posting_list_used, + uint32_t index) const; + + // Sets the value stored in the index-th special hit to val. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + void SetSpecialHit(PostingListUsed* posting_list_used, uint32_t index, + const EmbeddingHit& val) const; + + // Prepends hit to the memory region [offset - sizeof(EmbeddingHit), offset] + // and returns the new beginning of the padded region. + // + // RETURNS: + // - The new beginning of the padded region, if successful. + // - INVALID_ARGUMENT if hit will not fit (uncompressed) between offset and + // kSpecialHitsSize + libtextclassifier3::StatusOr<uint32_t> PrependHitUncompressed( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const; +}; + +// Inlined functions. Implementation details below. Avert eyes! +template <class T, EmbeddingHit (*GetHit)(const T&)> +libtextclassifier3::StatusOr<uint32_t> +PostingListEmbeddingHitSerializer::PrependHitArray( + PostingListUsed* posting_list_used, const T* array, uint32_t num_hits, + bool keep_prepended) const { + if (!IsPostingListValid(posting_list_used)) { + return 0; + } + + // Prepend hits working backwards from array[num_hits - 1]. + uint32_t i; + for (i = 0; i < num_hits; ++i) { + if (!PrependHit(posting_list_used, GetHit(array[num_hits - i - 1])).ok()) { + break; + } + } + if (i != num_hits && !keep_prepended) { + // Didn't fit. Undo everything and check that we have the same offset as + // before. PopFrontHits guarantees that it will remove all 'i' hits so long + // as there are at least 'i' hits in the posting list, which we know there + // are. + ICING_RETURN_IF_ERROR(PopFrontHits(posting_list_used, /*num_hits=*/i)); + } + return i; +} + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-serializer_test.cc b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc new file mode 100644 index 0000000..f829634 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc @@ -0,0 +1,864 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <deque> +#include <iterator> +#include <limits> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/hit-test-utils.h" + +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::Eq; +using testing::IsEmpty; +using testing::Le; +using testing::Lt; + +namespace icing { +namespace lib { + +namespace { + +struct HitElt { + HitElt() = default; + explicit HitElt(const EmbeddingHit &hit_in) : hit(hit_in) {} + + static EmbeddingHit get_hit(const HitElt &hit_elt) { return hit_elt.hit; } + + EmbeddingHit hit; +}; + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedPrependHitNotFull) { + PostingListEmbeddingHitSerializer serializer; + + static const int kNumHits = 2551; + static const size_t kHitsSize = kNumHits * sizeof(EmbeddingHit); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + + // Make used. + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0)); + int expected_size = sizeof(EmbeddingHit::Value); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + EmbeddingHit hit1(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/1); + uint64_t delta = hit0.value() - hit1.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/2); + delta = hit1.value() - hit2.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + EmbeddingHit hit3(BasicHit(/*section_id=*/0, /*document_id=*/3), + /*location=*/3); + delta = hit2.value() - hit3.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListUsedPrependHitAlmostFull) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 32 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Fill up the compressed region. + // Transitions: + // Adding hit0: EMPTY -> NOT_FULL + // Adding hit1: NOT_FULL -> NOT_FULL + // Adding hit2: NOT_FULL -> NOT_FULL + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/1); + EmbeddingHit hit1 = CreateEmbeddingHit(hit0, /*desired_byte_length=*/3); + EmbeddingHit hit2 = CreateEmbeddingHit(hit1, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); + // Size used will be 8 (hit2) + 3 (hit1-hit2) + 3 (hit0-hit1) = 14 bytes + int expected_size = sizeof(EmbeddingHit) + 3 + 3; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + // Add one more hit to transition NOT_FULL -> ALMOST_FULL + EmbeddingHit hit3 = CreateEmbeddingHit(hit2, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); + // Storing them in the compressed region requires 8 (hit) + 3 (hit2-hit3) + + // 3 (hit1-hit2) + 3 (hit0-hit1) = 17 bytes, but there are only 16 bytes in + // the compressed region. So instead, the posting list will transition to + // ALMOST_FULL. The in-use compressed region will actually shrink from 14 + // bytes to 9 bytes because the uncompressed version of hit2 will be + // overwritten with the compressed delta of hit2. hit3 will be written to one + // of the special hits. Because we're in ALMOST_FULL, the expected size is the + // size of the pl minus the one hit used to mark the posting list as + // ALMOST_FULL. + expected_size = pl_size - sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL + EmbeddingHit hit4 = CreateEmbeddingHit(hit3, /*desired_byte_length=*/6); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4)); + // There are currently 9 bytes in use in the compressed region. Hit3 will + // have a 6-byte delta, which fits in the compressed region. Hit3 will be + // moved from the special hit to the compressed region (which will have 15 + // bytes in use after adding hit3). Hit4 will be placed in one of the special + // hits and the posting list will remain in ALMOST_FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> FULL + EmbeddingHit hit5 = CreateEmbeddingHit(hit4, /*desired_byte_length=*/2); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5)); + // There are currently 15 bytes in use in the compressed region. Hit4 will + // have a 2-byte delta which will not fit in the compressed region. So hit4 + // will remain in one of the special hits and hit5 will occupy the other, + // making the posting list FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0))); + + // The posting list is FULL. Adding another hit should fail. + EmbeddingHit hit6 = CreateEmbeddingHit(hit5, /*desired_byte_length=*/1); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit6), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedMinSize) { + PostingListEmbeddingHitSerializer serializer; + + // Min size = 16 + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PL State: EMPTY + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + // Add a hit, PL should shift to ALMOST_FULL state + EmbeddingHit hit0(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/1); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + // Size = sizeof(uncompressed hit0) + int expected_size = sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + // Add the smallest hit possible with a delta of 0b1. PL should shift to FULL + // state. + EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0) + expected_size += sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + // Try to add the smallest hit possible. Should fail + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit2), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayMinSizePostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Min Size = 16 + int pl_size = serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + + // Add five hits. The PL is in the empty state and an empty min size PL can + // only fit two hits. So PrependHitArray should fail. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(2)); + + int can_fit_hits = num_can_prepend; + // The PL has room for 2 hits. We should be able to add them without any + // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL + const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2); + ICING_ASSERT_OK_AND_ASSIGN( + num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(can_fit_hits)); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), + hits_in.rend() - hits_in.size() + can_fit_hits, + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayPostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 48 + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + // The last hit is uncompressed and the four before it should only take one + // byte. Total use = 8 bytes. + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43-36 EmbeddingHit #4 + // 35-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=36 + // ---------------------- + int byte_size = sizeof(EmbeddingHit::Value) + hits_in.size() - 1; + + // Add five hits. The PL is in the empty state and should be able to fit all + // five hits without issue, transitioning the PL from EMPTY -> NOT_FULL. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + EmbeddingHit first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/1); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 EmbeddingHit #11 + // 21-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=22 + // ---------------------- + byte_size += 14; + + // Add these 7 hits. The PL is currently in the NOT_FULL state and should + // remain in the NOT_FULL state. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/8); + hits_in.clear(); + hits_in.emplace_back(first_hit); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-16 <unused> + // 15-8 EmbeddingHit #12 + // 7-0 kSpecialHit + // ---------------------- + byte_size = 40; // 48 - 8 + + // Add this 1 hit. The PL is currently in the NOT_FULL state and should + // transition to the ALMOST_FULL state - even though there is still some + // unused space. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/5); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-17 delta(EmbeddingHit #12) + // 16 <unused> + // 15-8 EmbeddingHit #13 + // 7-0 EmbeddingHit #14 + // ---------------------- + + // Add these 2 hits. + // - The PL is currently in the ALMOST_FULL state. Adding the first hit should + // keep the PL in ALMOST_FULL because the delta between + // EmbeddingHit #12 and EmbeddingHit #13 (5 byte) can fit in the unused area + // (6 bytes). + // - Adding the second hit should transition to the FULL state because the + // delta between EmbeddingHit #13 and EmbeddingHit #14 (3 bytes) is larger + // than the remaining unused area (1 byte). + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayTooManyHits) { + PostingListEmbeddingHitSerializer serializer; + + static constexpr int kNumHits = 130; + static constexpr int kDeltaSize = 1; + static constexpr size_t kHitsSize = + ((kNumHits - 2) * kDeltaSize + (2 * sizeof(EmbeddingHit))); + + // Create an array with one too many hits + std::vector<HitElt> hit_elts_in_too_many; + hit_elts_in_too_many.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0)); + for (int i = 0; i < kNumHits; ++i) { + hit_elts_in_too_many.emplace_back(CreateEmbeddingHit( + hit_elts_in_too_many.back().hit, /*desired_byte_length=*/1)); + } + // Reverse so that hits are inserted in descending order + std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end()); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PrependHitArray should fail because hit_elts_in_too_many is far too large + // for the minimum size pl. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(2)); + ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size())); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + ICING_ASSERT_OK_AND_ASSIGN( + pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + // PrependHitArray should fail because hit_elts_in_too_many is one hit too + // large for this pl. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1)); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListStatusJumpFromNotFullToFullAndBack) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 24 + const uint32_t pl_size = 3 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + EmbeddingHit max_valued_hit( + BasicHit(/*section_id=*/kMaxSectionId, /*document_id=*/kMinDocumentId), + /*location=*/std::numeric_limits<uint32_t>::max()); + ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit)); + uint32_t bytes_used = serializer.GetBytesUsed(&pl); + ASSERT_THAT(bytes_used, sizeof(EmbeddingHit)); + // Status not full. + ASSERT_THAT( + bytes_used, + Le(pl_size - PostingListEmbeddingHitSerializer::kSpecialHitsSize)); + + EmbeddingHit min_valued_hit( + BasicHit(/*section_id=*/kMinSectionId, /*document_id=*/kMaxDocumentId), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl, min_valued_hit)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(min_valued_hit, max_valued_hit))); + // Status should jump to full directly. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size)); + ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(max_valued_hit))); + // Status should return to not full as before. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(bytes_used)); +} + +TEST(PostingListEmbeddingHitSerializerTest, DeltaOverflow) { + PostingListEmbeddingHitSerializer serializer; + + const uint32_t pl_size = 4 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + static const EmbeddingHit::Value kMaxHitValue = + std::numeric_limits<EmbeddingHit::Value>::max(); + static const EmbeddingHit::Value kOverflow[4] = { + kMaxHitValue >> 2, + (kMaxHitValue >> 2) * 2, + (kMaxHitValue >> 2) * 3, + kMaxHitValue - 1, + }; + + // Fit at least 4 ordinary values. + std::deque<EmbeddingHit> hits_pushed; + for (EmbeddingHit::Value v = 0; v < 4; v++) { + hits_pushed.push_front( + EmbeddingHit(BasicHit(kMaxSectionId, kMaxDocumentId), 4 - v)); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + + // Cannot fit 4 overflow values. + hits_pushed.clear(); + ICING_ASSERT_OK_AND_ASSIGN( + pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + for (int i = 3; i >= 1; i--) { + hits_pushed.push_front(EmbeddingHit(/*value=*/kOverflow[i])); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + EXPECT_THAT(serializer.PrependHit(&pl, EmbeddingHit(/*value=*/kOverflow[0])), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForNotFullPL) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 64 + int pl_size = 4 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used NOT_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 63-62 delta(EmbeddingHit #0) + // 61-60 delta(EmbeddingHit #1) + // 59-58 delta(EmbeddingHit #2) + // 57-56 delta(EmbeddingHit #3) + // 55-48 EmbeddingHit #5 + // 47-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=48 + // ---------------------- + int bytes_used = 16; + + // Check that all hits have been inserted + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used + // into a posting list with this min size should make it ALMOST_FULL, which we + // can see should have size = 24. + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int expected_min_size = 24; + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(expected_min_size)); + + // Also check that this min size to fit posting list actually does fit all the + // hits and can only hit one more hit in the ALMOST_FULL state. + ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, min_size_to_fit)); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + } + + // Adding another hit to the min-size-to-fit posting list should succeed + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + // Adding any other hits should fail with RESOURCE_EXHAUSTED error. + EXPECT_THAT( + serializer.PrependHit(&min_size_to_fit_pl, + CreateEmbeddingHit(hit, /*desired_byte_length=*/1)), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Check that all hits have been inserted and the min-fit posting list is now + // FULL. + EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl), + Eq(min_size_to_fit)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 24; + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used ALMOST_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int bytes_used = 16; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should return the same size as pl_used. + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); + + // Add another hit to make the posting list FULL + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should still be the same size as pl_used. + min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveFrom) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + ICING_ASSERT_OK(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromNullArgumentReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromInvalidPostingListReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used1 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used1.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveToInvalidPostingListReturnsFailedPrecondition) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used2 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used2.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveToPostingListTooSmall) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/1, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embedding-indexing-handler.cc b/icing/index/embedding-indexing-handler.cc new file mode 100644 index 0000000..049e307 --- /dev/null +++ b/icing/index/embedding-indexing-handler.cc @@ -0,0 +1,85 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embedding-indexing-handler.h" + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>> +EmbeddingIndexingHandler::Create(const Clock* clock, + EmbeddingIndex* embedding_index) { + ICING_RETURN_ERROR_IF_NULL(clock); + ICING_RETURN_ERROR_IF_NULL(embedding_index); + + return std::unique_ptr<EmbeddingIndexingHandler>( + new EmbeddingIndexingHandler(clock, embedding_index)); +} + +libtextclassifier3::Status EmbeddingIndexingHandler::Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + bool recovery_mode, PutDocumentStatsProto* put_document_stats) { + std::unique_ptr<Timer> index_timer = clock_.GetNewTimer(); + + if (!IsDocumentIdValid(document_id)) { + return absl_ports::InvalidArgumentError( + IcingStringUtil::StringPrintf("Invalid DocumentId %d", document_id)); + } + + if (embedding_index_.last_added_document_id() != kInvalidDocumentId && + document_id <= embedding_index_.last_added_document_id()) { + if (recovery_mode) { + // Skip the document if document_id <= last_added_document_id in + // recovery mode without returning an error. + return libtextclassifier3::Status::OK; + } + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "DocumentId %d must be greater than last added document_id %d", + document_id, embedding_index_.last_added_document_id())); + } + embedding_index_.set_last_added_document_id(document_id); + + for (const Section<PropertyProto::VectorProto>& vector_section : + tokenized_document.vector_sections()) { + BasicHit hit(/*section_id=*/vector_section.metadata.id, document_id); + for (const PropertyProto::VectorProto& vector : vector_section.content) { + ICING_RETURN_IF_ERROR(embedding_index_.BufferEmbedding(hit, vector)); + } + } + ICING_RETURN_IF_ERROR(embedding_index_.CommitBufferToIndex()); + + if (put_document_stats != nullptr) { + put_document_stats->set_embedding_index_latency_ms( + index_timer->GetElapsedMilliseconds()); + } + + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embedding-indexing-handler.h b/icing/index/embedding-indexing-handler.h new file mode 100644 index 0000000..f3adf6a --- /dev/null +++ b/icing/index/embedding-indexing-handler.h @@ -0,0 +1,70 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ +#define ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +class EmbeddingIndexingHandler : public DataIndexingHandler { + public: + ~EmbeddingIndexingHandler() override = default; + + // Creates an EmbeddingIndexingHandler instance which does not take + // ownership of any input components. All pointers must refer to valid objects + // that outlive the created EmbeddingIndexingHandler instance. + // + // Returns: + // - An EmbeddingIndexingHandler instance on success + // - FAILED_PRECONDITION_ERROR if any of the input pointer is null + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>> + Create(const Clock* clock, EmbeddingIndex* embedding_index); + + // Handles the embedding indexing process: add hits into the embedding index + // for all contents in tokenized_document.vector_sections. + // + // Returns: + // - OK on success. + // - INVALID_ARGUMENT_ERROR if document_id is invalid OR document_id is less + // than or equal to the document_id of a previously indexed document in + // non recovery mode. + // - INTERNAL_ERROR if any other errors occur. + // - Any embedding index errors. + libtextclassifier3::Status Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + bool recovery_mode, PutDocumentStatsProto* put_document_stats) override; + + private: + explicit EmbeddingIndexingHandler(const Clock* clock, + EmbeddingIndex* embedding_index) + : DataIndexingHandler(clock), embedding_index_(*embedding_index) {} + + EmbeddingIndex& embedding_index_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ diff --git a/icing/index/embedding-indexing-handler_test.cc b/icing/index/embedding-indexing-handler_test.cc new file mode 100644 index 0000000..ed711c1 --- /dev/null +++ b/icing/index/embedding-indexing-handler_test.cc @@ -0,0 +1,620 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embedding-indexing-handler.h" + +#include <cstdint> +#include <initializer_list> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/portable/platform.h" +#include "icing/proto/document_wrapper.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsTrue; + +// Indexable properties (section) and section id. Section id is determined by +// the lexicographical order of indexable property paths. +// Schema type with indexable properties: FakeType +// Section id = 0: "body" +// Section id = 1: "bodyEmbedding" +// Section id = 2: "title" +// Section id = 3: "titleEmbedding" +static constexpr std::string_view kFakeType = "FakeType"; +static constexpr std::string_view kPropertyBody = "body"; +static constexpr std::string_view kPropertyBodyEmbedding = "bodyEmbedding"; +static constexpr std::string_view kPropertyTitle = "title"; +static constexpr std::string_view kPropertyTitleEmbedding = "titleEmbedding"; +static constexpr std::string_view kPropertyNonIndexableEmbedding = + "nonIndexableEmbedding"; + +static constexpr SectionId kSectionIdBodyEmbedding = 1; +static constexpr SectionId kSectionIdTitleEmbedding = 3; + +// Schema type with nested indexable properties: FakeCollectionType +// Section id = 0: "collection.body" +// Section id = 1: "collection.bodyEmbedding" +// Section id = 2: "collection.title" +// Section id = 3: "collection.titleEmbedding" +// Section id = 4: "fullDocEmbedding" +static constexpr std::string_view kFakeCollectionType = "FakeCollectionType"; +static constexpr std::string_view kPropertyCollection = "collection"; +static constexpr std::string_view kPropertyFullDocEmbedding = + "fullDocEmbedding"; + +static constexpr SectionId kSectionIdNestedBodyEmbedding = 1; +static constexpr SectionId kSectionIdNestedTitleEmbedding = 3; +static constexpr SectionId kSectionIdFullDocEmbedding = 4; + +class EmbeddingIndexingHandlerTest : public ::testing::Test { + protected: + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + base_dir_ = GetTestTempDir() + "/icing_test"; + ASSERT_THAT(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()), + IsTrue()); + + embedding_index_working_path_ = base_dir_ + "/embedding_index"; + schema_store_dir_ = base_dir_ + "/schema_store"; + document_store_dir_ = base_dir_ + "/document_store"; + + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_working_path_)); + + language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + lang_segmenter_, + language_segmenter_factory::Create(std::move(segmenter_options))); + + ASSERT_THAT( + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()), + IsTrue()); + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + SchemaProto schema = + SchemaBuilder() + .AddType( + SchemaTypeConfigBuilder() + .SetType(kFakeType) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyTitle) + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyBody) + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyTitleEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType:: + LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyBodyEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType:: + LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyNonIndexableEmbedding) + .SetDataType(TYPE_VECTOR) + .SetCardinality(CARDINALITY_REPEATED))) + .AddType(SchemaTypeConfigBuilder() + .SetType(kFakeCollectionType) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyCollection) + .SetDataTypeDocument( + kFakeType, + /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyFullDocEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig:: + EmbeddingIndexingType::LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ICING_ASSERT_OK(schema_store_->SetSchema( + schema, /*ignore_errors_and_delete_documents=*/false, + /*allow_circular_schema_definitions=*/false)); + + ASSERT_TRUE( + filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str())); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult doc_store_create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get(), + /*force_recovery_and_revalidate_documents=*/false, + /*namespace_id_fingerprint=*/false, + /*pre_mapping_fbv=*/false, + /*use_persistent_hash_map=*/false, + PortableFileBackedProtoLog< + DocumentWrapper>::kDeflateCompressionLevel, + /*initialize_stats=*/nullptr)); + document_store_ = std::move(doc_store_create_result.document_store); + } + + void TearDown() override { + document_store_.reset(); + schema_store_.reset(); + lang_segmenter_.reset(); + embedding_index_.reset(); + + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + uint32_t dimension, std::string_view model_signature) { + std::vector<EmbeddingHit> hits; + + libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = + embedding_index_->GetAccessor(dimension, model_signature); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + return hits; + } else { + return std::move(pl_accessor_or).status(); + } + + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + return hits; + } + hits.insert(hits.end(), batch.begin(), batch.end()); + } + } + + std::vector<float> GetRawEmbeddingData() { + return std::vector<float>(embedding_index_->GetRawEmbeddingData(), + embedding_index_->GetRawEmbeddingData() + + embedding_index_->GetTotalVectorSize()); + } + + Filesystem filesystem_; + FakeClock fake_clock_; + std::string base_dir_; + std::string embedding_index_working_path_; + std::string schema_store_dir_; + std::string document_store_dir_; + + std::unique_ptr<EmbeddingIndex> embedding_index_; + std::unique_ptr<LanguageSegmenter> lang_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> document_store_; +}; + +} // namespace + +TEST_F(EmbeddingIndexingHandlerTest, CreationWithNullPointerShouldFail) { + EXPECT_THAT(EmbeddingIndexingHandler::Create(/*clock=*/nullptr, + embedding_index_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + + EXPECT_THAT(EmbeddingIndexingHandler::Create(&fake_clock_, + /*embedding_index=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(EmbeddingIndexingHandlerTest, HandleEmbeddingSection) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kInvalidDocumentId)); + // Handle document. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + IsOk()); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); +} + +TEST_F(EmbeddingIndexingHandlerTest, HandleNestedEmbeddingSection) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_collection_type/1") + .SetSchema(std::string(kFakeCollectionType)) + .AddDocumentProperty( + std::string(kPropertyCollection), + DocumentBuilder() + .SetKey("icing", "nested_fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty( + std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build()) + .AddVectorProperty(std::string(kPropertyFullDocEmbedding), + CreateVector("model", {2.1, 2.2, 2.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kInvalidDocumentId)); + // Handle document. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + IsOk()); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit( + BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit( + BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit( + BasicHit(kSectionIdNestedTitleEmbedding, /*document_id=*/0), + /*location=*/6), + EmbeddingHit(BasicHit(kSectionIdFullDocEmbedding, /*document_id=*/0), + /*location=*/9)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, + 0.1, 0.2, 0.3, 2.1, 2.2, 2.3)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleInvalidDocumentIdShouldReturnInvalidArgumentError) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK(document_store_->Put(tokenized_document.document())); + + static constexpr DocumentId kCurrentDocumentId = 3; + embedding_index_->set_last_added_document_id(kCurrentDocumentId); + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handling document with kInvalidDocumentId should cause a failure, and both + // index data and last_added_document_id should remain unchanged. + EXPECT_THAT( + handler->Handle(tokenized_document, kInvalidDocumentId, + /*recovery_mode=*/false, /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + + // Recovery mode should get the same result. + EXPECT_THAT( + handler->Handle(tokenized_document, kInvalidDocumentId, + /*recovery_mode=*/true, /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handling document with document_id == last_added_document_id should cause a + // failure, and both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); + + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + + // Handling document with document_id < last_added_document_id should cause a + // failure, and both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id + 1); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1)); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1)); + + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId) { + DocumentProto document1 = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title one") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body one") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("icing", "fake_type/2") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title two") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {10.1, 10.2, 10.3})) + .AddStringProperty(std::string(kPropertyBody), "body two") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {10.4, 10.5, 10.6}), + CreateVector("model", {10.7, 10.8, 10.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {11.1, 11.2, 11.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document1, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document1))); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document2, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document2))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store_->Put(tokenized_document1.document())); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store_->Put(tokenized_document2.document())); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handle document with document_id > last_added_document_id in recovery mode. + // The handler should index this document and update last_added_document_id. + EXPECT_THAT( + handler->Handle(tokenized_document1, document_id1, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id1)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + + // Handle document with document_id == last_added_document_id in recovery + // mode. We should not get any error, but the handler should ignore the + // document, so both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id2); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2)); + EXPECT_THAT( + handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + + // Handle document with document_id < last_added_document_id in recovery mode. + // We should not get any error, but the handler should ignore the document, so + // both index data and last_added_document_id should remain unchanged. + embedding_index_->set_last_added_document_id(document_id2 + 1); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1)); + EXPECT_THAT( + handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index 83fff12..e971016 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -60,6 +60,7 @@ class BasicHit { explicit BasicHit(SectionId section_id, DocumentId document_id); explicit BasicHit() : value_(kInvalidValue) {} + explicit BasicHit(Value value) : value_(value) {} bool is_valid() const { return value_ != kInvalidValue; } Value value() const { return value_; } diff --git a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h index c16a1c4..d766712 100644 --- a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h +++ b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h @@ -17,13 +17,17 @@ #include <cstdint> #include <memory> +#include <set> #include <string> -#include <string_view> #include <utility> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" namespace icing { diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc index 682b752..735adaa 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc @@ -134,7 +134,9 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions( ChildrenMapper mapper; mapper = [&data, &mapper](std::unique_ptr<DocHitInfoIterator> iterator) -> std::unique_ptr<DocHitInfoIterator> { - if (iterator->is_leaf()) { + if (iterator->full_section_restriction_applied()) { + return iterator; + } else if (iterator->is_leaf()) { return std::make_unique<DocHitInfoIteratorSectionRestrict>( std::move(iterator), data); } else { @@ -149,24 +151,8 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { doc_hit_info_ = DocHitInfo(kInvalidDocumentId); while (delegate_->Advance().ok()) { DocumentId document_id = delegate_->doc_hit_info().document_id(); - - auto data_optional = data_->document_store().GetAliveDocumentFilterData( - document_id, data_->current_time_ms()); - if (!data_optional) { - // Ran into some error retrieving information on this hit, skip - continue; - } - - // Guaranteed that the DocumentFilterData exists at this point - SchemaTypeId schema_type_id = data_optional.value().schema_type_id(); - auto schema_type_or = data_->schema_store().GetSchemaType(schema_type_id); - if (!schema_type_or.ok()) { - // Ran into error retrieving schema type, skip - continue; - } - const std::string* schema_type = std::move(schema_type_or).ValueOrDie(); SectionIdMask allowed_sections_mask = - data_->ComputeAllowedSectionsMask(*schema_type); + data_->ComputeAllowedSectionsMask(document_id); // A hit can be in multiple sections at once, need to check which of the // section ids match the sections allowed by type_property_masks_. This can diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h index 728f957..921e4d4 100644 --- a/icing/index/iterator/doc-hit-info-iterator.h +++ b/icing/index/iterator/doc-hit-info-iterator.h @@ -214,6 +214,12 @@ class DocHitInfoIterator { virtual bool is_leaf() { return false; } + // Whether the iterator has already been applied all the required section + // restrictions. + // If true, calling DocHitInfoIteratorSectionRestrict::ApplyRestrictions on + // this iterator will have no effect. + virtual bool full_section_restriction_applied() const { return false; } + virtual ~DocHitInfoIterator() = default; // Returns: diff --git a/icing/index/iterator/section-restrict-data.cc b/icing/index/iterator/section-restrict-data.cc index 085437d..581ac8e 100644 --- a/icing/index/iterator/section-restrict-data.cc +++ b/icing/index/iterator/section-restrict-data.cc @@ -22,6 +22,8 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" namespace icing { namespace lib { @@ -78,5 +80,24 @@ SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask( return new_section_id_mask; } +SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask( + DocumentId document_id) { + auto data_optional = + document_store_.GetAliveDocumentFilterData(document_id, current_time_ms_); + if (!data_optional) { + // Ran into some error retrieving information on this document, skip + return kSectionIdMaskNone; + } + // Guaranteed that the DocumentFilterData exists at this point + SchemaTypeId schema_type_id = data_optional.value().schema_type_id(); + auto schema_type_or = schema_store_.GetSchemaType(schema_type_id); + if (!schema_type_or.ok()) { + // Ran into error retrieving schema type, skip + return kSectionIdMaskNone; + } + const std::string* schema_type = std::move(schema_type_or).ValueOrDie(); + return ComputeAllowedSectionsMask(*schema_type); +} + } // namespace lib } // namespace icing diff --git a/icing/index/iterator/section-restrict-data.h b/icing/index/iterator/section-restrict-data.h index 26ca597..64d5087 100644 --- a/icing/index/iterator/section-restrict-data.h +++ b/icing/index/iterator/section-restrict-data.h @@ -23,6 +23,7 @@ #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" namespace icing { @@ -48,10 +49,21 @@ class SectionRestrictData { // Returns: // - If type_property_filters_ has an entry for the given schema type or // wildcard(*), return a bitwise or of section IDs in the schema type - // that that are also present in the relevant filter list. + // that are also present in the relevant filter list. // - Otherwise, return kSectionIdMaskAll. SectionIdMask ComputeAllowedSectionsMask(const std::string& schema_type); + // Calculates the section mask of allowed sections(determined by the + // property filters map) for the given document id, by retrieving its schema + // type name and calling the above method. + // + // Returns: + // - If type_property_filters_ has an entry for the given document's schema + // type or wildcard(*), return a bitwise or of section IDs in the schema + // type that are also present in the relevant filter list. + // - Otherwise, return kSectionIdMaskAll. + SectionIdMask ComputeAllowedSectionsMask(DocumentId document_id); + const DocumentStore& document_store() const { return document_store_; } const SchemaStore& schema_store() const { return schema_store_; } diff --git a/icing/query/advanced_query_parser/function.cc b/icing/query/advanced_query_parser/function.cc index e7938db..0ff160b 100644 --- a/icing/query/advanced_query_parser/function.cc +++ b/icing/query/advanced_query_parser/function.cc @@ -13,8 +13,15 @@ // limitations under the License. #include "icing/query/advanced_query_parser/function.h" +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/query/advanced_query_parser/param.h" +#include "icing/query/advanced_query_parser/pending-value.h" #include "icing/util/status-macros.h" namespace icing { @@ -73,5 +80,20 @@ libtextclassifier3::StatusOr<PendingValue> Function::Eval( return eval_(std::move(args)); } +libtextclassifier3::StatusOr<DataType> Function::get_param_type(int i) const { + if (i < 0 || params_.empty()) { + return absl_ports::OutOfRangeError("Invalid argument index."); + } + const Param* parm; + if (i < params_.size()) { + parm = ¶ms_.at(i); + } else if (params_.back().cardinality == Cardinality::kVariable) { + parm = ¶ms_.back(); + } else { + return absl_ports::OutOfRangeError("Invalid argument index."); + } + return parm->data_type; +} + } // namespace lib -} // namespace icing
\ No newline at end of file +} // namespace icing diff --git a/icing/query/advanced_query_parser/function.h b/icing/query/advanced_query_parser/function.h index 3514878..08cc7e8 100644 --- a/icing/query/advanced_query_parser/function.h +++ b/icing/query/advanced_query_parser/function.h @@ -46,6 +46,8 @@ class Function { libtextclassifier3::StatusOr<PendingValue> Eval( std::vector<PendingValue>&& args) const; + libtextclassifier3::StatusOr<DataType> get_param_type(int i) const; + private: Function(DataType return_type, std::string name, std::vector<Param> params, EvalFunction eval) diff --git a/icing/query/advanced_query_parser/param.h b/icing/query/advanced_query_parser/param.h index 69c46be..9ea1915 100644 --- a/icing/query/advanced_query_parser/param.h +++ b/icing/query/advanced_query_parser/param.h @@ -35,13 +35,19 @@ struct Param { libtextclassifier3::Status Matches(PendingValue& arg) const { bool matches = arg.data_type() == data_type; - // Values of type kText could also potentially be valid kLong values. If - // we're expecting a kLong and we have a kText, try to parse it as a kLong. + // Values of type kText could also potentially be valid kLong or kDouble + // values. If we're expecting a kLong or kDouble and we have a kText, try to + // parse it as what we expect. if (!matches && data_type == DataType::kLong && arg.data_type() == DataType::kText) { ICING_RETURN_IF_ERROR(arg.ParseInt()); matches = true; } + if (!matches && data_type == DataType::kDouble && + arg.data_type() == DataType::kText) { + ICING_RETURN_IF_ERROR(arg.ParseDouble()); + matches = true; + } return matches ? libtextclassifier3::Status::OK : absl_ports::InvalidArgumentError( "Provided arg doesn't match required param type."); diff --git a/icing/query/advanced_query_parser/pending-value.cc b/icing/query/advanced_query_parser/pending-value.cc index 67bdc3a..a3f95d9 100644 --- a/icing/query/advanced_query_parser/pending-value.cc +++ b/icing/query/advanced_query_parser/pending-value.cc @@ -13,7 +13,11 @@ // limitations under the License. #include "icing/query/advanced_query_parser/pending-value.h" +#include <cstdlib> + +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" namespace icing { namespace lib { @@ -40,5 +44,27 @@ libtextclassifier3::Status PendingValue::ParseInt() { return libtextclassifier3::Status::OK; } +libtextclassifier3::Status PendingValue::ParseDouble() { + if (data_type_ == DataType::kDouble) { + return libtextclassifier3::Status::OK; + } else if (data_type_ != DataType::kText) { + return absl_ports::InvalidArgumentError("Cannot parse value as double"); + } + if (query_term_.is_prefix_val) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Cannot use prefix operator '*' with numeric value: ", + query_term_.term)); + } + char* value_end; + double_val_ = std::strtod(query_term_.term.c_str(), &value_end); + if (value_end != query_term_.term.c_str() + query_term_.term.length()) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Unable to parse \"", query_term_.term, "\" as double.")); + } + data_type_ = DataType::kDouble; + query_term_ = {/*term=*/"", /*raw_term=*/"", /*is_prefix_val=*/false}; + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/query/advanced_query_parser/pending-value.h b/icing/query/advanced_query_parser/pending-value.h index 1a6717e..34912f3 100644 --- a/icing/query/advanced_query_parser/pending-value.h +++ b/icing/query/advanced_query_parser/pending-value.h @@ -14,12 +14,16 @@ #ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_ #define ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_ +#include <cstdint> #include <memory> #include <string> +#include <string_view> #include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/util/status-macros.h" @@ -30,10 +34,14 @@ namespace lib { enum class DataType { kNone, kLong, + kDouble, kText, kString, kStringList, kDocumentIterator, + // TODO(b/326656531): Instead of creating a vector index type, consider + // changing it to vector type so that the data is the vector directly. + kVectorIndex, }; struct QueryTerm { @@ -52,6 +60,10 @@ struct PendingValue { return PendingValue(std::move(text), DataType::kText); } + static PendingValue CreateVectorIndexPendingValue(int64_t vector_index) { + return PendingValue(vector_index, DataType::kVectorIndex); + } + PendingValue() : data_type_(DataType::kNone) {} explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator) @@ -111,6 +123,16 @@ struct PendingValue { return long_val_; } + libtextclassifier3::StatusOr<double> double_val() { + ICING_RETURN_IF_ERROR(ParseDouble()); + return double_val_; + } + + libtextclassifier3::StatusOr<int64_t> vector_index_val() const { + ICING_RETURN_IF_ERROR(CheckDataType(DataType::kVectorIndex)); + return long_val_; + } + // Attempts to interpret the value as an int. A pending value can be parsed as // an int under two circumstances: // 1. It holds a kText value which can be parsed to an int @@ -122,12 +144,26 @@ struct PendingValue { // - INVALID_ARGUMENT if the value could not be parsed as a long libtextclassifier3::Status ParseInt(); + // Attempts to interpret the value as a double. A pending value can be parsed + // as a double under two circumstances: + // 1. It holds a kText value which can be parsed to a double + // 2. It holds a kDouble value + // If #1 is true, then the parsed value will be stored in double_val_ and + // data_type will be updated to kDouble. + // RETURNS: + // - OK, if able to successfully parse the value into a double + // - INVALID_ARGUMENT if the value could not be parsed as a double + libtextclassifier3::Status ParseDouble(); + DataType data_type() const { return data_type_; } private: explicit PendingValue(QueryTerm query_term, DataType data_type) : query_term_(std::move(query_term)), data_type_(data_type) {} + explicit PendingValue(int64_t long_val, DataType data_type) + : long_val_(long_val), data_type_(data_type) {} + libtextclassifier3::Status CheckDataType(DataType required_data_type) const { if (data_type_ == required_data_type) { return libtextclassifier3::Status::OK; @@ -151,6 +187,9 @@ struct PendingValue { // long_val_ will be populated when data_type_ is kLong - after a successful // call to ParseInt. int64_t long_val_; + // double_val_ will be populated when data_type_ is kDouble - after a + // successful call to ParseDouble. + double double_val_; DataType data_type_; }; diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc index 31da959..1ac52c5 100644 --- a/icing/query/advanced_query_parser/query-visitor.cc +++ b/icing/query/advanced_query_parser/query-visitor.cc @@ -16,20 +16,26 @@ #include <algorithm> #include <cstdint> -#include <cstdlib> #include <iterator> #include <limits> #include <memory> #include <set> #include <string> +#include <string_view> +#include <unordered_map> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" +#include "icing/index/embed/doc-hit-info-iterator-embedding.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h" #include "icing/index/iterator/doc-hit-info-iterator-and.h" +#include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator-none.h" #include "icing/index/iterator/doc-hit-info-iterator-not.h" #include "icing/index/iterator/doc-hit-info-iterator-or.h" @@ -37,17 +43,23 @@ #include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h" #include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/iterator/section-restrict-data.h" #include "icing/index/property-existence-indexing-handler.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/function.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/param.h" #include "icing/query/advanced_query_parser/parser.h" #include "icing/query/advanced_query_parser/pending-value.h" #include "icing/query/advanced_query_parser/util/string-util.h" #include "icing/query/query-features.h" +#include "icing/query/query-results.h" #include "icing/schema/property-util.h" +#include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer.h" +#include "icing/util/embedding-util.h" #include "icing/util/status-macros.h" namespace icing { @@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() { .ValueOrDie(); registered_functions_.insert( {has_property_function.name(), std::move(has_property_function)}); + + // vector_index getSearchSpecEmbedding(long); + auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) { + return PendingValue::CreateVectorIndexPendingValue( + args.at(0).long_val().ValueOrDie()); + }; + Function get_search_spec_embedding_function = + Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding", + {Param(DataType::kLong)}, + std::move(get_search_spec_embedding)) + .ValueOrDie(); + registered_functions_.insert({get_search_spec_embedding_function.name(), + std::move(get_search_spec_embedding_function)}); + + // DocHitInfoIterator semanticSearch(vector_index, double, double, string); + auto semantic_search = [this](std::vector<PendingValue>&& args) { + return this->SemanticSearchFunction(std::move(args)); + }; + Function semantic_search_function = + Function::Create(DataType::kDocumentIterator, "semanticSearch", + {Param(DataType::kVectorIndex), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kString, Cardinality::kOptional)}, + std::move(semantic_search)) + .ValueOrDie(); + registered_functions_.insert( + {semantic_search_function.name(), std::move(semantic_search_function)}); } libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( @@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( document_store_.last_added_document_id()); } else { QueryVisitor query_visitor( - &index_, &numeric_index_, &document_store_, &schema_store_, - &normalizer_, &tokenizer_, query->raw_term, filter_options_, - match_type_, needs_term_frequency_info_, pending_property_restricts_, - processing_not_, current_time_ms_); + &index_, &numeric_index_, &embedding_index_, &document_store_, + &schema_store_, &normalizer_, &tokenizer_, query->raw_term, + embedding_query_vectors_, filter_options_, match_type_, + embedding_query_metric_type_, needs_term_frequency_info_, + pending_property_restricts_, processing_not_, current_time_ms_); tree_root->Accept(&query_visitor); ICING_ASSIGN_OR_RETURN(query_result, std::move(query_visitor).ConsumeResults()); @@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction( return PendingValue(std::move(property_in_document_iterator)); } +libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction( + std::vector<PendingValue>&& args) { + features_.insert(kEmbeddingSearchFeature); + + int64_t vector_index = args.at(0).vector_index_val().ValueOrDie(); + if (embedding_query_vectors_ == nullptr || vector_index < 0 || + vector_index >= embedding_query_vectors_->size()) { + return absl_ports::InvalidArgumentError("Got invalid vector search index!"); + } + + // Handle default values for the optional arguments. + double low = -std::numeric_limits<double>::infinity(); + double high = std::numeric_limits<double>::infinity(); + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + embedding_query_metric_type_; + if (args.size() >= 2) { + low = args.at(1).double_val().ValueOrDie(); + } + if (args.size() >= 3) { + high = args.at(2).double_val().ValueOrDie(); + } + if (args.size() >= 4) { + const std::string& metric = args.at(3).string_val().ValueOrDie()->term; + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + + // Create SectionRestrictData for section restriction. + std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr; + if (pending_property_restricts_.has_active_property_restricts()) { + std::unordered_map<std::string, std::set<std::string>> + type_property_filters; + type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] = + pending_property_restricts_.active_property_restricts(); + section_restrict_data = std::make_unique<SectionRestrictData>( + &document_store_, &schema_store_, current_time_ms_, + type_property_filters); + } + + // Create and return iterator. + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map = + &embedding_query_results_.result_scores[vector_index][metric_type]; + ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator, + DocHitInfoIteratorEmbedding::Create( + &embedding_query_vectors_->at(vector_index), + std::move(section_restrict_data), metric_type, low, + high, score_map, &embedding_index_)); + return PendingValue(std::move(iterator)); +} + libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() { if (pending_values_.empty()) { return absl_ports::InvalidArgumentError("Unable to retrieve int value."); @@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() { // raw_text, then all of raw_text must correspond to this token. raw_token = raw_text; } else { - ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken( - raw_text, token.text)); + ICING_ASSIGN_OR_RETURN( + raw_token, string_util::FindEscapedToken(raw_text, token.text)); } normalized_term = normalizer_.NormalizeTerm(token.text); QueryTerm term_value{std::move(normalized_term), raw_token, @@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator( "Visit unary operator child didn't correctly add pending values."); } - // 3. We want to preserve the original text of the integer value, append our - // minus and *then* parse as an int. - ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue()); - int_text_val.term = absl_ports::StrCat("-", int_text_val.term); + // 3. We want to preserve the original text of the numeric value, append our + // minus to the text. It will be parsed as either an int or a double later. + ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue()); + numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term); PendingValue pending_value = - PendingValue::CreateTextPendingValue(std::move(int_text_val)); - ICING_RETURN_IF_ERROR(pending_value.long_val()); + PendingValue::CreateTextPendingValue(std::move(numeric_text_val)); - // We've parsed our integer value successfully. Pop our placeholder, push it + // We've parsed our numeric value successfully. Pop our placeholder, push it // on to the stack and return successfully. if (!pending_values_.top().is_placeholder()) { return absl_ports::InvalidArgumentError( @@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) { end = text_val.raw_term.data() + text_val.raw_term.length(); } else { start = std::min(start, text_val.raw_term.data()); - end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length()); + end = std::max(end, + text_val.raw_term.data() + text_val.raw_term.length()); } members.push_back(std::move(text_val.term)); } @@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { "Function ", node->function_name()->value(), " is not supported.")); return; } + const Function& function = itr->second; // 2. Put in a placeholder PendingValue pending_values_.push(PendingValue()); // 3. Visit the children. - for (const std::unique_ptr<Node>& arg : node->args()) { + expecting_numeric_arg_ = true; + for (int i = 0; i < node->args().size(); ++i) { + const std::unique_ptr<Node>& arg = node->args()[i]; + libtextclassifier3::StatusOr<DataType> arg_type_or = + function.get_param_type(i); + bool current_level_expecting_numeric_arg = expecting_numeric_arg_; + // If arg_type_or has an error, we should ignore it for now, since + // function.Eval should do the type check and return better error messages. + if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong || + arg_type_or.ValueOrDie() == DataType::kDouble)) { + expecting_numeric_arg_ = true; + } arg->Accept(this); + expecting_numeric_arg_ = current_level_expecting_numeric_arg; if (has_pending_error()) { return; } @@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { pending_values_.pop(); } std::reverse(args.begin(), args.end()); - const Function& function = itr->second; auto eval_result = function.Eval(std::move(args)); if (!eval_result.ok()) { pending_error_ = std::move(eval_result).status(); @@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && { results.root_iterator = std::move(iterator_or).ValueOrDie(); results.query_term_iterators = std::move(query_term_iterators_); results.query_terms = std::move(property_query_terms_map_); + results.embedding_query_results = std::move(embedding_query_results_); results.features_in_use = std::move(features_); return results; } diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h index d090b3c..17149f5 100644 --- a/icing/query/advanced_query_parser/query-visitor.h +++ b/icing/query/advanced_query_parser/query-visitor.h @@ -17,13 +17,19 @@ #include <cstdint> #include <memory> +#include <set> #include <stack> #include <string> +#include <string_view> +#include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -33,10 +39,12 @@ #include "icing/query/advanced_query_parser/pending-value.h" #include "icing/query/query-features.h" #include "icing/query/query-results.h" +#include "icing/query/query-terms.h" #include "icing/schema/schema-store.h" #include "icing/store/document-store.h" #include "icing/tokenization/tokenizer.h" #include "icing/transform/normalizer.h" +#include <google/protobuf/repeated_field.h> namespace icing { namespace lib { @@ -45,19 +53,23 @@ namespace lib { // the parser. class QueryVisitor : public AbstractSyntaxTreeVisitor { public: - explicit QueryVisitor(Index* index, - const NumericIndex<int64_t>* numeric_index, - const DocumentStore* document_store, - const SchemaStore* schema_store, - const Normalizer* normalizer, - const Tokenizer* tokenizer, - std::string_view raw_query_text, - DocHitInfoIteratorFilter::Options filter_options, - TermMatchType::Code match_type, - bool needs_term_frequency_info, int64_t current_time_ms) - : QueryVisitor(index, numeric_index, document_store, schema_store, - normalizer, tokenizer, raw_query_text, filter_options, - match_type, needs_term_frequency_info, + explicit QueryVisitor( + Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, + const DocumentStore* document_store, const SchemaStore* schema_store, + const Normalizer* normalizer, const Tokenizer* tokenizer, + std::string_view raw_query_text, + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors, + DocHitInfoIteratorFilter::Options filter_options, + TermMatchType::Code match_type, + SearchSpecProto::EmbeddingQueryMetricType::Code + embedding_query_metric_type, + bool needs_term_frequency_info, int64_t current_time_ms) + : QueryVisitor(index, numeric_index, embedding_index, document_store, + schema_store, normalizer, tokenizer, raw_query_text, + embedding_query_vectors, filter_options, match_type, + embedding_query_metric_type, needs_term_frequency_info, PendingPropertyRestricts(), /*processing_not=*/false, current_time_ms) {} @@ -106,22 +118,31 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { explicit QueryVisitor( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const DocumentStore* document_store, const SchemaStore* schema_store, const Normalizer* normalizer, const Tokenizer* tokenizer, std::string_view raw_query_text, + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors, DocHitInfoIteratorFilter::Options filter_options, - TermMatchType::Code match_type, bool needs_term_frequency_info, + TermMatchType::Code match_type, + SearchSpecProto::EmbeddingQueryMetricType::Code + embedding_query_metric_type, + bool needs_term_frequency_info, PendingPropertyRestricts pending_property_restricts, bool processing_not, int64_t current_time_ms) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), document_store_(*document_store), schema_store_(*schema_store), normalizer_(*normalizer), tokenizer_(*tokenizer), raw_query_text_(raw_query_text), + embedding_query_vectors_(embedding_query_vectors), filter_options_(std::move(filter_options)), match_type_(match_type), + embedding_query_metric_type_(embedding_query_metric_type), needs_term_frequency_info_(needs_term_frequency_info), pending_property_restricts_(std::move(pending_property_restricts)), processing_not_(processing_not), @@ -264,6 +285,22 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { libtextclassifier3::StatusOr<PendingValue> HasPropertyFunction( std::vector<PendingValue>&& args); + // Implementation of the semanticSearch(vector, low, high, metric) custom + // function. This function is used for supporting vector search with a + // syntax like `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`. + // + // low, high, metric are optional parameters: + // - low is default to negative infinity + // - high is default to positive infinity + // - metric is default to the metric specified in SearchSpec + // + // Returns: + // - a Pending Value of type DocHitIterator that returns all documents with + // an embedding vector that has a score within [low, high]. + // - any errors returned by Lexer::ExtractTokens + libtextclassifier3::StatusOr<PendingValue> SemanticSearchFunction( + std::vector<PendingValue>&& args); + // Handles a NaryOperatorNode where the operator is HAS (':') and pushes an // iterator with the proper section filter applied. If the current property // restriction represented by pending_property_restricts and the first child @@ -292,19 +329,26 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { SectionRestrictQueryTermsMap property_query_terms_map_; QueryTermIteratorsMap query_term_iterators_; + + EmbeddingQueryResults embedding_query_results_; + // Set of features invoked in the query. std::unordered_set<Feature> features_; Index& index_; // Does not own! const NumericIndex<int64_t>& numeric_index_; // Does not own! + const EmbeddingIndex& embedding_index_; // Does not own! const DocumentStore& document_store_; // Does not own! const SchemaStore& schema_store_; // Does not own! const Normalizer& normalizer_; // Does not own! const Tokenizer& tokenizer_; // Does not own! std::string_view raw_query_text_; + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors_; // Nullable, does not own! DocHitInfoIteratorFilter::Options filter_options_; TermMatchType::Code match_type_; + SearchSpecProto::EmbeddingQueryMetricType::Code embedding_query_metric_type_; // Whether or not term_frequency information is needed. This affects: // - how DocHitInfoIteratorTerms are constructed // - whether the QueryTermIteratorsMap is populated in the QueryResults. diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc index 9455baa..c5ba866 100644 --- a/icing/query/advanced_query_parser/query-visitor_test.cc +++ b/icing/query/advanced_query_parser/query-visitor_test.cc @@ -15,6 +15,7 @@ #include "icing/query/advanced_query_parser/query-visitor.h" #include <cstdint> +#include <initializer_list> #include <limits> #include <memory> #include <string> @@ -31,6 +32,7 @@ #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/hit/hit.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-filter.h" @@ -42,6 +44,7 @@ #include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" +#include "icing/proto/search.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" @@ -54,6 +57,7 @@ #include "icing/store/document-store.h" #include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" #include "icing/testing/test-data.h" @@ -67,16 +71,22 @@ #include "icing/util/clock.h" #include "icing/util/status-macros.h" #include "unicode/uloc.h" +#include <google/protobuf/repeated_field.h> namespace icing { namespace lib { namespace { +using ::testing::DoubleNear; using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Pointee; using ::testing::UnorderedElementsAre; +constexpr float kEps = 0.000001; + constexpr DocumentId kDocumentId0 = 0; constexpr DocumentId kDocumentId1 = 1; constexpr DocumentId kDocumentId2 = 2; @@ -85,6 +95,18 @@ constexpr SectionId kSectionId0 = 0; constexpr SectionId kSectionId1 = 1; constexpr SectionId kSectionId2 = 2; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_UNKNOWN = + SearchSpecProto::EmbeddingQueryMetricType::UNKNOWN; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_COSINE = SearchSpecProto::EmbeddingQueryMetricType::COSINE; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_DOT_PRODUCT = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_EUCLIDEAN = + SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN; + template <typename T, typename U> std::vector<T> ExtractKeys(const std::unordered_map<T, U>& map) { std::vector<T> keys; @@ -106,6 +128,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { test_dir_ = GetTestTempDir() + "/icing"; index_dir_ = test_dir_ + "/index"; numeric_index_dir_ = test_dir_ + "/numeric_index"; + embedding_index_dir_ = test_dir_ + "/embedding_index"; store_dir_ = test_dir_ + "/store"; schema_store_dir_ = test_dir_ + "/schema_store"; filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -154,6 +177,10 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( /*max_term_byte_size=*/1000)); @@ -219,6 +246,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { std::string test_dir_; std::string index_dir_; std::string numeric_index_dir_; + std::string embedding_index_dir_; std::string schema_store_dir_; std::string store_dir_; Clock clock_; @@ -226,6 +254,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { std::unique_ptr<DocumentStore> document_store_; std::unique_ptr<Index> index_; std::unique_ptr<DummyNumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<Normalizer> normalizer_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Tokenizer> tokenizer_; @@ -252,9 +281,11 @@ TEST_P(QueryVisitorTest, SimpleLessThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -295,9 +326,11 @@ TEST_P(QueryVisitorTest, SimpleLessThanEq) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -338,9 +371,11 @@ TEST_P(QueryVisitorTest, SimpleEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -381,9 +416,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThanEq) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -424,9 +461,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -468,9 +507,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -512,9 +553,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -557,9 +600,11 @@ TEST_P(QueryVisitorTest, NestedPropertyLessThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -585,9 +630,11 @@ TEST_P(QueryVisitorTest, IntParsingError) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -599,9 +646,11 @@ TEST_P(QueryVisitorTest, NotEqualsUnsupported) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -647,9 +696,11 @@ TEST_P(QueryVisitorTest, LessThanTooManyOperandsInvalid) { args.push_back(std::move(extra_value_node)); auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -674,9 +725,11 @@ TEST_P(QueryVisitorTest, LessThanTooFewOperandsInvalid) { args.push_back(std::move(member_node)); auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -705,9 +758,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -727,9 +782,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { TEST_P(QueryVisitorTest, NeverVisitedReturnsInvalid) { QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), "", + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), "", /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); @@ -756,9 +813,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -786,9 +845,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -801,9 +862,11 @@ TEST_P(QueryVisitorTest, NumericComparisonPropertyStringIsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -865,9 +928,11 @@ TEST_P(QueryVisitorTest, NumericComparatorDoesntAffectLaterTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -908,9 +973,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyEnabled) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -960,9 +1027,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyDisabled) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/false, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1012,9 +1081,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1028,9 +1099,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) { query = CreateQuery("fo*"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1048,9 +1121,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1062,9 +1137,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterNumericValueReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1076,9 +1153,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyRestrictReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1114,9 +1193,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1137,9 +1218,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) { query = CreateQuery("ba?fo*"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1174,9 +1257,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1221,9 +1306,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTermPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1274,9 +1361,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingQuote) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1326,9 +1415,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingEscape) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1380,9 +1471,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1407,9 +1500,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { query = CreateQuery(R"(("foobar\\y"))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1462,9 +1557,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1488,9 +1585,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) { query = CreateQuery(R"(("foobar\\n"))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1537,9 +1636,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingComplex) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1596,9 +1697,11 @@ TEST_P(QueryVisitorTest, SingleMinusTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1650,9 +1753,11 @@ TEST_P(QueryVisitorTest, SingleNotTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1705,9 +1810,11 @@ TEST_P(QueryVisitorTest, NestedNotTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1774,9 +1881,11 @@ TEST_P(QueryVisitorTest, DeeplyNestedNotTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1813,9 +1922,11 @@ TEST_P(QueryVisitorTest, ImplicitAndTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1856,9 +1967,11 @@ TEST_P(QueryVisitorTest, ExplicitAndTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1899,9 +2012,11 @@ TEST_P(QueryVisitorTest, OrTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1944,9 +2059,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1969,9 +2086,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { query = CreateQuery("bar OR baz foo"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1993,9 +2112,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { query = CreateQuery("(bar OR baz) foo"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2055,9 +2176,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2075,9 +2198,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) { query = CreateQuery("foo NOT (bar OR baz)"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2140,9 +2265,11 @@ TEST_P(QueryVisitorTest, PropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2216,9 +2343,11 @@ TEST_F(QueryVisitorTest, MultiPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2259,9 +2388,11 @@ TEST_P(QueryVisitorTest, PropertyFilterStringIsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -2315,9 +2446,11 @@ TEST_P(QueryVisitorTest, PropertyFilterNonNormalized) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2386,9 +2519,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithGrouping) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2453,9 +2588,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2474,9 +2611,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) { /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2540,9 +2679,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2561,9 +2702,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) { /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2627,9 +2770,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2648,9 +2793,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) { "NOT ", CreateQuery("(foo OR bar)", /*property_restrict=*/"prop1")); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2728,9 +2875,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2755,9 +2904,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) { query = CreateQuery("(NOT foo OR bar)", /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2837,9 +2988,11 @@ TEST_P(QueryVisitorTest, SegmentationTest) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2957,9 +3110,11 @@ TEST_P(QueryVisitorTest, PropertyRestrictsPopCorrectly) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3074,9 +3229,11 @@ TEST_P(QueryVisitorTest, UnsatisfiablePropertyRestrictsPopCorrectly) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3098,9 +3255,11 @@ TEST_F(QueryVisitorTest, UnsupportedFunctionReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3112,9 +3271,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooFewArgumentsReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3126,9 +3287,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooManyArgumentsReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3142,9 +3305,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3154,9 +3319,11 @@ TEST_F(QueryVisitorTest, query = R"(search(createList("subject")))"; ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(), @@ -3170,9 +3337,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3182,9 +3351,11 @@ TEST_F(QueryVisitorTest, query = R"(search("foo", 7))"; ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(), @@ -3197,9 +3368,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3260,9 +3433,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3284,9 +3459,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { R"(", createList("prop1")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3308,9 +3485,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { R"(", createList("prop1")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_four_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_four_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_four_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3430,9 +3609,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_one_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3462,9 +3643,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { R"(", createList("prop6", "prop0", "prop4", "prop2")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3488,9 +3671,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { R"(", createList("prop0", "prop6")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3610,9 +3795,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_one_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3634,9 +3821,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { R"(", createList("prop6", "prop0", "prop4", "prop2")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3659,9 +3848,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { R"( "prop0", "prop6", "prop4", "prop7")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3685,9 +3876,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3701,9 +3894,11 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3717,9 +3912,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3732,9 +3929,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3786,9 +3985,11 @@ TEST_P(QueryVisitorTest, PropertyDefinedFunctionReturnsMatchingDocuments) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3839,9 +4040,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3890,9 +4093,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3910,9 +4115,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3925,9 +4132,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3941,9 +4150,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3956,9 +4167,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -4015,9 +4228,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor1( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor1); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -4033,9 +4248,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) { query = CreateQuery("bar OR NOT hasProperty(\"price\")"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor2( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor2); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -4088,9 +4305,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -4102,6 +4321,890 @@ TEST_P(QueryVisitorTest, EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); } +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithNoArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = "semanticSearch()"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithIncorrectArgumentTypeReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = "semanticSearch(0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithExtraArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, \"COSINE\", 0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + GetSearchSpecEmbeddingFunctionWithExtraArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query index is invalid, since there are only 2 queries. + std::string query = "semanticSearch(getSearchSpecEmbedding(0, 1))"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithInvalidIndexReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query index is invalid, since there are only 2 queries. + std::string query = "semanticSearch(getSearchSpecEmbedding(10))"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithInvalidMetricReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query metric is invalid. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 10, \"UNKNOWN\")"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // Passing an unknown default metric type without overriding it in the query + // expression is also considered invalid. + query = "semanticSearch(getSearchSpecEmbedding(0), -10, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + EXPECT_THAT(std::move(query_visitor2).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleLowerBound) { + // Index two embedding vectors. + PropertyProto::VectorProto vector0 = + CreateVector("my_model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector1 = + CreateVector("my_model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector0)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has a semantic score of 1 with vector0 and + // -1 with vector1. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // The query should match vector0 only. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0.5)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + + // The query should match both vector0 and vector1. + query = "semanticSearch(getSearchSpecEmbedding(0), -1.5)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match nothing, since there is no vector with a + // score >= 1.01. + query = "semanticSearch(getSearchSpecEmbedding(0), 1.01)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleUpperBound) { + // Index two embedding vectors. + PropertyProto::VectorProto vector0 = + CreateVector("my_model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector1 = + CreateVector("my_model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector0)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has a semantic score of 1 with vector0 and + // -1 with vector1. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // The query should match vector1 only. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), -100, 0.5)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match both vector0 and vector1. + query = "semanticSearch(getSearchSpecEmbedding(0), -100, 1.5)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match nothing, since there is no vector with a + // score <= -1.01. + query = "semanticSearch(getSearchSpecEmbedding(0), -100, -1.01)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionMetricOverride) { + // Index a embedding vector. + PropertyProto::VectorProto vector = CreateVector("my_model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has: + // - a cosine semantic score of 1 + // - a dot product semantic score of 0.14 + // - a euclidean semantic score of 0 + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // Create a query that overrides the metric to COSINE. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), 0.95, 1.05, \"COSINE\")"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + + // Create a query that overrides the metric to DOT_PRODUCT. + query = + "semanticSearch(getSearchSpecEmbedding(0), 0.1, 0.2, \"DOT_PRODUCT\")"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(0.14, kEps)))); + + // Create a query that overrides the metric to EUCLIDEAN. + query = + "semanticSearch(getSearchSpecEmbedding(0), -0.05, 0.05, \"EUCLIDEAN\")"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_UNKNOWN, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_EUCLIDEAN, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(0, kEps)))); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionMultipleQueries) { + // Index 3 embedding vectors for document 0. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId2, kDocumentId0), + CreateVector("my_model2", {-1, 2, 3, -4}))); + // Index 2 embedding vectors for document 1. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId1), + CreateVector("my_model2", {1, -2, 3, -4}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + // Semantic scores for this query: + // - document 0: -2 (section 0), 0 (section 1) + // - document 1: 6 (section 0) + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + // Semantic scores for this query: + // - document 0: 4 (section 2) + // - document 1: -2 (section 1) + embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model2", {-1, 1, -1, -1}); + + // The query can only match document 0: + // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part should match + // semantic scores {-2, 0}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0)" part should match + // semantic scores {4}. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -5) AND " + "semanticSearch(getSearchSpecEmbedding(1), 0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId0, kSectionId1, + kSectionId2})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(4))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // The query can match both document 0 and document 1: + // For document 0: + // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return + // semantic scores {}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return + // semantic scores {4}. + // For document 1: + // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return + // semantic scores {6}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return + // semantic scores {}. + query = + "semanticSearch(getSearchSpecEmbedding(0), 1) OR " + "semanticSearch(getSearchSpecEmbedding(1), 0.1)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + IsNull()); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId2})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + IsNull()); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(4))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionMultipleQueriesScoresMergedRepeat) { + // Index 3 embedding vectors for document 0. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, -3}))); + // Index 2 embedding vectors for document 1. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + // Semantic scores for this query: + // - document 0: -2 (section 0), 0 (section 1) + // - document 1: 6 (section 0) + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // The query should match both document 0 and document 1, since the overall + // range is [-10, 10]. The scores in the results should be merged. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 0) OR " + "semanticSearch(getSearchSpecEmbedding(0), 0.0001, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // The same query appears twice, in which case all the scores in the results + // should repeat twice. + query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 10) OR " + "semanticSearch(getSearchSpecEmbedding(0), -10, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6, 6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0, -2, 0))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionHybridQueries) { + // Index terms + Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1, + TERM_MATCH_PREFIX, /*namespace_id=*/0); + ICING_ASSERT_OK(editor.BufferTerm("foo")); + ICING_ASSERT_OK(editor.IndexAllBufferedTerms()); + editor = index_->Edit(kDocumentId1, kSectionId1, TERM_MATCH_PREFIX, + /*namespace_id=*/0); + ICING_ASSERT_OK(editor.BufferTerm("bar")); + ICING_ASSERT_OK(editor.IndexAllBufferedTerms()); + + // Index embedding vectors + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query with semantic scores: + // - document 0: -2 + // - document 1: 6 + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // Perform a hybrid search: + // - The "semanticSearch(getSearchSpecEmbedding(0), 0)" part only matches + // document 1. + // - The "foo" part only matches document 0. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0) OR foo"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), + UnorderedElementsAre("foo")); + EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre("")); + EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo")); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + IsNull()); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Perform another hybrid search: + // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part matches both + // document 0 and 1. + // - The "foo" part only matches document 0. + // As a result, only document 0 will be returned. + query = "semanticSearch(getSearchSpecEmbedding(0), -5) AND foo"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), + UnorderedElementsAre("foo")); + EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre("")); + EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo")); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSectionRestriction) { + ICING_ASSERT_OK(schema_store_->SetSchema( + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("type") + .AddProperty(PropertyConfigBuilder() + .SetName("prop1") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("prop2") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(), + /*ignore_errors_and_delete_documents=*/false, + /*allow_circular_schema_definitions=*/false)); + + // Create two documents. + ICING_ASSERT_OK(document_store_->Put( + DocumentBuilder().SetKey("ns", "uri0").SetSchema("type").Build())); + ICING_ASSERT_OK(document_store_->Put( + DocumentBuilder().SetKey("ns", "uri1").SetSchema("type").Build())); + // Add embedding vectors into different sections for the two documents. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK( + embedding_index_->BufferEmbedding(BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, 2, 3}))); + ICING_ASSERT_OK( + embedding_index_->BufferEmbedding(BasicHit(kSectionId1, kDocumentId1), + CreateVector("my_model1", {1, 2, -3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query with semantic scores: + // - document 0: -2 (section 0), 6 (section 1) + // - document 1: 2 (section 0), -6 (section 1) + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // An embedding query with section restriction. The scores returned should + // only be limited to the section restricted. + std::string query = "prop1:semanticSearch(getSearchSpecEmbedding(0), -100)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(2))); + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + INSTANTIATE_TEST_SUITE_P(QueryVisitorTest, QueryVisitorTest, testing::Values(QueryType::kPlain, QueryType::kSearch)); diff --git a/icing/query/query-features.h b/icing/query/query-features.h index d829cd7..bc3602f 100644 --- a/icing/query/query-features.h +++ b/icing/query/query-features.h @@ -52,9 +52,15 @@ constexpr Feature kListFilterQueryLanguageFeature = constexpr Feature kHasPropertyFunctionFeature = "HAS_PROPERTY_FUNCTION"; // Features#HAS_PROPERTY_FUNCTION +// This feature relates to the use of embedding searches in the advanced query +// language. Ex. `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`. +constexpr Feature kEmbeddingSearchFeature = + "EMBEDDING_SEARCH"; // Features#EMBEDDING_SEARCH + inline std::unordered_set<Feature> GetQueryFeaturesSet() { return {kNumericSearchFeature, kVerbatimSearchFeature, - kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature}; + kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature, + kEmbeddingSearchFeature}; } } // namespace lib diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index bc0917b..6e13001 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -27,6 +27,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h" #include "icing/index/iterator/doc-hit-info-iterator-and.h" @@ -111,25 +112,28 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame( libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(numeric_index); + ICING_RETURN_ERROR_IF_NULL(embedding_index); ICING_RETURN_ERROR_IF_NULL(language_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); ICING_RETURN_ERROR_IF_NULL(clock); - return std::unique_ptr<QueryProcessor>( - new QueryProcessor(index, numeric_index, language_segmenter, normalizer, - document_store, schema_store, clock)); + return std::unique_ptr<QueryProcessor>(new QueryProcessor( + index, numeric_index, embedding_index, language_segmenter, normalizer, + document_store, schema_store, clock)); } QueryProcessor::QueryProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, @@ -137,6 +141,7 @@ QueryProcessor::QueryProcessor(Index* index, const Clock* clock) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), language_segmenter_(*language_segmenter), normalizer_(*normalizer), document_store_(*document_store), @@ -229,11 +234,12 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery( ranking_strategy == ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE; std::unique_ptr<Timer> query_visitor_timer = clock_.GetNewTimer(); - QueryVisitor query_visitor(&index_, &numeric_index_, &document_store_, - &schema_store_, &normalizer_, - plain_tokenizer.get(), search_spec.query(), - std::move(options), search_spec.term_match_type(), - needs_term_frequency_info, current_time_ms); + QueryVisitor query_visitor( + &index_, &numeric_index_, &embedding_index_, &document_store_, + &schema_store_, &normalizer_, plain_tokenizer.get(), search_spec.query(), + &search_spec.embedding_query_vectors(), std::move(options), + search_spec.term_match_type(), search_spec.embedding_query_metric_type(), + needs_term_frequency_info, current_time_ms); tree_root->Accept(&query_visitor); ICING_ASSIGN_OR_RETURN(QueryResults results, std::move(query_visitor).ConsumeResults()); diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h index de256ee..d90b5f6 100644 --- a/icing/query/query-processor.h +++ b/icing/query/query-processor.h @@ -19,6 +19,7 @@ #include <memory> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/numeric-index.h" #include "icing/proto/logging.pb.h" @@ -47,6 +48,7 @@ class QueryProcessor { // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store, const Clock* clock); @@ -74,6 +76,7 @@ class QueryProcessor { private: explicit QueryProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, @@ -110,6 +113,7 @@ class QueryProcessor { // query time. Index& index_; // Does not own. const NumericIndex<int64_t>& numeric_index_; // Does not own. + const EmbeddingIndex& embedding_index_; // Does not own. const LanguageSegmenter& language_segmenter_; // Does not own. const Normalizer& normalizer_; // Does not own. const DocumentStore& document_store_; // Does not own. diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc index 3be74fd..5b3bf99 100644 --- a/icing/query/query-processor_benchmark.cc +++ b/icing/query/query-processor_benchmark.cc @@ -12,26 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "testing/base/public/benchmark.h" -#include "gmock/gmock.h" #include "third_party/absl/flags/flag.h" +#include "icing/absl_ports/str_cat.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/dummy-numeric-index.h" -#include "icing/index/numeric/numeric-index.h" +#include "icing/legacy/index/icing-filesystem.h" #include "icing/proto/schema.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/term.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/query-results.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/store/document-id.h" +#include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "icing/util/clock.h" #include "icing/util/logging.h" #include "unicode/uloc.h" @@ -118,6 +132,7 @@ void BM_QueryOneTerm(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -134,6 +149,9 @@ void BM_QueryOneTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -172,8 +190,9 @@ void BM_QueryOneTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get(), &clock)); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -247,6 +266,7 @@ void BM_QueryFiveTerms(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -263,6 +283,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -315,8 +338,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get(), &clock)); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); const std::string query_string = absl_ports::StrCat( input_string_a, " ", input_string_b, " ", input_string_c, " ", @@ -394,6 +418,7 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -410,6 +435,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -451,8 +479,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get(), &clock)); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -526,6 +555,7 @@ void BM_QueryHiragana(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -542,6 +572,9 @@ void BM_QueryHiragana(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -583,8 +616,9 @@ void BM_QueryHiragana(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get(), &clock)); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index 43c9629..288c7fb 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -14,17 +14,24 @@ #include "icing/query/query-processor.h" +#include <array> #include <cstdint> #include <memory> #include <string> +#include <unordered_map> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -54,6 +61,8 @@ #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "icing/util/status-macros.h" #include "unicode/uloc.h" namespace icing { @@ -87,7 +96,8 @@ class QueryProcessorTest store_dir_(test_dir_ + "/store"), schema_store_dir_(test_dir_ + "/schema_store"), index_dir_(test_dir_ + "/index"), - numeric_index_dir_(test_dir_ + "/numeric_index") {} + numeric_index_dir_(test_dir_ + "/numeric_index"), + embedding_index_dir_(test_dir_ + "/embedding_index") {} void SetUp() override { filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -126,6 +136,9 @@ class QueryProcessorTest ICING_ASSERT_OK_AND_ASSIGN( numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); language_segmenter_factory::SegmenterOptions segmenter_options( ULOC_US, jni_cache_.get()); @@ -138,10 +151,10 @@ class QueryProcessorTest ICING_ASSERT_OK_AND_ASSIGN( query_processor_, - QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get(), - &fake_clock_)); + QueryProcessor::Create( + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); } libtextclassifier3::Status AddTokenToIndex( @@ -177,10 +190,12 @@ class QueryProcessorTest IcingFilesystem icing_filesystem_; const std::string index_dir_; const std::string numeric_index_dir_; + const std::string embedding_index_dir_; protected: std::unique_ptr<Index> index_; std::unique_ptr<NumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; FakeClock fake_clock_; @@ -191,42 +206,53 @@ class QueryProcessorTest }; TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) { - EXPECT_THAT(QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(), - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get(), &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr, - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get(), &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr, + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), - /*language_segmenter=*/nullptr, + /*embedding_index=*/nullptr, + language_segmenter_.get(), normalizer_.get(), document_store_.get(), schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(QueryProcessor::Create( - index_.get(), numeric_index_.get(), language_segmenter_.get(), - /*normalizer=*/nullptr, document_store_.get(), - schema_store_.get(), &fake_clock_), + index_.get(), numeric_index_.get(), embedding_index_.get(), + /*language_segmenter=*/nullptr, normalizer_.get(), + document_store_.get(), schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - /*document_store=*/nullptr, schema_store_.get(), - &fake_clock_), + embedding_index_.get(), language_segmenter_.get(), + /*normalizer=*/nullptr, document_store_.get(), + schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create( + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), + /*document_store=*/nullptr, schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(index_.get(), numeric_index_.get(), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + /*schema_store=*/nullptr, &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(index_.get(), numeric_index_.get(), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), /*clock=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - /*schema_store=*/nullptr, &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get(), /*clock=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) { @@ -2956,9 +2982,9 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> local_query_processor, QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get(), - &fake_clock_)); + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); SearchSpecProto search_spec; search_spec.set_query("hello"); @@ -3019,9 +3045,9 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> local_query_processor, QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get(), - &fake_clock_)); + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); SearchSpecProto search_spec; search_spec.set_query("hello"); diff --git a/icing/query/query-results.h b/icing/query/query-results.h index 52cdd71..983ab35 100644 --- a/icing/query/query-results.h +++ b/icing/query/query-results.h @@ -18,9 +18,10 @@ #include <memory> #include <unordered_set> +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator.h" -#include "icing/query/query-terms.h" #include "icing/query/query-features.h" +#include "icing/query/query-terms.h" namespace icing { namespace lib { @@ -35,6 +36,10 @@ struct QueryResults { // beginning with root_iterator. // This will only be populated when ranking_strategy == RELEVANCE_SCORE. QueryTermIteratorsMap query_term_iterators; + // Contains similarity scores from embedding based queries, which will be used + // in the advanced scoring language to determine the results for the + // "this.matchedSemanticScores(...)" function. + EmbeddingQueryResults embedding_query_results; // Features that are invoked during query execution. // The list of possible features is defined in query_features.h. std::unordered_set<Feature> features_in_use; diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc index 2d73c6b..dfebb98 100644 --- a/icing/query/suggestion-processor.cc +++ b/icing/query/suggestion-processor.cc @@ -26,11 +26,14 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/numeric/numeric-index.h" #include "icing/index/term-metadata.h" #include "icing/proto/search.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/query-results.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/store/document-filter-data.h" @@ -49,6 +52,7 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> SuggestionProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, @@ -56,15 +60,16 @@ SuggestionProcessor::Create(Index* index, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(numeric_index); + ICING_RETURN_ERROR_IF_NULL(embedding_index); ICING_RETURN_ERROR_IF_NULL(language_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); ICING_RETURN_ERROR_IF_NULL(clock); - return std::unique_ptr<SuggestionProcessor>( - new SuggestionProcessor(index, numeric_index, language_segmenter, - normalizer, document_store, schema_store, clock)); + return std::unique_ptr<SuggestionProcessor>(new SuggestionProcessor( + index, numeric_index, embedding_index, language_segmenter, normalizer, + document_store, schema_store, clock)); } libtextclassifier3::StatusOr< @@ -246,9 +251,9 @@ SuggestionProcessor::QuerySuggestions( ICING_ASSIGN_OR_RETURN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(&index_, &numeric_index_, &language_segmenter_, - &normalizer_, &document_store_, &schema_store_, - &clock_)); + QueryProcessor::Create(&index_, &numeric_index_, &embedding_index_, + &language_segmenter_, &normalizer_, + &document_store_, &schema_store_, &clock_)); SearchSpecProto search_spec; search_spec.set_query(suggestion_spec.prefix()); @@ -321,11 +326,13 @@ SuggestionProcessor::QuerySuggestions( SuggestionProcessor::SuggestionProcessor( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store, const Clock* clock) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), language_segmenter_(*language_segmenter), normalizer_(*normalizer), document_store_(*document_store), diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h index fe1437c..cf393b4 100644 --- a/icing/query/suggestion-processor.h +++ b/icing/query/suggestion-processor.h @@ -20,6 +20,7 @@ #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/numeric-index.h" #include "icing/proto/search.pb.h" @@ -46,6 +47,7 @@ class SuggestionProcessor { // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store, const Clock* clock); @@ -62,6 +64,7 @@ class SuggestionProcessor { private: explicit SuggestionProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, @@ -72,6 +75,7 @@ class SuggestionProcessor { // index. Index& index_; // Does not own. const NumericIndex<int64_t>& numeric_index_; // Does not own. + const EmbeddingIndex& embedding_index_; // Does not own. const LanguageSegmenter& language_segmenter_; // Does not own. const Normalizer& normalizer_; // Does not own. const DocumentStore& document_store_; // Does not own. diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc index 4c5e4ac..c9bdd8a 100644 --- a/icing/query/suggestion-processor_test.cc +++ b/icing/query/suggestion-processor_test.cc @@ -14,14 +14,30 @@ #include "icing/query/suggestion-processor.h" +#include <cstdint> +#include <memory> #include <string> +#include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/index.h" #include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/index/term-metadata.h" +#include "icing/jni/jni-cache.h" +#include "icing/legacy/index/icing-filesystem.h" +#include "icing/portable/platform.h" #include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" @@ -30,7 +46,9 @@ #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "unicode/uloc.h" namespace icing { @@ -59,7 +77,8 @@ class SuggestionProcessorTest : public Test { store_dir_(test_dir_ + "/store"), schema_store_dir_(test_dir_ + "/schema_store"), index_dir_(test_dir_ + "/index"), - numeric_index_dir_(test_dir_ + "/numeric_index") {} + numeric_index_dir_(test_dir_ + "/numeric_index"), + embedding_index_dir_(test_dir_ + "/embedding_index") {} void SetUp() override { filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -105,6 +124,9 @@ class SuggestionProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); language_segmenter_factory::SegmenterOptions segmenter_options( ULOC_US, jni_cache_.get()); @@ -117,10 +139,10 @@ class SuggestionProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( suggestion_processor_, - SuggestionProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get(), &fake_clock_)); + SuggestionProcessor::Create( + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); } libtextclassifier3::Status AddTokenToIndex( @@ -147,10 +169,12 @@ class SuggestionProcessorTest : public Test { IcingFilesystem icing_filesystem_; const std::string index_dir_; const std::string numeric_index_dir_; + const std::string embedding_index_dir_; protected: std::unique_ptr<Index> index_; std::unique_ptr<NumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; FakeClock fake_clock_; @@ -674,6 +698,16 @@ TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) { EXPECT_THAT(terms_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } + + if (SearchSpecProto::default_instance().search_type() == + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + suggestion_spec.set_prefix( + "bar OR semanticSearch(getSearchSpecEmbedding(0), 0.5, 1)"); + terms_or = suggestion_processor_->QuerySuggestions( + suggestion_spec, fake_clock_.GetSystemTimeMilliseconds()); + EXPECT_THAT(terms_or, + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + } } TEST_F(SuggestionProcessorTest, InvalidPrefixTest) { diff --git a/icing/schema-builder.h b/icing/schema-builder.h index c74505e..c702afe 100644 --- a/icing/schema-builder.h +++ b/icing/schema-builder.h @@ -56,6 +56,13 @@ constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_UNKNOWN = constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_RANGE = IntegerIndexingConfig::NumericMatchType::RANGE; +constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code + EMBEDDING_INDEXING_UNKNOWN = + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN; +constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code + EMBEDDING_INDEXING_LINEAR_SEARCH = + EmbeddingIndexingConfig::EmbeddingIndexingType::LINEAR_SEARCH; + constexpr PropertyConfigProto::DataType::Code TYPE_UNKNOWN = PropertyConfigProto::DataType::UNKNOWN; constexpr PropertyConfigProto::DataType::Code TYPE_STRING = @@ -70,6 +77,8 @@ constexpr PropertyConfigProto::DataType::Code TYPE_BYTES = PropertyConfigProto::DataType::BYTES; constexpr PropertyConfigProto::DataType::Code TYPE_DOCUMENT = PropertyConfigProto::DataType::DOCUMENT; +constexpr PropertyConfigProto::DataType::Code TYPE_VECTOR = + PropertyConfigProto::DataType::VECTOR; constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_NONE = JoinableConfig::ValueType::NONE; @@ -146,6 +155,15 @@ class PropertyConfigBuilder { return *this; } + PropertyConfigBuilder& SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType::Code + embedding_indexing_type) { + property_.set_data_type(PropertyConfigProto::DataType::VECTOR); + property_.mutable_embedding_indexing_config()->set_embedding_indexing_type( + embedding_indexing_type); + return *this; + } + PropertyConfigBuilder& SetJoinable( JoinableConfig::ValueType::Code join_value_type, bool propagate_delete) { property_.mutable_joinable_config()->set_value_type(join_value_type); @@ -159,6 +177,11 @@ class PropertyConfigBuilder { return *this; } + PropertyConfigBuilder& SetDescription(std::string description) { + property_.set_description(std::move(description)); + return *this; + } + PropertyConfigProto Build() const { return std::move(property_); } private: @@ -186,6 +209,11 @@ class SchemaTypeConfigBuilder { return *this; } + SchemaTypeConfigBuilder& SetDescription(std::string description) { + type_config_.set_description(std::move(description)); + return *this; + } + SchemaTypeConfigBuilder& AddProperty(PropertyConfigProto property) { *type_config_.add_properties() = std::move(property); return *this; diff --git a/icing/schema/property-util.cc b/icing/schema/property-util.cc index 67ff748..7c86122 100644 --- a/icing/schema/property-util.cc +++ b/icing/schema/property-util.cc @@ -14,6 +14,8 @@ #include "icing/schema/property-util.h" +#include <cstddef> +#include <cstdint> #include <string> #include <string_view> #include <vector> @@ -131,6 +133,14 @@ ExtractPropertyValues<int64_t>(const PropertyProto& property) { property.int64_values().end()); } +template <> +libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>> +ExtractPropertyValues<PropertyProto::VectorProto>( + const PropertyProto& property) { + return std::vector<PropertyProto::VectorProto>( + property.vector_values().begin(), property.vector_values().end()); +} + } // namespace property_util } // namespace lib diff --git a/icing/schema/property-util.h b/icing/schema/property-util.h index 7557879..c409a9d 100644 --- a/icing/schema/property-util.h +++ b/icing/schema/property-util.h @@ -15,8 +15,11 @@ #ifndef ICING_SCHEMA_PROPERTY_UTIL_H_ #define ICING_SCHEMA_PROPERTY_UTIL_H_ +#include <cstddef> +#include <cstdint> #include <string> #include <string_view> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -163,6 +166,11 @@ template <> libtextclassifier3::StatusOr<std::vector<int64_t>> ExtractPropertyValues<int64_t>(const PropertyProto& property); +template <> +libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>> +ExtractPropertyValues<PropertyProto::VectorProto>( + const PropertyProto& property); + template <typename T> libtextclassifier3::StatusOr<std::vector<T>> ExtractPropertyValuesFromDocument( const DocumentProto& document, std::string_view property_path) { diff --git a/icing/schema/property-util_test.cc b/icing/schema/property-util_test.cc index eddcc84..5e8a430 100644 --- a/icing/schema/property-util_test.cc +++ b/icing/schema/property-util_test.cc @@ -22,6 +22,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/portable/equals-proto.h" #include "icing/proto/document.pb.h" #include "icing/testing/common-matchers.h" @@ -30,6 +31,7 @@ namespace lib { namespace { +using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::ElementsAre; using ::testing::IsEmpty; @@ -38,6 +40,8 @@ static constexpr std::string_view kPropertySingleString = "singleString"; static constexpr std::string_view kPropertyRepeatedString = "repeatedString"; static constexpr std::string_view kPropertySingleInteger = "singleInteger"; static constexpr std::string_view kPropertyRepeatedInteger = "repeatedInteger"; +static constexpr std::string_view kPropertySingleVector = "singleVector"; +static constexpr std::string_view kPropertyRepeatedVector = "repeatedVector"; static constexpr std::string_view kTypeNestedTest = "NestedTest"; static constexpr std::string_view kPropertyStr = "str"; @@ -83,6 +87,28 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeInteger) { IsOkAndHolds(ElementsAre(123, -456, 0))); } +TEST(PropertyUtilTest, ExtractPropertyValuesTypeVector) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + + PropertyProto property; + *property.mutable_vector_values()->Add() = vector1; + *property.mutable_vector_values()->Add() = vector2; + + EXPECT_THAT( + property_util::ExtractPropertyValues<PropertyProto::VectorProto>( + property), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); +} + TEST(PropertyUtilTest, ExtractPropertyValuesMismatchedType) { PropertyProto property; property.mutable_int64_values()->Add(123); @@ -110,6 +136,16 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeUnimplemented) { } TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + DocumentProto document = DocumentBuilder() .SetKey("icing", "test/1") @@ -119,6 +155,9 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { "repeated2", "repeated3") .AddInt64Property(std::string(kPropertySingleInteger), 123) .AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3) + .AddVectorProperty(std::string(kPropertySingleVector), vector1) + .AddVectorProperty(std::string(kPropertyRepeatedVector), vector1, + vector2) .Build(); // Single string @@ -139,9 +178,33 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<int64_t>( document, /*property_path=*/kPropertyRepeatedInteger), IsOkAndHolds(ElementsAre(1, 2, 3))); + // Single vector + EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + document, /*property_path=*/kPropertySingleVector), + IsOkAndHolds(ElementsAre(EqualsProto(vector1)))); + // Repeated vector + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + document, /*property_path=*/kPropertyRepeatedVector), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); } TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + PropertyProto::VectorProto vector3; + vector3.set_model_signature("my_model3"); + vector3.add_values(1.0f); + DocumentProto nested_document = DocumentBuilder() .SetKey("icing", "nested/1") @@ -158,6 +221,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { .AddInt64Property(std::string(kPropertySingleInteger), 123) .AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3) + .AddVectorProperty(std::string(kPropertySingleVector), + vector1) + .AddVectorProperty(std::string(kPropertyRepeatedVector), + vector1, vector2) .Build(), DocumentBuilder() .SetSchema(std::string(kTypeTest)) @@ -168,6 +235,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { .AddInt64Property(std::string(kPropertySingleInteger), 456) .AddInt64Property(std::string(kPropertyRepeatedInteger), 4, 5, 6) + .AddVectorProperty(std::string(kPropertySingleVector), + vector2) + .AddVectorProperty(std::string(kPropertyRepeatedVector), + vector2, vector3) .Build()) .Build(); @@ -189,6 +260,17 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { property_util::ExtractPropertyValuesFromDocument<int64_t>( nested_document, /*property_path=*/"nestedDocument.repeatedInteger"), IsOkAndHolds(ElementsAre(1, 2, 3, 4, 5, 6))); + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + nested_document, /*property_path=*/"nestedDocument.singleVector"), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + nested_document, /*property_path=*/"nestedDocument.repeatedVector"), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2), + EqualsProto(vector2), EqualsProto(vector3)))); // Test the property at first level EXPECT_THAT( diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc index 8cc7008..ca5cdd3 100644 --- a/icing/schema/schema-store_test.cc +++ b/icing/schema/schema-store_test.cc @@ -14,8 +14,10 @@ #include "icing/schema/schema-store.h" +#include <cstdint> #include <memory> #include <string> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" @@ -23,6 +25,7 @@ #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" #include "icing/document-builder.h" +#include "icing/file/file-backed-proto.h" #include "icing/file/filesystem.h" #include "icing/file/mock-filesystem.h" #include "icing/file/version-util.h" @@ -32,10 +35,7 @@ #include "icing/proto/logging.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/storage.pb.h" -#include "icing/proto/term.pb.h" #include "icing/schema-builder.h" -#include "icing/schema/schema-util.h" -#include "icing/schema/section-manager.h" #include "icing/schema/section.h" #include "icing/store/document-filter-data.h" #include "icing/testing/common-matchers.h" @@ -142,7 +142,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveConstructible) { IsOkAndHolds(Eq(expected_checksum))); SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN, - "prop1"); + EMBEDDING_INDEXING_UNKNOWN, "prop1"); EXPECT_THAT(move_constructed_schema_store.GetSectionMetadata("type_a"), IsOkAndHolds(Pointee(ElementsAre(expected_metadata)))); } @@ -193,7 +193,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveAssignment) { IsOkAndHolds(Eq(expected_checksum))); SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN, - "prop1"); + EMBEDDING_INDEXING_UNKNOWN, "prop1"); EXPECT_THAT(move_assigned_schema_store->GetSectionMetadata("type_a"), IsOkAndHolds(Pointee(ElementsAre(expected_metadata)))); } @@ -1527,10 +1527,10 @@ TEST_F(SchemaStoreTest, SetSchemaRegenerateDerivedFilesFailure) { .Build(); SectionMetadata expected_int_prop1_metadata( /*id_in=*/0, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, "intProp1"); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, "intProp1"); SectionMetadata expected_string_prop1_metadata( /*id_in=*/1, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, - NUMERIC_MATCH_UNKNOWN, "stringProp1"); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, "stringProp1"); ICING_ASSERT_OK_AND_ASSIGN(SectionGroup section_group, schema_store->ExtractSections(document)); ASSERT_THAT(section_group.string_sections, SizeIs(1)); diff --git a/icing/schema/schema-util.cc b/icing/schema/schema-util.cc index 72287a8..976d1b7 100644 --- a/icing/schema/schema-util.cc +++ b/icing/schema/schema-util.cc @@ -115,6 +115,12 @@ bool IsIntegerNumericMatchTypeCompatible( return old_indexed.numeric_match_type() == new_indexed.numeric_match_type(); } +bool IsEmbeddingIndexingCompatible(const EmbeddingIndexingConfig& old_indexed, + const EmbeddingIndexingConfig& new_indexed) { + return old_indexed.embedding_indexing_type() == + new_indexed.embedding_indexing_type(); +} + bool IsDocumentIndexingCompatible(const DocumentIndexingConfig& old_indexed, const DocumentIndexingConfig& new_indexed) { // TODO(b/265304217): This could mark the new schema as incompatible and @@ -824,6 +830,10 @@ libtextclassifier3::Status SchemaUtil::ValidateDocumentIndexingConfig( !property_config.document_indexing_config() .indexable_nested_properties_list() .empty(); + case PropertyConfigProto::DataType::VECTOR: + return property_config.embedding_indexing_config() + .embedding_indexing_type() != + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN; case PropertyConfigProto::DataType::UNKNOWN: case PropertyConfigProto::DataType::DOUBLE: case PropertyConfigProto::DataType::BOOLEAN: @@ -1087,7 +1097,10 @@ const SchemaUtil::SchemaDelta SchemaUtil::ComputeCompatibilityDelta( new_property_config->integer_indexing_config()) || !IsDocumentIndexingCompatible( old_property_config.document_indexing_config(), - new_property_config->document_indexing_config())) { + new_property_config->document_indexing_config()) || + !IsEmbeddingIndexingCompatible( + old_property_config.embedding_indexing_config(), + new_property_config->embedding_indexing_config())) { is_index_incompatible = true; } diff --git a/icing/schema/schema-util_test.cc b/icing/schema/schema-util_test.cc index 82683ba..e77ffab 100644 --- a/icing/schema/schema-util_test.cc +++ b/icing/schema/schema-util_test.cc @@ -2900,6 +2900,118 @@ TEST_P(SchemaUtilTest, IsEmpty()); } +TEST_P(SchemaUtilTest, ChangingIndexedVectorPropertiesMakesIndexIncompatible) { + SchemaProto schema_with_indexed_property = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaProto schema_with_unindexed_property = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty( + PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::SchemaDelta schema_delta; + schema_delta.schema_types_index_incompatible.insert(kPersonType); + + // New schema gained a new indexed vector property. + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta( + schema_with_unindexed_property, schema_with_indexed_property, + no_dependents_map), + Eq(schema_delta)); + + // New schema lost an indexed vector property. + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta( + schema_with_indexed_property, schema_with_unindexed_property, + no_dependents_map), + Eq(schema_delta)); +} + +TEST_P(SchemaUtilTest, AddingNewIndexedVectorPropertyMakesIndexIncompatible) { + // Configure old schema + SchemaProto old_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Configure new schema + SchemaProto new_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("NewEmbeddingProperty") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::SchemaDelta schema_delta; + schema_delta.schema_types_index_incompatible.insert(kPersonType); + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema, + no_dependents_map), + Eq(schema_delta)); +} + +TEST_P(SchemaUtilTest, + AddingNewNonIndexedVectorPropertyShouldRemainIndexCompatible) { + // Configure old schema + SchemaProto old_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Configure new schema + SchemaProto new_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("NewProperty") + .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema, + no_dependents_map) + .schema_types_index_incompatible, + IsEmpty()); +} + TEST_P(SchemaUtilTest, AddingNewIndexedDocumentPropertyMakesIndexAndJoinIncompatible) { SchemaTypeConfigProto nested_schema = diff --git a/icing/schema/section-manager.cc b/icing/schema/section-manager.cc index 3d540d6..8689bf2 100644 --- a/icing/schema/section-manager.cc +++ b/icing/schema/section-manager.cc @@ -61,6 +61,7 @@ libtextclassifier3::Status AppendNewSectionMetadata( property_config.string_indexing_config().tokenizer_type(), property_config.string_indexing_config().term_match_type(), property_config.integer_indexing_config().numeric_match_type(), + property_config.embedding_indexing_config().embedding_indexing_type(), std::move(concatenated_path))); return libtextclassifier3::Status::OK; } @@ -162,6 +163,19 @@ libtextclassifier3::StatusOr<SectionGroup> SectionManager::ExtractSections( section_group.integer_sections); break; } + case PropertyConfigProto::DataType::VECTOR: { + if (section_metadata.embedding_indexing_type == + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN) { + // Skip if embedding indexing type is UNKNOWN. + break; + } + AppendSection( + section_metadata, + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>(document, section_metadata.path), + section_group.vector_sections); + break; + } default: { // Skip other data types. break; diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc index eee78e9..b735fb1 100644 --- a/icing/schema/section-manager_test.cc +++ b/icing/schema/section-manager_test.cc @@ -14,10 +14,12 @@ #include "icing/schema/section-manager.h" +#include <cstdint> #include <memory> #include <string> #include <string_view> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -27,6 +29,8 @@ #include "icing/schema-builder.h" #include "icing/schema/schema-type-manager.h" #include "icing/schema/schema-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-filter-data.h" #include "icing/store/dynamic-trie-key-mapper.h" #include "icing/store/key-mapper.h" #include "icing/testing/common-matchers.h" @@ -37,6 +41,7 @@ namespace lib { namespace { +using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Pointee; @@ -48,6 +53,8 @@ static constexpr std::string_view kTypeEmail = "Email"; static constexpr std::string_view kPropertyRecipientIds = "recipientIds"; static constexpr std::string_view kPropertyRecipients = "recipients"; static constexpr std::string_view kPropertySubject = "subject"; +static constexpr std::string_view kPropertySubjectEmbedding = + "subjectEmbedding"; static constexpr std::string_view kPropertyTimestamp = "timestamp"; // non-indexable static constexpr std::string_view kPropertyAttachment = "attachment"; @@ -109,6 +116,14 @@ PropertyConfigProto CreateSubjectPropertyConfig() { .Build(); } +PropertyConfigProto CreateSubjectEmbeddingPropertyConfig() { + return PropertyConfigBuilder() + .SetName(kPropertySubjectEmbedding) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); +} + PropertyConfigProto CreateTimestampPropertyConfig() { return PropertyConfigBuilder() .SetName(kPropertyTimestamp) @@ -145,6 +160,7 @@ SchemaTypeConfigProto CreateEmailTypeConfig() { return SchemaTypeConfigBuilder() .SetType(kTypeEmail) .AddProperty(CreateSubjectPropertyConfig()) + .AddProperty(CreateSubjectEmbeddingPropertyConfig()) .AddProperty(PropertyConfigBuilder() .SetName(kPropertyText) .SetDataTypeString(TERM_MATCH_UNKNOWN, TOKENIZER_NONE) @@ -220,11 +236,19 @@ class SectionManagerTest : public ::testing::Test { ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeConversation, 1)); ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeGroup, 2)); + email_subject_embedding_ = PropertyProto::VectorProto(); + email_subject_embedding_.add_values(1.0); + email_subject_embedding_.add_values(2.0); + email_subject_embedding_.add_values(3.0); + email_subject_embedding_.set_model_signature("my_model"); + email_document_ = DocumentBuilder() .SetKey("icing", "email/1") .SetSchema(std::string(kTypeEmail)) .AddStringProperty(std::string(kPropertySubject), "the subject") + .AddVectorProperty(std::string(kPropertySubjectEmbedding), + email_subject_embedding_) .AddStringProperty(std::string(kPropertyText), "the text") .AddBytesProperty(std::string(kPropertyAttachment), "attachment bytes") @@ -265,6 +289,7 @@ class SectionManagerTest : public ::testing::Test { SchemaUtil::TypeConfigMap type_config_map_; std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper_; + PropertyProto::VectorProto email_subject_embedding_; DocumentProto email_document_; DocumentProto conversation_document_; DocumentProto group_document_; @@ -308,11 +333,21 @@ TEST_F(SectionManagerTest, ExtractSections) { EXPECT_THAT(section_group.integer_sections[0].content, ElementsAre(1, 2, 3)); EXPECT_THAT(section_group.integer_sections[1].metadata, - EqualsSectionMetadata(/*expected_id=*/3, + EqualsSectionMetadata(/*expected_id=*/4, /*expected_property_path=*/"timestamp", CreateTimestampPropertyConfig())); EXPECT_THAT(section_group.integer_sections[1].content, ElementsAre(kDefaultTimestamp)); + + // Vector sections + EXPECT_THAT(section_group.vector_sections, SizeIs(1)); + EXPECT_THAT( + section_group.vector_sections[0].metadata, + EqualsSectionMetadata(/*expected_id=*/3, + /*expected_property_path=*/"subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())); + EXPECT_THAT(section_group.vector_sections[0].content, + ElementsAre(EqualsProto(email_subject_embedding_))); } TEST_F(SectionManagerTest, ExtractSectionsNested) { @@ -359,11 +394,22 @@ TEST_F(SectionManagerTest, ExtractSectionsNested) { EXPECT_THAT( section_group.integer_sections[1].metadata, - EqualsSectionMetadata(/*expected_id=*/3, + EqualsSectionMetadata(/*expected_id=*/4, /*expected_property_path=*/"emails.timestamp", CreateTimestampPropertyConfig())); EXPECT_THAT(section_group.integer_sections[1].content, ElementsAre(kDefaultTimestamp, kDefaultTimestamp)); + + // Vector sections + EXPECT_THAT(section_group.vector_sections, SizeIs(1)); + EXPECT_THAT(section_group.vector_sections[0].metadata, + EqualsSectionMetadata( + /*expected_id=*/3, + /*expected_property_path=*/"emails.subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())); + EXPECT_THAT(section_group.vector_sections[0].content, + ElementsAre(EqualsProto(email_subject_embedding_), + EqualsProto(email_subject_embedding_))); } TEST_F(SectionManagerTest, ExtractSectionsIndexableNestedPropertiesList) { @@ -461,7 +507,8 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { // 0 -> recipientIds // 1 -> recipients // 2 -> subject - // 3 -> timestamp + // 3 -> subjectEmbedding + // 4 -> timestamp EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/0, /*section_id=*/0), IsOkAndHolds(Pointee(EqualsSectionMetadata( @@ -472,13 +519,30 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { IsOkAndHolds(Pointee(EqualsSectionMetadata( /*expected_id=*/1, /*expected_property_path=*/"recipients", CreateRecipientsPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/2), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/2, /*expected_property_path=*/"subject", + CreateSubjectPropertyConfig())))); + EXPECT_THAT( + schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/3), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/3, /*expected_property_path=*/"subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/4), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/4, /*expected_property_path=*/"timestamp", + CreateTimestampPropertyConfig())))); // Conversation (section id -> section property path): // 0 -> emails.recipientIds // 1 -> emails.recipients // 2 -> emails.subject - // 3 -> emails.timestamp - // 4 -> name + // 3 -> emails.subjectEmbedding + // 4 -> emails.timestamp + // 5 -> name EXPECT_THAT( schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/1, /*section_id=*/0), @@ -497,16 +561,22 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { IsOkAndHolds(Pointee(EqualsSectionMetadata( /*expected_id=*/2, /*expected_property_path=*/"emails.subject", CreateSubjectPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/1, /*section_id=*/3), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/3, + /*expected_property_path=*/"emails.subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())))); EXPECT_THAT( schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/3), + /*schema_type_id=*/1, /*section_id=*/4), IsOkAndHolds(Pointee(EqualsSectionMetadata( - /*expected_id=*/3, /*expected_property_path=*/"emails.timestamp", + /*expected_id=*/4, /*expected_property_path=*/"emails.timestamp", CreateTimestampPropertyConfig())))); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/4), + /*schema_type_id=*/1, /*section_id=*/5), IsOkAndHolds(Pointee(EqualsSectionMetadata( - /*expected_id=*/4, /*expected_property_path=*/"name", + /*expected_id=*/5, /*expected_property_path=*/"name", CreateNamePropertyConfig())))); // Group (section id -> section property path): @@ -615,25 +685,27 @@ TEST_F(SectionManagerTest, GetSectionMetadataInvalidSectionId) { // 0 -> recipientIds // 1 -> recipients // 2 -> subject - // 3 -> timestamp + // 3 -> subjectEmbedding + // 4 -> timestamp EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/0, /*section_id=*/-1), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/0, /*section_id=*/4), + /*schema_type_id=*/0, /*section_id=*/5), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // Conversation (section id -> section property path): // 0 -> emails.recipientIds // 1 -> emails.recipients // 2 -> emails.subject - // 3 -> emails.timestamp - // 4 -> name + // 3 -> emails.subjectEmbedding + // 4 -> emails.timestamp + // 5 -> name EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/1, /*section_id=*/-1), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/5), + /*schema_type_id=*/1, /*section_id=*/6), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } diff --git a/icing/schema/section.h b/icing/schema/section.h index 3685a29..76e9e57 100644 --- a/icing/schema/section.h +++ b/icing/schema/section.h @@ -21,6 +21,7 @@ #include <utility> #include <vector> +#include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/term.pb.h" @@ -89,18 +90,31 @@ struct SectionMetadata { // Contents will be matched by a range query. IntegerIndexingConfig::NumericMatchType::Code numeric_match_type; + // How vectors in a vector section should be indexed. + // + // EmbeddingIndexingType::UNKNOWN: + // Contents will not be indexed. It is invalid for a vector section + // (data_type == 'VECTOR') to have embedding_indexing_type == 'UNKNOWN'. + // + // EmbeddingIndexingType::LINEAR_SEARCH: + // Contents will be indexed for linear search. + EmbeddingIndexingConfig::EmbeddingIndexingType::Code embedding_indexing_type; + explicit SectionMetadata( SectionId id_in, PropertyConfigProto::DataType::Code data_type_in, StringIndexingConfig::TokenizerType::Code tokenizer, TermMatchType::Code term_match_type_in, IntegerIndexingConfig::NumericMatchType::Code numeric_match_type_in, + EmbeddingIndexingConfig::EmbeddingIndexingType::Code + embedding_indexing_type_in, std::string&& path_in) : path(std::move(path_in)), id(id_in), data_type(data_type_in), tokenizer(tokenizer), term_match_type(term_match_type_in), - numeric_match_type(numeric_match_type_in) {} + numeric_match_type(numeric_match_type_in), + embedding_indexing_type(embedding_indexing_type_in) {} SectionMetadata(const SectionMetadata& other) = default; SectionMetadata& operator=(const SectionMetadata& other) = default; @@ -144,6 +158,7 @@ struct Section { struct SectionGroup { std::vector<Section<std::string_view>> string_sections; std::vector<Section<int64_t>> integer_sections; + std::vector<Section<PropertyProto::VectorProto>> vector_sections; }; } // namespace lib diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc index 83c1519..e375a8e 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer.cc @@ -14,14 +14,25 @@ #include "icing/scoring/advanced_scoring/advanced-scorer.h" +#include <cstdint> #include <memory> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/advanced_scoring/scoring-visitor.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/section-weights.h" +#include "icing/store/document-store.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -29,11 +40,15 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher) { + const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); Lexer lexer(scoring_spec.advanced_scoring_expression(), Lexer::Language::SCORING); @@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, std::unique_ptr<Bm25fCalculator> bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store, section_weights.get(), current_time_ms); - ScoringVisitor visitor(default_score, document_store, schema_store, - section_weights.get(), bm25f_calculator.get(), - join_children_fetcher, current_time_ms); + ScoringVisitor visitor(default_score, default_semantic_metric_type, + document_store, schema_store, section_weights.get(), + bm25f_calculator.get(), join_children_fetcher, + embedding_query_results, current_time_ms); tree_root->Accept(&visitor); ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression, diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h index d69abad..00477a3 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.h +++ b/icing/scoring/advanced_scoring/advanced-scorer.h @@ -15,17 +15,26 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ #define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ +#include <cstdint> #include <memory> +#include <string> +#include <unordered_map> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-store.h" +#include "icing/util/logging.h" namespace icing { namespace lib { @@ -38,9 +47,11 @@ class AdvancedScorer : public Scorer { // INVALID_ARGUMENT if fails to create an instance static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); double GetScore(const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override { @@ -63,7 +74,7 @@ class AdvancedScorer : public Scorer { bm25f_calculator_->PrepareToScore(query_term_iterators); } - bool is_constant() const { return score_expression_->is_constant_double(); } + bool is_constant() const { return score_expression_->is_constant(); } private: explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, diff --git a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc index 3612359..ca081cd 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc @@ -12,11 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <cstddef> #include <cstdint> #include <memory> +#include <string> #include <string_view> +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/advanced-scorer.h" +#include "icing/store/document-store.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" @@ -29,6 +36,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { const std::string test_dir = GetTestTempDir() + "/icing"; const std::string doc_store_dir = test_dir + "/doc_store"; const std::string schema_store_dir = test_dir + "/schema_store"; + EmbeddingQueryResults empty_embedding_query_results_; filesystem.DeleteDirectoryRecursively(test_dir.c_str()); filesystem.CreateDirectoryRecursively(doc_store_dir.c_str()); filesystem.CreateDirectoryRecursively(schema_store_dir.c_str()); @@ -54,9 +62,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { scoring_spec.set_advanced_scoring_expression(text); AdvancedScorer::Create(scoring_spec, - /*default_score=*/10, document_store.get(), - schema_store.get(), - fake_clock.GetSystemTimeMilliseconds()); + /*default_score=*/10, + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + fake_clock.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_); // Not able to test the GetScore method of AdvancedScorer, since it will only // be available after AdvancedScorer is successfully created. However, the diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc index cc1d413..4b9a46c 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc @@ -15,14 +15,22 @@ #include "icing/scoring/advanced_scoring/advanced-scorer.h" #include <cmath> +#include <cstdint> #include <memory> #include <string> #include <string_view> +#include <unordered_map> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/join/join-children-fetcher.h" #include "icing/proto/document.pb.h" @@ -31,6 +39,8 @@ #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer.h" #include "icing/store/document-id.h" @@ -45,6 +55,7 @@ namespace lib { namespace { using ::testing::DoubleNear; using ::testing::Eq; +using ::testing::HasSubstr; class AdvancedScorerTest : public testing::Test { protected: @@ -124,15 +135,19 @@ class AdvancedScorerTest : public testing::Test { const std::string test_dir_; const std::string doc_store_dir_; const std::string schema_store_dir_; + EmbeddingQueryResults empty_embedding_query_results_; Filesystem filesystem_; std::unique_ptr<SchemaStore> schema_store_; std::unique_ptr<DocumentStore> document_store_; FakeClock fake_clock_; }; -constexpr double kEps = 0.0000000001; +constexpr double kEps = 0.0000001; constexpr int kDefaultScore = 0; constexpr int64_t kDefaultCreationTimestampMs = 1571100001111; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + kDefaultSemanticMetricType = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; DocumentProto CreateDocument( const std::string& name_space, const std::string& uri, @@ -193,8 +208,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) { scoring_spec.set_rank_by( ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10, + kDefaultSemanticMetricType, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // Non-empty scoring expression for normal scoring @@ -202,8 +220,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); scoring_spec.set_advanced_scoring_expression("1"); EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10, + kDefaultSemanticMetricType, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -215,9 +236,11 @@ TEST_F(AdvancedScorerTest, SimpleExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("123"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -233,44 +256,61 @@ TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5)); } @@ -283,103 +323,131 @@ TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("log(2.718281828459045)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("len(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sum(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10 + 11 + 12 + 13 + 14)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("avg(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq((10 + 11 + 12 + 13 + 14) / 5.)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sin(3.141592653589793)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("cos(3.141592653589793)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); } @@ -394,17 +462,21 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("this.documentScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.creationTimestamp()"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(kDefaultCreationTimestampMs)); ICING_ASSERT_OK_AND_ASSIGN( @@ -412,8 +484,10 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) { AdvancedScorer::Create( CreateAdvancedScoringSpec( "this.documentScore() + this.creationTimestamp()"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123 + kDefaultCreationTimestampMs)); } @@ -429,8 +503,10 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) { AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageCount(1) + this.usageCount(2) " "+ this.usageLastUsedTimestamp(3)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); ICING_ASSERT_OK(document_store_->ReportUsage( CreateUsageReport("namespace", "uri", 100000, UsageReport::USAGE_TYPE1))); @@ -446,22 +522,28 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(1)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(100000)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(2)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(200000)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(3)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300000)); } @@ -478,24 +560,29 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionOutOfRange) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("this.usageCount(4)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(4)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount(0)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(0)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount(1.5)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(1.5)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); } @@ -516,9 +603,11 @@ TEST_F(AdvancedScorerTest, RelevanceScoreFunctionScoreExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("this.relevanceScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); scorer->PrepareToScore(/*query_term_iterators=*/{}); // Should get the default score. @@ -566,8 +655,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(2)); // document_id_2 has one child. @@ -579,8 +669,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sum(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3)); // document_id_2 has one child with score 4. @@ -592,8 +683,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("avg(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.)); // document_id_2 has one child with score 4. @@ -604,13 +696,15 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec( - // Equivalent to "avg(this.childrenRankingSignals())" - "sum(this.childrenRankingSignals()) / " - "len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + // Equivalent to "avg(this.childrenRankingSignals())" + "sum(this.childrenRankingSignals()) / " + "len(this.childrenRankingSignals())"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.)); // document_id_2 has one child with score 4. @@ -670,9 +764,11 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // min([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // min([0.5, 0.8]) = 0.5 @@ -682,10 +778,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // max([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // max([0.5, 0.8]) = 0.8 @@ -695,10 +794,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // sum([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // sum([0.5, 0.8]) = 1.3 @@ -747,9 +849,11 @@ TEST_F(AdvancedScorerTest, ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // min([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // min([0.5, 1, 0.5]) = 0.5 @@ -757,10 +861,13 @@ TEST_F(AdvancedScorerTest, spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // max([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // max([0.5, 1, 0.5]) = 1 @@ -768,10 +875,13 @@ TEST_F(AdvancedScorerTest, spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // sum([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // sum([0.5, 1, 0.5]) = 2 @@ -786,20 +896,22 @@ TEST_F(AdvancedScorerTest, InvalidChildrenScoresFunctionScoreExpression) { EXPECT_THAT( AdvancedScorer::Create( CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), - /*join_children_fetcher=*/nullptr), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // The root expression can only be of double type, but here it is of list // type. JoinChildrenFetcher fake_fetcher(JoinSpecProto::default_instance(), /*map_joinable_qualified_id=*/{}); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.childrenRankingSignals()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fake_fetcher), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.childrenRankingSignals()"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fake_fetcher, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, ComplexExpression) { @@ -821,9 +933,11 @@ TEST_F(AdvancedScorerTest, ComplexExpression) { "+ 10 * (2 + 10 + this.creationTimestamp()))" // This should evaluate to default score. "+ this.relevanceScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_FALSE(scorer->is_constant()); scorer->PrepareToScore(/*query_term_iterators=*/{}); @@ -847,19 +961,24 @@ TEST_F(AdvancedScorerTest, ConstantExpression) { "pow(sin(2), 2)" "+ log(2, 122) / 12.34" "* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_TRUE(scorer->is_constant()); } // Should be a parsing Error TEST_F(AdvancedScorerTest, EmptyExpression) { - EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec(""), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec(""), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { @@ -872,30 +991,38 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + AdvancedScorer::Create( + CreateAdvancedScoringSpec("log(0)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); } @@ -904,133 +1031,402 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { TEST_F(AdvancedScorerTest, MathTypeError) { const double default_score = 0; - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, DocumentFunctionTypeError) { const double default_score = 0; - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("documentScore(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.creationTimestamp(1)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.creationTimestamp(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount()"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("relevanceScore(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(this)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("that.documentScore()"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.this.creationTimestamp()"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount()"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), +} + +TEST_F(AdvancedScorerTest, + MatchedSemanticScoresFunctionScoreExpressionTypeError) { + libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> scorer_or = + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(matchedSemanticScores(getSearchSpecEmbedding(0)))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("not called with \"this\"")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("relevanceScore(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("got invalid argument type for embedding vector")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0), 0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("documentScore(this)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("Embedding metric can only be given as a string")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSINE\", 0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("that.documentScore()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("got invalid number of arguments")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSIGN\"))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.this.creationTimestamp()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("Unknown metric type: COSIGN")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(\"0\")))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("getSearchSpecEmbedding got invalid argument type")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding()))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("getSearchSpecEmbedding must have 1 argument")); +} + +void AddEntryToEmbeddingQueryScoreMap( + EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map, + double semantic_score, DocumentId document_id) { + score_map[document_id].push_back(semantic_score); +} + +TEST_F(AdvancedScorerTest, MatchedSemanticScoresFunctionScoreExpression) { + DocumentId document_id_0 = 0; + DocumentId document_id_1 = 1; + DocHitInfo doc_hit_info_0(document_id_0); + DocHitInfo doc_hit_info_1(document_id_1); + EmbeddingQueryResults embedding_query_results; + + // Let the first query assign the following semantic scores: + // COSINE: + // Document 0: 0.1, 0.2 + // Document 1: 0.3, 0.4 + // DOT_PRODUCT: + // Document 0: 0.5 + // Document 1: 0.6 + // EUCLIDEAN: + // Document 0: 0.7 + // Document 1: 0.8 + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map = + &embedding_query_results + .result_scores[0][SearchSpecProto::EmbeddingQueryMetricType::COSINE]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.1, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.2, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.3, document_id_1); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.4, document_id_1); + score_map = &embedding_query_results.result_scores + [0][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.5, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.6, document_id_1); + score_map = + &embedding_query_results + .result_scores[0] + [SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.7, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.8, document_id_1); + + // Let the second query only assign DOT_PRODUCT scores: + // DOT_PRODUCT: + // Document 0: 0.1 + // Document 1: 0.2 + score_map = &embedding_query_results.result_scores + [1][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.1, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.2, document_id_1); + + // Get semantic scores for default metric (DOT_PRODUCT) for the first query. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.5, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.6, kEps)); + + // Get semantic scores for a metric overriding the default one for the first + // query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSINE\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1 + 0.2, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.3 + 0.4, kEps)); + + // Get semantic scores for multiple metrics for the first query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"COSINE\")) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"DOT_PRODUCT\")) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"EUCLIDEAN\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), + DoubleNear(0.1 + 0.2 + 0.5 + 0.7, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), + DoubleNear(0.3 + 0.4 + 0.6 + 0.8, kEps)); + + // Get semantic scores for the second query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.2, kEps)); + + // The second query does not contain cosine scores. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(1), \"COSINE\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0, kEps)); } } // namespace diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc index e8a2a89..687180a 100644 --- a/icing/scoring/advanced_scoring/score-expression.cc +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -14,10 +14,39 @@ #include "icing/scoring/advanced_scoring/score-expression.h" +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <cstdlib> +#include <memory> #include <numeric> +#include <optional> +#include <string> +#include <string_view> +#include <unordered_map> +#include <unordered_set> +#include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/schema/section.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/util/embedding-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -49,7 +78,7 @@ OperatorScoreExpression::Create( return absl_ports::InvalidArgumentError( "Operators are only supported for double type."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { children_all_constant_double = false; } } @@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create( "Got an invalid type for the math function. Should expect a double " "type argument."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { args_all_constant_double = false; } } @@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId( return filter_data_optional.value().schema_type_id(); } +libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> +GetSearchSpecEmbeddingFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args) { + if (args.size() != 1) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " must have 1 argument.")); + } + if (args[0]->type() != ScoreExpressionType::kDouble) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid argument type.")); + } + bool is_constant = args[0]->is_constant(); + std::unique_ptr<ScoreExpression> expression = + std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>( + new GetSearchSpecEmbeddingFunctionScoreExpression( + std::move(args[0]))); + if (is_constant) { + return ConstantScoreExpression::Create( + expression->eval(DocHitInfo(), /*query_it=*/nullptr), + expression->type()); + } + return expression; +} + +libtextclassifier3::StatusOr<double> +GetSearchSpecEmbeddingFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + arg_->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + if (query_index != raw_query_index) { + return absl_ports::InvalidArgumentError( + "The index of an embedding query must be an integer."); + } + return query_index; +} + +libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> +MatchedSemanticScoresFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results) { + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); + ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args)); + + if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " is not called with \"this\"")); + } + if (args.size() != 2 && args.size() != 3) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid number of arguments.")); + } + if (args[1]->type() != ScoreExpressionType::kVectorIndex) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + kFunctionName, " got invalid argument type for embedding vector.")); + } + if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a string."); + } + + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + default_metric_type; + if (args.size() == 3) { + if (!args[2]->is_constant()) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a constant string."); + } + ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string()); + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>( + new MatchedSemanticScoresFunctionScoreExpression( + std::move(args), metric_type, *embedding_query_results)); +} + +libtextclassifier3::StatusOr<std::vector<double>> +MatchedSemanticScoresFunctionScoreExpression::eval_list( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + args_[1]->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + const std::vector<double>* scores = + embedding_query_results_.GetMatchedScoresForDocument( + query_index, metric_type_, hit_info.document_id()); + if (scores == nullptr) { + return std::vector<double>(); + } + return *scores; +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h index 08d7997..e28fcd7 100644 --- a/icing/scoring/advanced_scoring/score-expression.h +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -15,20 +15,26 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ -#include <algorithm> -#include <cmath> +#include <cstdint> #include <memory> +#include <string> +#include <string_view> #include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" -#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -36,7 +42,11 @@ namespace lib { enum class ScoreExpressionType { kDouble, kDoubleList, - kDocument // Only "this" is considered as document type. + kDocument, // Only "this" is considered as document type. + // TODO(b/326656531): Instead of creating a vector index type, consider + // changing it to vector type so that the data is the vector directly. + kVectorIndex, + kString, }; class ScoreExpression { @@ -75,12 +85,24 @@ class ScoreExpression { "checking."); } + virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const { + if (type() == ScoreExpressionType::kString) { + return absl_ports::UnimplementedError( + "All ScoreExpressions of type string must provide their own " + "implementation of eval_string!"); + } + return absl_ports::InternalError( + "Runtime type error: the expression should never be evaluated to a " + "string. There must be inconsistencies in the static type checking."); + } + // Indicate the type to which the current expression will be evaluated. virtual ScoreExpressionType type() const = 0; - // Indicate whether the current expression is a constant double. - // Returns true if and only if the object is of ConstantScoreExpression type. - virtual bool is_constant_double() const { return false; } + // Indicate whether the current expression is a constant. + // Returns true if and only if the object is of ConstantScoreExpression or + // StringExpression type. + virtual bool is_constant() const { return false; } }; class ThisExpression : public ScoreExpression { @@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression { class ConstantScoreExpression : public ScoreExpression { public: static std::unique_ptr<ConstantScoreExpression> Create( - libtextclassifier3::StatusOr<double> c) { + libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type = ScoreExpressionType::kDouble) { return std::unique_ptr<ConstantScoreExpression>( - new ConstantScoreExpression(c)); + new ConstantScoreExpression(c, type)); } libtextclassifier3::StatusOr<double> eval( @@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression { return c_; } - ScoreExpressionType type() const override { - return ScoreExpressionType::kDouble; - } + ScoreExpressionType type() const override { return type_; } - bool is_constant_double() const override { return true; } + bool is_constant() const override { return true; } private: - explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c) - : c_(c) {} + explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type) + : c_(c), type_(type) {} libtextclassifier3::StatusOr<double> c_; + ScoreExpressionType type_; +}; + +class StringExpression : public ScoreExpression { + public: + static std::unique_ptr<StringExpression> Create(std::string str) { + return std::unique_ptr<StringExpression>( + new StringExpression(std::move(str))); + } + + libtextclassifier3::StatusOr<std::string_view> eval_string() const override { + return str_; + } + + ScoreExpressionType type() const override { + return ScoreExpressionType::kString; + } + + bool is_constant() const override { return true; } + + private: + explicit StringExpression(std::string str) : str_(std::move(str)) {} + std::string str_; }; class OperatorScoreExpression : public ScoreExpression { @@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression { int64_t current_time_ms_; }; +class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding"; + + // RETURNS: + // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if + // not simplifiable. + // - A ConstantScoreExpression instance on success if simplifiable. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( + std::vector<std::unique_ptr<ScoreExpression>> args); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kVectorIndex; + } + + private: + explicit GetSearchSpecEmbeddingFunctionScoreExpression( + std::unique_ptr<ScoreExpression> arg) + : arg_(std::move(arg)) {} + std::unique_ptr<ScoreExpression> arg_; +}; + +class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "matchedSemanticScores"; + + // RETURNS: + // - A MatchedSemanticScoresFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> + Create(std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results); + + libtextclassifier3::StatusOr<std::vector<double>> eval_list( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kDoubleList; + } + + private: + explicit MatchedSemanticScoresFunctionScoreExpression( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + const EmbeddingQueryResults& embedding_query_results) + : args_(std::move(args)), + metric_type_(metric_type), + embedding_query_results_(embedding_query_results) {} + + std::vector<std::unique_ptr<ScoreExpression>> args_; + const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; + const EmbeddingQueryResults& embedding_query_results_; +}; + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc index 588090d..cd58366 100644 --- a/icing/scoring/advanced_scoring/score-expression_test.cc +++ b/icing/scoring/advanced_scoring/score-expression_test.cc @@ -14,15 +14,16 @@ #include "icing/scoring/advanced_scoring/score-expression.h" -#include <cmath> #include <memory> -#include <string> -#include <string_view> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/testing/common-matchers.h" namespace icing { @@ -47,7 +48,7 @@ class NonConstantScoreExpression : public ScoreExpression { return ScoreExpressionType::kDouble; } - bool is_constant_double() const override { return false; } + bool is_constant() const override { return false; } }; class ListScoreExpression : public ScoreExpression { @@ -87,7 +88,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { OperatorScoreExpression::OperatorType::kPlus, MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(1)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); // 1 - 2 - 3 = -4 @@ -97,7 +98,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(3)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4))); // 1 * 2 * 3 * 4 = 24 @@ -108,7 +109,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(3), ConstantScoreExpression::Create(4)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24))); // 1 / 2 / 4 = 0.125 @@ -118,7 +119,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(4)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125))); // -(2) = -2 @@ -126,7 +127,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { expression, OperatorScoreExpression::Create( OperatorScoreExpression::OperatorType::kNegative, MakeChildren(ConstantScoreExpression::Create(2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2))); } @@ -138,7 +139,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { MathFunctionScoreExpression::FunctionType::kPow, MakeChildren(ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4))); // abs(-2) = 2 @@ -146,7 +147,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kAbs, MakeChildren(ConstantScoreExpression::Create(-2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); // log(e) = 1 @@ -154,7 +155,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kLog, MakeChildren(ConstantScoreExpression::Create(M_E)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1))); } @@ -166,7 +167,7 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { OperatorScoreExpression::OperatorType::kPlus, MakeChildren(ConstantScoreExpression::Create(1), NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // non_constant * non_constant = non_constant ICING_ASSERT_OK_AND_ASSIGN( @@ -174,14 +175,14 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { OperatorScoreExpression::OperatorType::kTimes, MakeChildren(NonConstantScoreExpression::Create(), NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // -(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, OperatorScoreExpression::Create( OperatorScoreExpression::OperatorType::kNegative, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // pow(non_constant, 2) = non_constant ICING_ASSERT_OK_AND_ASSIGN( @@ -189,21 +190,21 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { MathFunctionScoreExpression::FunctionType::kPow, MakeChildren(NonConstantScoreExpression::Create(), ConstantScoreExpression::Create(2)))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // abs(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kAbs, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // log(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kLog, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); } TEST(ScoreExpressionTest, MathFunctionsWithListTypeArgument) { diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc index e2b24a2..05240c0 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.cc +++ b/icing/scoring/advanced_scoring/scoring-visitor.cc @@ -14,7 +14,17 @@ #include "icing/scoring/advanced_scoring/scoring-visitor.h" +#include <cstdlib> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/scoring/advanced_scoring/score-expression.h" namespace icing { namespace lib { @@ -25,8 +35,7 @@ void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) { } void ScoringVisitor::VisitString(const StringNode* node) { - pending_error_ = - absl_ports::InvalidArgumentError("Scoring does not support String!"); + stack_.push_back(StringExpression::Create(node->value())); } void ScoringVisitor::VisitText(const TextNode* node) { @@ -120,6 +129,17 @@ void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node, expression = MathFunctionScoreExpression::Create( MathFunctionScoreExpression::kFunctionNames.at(function_name), std::move(args)); + } else if (function_name == + GetSearchSpecEmbeddingFunctionScoreExpression::kFunctionName) { + // getSearchSpecEmbedding function + expression = + GetSearchSpecEmbeddingFunctionScoreExpression::Create(std::move(args)); + } else if (function_name == + MatchedSemanticScoresFunctionScoreExpression::kFunctionName) { + // matchedSemanticScores function + expression = MatchedSemanticScoresFunctionScoreExpression::Create( + std::move(args), default_semantic_metric_type_, + &embedding_query_results_); } if (!expression.ok()) { diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h index cfee25b..bb5e6ba 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.h +++ b/icing/scoring/advanced_scoring/scoring-visitor.h @@ -15,14 +15,22 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ #define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/join/join-children-fetcher.h" #include "icing/legacy/core/icing-string-util.h" -#include "icing/proto/scoring.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-store.h" namespace icing { @@ -31,18 +39,23 @@ namespace lib { class ScoringVisitor : public AbstractSyntaxTreeVisitor { public: explicit ScoringVisitor(double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, SectionWeights* section_weights, Bm25fCalculator* bm25f_calculator, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results, int64_t current_time_ms) : default_score_(default_score), + default_semantic_metric_type_(default_semantic_metric_type), document_store_(*document_store), schema_store_(*schema_store), section_weights_(*section_weights), bm25f_calculator_(*bm25f_calculator), join_children_fetcher_(join_children_fetcher), + embedding_query_results_(*embedding_query_results), current_time_ms_(current_time_ms) {} void VisitFunctionName(const FunctionNameNode* node) override; @@ -90,12 +103,15 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor { } double default_score_; + const SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type_; const DocumentStore& document_store_; const SchemaStore& schema_store_; SectionWeights& section_weights_; Bm25fCalculator& bm25f_calculator_; // A non-null join_children_fetcher_ indicates scoring in a join. const JoinChildrenFetcher* join_children_fetcher_; // Does not own. + const EmbeddingQueryResults& embedding_query_results_; libtextclassifier3::Status pending_error_; std::vector<std::unique_ptr<ScoreExpression>> stack_; diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index 7cb5a95..4da8d1f 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -12,18 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <algorithm> #include <cstdint> #include <limits> #include <memory> #include <random> #include <string> +#include <unordered_map> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "testing/base/public/benchmark.h" +#include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -31,6 +35,7 @@ #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" #include "icing/scoring/ranker.h" #include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scoring-processor.h" @@ -131,11 +136,15 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -237,11 +246,15 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by( ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -344,11 +357,15 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -446,11 +463,15 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc index e56f10c..1d66d7f 100644 --- a/icing/scoring/scorer-factory.cc +++ b/icing/scoring/scorer-factory.cc @@ -14,19 +14,26 @@ #include "icing/scoring/scorer-factory.h" +#include <cstdint> #include <memory> +#include <optional> +#include <string> #include <unordered_map> +#include <utility> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" #include "icing/proto/scoring.pb.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/advanced-scorer.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" #include "icing/scoring/section-weights.h" -#include "icing/store/document-id.h" +#include "icing/store/document-associated-score-data.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -173,10 +180,14 @@ namespace scorer_factory { libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) { + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); if (!scoring_spec.advanced_scoring_expression().empty() && scoring_spec.rank_by() != @@ -223,9 +234,10 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( return absl_ports::InvalidArgumentError( "Advanced scoring is enabled, but the expression is empty!"); } - return AdvancedScorer::Create(scoring_spec, default_score, document_store, - schema_store, current_time_ms, - join_children_fetcher); + return AdvancedScorer::Create( + scoring_spec, default_score, default_semantic_metric_type, + document_store, schema_store, current_time_ms, join_children_fetcher, + embedding_query_results); case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: // Use join aggregate score to rank. Since the aggregation score is // calculated by child documents after joining (in JoinProcessor), we can diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h index 659bebd..f5766b3 100644 --- a/icing/scoring/scorer-factory.h +++ b/icing/scoring/scorer-factory.h @@ -15,8 +15,13 @@ #ifndef ICING_SCORING_SCORER_FACTORY_H_ #define ICING_SCORING_SCORER_FACTORY_H_ +#include <cstdint> +#include <memory> + #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/join/join-children-fetcher.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/scorer.h" #include "icing/store/document-store.h" @@ -37,9 +42,11 @@ namespace scorer_factory { // INVALID_ARGUMENT if fails to create an instance libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); } // namespace scorer_factory diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 5194c7f..e22d5f4 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -14,13 +14,19 @@ #include "icing/scoring/scorer.h" +#include <cstdint> +#include <limits> #include <memory> #include <string> +#include <utility> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" @@ -30,7 +36,6 @@ #include "icing/schema/schema-store.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer-test-utils.h" -#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" @@ -107,6 +112,10 @@ class ScorerTest : public ::testing::TestWithParam<ScorerTestingMode> { fake_clock1_.SetSystemTimeMilliseconds(new_time); } + SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + EmbeddingQueryResults empty_embedding_query_results; + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -134,8 +143,10 @@ TEST_P(ScorerTest, CreationWithNullDocumentStoreShouldFail) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, /*document_store=*/nullptr, schema_store(), - fake_clock1().GetSystemTimeMilliseconds()), + /*default_score=*/0, default_semantic_metric_type, + /*document_store=*/nullptr, schema_store(), + fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -144,8 +155,9 @@ TEST_P(ScorerTest, CreationWithNullSchemaStoreShouldFail) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, document_store(), - /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -155,8 +167,9 @@ TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -181,8 +194,9 @@ TEST_P(ScorerTest, ShouldGetDefaultDocumentScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -206,8 +220,9 @@ TEST_P(ScorerTest, ShouldGetCorrectDocumentScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -233,8 +248,9 @@ TEST_P(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -268,8 +284,9 @@ TEST_P(ScorerTest, ShouldGetCorrectCreationTimestampScore) { CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -297,22 +314,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -347,22 +367,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -397,22 +420,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -444,31 +470,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -516,31 +545,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -588,31 +620,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -651,8 +686,9 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::NONE, GetParam()), - /*default_score=*/3, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/3, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -662,11 +698,13 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, scorer_factory::Create( - CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE, GetParam()), - /*default_score=*/111, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer, + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE, GetParam()), + /*default_score=*/111, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -690,13 +728,14 @@ TEST_P(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index b827bd8..bb62db9 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -14,6 +14,7 @@ #include "icing/scoring/scoring-processor.h" +#include <cstdint> #include <limits> #include <memory> #include <string> @@ -23,10 +24,12 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" #include "icing/proto/scoring.pb.h" -#include "icing/scoring/ranker.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer.h" @@ -44,24 +47,28 @@ constexpr double kDefaultScoreInAscendingOrder = libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher) { + const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); bool is_descending_order = scoring_spec.order_by() == ScoringSpecProto::Order::DESC; ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - scorer_factory::Create(scoring_spec, - is_descending_order - ? kDefaultScoreInDescendingOrder - : kDefaultScoreInAscendingOrder, - document_store, schema_store, current_time_ms, - join_children_fetcher)); + scorer_factory::Create( + scoring_spec, + is_descending_order ? kDefaultScoreInDescendingOrder + : kDefaultScoreInAscendingOrder, + default_semantic_metric_type, document_store, schema_store, + current_time_ms, join_children_fetcher, embedding_query_results)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 8634a22..de0b95c 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -23,6 +23,7 @@ #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/proto/logging.pb.h" @@ -46,9 +47,12 @@ class ScoringProcessor { // A ScoringProcessor on success // FAILED_PRECONDITION on any null pointer input static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create( - const ScoringSpecProto& scoring_spec, const DocumentStore* document_store, - const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + const ScoringSpecProto& scoring_spec, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, + const DocumentStore* document_store, const SchemaStore* schema_store, + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); // Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns // a vector of ScoredDocumentHits. The size of results is no more than diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index deddff8..e3b70a6 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -15,22 +15,39 @@ #include "icing/scoring/scoring-processor.h" #include <cstdint> +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/term.pb.h" #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-test-utils.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -111,6 +128,10 @@ class ScoringProcessorTest const FakeClock& fake_clock() const { return fake_clock_; } + SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + EmbeddingQueryResults empty_embedding_query_results; + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -187,27 +208,31 @@ PropertyWeight CreatePropertyWeight(std::string path, double weight) { TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) { ScoringSpecProto spec_proto; - EXPECT_THAT(ScoringProcessor::Create( - spec_proto, /*document_store=*/nullptr, schema_store(), - fake_clock().GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, /*document_store=*/nullptr, + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { ScoringSpecProto spec_proto; EXPECT_THAT( - ScoringProcessor::Create(spec_proto, document_store(), - /*schema_store=*/nullptr, - fake_clock().GetSystemTimeMilliseconds()), + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + /*schema_store=*/nullptr, fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_P(ScoringProcessorTest, ShouldCreateInstance) { ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); - ICING_EXPECT_OK( - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ICING_EXPECT_OK(ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); } TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { @@ -222,8 +247,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/5), @@ -249,8 +276,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/-1), @@ -280,8 +309,10 @@ TEST_P(ScoringProcessorTest, ShouldRespectNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/2), @@ -313,8 +344,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -369,8 +402,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -439,8 +474,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -513,8 +550,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -563,8 +602,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -629,8 +670,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -700,8 +743,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -762,8 +807,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor with no explicit weights set. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ScoringSpecProto spec_proto_with_weights = CreateScoringSpecForRankingStrategy( @@ -778,9 +825,11 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor with default weight set for "body" property. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor_with_weights, - ScoringProcessor::Create(spec_proto_with_weights, document_store(), - schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto_with_weights, default_semantic_metric_type, + document_store(), schema_store(), + fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -866,8 +915,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -929,8 +980,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -990,8 +1043,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageCount) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -1051,8 +1106,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -1088,8 +1145,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNoScores) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/4), ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default), @@ -1138,8 +1197,10 @@ TEST_P(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), diff --git a/icing/testing/embedding-test-utils.h b/icing/testing/embedding-test-utils.h new file mode 100644 index 0000000..931953e --- /dev/null +++ b/icing/testing/embedding-test-utils.h @@ -0,0 +1,45 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_TESTING_EMBEDDING_TEST_UTILS_H_ +#define ICING_TESTING_EMBEDDING_TEST_UTILS_H_ + +#include <initializer_list> +#include <string> + +#include "icing/proto/document.pb.h" + +namespace icing { +namespace lib { + +inline PropertyProto::VectorProto CreateVector( + const std::string& model_signature, std::initializer_list<float> values) { + PropertyProto::VectorProto vector; + vector.set_model_signature(model_signature); + for (float value : values) { + vector.add_values(value); + } + return vector; +} + +template <typename... V> +inline PropertyProto::VectorProto CreateVector( + const std::string& model_signature, V&&... values) { + return CreateVector(model_signature, values...); +} + +} // namespace lib +} // namespace icing + +#endif // ICING_TESTING_EMBEDDING_TEST_UTILS_H_ diff --git a/icing/testing/hit-test-utils.cc b/icing/testing/hit-test-utils.cc index 2fd3ac8..c235e23 100644 --- a/icing/testing/hit-test-utils.cc +++ b/icing/testing/hit-test-utils.cc @@ -17,6 +17,7 @@ #include <cstdint> #include <vector> +#include "icing/index/embed/embedding-hit.h" #include "icing/index/hit/hit.h" #include "icing/index/main/posting-list-hit-serializer.h" #include "icing/schema/section.h" @@ -87,5 +88,28 @@ std::vector<Hit> CreateHits(int num_hits, int desired_byte_length) { return CreateHits(/*start_docid=*/0, num_hits, desired_byte_length); } +EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit, + uint32_t desired_byte_length) { + // Create a delta that has (desired_byte_length - 1) * 7 + 1 bits, so that it + // can be encoded in desired_byte_length bytes. + uint64_t delta = UINT64_C(1) << ((desired_byte_length - 1) * 7); + return EmbeddingHit(last_hit.value() - delta); +} + +std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits, + int desired_byte_length) { + std::vector<EmbeddingHit> hits; + if (num_hits == 0) { + return hits; + } + hits.reserve(num_hits); + hits.push_back(EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)); + for (int i = 1; i < num_hits; ++i) { + hits.push_back(CreateEmbeddingHit(hits.back(), desired_byte_length)); + } + return hits; +} + } // namespace lib } // namespace icing diff --git a/icing/testing/hit-test-utils.h b/icing/testing/hit-test-utils.h index 2953c5c..e041c22 100644 --- a/icing/testing/hit-test-utils.h +++ b/icing/testing/hit-test-utils.h @@ -15,8 +15,10 @@ #ifndef ICING_TESTING_HIT_TEST_UTILS_H_ #define ICING_TESTING_HIT_TEST_UTILS_H_ +#include <cstdint> #include <vector> +#include "icing/index/embed/embedding-hit.h" #include "icing/index/hit/hit.h" #include "icing/store/document-id.h" @@ -46,6 +48,18 @@ std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits, // with desired_byte_length deltas. std::vector<Hit> CreateHits(int num_hits, int desired_byte_length); +// Returns a hit that has a delta of desired_byte_length from last_hit after +// VarInt encoding. +// Requires that 0 < desired_byte_length <= VarInt::kMaxEncodedLen64. +EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit, + uint32_t desired_byte_length); + +// Returns a vector of num_hits Hits with the first hit starting at document 0 +// and with a delta of desired_byte_length between each subsequent hit after +// VarInt encoding. +std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits, + int desired_byte_length); + } // namespace lib } // namespace icing diff --git a/icing/util/document-validator.cc b/icing/util/document-validator.cc index e0880ea..bc75334 100644 --- a/icing/util/document-validator.cc +++ b/icing/util/document-validator.cc @@ -15,13 +15,23 @@ #include "icing/util/document-validator.h" #include <cstdint> +#include <string> +#include <string_view> +#include <unordered_map> #include <unordered_set> +#include <utility> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/legacy/core/icing-string-util.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" +#include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" +#include "icing/store/document-filter-data.h" +#include "icing/util/logging.h" #include "icing/util/status-macros.h" namespace icing { @@ -123,6 +133,18 @@ libtextclassifier3::Status DocumentValidator::Validate( } else if (property_config.data_type() == PropertyConfigProto::DataType::DOCUMENT) { value_size = property.document_values_size(); + } else if (property_config.data_type() == + PropertyConfigProto::DataType::VECTOR) { + value_size = property.vector_values_size(); + for (const PropertyProto::VectorProto& vector_value : + property.vector_values()) { + if (vector_value.values_size() == 0) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Property '%s' contains empty vectors for key: (%s, %s).", + property.name().c_str(), document.namespace_().c_str(), + document.uri().c_str())); + } + } } if (property_config.cardinality() == diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc index 9d10b36..2c366fd 100644 --- a/icing/util/document-validator_test.cc +++ b/icing/util/document-validator_test.cc @@ -15,7 +15,11 @@ #include "icing/util/document-validator.h" #include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -42,6 +46,7 @@ constexpr char kPropertySubject[] = "subject"; constexpr char kPropertyText[] = "text"; constexpr char kPropertyRecipients[] = "recipients"; constexpr char kPropertyNote[] = "note"; +constexpr char kPropertyNoteEmbedding[] = "noteEmbedding"; // type and property names of Conversation constexpr char kTypeConversation[] = "Conversation"; constexpr char kTypeConversationWithEmailNote[] = "ConversationWithEmailNote"; @@ -92,6 +97,10 @@ class DocumentValidatorTest : public ::testing::Test { .AddProperty(PropertyConfigBuilder() .SetName(kPropertyNote) .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyNoteEmbedding) + .SetDataType(TYPE_VECTOR) .SetCardinality(CARDINALITY_OPTIONAL))) .AddType( SchemaTypeConfigBuilder() @@ -146,6 +155,12 @@ class DocumentValidatorTest : public ::testing::Test { } DocumentBuilder SimpleEmailWithNoteBuilder() { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + vector.set_model_signature("my_model"); + return DocumentBuilder() .SetKey(kDefaultNamespace, "email_with_note/1") .SetSchema(kTypeEmailWithNote) @@ -153,7 +168,8 @@ class DocumentValidatorTest : public ::testing::Test { .AddStringProperty(kPropertyText, kDefaultString) .AddStringProperty(kPropertyRecipients, kDefaultString, kDefaultString, kDefaultString) - .AddStringProperty(kPropertyNote, kDefaultString); + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector); } DocumentBuilder SimpleConversationBuilder() { @@ -597,6 +613,64 @@ TEST_F(DocumentValidatorTest, NegativeDocumentTtlMsInvalid) { HasSubstr("is negative"))); } +TEST_F(DocumentValidatorTest, ValidateEmbeddingZeroDimensionInvalid) { + PropertyProto::VectorProto vector; + vector.set_model_signature("my_model"); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + EXPECT_THAT(document_validator_->Validate(email), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT, + HasSubstr("contains empty vectors"))); +} + +TEST_F(DocumentValidatorTest, ValidateEmbeddingEmptySignatureOk) { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + vector.set_model_signature(""); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + ICING_EXPECT_OK(document_validator_->Validate(email)); +} + +TEST_F(DocumentValidatorTest, ValidateEmbeddingNoSignatureOk) { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + ICING_EXPECT_OK(document_validator_->Validate(email)); +} + } // namespace } // namespace lib diff --git a/icing/util/embedding-util.h b/icing/util/embedding-util.h new file mode 100644 index 0000000..5026051 --- /dev/null +++ b/icing/util/embedding-util.h @@ -0,0 +1,49 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_UTIL_EMBEDDING_UTIL_H_ +#define ICING_UTIL_EMBEDDING_UTIL_H_ + +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +namespace embedding_util { + +inline libtextclassifier3::StatusOr< + SearchSpecProto::EmbeddingQueryMetricType::Code> +GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name) { + if (metric_name == "COSINE") { + return SearchSpecProto::EmbeddingQueryMetricType::COSINE; + } else if (metric_name == "DOT_PRODUCT") { + return SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + } else if (metric_name == "EUCLIDEAN") { + return SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN; + } + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unknown metric type: ", metric_name)); +} + +} // namespace embedding_util + +} // namespace lib +} // namespace icing + +#endif // ICING_UTIL_EMBEDDING_UTIL_H_ diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc index 19aaddf..e10fe25 100644 --- a/icing/util/tokenized-document.cc +++ b/icing/util/tokenized-document.cc @@ -14,16 +14,18 @@ #include "icing/util/tokenized-document.h" -#include <string> +#include <memory> #include <string_view> +#include <utility> #include <vector> -#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/proto/document.pb.h" #include "icing/schema/joinable-property.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/tokenization/language-segmenter.h" +#include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer-factory.h" #include "icing/tokenization/tokenizer.h" #include "icing/util/document-validator.h" @@ -85,6 +87,7 @@ TokenizedDocument::Create(const SchemaStore* schema_store, return TokenizedDocument(std::move(document), std::move(tokenized_string_sections), std::move(section_group.integer_sections), + std::move(section_group.vector_sections), std::move(joinable_property_group)); } diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h index 7cc34e3..0337083 100644 --- a/icing/util/tokenized-document.h +++ b/icing/util/tokenized-document.h @@ -16,7 +16,8 @@ #define ICING_STORE_TOKENIZED_DOCUMENT_H_ #include <cstdint> -#include <string> +#include <string_view> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -63,6 +64,11 @@ class TokenizedDocument { return integer_sections_; } + const std::vector<Section<PropertyProto::VectorProto>>& vector_sections() + const { + return vector_sections_; + } + const std::vector<JoinableProperty<std::string_view>>& qualified_id_join_properties() const { return joinable_property_group_.qualified_id_properties; @@ -74,15 +80,18 @@ class TokenizedDocument { DocumentProto&& document, std::vector<TokenizedSection>&& tokenized_string_sections, std::vector<Section<int64_t>>&& integer_sections, + std::vector<Section<PropertyProto::VectorProto>>&& vector_sections, JoinablePropertyGroup&& joinable_property_group) : document_(std::move(document)), tokenized_string_sections_(std::move(tokenized_string_sections)), integer_sections_(std::move(integer_sections)), + vector_sections_(std::move(vector_sections)), joinable_property_group_(std::move(joinable_property_group)) {} DocumentProto document_; std::vector<TokenizedSection> tokenized_string_sections_; std::vector<Section<int64_t>> integer_sections_; + std::vector<Section<PropertyProto::VectorProto>> vector_sections_; JoinablePropertyGroup joinable_property_group_; }; diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc index 7c97776..ab7f4b9 100644 --- a/icing/util/tokenized-document_test.cc +++ b/icing/util/tokenized-document_test.cc @@ -16,6 +16,8 @@ #include <memory> #include <string> +#include <string_view> +#include <utility> #include <vector> #include "gmock/gmock.h" @@ -25,7 +27,6 @@ #include "icing/portable/platform.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" -#include "icing/proto/term.pb.h" #include "icing/schema-builder.h" #include "icing/schema/joinable-property.h" #include "icing/schema/schema-store.h" @@ -59,13 +60,19 @@ static constexpr std::string_view kIndexableIntegerProperty1 = "indexableInteger1"; static constexpr std::string_view kIndexableIntegerProperty2 = "indexableInteger2"; +static constexpr std::string_view kIndexableVectorProperty1 = + "indexableVector1"; +static constexpr std::string_view kIndexableVectorProperty2 = + "indexableVector2"; static constexpr std::string_view kStringExactProperty = "stringExact"; static constexpr std::string_view kStringPrefixProperty = "stringPrefix"; static constexpr SectionId kIndexableInteger1SectionId = 0; static constexpr SectionId kIndexableInteger2SectionId = 1; -static constexpr SectionId kStringExactSectionId = 2; -static constexpr SectionId kStringPrefixSectionId = 3; +static constexpr SectionId kIndexableVector1SectionId = 2; +static constexpr SectionId kIndexableVector2SectionId = 3; +static constexpr SectionId kStringExactSectionId = 4; +static constexpr SectionId kStringPrefixSectionId = 5; // Joinable properties and joinable property id. Joinable property id is // determined by the lexicographical order of joinable property path. @@ -77,19 +84,33 @@ static constexpr JoinablePropertyId kQualifiedId2JoinablePropertyId = 1; const SectionMetadata kIndexableInteger1SectionMetadata( kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1)); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, + std::string(kIndexableIntegerProperty1)); const SectionMetadata kIndexableInteger2SectionMetadata( kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2)); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, + std::string(kIndexableIntegerProperty2)); + +const SectionMetadata kIndexableVector1SectionMetadata( + kIndexableVector1SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH, + std::string(kIndexableVectorProperty1)); + +const SectionMetadata kIndexableVector2SectionMetadata( + kIndexableVector2SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH, + std::string(kIndexableVectorProperty2)); const SectionMetadata kStringExactSectionMetadata( kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, - NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty)); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, + std::string(kStringExactProperty)); const SectionMetadata kStringPrefixSectionMetadata( kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX, - NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty)); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, + std::string(kStringPrefixProperty)); const JoinablePropertyMetadata kQualifiedId1JoinablePropertyMetadata( kQualifiedId1JoinablePropertyId, TYPE_STRING, @@ -102,6 +123,7 @@ const JoinablePropertyMetadata kQualifiedId2JoinablePropertyMetadata( // Other non-indexable/joinable properties. constexpr std::string_view kUnindexedStringProperty = "unindexedString"; constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger"; +constexpr std::string_view kUnindexedVectorProperty = "unindexedVector"; class TokenizedDocumentTest : public ::testing::Test { protected: @@ -140,6 +162,10 @@ class TokenizedDocumentTest : public ::testing::Test { .SetDataType(TYPE_INT64) .SetCardinality(CARDINALITY_OPTIONAL)) .AddProperty(PropertyConfigBuilder() + .SetName(kUnindexedVectorProperty) + .SetDataType(TYPE_VECTOR) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() .SetName(kIndexableIntegerProperty1) .SetDataTypeInt64(NUMERIC_MATCH_RANGE) .SetCardinality(CARDINALITY_REPEATED)) @@ -147,6 +173,16 @@ class TokenizedDocumentTest : public ::testing::Test { .SetName(kIndexableIntegerProperty2) .SetDataTypeInt64(NUMERIC_MATCH_RANGE) .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kIndexableVectorProperty1) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kIndexableVectorProperty2) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) .AddProperty(PropertyConfigBuilder() .SetName(kStringExactProperty) .SetDataTypeString(TERM_MATCH_EXACT, @@ -196,6 +232,16 @@ class TokenizedDocumentTest : public ::testing::Test { }; TEST_F(TokenizedDocumentTest, CreateAll) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -208,6 +254,10 @@ TEST_F(TokenizedDocumentTest, CreateAll) { .AddInt64Property(std::string(kUnindexedIntegerProperty), 789) .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3) .AddInt64Property(std::string(kIndexableIntegerProperty2), 456) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1) + .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1, + vector2) + .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1) .AddStringProperty(std::string(kQualifiedId1), "pkg$db/ns#uri1") .AddStringProperty(std::string(kQualifiedId2), "pkg$db/ns#uri2") .Build(); @@ -244,6 +294,17 @@ TEST_F(TokenizedDocumentTest, CreateAll) { EXPECT_THAT(tokenized_document.integer_sections().at(1).content, ElementsAre(456)); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata, + Eq(kIndexableVector1SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).content, + ElementsAre(EqualsProto(vector1), EqualsProto(vector2))); + EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata, + Eq(kIndexableVector2SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(1).content, + ElementsAre(EqualsProto(vector1))); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2)); EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata, @@ -278,6 +339,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -314,6 +378,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) { EXPECT_THAT(tokenized_document.integer_sections().at(1).content, ElementsAre(456)); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -341,6 +408,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -381,6 +451,92 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + + // Qualified id join properties + EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateNoIndexableVectorProperties) { + PropertyProto::VectorProto vector; + vector.set_model_signature("my_model"); + vector.add_values(1.0f); + + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + + // Qualified id join properties + EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateMultipleIndexableVectorProperties) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1) + .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1, + vector2) + .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata, + Eq(kIndexableVector1SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).content, + ElementsAre(EqualsProto(vector1), EqualsProto(vector2))); + EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata, + Eq(kIndexableVector2SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(1).content, + ElementsAre(EqualsProto(vector1))); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -408,6 +564,9 @@ TEST_F(TokenizedDocumentTest, CreateNoJoinQualifiedIdProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -437,6 +596,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleJoinQualifiedIdProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2)); EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata, diff --git a/proto/icing/proto/document.proto b/proto/icing/proto/document.proto index 1a501e7..919769e 100644 --- a/proto/icing/proto/document.proto +++ b/proto/icing/proto/document.proto @@ -78,7 +78,7 @@ message DocumentProto { } // Holds a property field of the Document. -// Next tag: 8 +// Next tag: 9 message PropertyProto { // Name of the property. // See icing.lib.PropertyConfigProto.property_name for details. @@ -92,6 +92,17 @@ message PropertyProto { repeated bool boolean_values = 5; repeated bytes bytes_values = 6; repeated DocumentProto document_values = 7; + + message VectorProto { + // The values of the vector. + repeated float values = 1 [packed = true]; + // The model signature of the vector, which can be any string used to + // identify the model, so that embedding searches can be restricted only to + // the vectors with the matching target signature. + // Eg: "universal-sentence-encoder_v0" + optional string model_signature = 2; + } + repeated VectorProto vector_values = 8; } // Result of a call to IcingSearchEngine.Put diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto index 4854521..46e988e 100644 --- a/proto/icing/proto/logging.proto +++ b/proto/icing/proto/logging.proto @@ -23,7 +23,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Stats of the top-level function IcingSearchEngine::Initialize(). -// Next tag: 14 +// Next tag: 15 message InitializeStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -121,10 +121,16 @@ message InitializeStatsProto { // - SCHEMA_CHANGES_OUT_OF_SYNC // - IO_ERROR optional RecoveryCause qualified_id_join_index_restoration_cause = 13; + + // Possible recovery causes for embedding index: + // - INCONSISTENT_WITH_GROUND_TRUTH + // - SCHEMA_CHANGES_OUT_OF_SYNC + // - IO_ERROR + optional RecoveryCause embedding_index_restoration_cause = 14; } // Stats of the top-level function IcingSearchEngine::Put(). -// Next tag: 12 +// Next tag: 13 message PutDocumentStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -170,6 +176,9 @@ message PutDocumentStatsProto { // Time used to index all metadata terms in the document, which can only be // added by PropertyExistenceIndexingHandler currently. optional int32 metadata_term_index_latency_ms = 11; + + // Time used to index all embeddings in the document. + optional int32 embedding_index_latency_ms = 12; } // Stats of the top-level function IcingSearchEngine::Search() and diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto index 78e1588..99439bb 100644 --- a/proto/icing/proto/schema.proto +++ b/proto/icing/proto/schema.proto @@ -184,6 +184,26 @@ message IntegerIndexingConfig { optional NumericMatchType.Code numeric_match_type = 1; } +// Describes how a vector property should be indexed. +// Next tag: 3 +message EmbeddingIndexingConfig { + // OPTIONAL: Indicates how the vector contents of this property should be + // matched. + // + // The default value is UNKNOWN. + message EmbeddingIndexingType { + enum Code { + // Contents in this property will not be indexed. Useful if the vector + // property type is not indexable. + UNKNOWN = 0; + + // Contents in this property will be indexed for linear search. + LINEAR_SEARCH = 1; + } + } + optional EmbeddingIndexingType.Code embedding_indexing_type = 1; +} + // Describes how a property can be used to join this document with another // document. See JoinSpecProto (in search.proto) for more details. // Next tag: 3 @@ -215,7 +235,7 @@ message JoinableConfig { // Describes the schema of a single property of Documents that belong to a // specific SchemaTypeConfigProto. These can be considered as a rich, structured // type for each property of Documents accepted by IcingSearchEngine. -// Next tag: 10 +// Next tag: 11 message PropertyConfigProto { // REQUIRED: Name that uniquely identifies a property within an Document of // a specific SchemaTypeConfigProto. @@ -252,6 +272,9 @@ message PropertyConfigProto { // a hierarchical Document schema. Any property using this data_type // MUST have a valid 'schema_type'. DOCUMENT = 6; + + // A list of floats. Vector type is used for embedding searches. + VECTOR = 7; } } optional DataType.Code data_type = 2; @@ -313,6 +336,10 @@ message PropertyConfigProto { // - The property itself and any upper-level (nested doc) property should // contain at most one element (i.e. Cardinality is OPTIONAL or REQUIRED). optional JoinableConfig joinable_config = 8; + + // OPTIONAL: Describes how vector properties should be indexed. Vector + // properties that do not set the indexing config will not be indexed. + optional EmbeddingIndexingConfig embedding_indexing_config = 10; } // List of all supported types constitutes the schema used by Icing. diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 7f4fb3e..7e145ce 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -27,7 +27,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Client-supplied specifications on what documents to retrieve. -// Next tag: 11 +// Next tag: 13 message SearchSpecProto { // REQUIRED: The "raw" query string that users may type. For example, "cat" // will search for documents with the term cat in it. @@ -112,6 +112,22 @@ message SearchSpecProto { // (TypePropertyMask for the given schema type has empty paths field), no // properties of that schema type will be searched. repeated TypePropertyMask type_property_filters = 10; + + // The vectors to be used in embedding queries. + repeated PropertyProto.VectorProto embedding_query_vectors = 11; + + message EmbeddingQueryMetricType { + enum Code { + UNKNOWN = 0; + COSINE = 1; + DOT_PRODUCT = 2; + EUCLIDEAN = 3; + } + } + + // The default metric type used to calculate the scores for embedding + // queries. + optional EmbeddingQueryMetricType.Code embedding_query_metric_type = 12; } // Client-supplied specifications on what to include/how to format the search diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index c22a1e8..55c4647 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=613977851) +set(synced_AOSP_CL_number=616925123) |