aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiayu Hu <hujiayu@google.com>2024-03-18 14:23:44 -0700
committerJiayu Hu <hujiayu@google.com>2024-03-18 15:47:26 -0700
commit555cb6e3295cf525baf46358235389ff52c9dcc2 (patch)
tree22dea3a10ec58bc8a54e07fb8ecb08720e4f8b13
parent29d4712b67d1ade154739e5fa9a9a7970afe6c0a (diff)
downloadicing-555cb6e3295cf525baf46358235389ff52c9dcc2.tar.gz
Update Icing from upstream.
Descriptions: ======================================================================== Integration test for embedding search ======================================================================== Support embedding search in the advanced scoring language ======================================================================== Support embedding search in the advanced query language ======================================================================== Add missing header inclusions. ======================================================================== Assign default values to MemoryMappedFile members to avoid temporary object error during std::move ======================================================================== Use GetEmbeddingVector in EmbeddingIndex::TransferIndex ======================================================================== Refactor SectionRestrictData ======================================================================== Add EmbeddingIndex to IcingSearchEngine and create EmbeddingIndexingHandler ======================================================================== Support optimize for embedding index ======================================================================== Implement embedding search index ======================================================================== Branch posting list accessor and serializer to store embedding hits ======================================================================== Introduce embedding hit ======================================================================== Update schema and document definition ======================================================================== Add tests verifying that Icing will correctly save updated schema description fields. ======================================================================== BUG: 326987971 BUG: 326656531 BUG: 329747255 Change-Id: I84f0a5a3f7fe133ece16c567a1f5f44e7866fd77
-rw-r--r--icing/document-builder.h18
-rw-r--r--icing/file/memory-mapped-file.h15
-rw-r--r--icing/icing-search-engine.cc164
-rw-r--r--icing/icing-search-engine.h13
-rw-r--r--icing/icing-search-engine_schema_test.cc82
-rw-r--r--icing/icing-search-engine_search_test.cc213
-rw-r--r--icing/index/embed/doc-hit-info-iterator-embedding.cc164
-rw-r--r--icing/index/embed/doc-hit-info-iterator-embedding.h161
-rw-r--r--icing/index/embed/embedding-hit.h67
-rw-r--r--icing/index/embed/embedding-hit_test.cc80
-rw-r--r--icing/index/embed/embedding-index.cc440
-rw-r--r--icing/index/embed/embedding-index.h274
-rw-r--r--icing/index/embed/embedding-index_test.cc582
-rw-r--r--icing/index/embed/embedding-query-results.h72
-rw-r--r--icing/index/embed/embedding-scorer.cc95
-rw-r--r--icing/index/embed/embedding-scorer.h54
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor.cc132
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor.h106
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor_test.cc387
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer.cc647
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer.h284
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer_test.cc864
-rw-r--r--icing/index/embedding-indexing-handler.cc85
-rw-r--r--icing/index/embedding-indexing-handler.h70
-rw-r--r--icing/index/embedding-indexing-handler_test.cc620
-rw-r--r--icing/index/hit/hit.h1
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-property-in-schema.h6
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-section-restrict.cc22
-rw-r--r--icing/index/iterator/doc-hit-info-iterator.h6
-rw-r--r--icing/index/iterator/section-restrict-data.cc21
-rw-r--r--icing/index/iterator/section-restrict-data.h14
-rw-r--r--icing/query/advanced_query_parser/function.cc24
-rw-r--r--icing/query/advanced_query_parser/function.h2
-rw-r--r--icing/query/advanced_query_parser/param.h10
-rw-r--r--icing/query/advanced_query_parser/pending-value.cc26
-rw-r--r--icing/query/advanced_query_parser/pending-value.h39
-rw-r--r--icing/query/advanced_query_parser/query-visitor.cc139
-rw-r--r--icing/query/advanced_query_parser/query-visitor.h72
-rw-r--r--icing/query/advanced_query_parser/query-visitor_test.cc1491
-rw-r--r--icing/query/query-features.h8
-rw-r--r--icing/query/query-processor.cc22
-rw-r--r--icing/query/query-processor.h4
-rw-r--r--icing/query/query-processor_benchmark.cc54
-rw-r--r--icing/query/query-processor_test.cc102
-rw-r--r--icing/query/query-results.h7
-rw-r--r--icing/query/suggestion-processor.cc19
-rw-r--r--icing/query/suggestion-processor.h4
-rw-r--r--icing/query/suggestion-processor_test.cc44
-rw-r--r--icing/schema-builder.h28
-rw-r--r--icing/schema/property-util.cc10
-rw-r--r--icing/schema/property-util.h8
-rw-r--r--icing/schema/property-util_test.cc82
-rw-r--r--icing/schema/schema-store_test.cc14
-rw-r--r--icing/schema/schema-util.cc15
-rw-r--r--icing/schema/schema-util_test.cc112
-rw-r--r--icing/schema/section-manager.cc14
-rw-r--r--icing/schema/section-manager_test.cc100
-rw-r--r--icing/schema/section.h17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.cc24
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.h17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_test.cc884
-rw-r--r--icing/scoring/advanced_scoring/score-expression.cc128
-rw-r--r--icing/scoring/advanced_scoring/score-expression.h139
-rw-r--r--icing/scoring/advanced_scoring/score-expression_test.cc37
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.cc24
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.h18
-rw-r--r--icing/scoring/score-and-rank_benchmark.cc47
-rw-r--r--icing/scoring/scorer-factory.cc22
-rw-r--r--icing/scoring/scorer-factory.h11
-rw-r--r--icing/scoring/scorer_test.cc259
-rw-r--r--icing/scoring/scoring-processor.cc23
-rw-r--r--icing/scoring/scoring-processor.h10
-rw-r--r--icing/scoring/scoring-processor_test.cc155
-rw-r--r--icing/testing/embedding-test-utils.h45
-rw-r--r--icing/testing/hit-test-utils.cc24
-rw-r--r--icing/testing/hit-test-utils.h14
-rw-r--r--icing/util/document-validator.cc22
-rw-r--r--icing/util/document-validator_test.cc76
-rw-r--r--icing/util/embedding-util.h49
-rw-r--r--icing/util/tokenized-document.cc7
-rw-r--r--icing/util/tokenized-document.h11
-rw-r--r--icing/util/tokenized-document_test.cc176
-rw-r--r--proto/icing/proto/document.proto13
-rw-r--r--proto/icing/proto/logging.proto13
-rw-r--r--proto/icing/proto/schema.proto29
-rw-r--r--proto/icing/proto/search.proto18
-rw-r--r--synced_AOSP_CL_number.txt2
88 files changed, 9665 insertions, 875 deletions
diff --git a/icing/document-builder.h b/icing/document-builder.h
index 44500f9..5d6f14d 100644
--- a/icing/document-builder.h
+++ b/icing/document-builder.h
@@ -126,6 +126,13 @@ class DocumentBuilder {
return AddDocumentProperty(std::move(property_name), {document_values...});
}
+ // Takes a property name and any number of vector values.
+ template <typename... V>
+ DocumentBuilder& AddVectorProperty(std::string property_name,
+ V... vector_values) {
+ return AddVectorProperty(std::move(property_name), {vector_values...});
+ }
+
DocumentProto Build() const { return document_; }
private:
@@ -183,6 +190,17 @@ class DocumentBuilder {
}
return *this;
}
+
+ DocumentBuilder& AddVectorProperty(
+ std::string property_name,
+ std::initializer_list<PropertyProto::VectorProto> vector_values) {
+ auto property = document_.add_properties();
+ property->set_name(std::move(property_name));
+ for (PropertyProto::VectorProto vector_value : vector_values) {
+ property->mutable_vector_values()->Add(std::move(vector_value));
+ }
+ return *this;
+ }
};
} // namespace lib
diff --git a/icing/file/memory-mapped-file.h b/icing/file/memory-mapped-file.h
index 54507af..185d940 100644
--- a/icing/file/memory-mapped-file.h
+++ b/icing/file/memory-mapped-file.h
@@ -79,7 +79,6 @@
#include <algorithm>
#include <cstdint>
-#include <memory>
#include <string>
#include <string_view>
@@ -314,7 +313,7 @@ class MemoryMappedFile {
// Cached constructor params.
const Filesystem* filesystem_;
std::string file_path_;
- Strategy strategy_;
+ Strategy strategy_ = Strategy::READ_WRITE_AUTO_SYNC;
// Raw file related fields:
// - max_file_size_
@@ -327,7 +326,7 @@ class MemoryMappedFile {
//
// Note: max_file_size_ will be specified in runtime and the caller should
// make sure its value is correct and reasonable.
- int64_t max_file_size_;
+ int64_t max_file_size_ = 0;
// Cached file size to avoid calling system call too frequently. It is only
// used in GrowAndRemapIfNecessary(), the new API that handles underlying file
@@ -336,7 +335,7 @@ class MemoryMappedFile {
// Note: it is guaranteed that file_size_ is smaller or equal to the actual
// file size as long as the underlying file hasn't been truncated or deleted
// externally. See GrowFileSize() for more details.
- int64_t file_size_;
+ int64_t file_size_ = 0;
// Memory mapped related fields:
// - mmap_result_
@@ -345,22 +344,22 @@ class MemoryMappedFile {
// - mmap_size_
// Raw pointer (or error) returned by calls to mmap().
- void* mmap_result_;
+ void* mmap_result_ = nullptr;
// Offset within the file at which the current memory-mapped region starts.
- int64_t file_offset_;
+ int64_t file_offset_ = 0;
// Size that is currently memory-mapped.
// Note that the mmapped size can be larger than the underlying file size. We
// can reduce remapping by pre-mmapping a large memory and grow the file size
// later. See GrowAndRemapIfNecessary().
- int64_t mmap_size_;
+ int64_t mmap_size_ = 0;
// The difference between file_offset_ and the actual adjusted (aligned)
// offset.
// Since mmap requires the offset to be a multiple of system page size, we
// have to align file_offset_ to the last multiple of system page size.
- int64_t alignment_adjustment_;
+ int64_t alignment_adjustment_ = 0;
// E.g. system_page_size = 5, RemapImpl(/*new_file_offset=*/8, mmap_size)
//
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc
index 72e744b..89caaf1 100644
--- a/icing/icing-search-engine.cc
+++ b/icing/icing-search-engine.cc
@@ -37,6 +37,8 @@
#include "icing/file/filesystem.h"
#include "icing/file/version-util.h"
#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embedding-indexing-handler.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/index-processor.h"
#include "icing/index/index.h"
@@ -108,6 +110,7 @@ constexpr std::string_view kIndexSubfolderName = "index_dir";
constexpr std::string_view kIntegerIndexSubfolderName = "integer_index_dir";
constexpr std::string_view kQualifiedIdJoinIndexSubfolderName =
"qualified_id_join_index_dir";
+constexpr std::string_view kEmbeddingIndexSubfolderName = "embedding_index_dir";
constexpr std::string_view kSchemaSubfolderName = "schema_dir";
constexpr std::string_view kSetSchemaMarkerFilename = "set_schema_marker";
constexpr std::string_view kInitMarkerFilename = "init_marker";
@@ -292,6 +295,11 @@ std::string MakeQualifiedIdJoinIndexWorkingPath(const std::string& base_dir) {
return absl_ports::StrCat(base_dir, "/", kQualifiedIdJoinIndexSubfolderName);
}
+// Working path for embedding index.
+std::string MakeEmbeddingIndexWorkingPath(const std::string& base_dir) {
+ return absl_ports::StrCat(base_dir, "/", kEmbeddingIndexSubfolderName);
+}
+
// SchemaStore files are in a standalone subfolder for easier file management.
// We can delete and recreate the subfolder and not touch/affect anything
// else.
@@ -478,6 +486,7 @@ void IcingSearchEngine::ResetMembers() {
index_.reset();
integer_index_.reset();
qualified_id_join_index_.reset();
+ embedding_index_.reset();
}
libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile(
@@ -668,15 +677,19 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
MakeIntegerIndexWorkingPath(options_.base_dir());
const std::string qualified_id_join_index_dir =
MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir());
+ const std::string embedding_index_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
if (!filesystem_->DeleteDirectoryRecursively(doc_store_dir.c_str()) ||
!filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) ||
!IntegerIndex::Discard(*filesystem_, integer_index_dir).ok() ||
!QualifiedIdJoinIndex::Discard(*filesystem_,
qualified_id_join_index_dir)
- .ok()) {
+ .ok() ||
+ !EmbeddingIndex::Discard(*filesystem_, embedding_index_dir).ok()) {
return absl_ports::InternalError(absl_ports::StrCat(
"Could not delete directories: ", index_dir, ", ", integer_index_dir,
- ", ", qualified_id_join_index_dir, " and ", doc_store_dir));
+ ", ", qualified_id_join_index_dir, ", ", embedding_index_dir, " and ",
+ doc_store_dir));
}
ICING_ASSIGN_OR_RETURN(
bool document_store_derived_files_regenerated,
@@ -734,6 +747,15 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
CreateQualifiedIdJoinIndex(
*filesystem_, std::move(qualified_id_join_index_dir), options_));
+ // Discard embedding index directory and instantiate a new one.
+ std::string embedding_index_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
+ ICING_RETURN_IF_ERROR(
+ EmbeddingIndex::Discard(*filesystem_, embedding_index_dir));
+ ICING_ASSIGN_OR_RETURN(
+ embedding_index_,
+ EmbeddingIndex::Create(filesystem_.get(), embedding_index_dir));
+
std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer();
IndexRestorationResult restore_result = RestoreIndexIfNeeded();
index_init_status = std::move(restore_result.status);
@@ -756,6 +778,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
initialize_stats->set_qualified_id_join_index_restoration_cause(
InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
+ initialize_stats->set_embedding_index_restoration_cause(
+ InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
} else if (version_state_change != version_util::StateChange::kCompatible) {
ICING_ASSIGN_OR_RETURN(bool document_store_derived_files_regenerated,
InitializeDocumentStore(
@@ -777,6 +801,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
InitializeStatsProto::VERSION_CHANGED);
initialize_stats->set_qualified_id_join_index_restoration_cause(
InitializeStatsProto::VERSION_CHANGED);
+ initialize_stats->set_embedding_index_restoration_cause(
+ InitializeStatsProto::VERSION_CHANGED);
} else {
ICING_ASSIGN_OR_RETURN(
bool document_store_derived_files_regenerated,
@@ -813,6 +839,7 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
initialize_stats->set_qualified_id_join_index_restoration_cause(
InitializeStatsProto::FEATURE_FLAG_CHANGED);
}
+ // TODO(b/326656531): Update version-util to consider embedding index.
}
if (status.ok()) {
@@ -981,6 +1008,30 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex(
}
}
+ // Embedding index
+ const std::string embedding_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
+ InitializeStatsProto::RecoveryCause embedding_index_recovery_cause;
+ auto embedding_index_or =
+ EmbeddingIndex::Create(filesystem_.get(), embedding_dir);
+ if (!embedding_index_or.ok()) {
+ ICING_RETURN_IF_ERROR(EmbeddingIndex::Discard(*filesystem_, embedding_dir));
+
+ embedding_index_recovery_cause = InitializeStatsProto::IO_ERROR;
+
+ // Try recreating it from scratch and re-indexing everything.
+ ICING_ASSIGN_OR_RETURN(
+ embedding_index_,
+ EmbeddingIndex::Create(filesystem_.get(), embedding_dir));
+ } else {
+ // Embedding index was created fine.
+ embedding_index_ = std::move(embedding_index_or).ValueOrDie();
+ // If a recover does have to happen, then it must be because the index is
+ // out of sync with the document store.
+ embedding_index_recovery_cause =
+ InitializeStatsProto::INCONSISTENT_WITH_GROUND_TRUTH;
+ }
+
std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer();
IndexRestorationResult restore_result = RestoreIndexIfNeeded();
if (restore_result.index_needed_restoration ||
@@ -1000,6 +1051,10 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex(
initialize_stats->set_qualified_id_join_index_restoration_cause(
qualified_id_join_index_recovery_cause);
}
+ if (restore_result.embedding_index_needed_restoration) {
+ initialize_stats->set_embedding_index_restoration_cause(
+ embedding_index_recovery_cause);
+ }
}
return restore_result.status;
}
@@ -1515,9 +1570,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery(
std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
// Gets unordered results from query processor
auto query_processor_or = QueryProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get(),
- clock_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!query_processor_or.ok()) {
TransformStatus(query_processor_or.status(), result_status);
delete_stats->set_parse_query_latency_ms(
@@ -1713,6 +1768,15 @@ OptimizeResultProto IcingSearchEngine::Optimize() {
<< qualified_id_join_index_optimize_status.error_message();
should_rebuild_index = true;
}
+
+ libtextclassifier3::Status embedding_index_optimize_status =
+ embedding_index_->Optimize(optimize_result.document_id_old_to_new,
+ document_store_->last_added_document_id());
+ if (!embedding_index_optimize_status.ok()) {
+ ICING_LOG(WARNING) << "Failed to optimize embedding index. Error: "
+ << embedding_index_optimize_status.error_message();
+ should_rebuild_index = true;
+ }
}
// If we received a DATA_LOSS error from OptimizeDocumentStore, we have a
// valid document store, but it might be the old one or the new one. So throw
@@ -1939,6 +2003,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk(
ICING_RETURN_IF_ERROR(index_->PersistToDisk());
ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk());
ICING_RETURN_IF_ERROR(qualified_id_join_index_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(embedding_index_->PersistToDisk());
return libtextclassifier3::Status::OK;
}
@@ -2220,9 +2285,9 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
// Gets unordered results from query processor
auto query_processor_or = QueryProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get(),
- clock_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!query_processor_or.ok()) {
search_stats->set_parse_query_latency_ms(
component_timer->GetElapsedMilliseconds());
@@ -2266,8 +2331,10 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
// Scores but does not rank the results.
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
scoring_processor_or = ScoringProcessor::Create(
- scoring_spec, document_store_.get(), schema_store_.get(),
- current_time_ms, join_children_fetcher);
+ scoring_spec, /*default_semantic_metric_type=*/
+ search_spec.embedding_query_metric_type(), document_store_.get(),
+ schema_store_.get(), current_time_ms, join_children_fetcher,
+ &query_results.embedding_query_results);
if (!scoring_processor_or.ok()) {
return QueryScoringResults(std::move(scoring_processor_or).status(),
std::move(query_results.query_terms),
@@ -2504,27 +2571,28 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
if (last_stored_document_id == index_->last_added_document_id() &&
last_stored_document_id == integer_index_->last_added_document_id() &&
last_stored_document_id ==
- qualified_id_join_index_->last_added_document_id()) {
+ qualified_id_join_index_->last_added_document_id() &&
+ last_stored_document_id == embedding_index_->last_added_document_id()) {
// No need to recover.
- return {libtextclassifier3::Status::OK, false, false, false};
+ return {libtextclassifier3::Status::OK, false, false, false, false};
}
if (last_stored_document_id == kInvalidDocumentId) {
// Document store is empty but index is not. Clear the index.
- return {ClearAllIndices(), false, false, false};
+ return {ClearAllIndices(), false, false, false, false};
}
// Truncate indices first.
auto truncate_result_or = TruncateIndicesTo(last_stored_document_id);
if (!truncate_result_or.ok()) {
- return {std::move(truncate_result_or).status(), false, false, false};
+ return {std::move(truncate_result_or).status(), false, false, false, false};
}
TruncateIndexResult truncate_result =
std::move(truncate_result_or).ValueOrDie();
if (truncate_result.first_document_to_reindex > last_stored_document_id) {
// Nothing to restore. Just return.
- return {libtextclassifier3::Status::OK, false, false, false};
+ return {libtextclassifier3::Status::OK, false, false, false, false};
}
auto data_indexing_handlers_or = CreateDataIndexingHandlers();
@@ -2532,7 +2600,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {data_indexing_handlers_or.status(),
truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
// By using recovery_mode for IndexProcessor, we're able to replay documents
// from smaller document id and it will skip documents that are already been
@@ -2559,7 +2628,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
// Returns other errors
return {document_or.status(), truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
}
DocumentProto document(std::move(document_or).ValueOrDie());
@@ -2572,7 +2642,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {tokenized_document_or.status(),
truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
TokenizedDocument tokenized_document(
std::move(tokenized_document_or).ValueOrDie());
@@ -2584,7 +2655,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
// Real error. Stop recovering and pass it up.
return {status, truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
// FIXME: why can we skip data loss error here?
// Just a data loss. Keep trying to add the remaining docs, but report the
@@ -2595,7 +2667,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {overall_status, truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() {
@@ -2648,6 +2721,11 @@ IcingSearchEngine::CreateDataIndexingHandlers() {
clock_.get(), document_store_.get(), qualified_id_join_index_.get()));
handlers.push_back(std::move(qualified_id_join_indexing_handler));
+ // Embedding index handler
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<EmbeddingIndexingHandler> embedding_indexing_handler,
+ EmbeddingIndexingHandler::Create(clock_.get(), embedding_index_.get()));
+ handlers.push_back(std::move(embedding_indexing_handler));
return handlers;
}
@@ -2736,10 +2814,43 @@ IcingSearchEngine::TruncateIndicesTo(DocumentId last_stored_document_id) {
first_document_to_reindex = kMinDocumentId;
}
+ // Attempt to truncate embedding index
+ bool embedding_index_needed_restoration = false;
+ DocumentId embedding_index_last_added_document_id =
+ embedding_index_->last_added_document_id();
+ if (embedding_index_last_added_document_id == kInvalidDocumentId ||
+ last_stored_document_id > embedding_index_last_added_document_id) {
+ // If last_stored_document_id is greater than
+ // embedding_index_last_added_document_id, then we only have to replay docs
+ // starting from (embedding_index_last_added_document_id + 1). Also use
+ // std::min since we might need to replay even smaller doc ids for other
+ // components.
+ embedding_index_needed_restoration = true;
+ if (embedding_index_last_added_document_id != kInvalidDocumentId) {
+ first_document_to_reindex =
+ std::min(first_document_to_reindex,
+ embedding_index_last_added_document_id + 1);
+ } else {
+ first_document_to_reindex = kMinDocumentId;
+ }
+ } else if (last_stored_document_id < embedding_index_last_added_document_id) {
+ // Clear the entire embedding index if last_stored_document_id is
+ // smaller than embedding_index_last_added_document_id, because
+ // there is no way to remove data with doc_id > last_stored_document_id from
+ // embedding index efficiently and we have to rebuild.
+ ICING_RETURN_IF_ERROR(embedding_index_->Clear());
+
+ // Since the entire embedding index is discarded, we start to
+ // rebuild it by setting first_document_to_reindex to kMinDocumentId.
+ embedding_index_needed_restoration = true;
+ first_document_to_reindex = kMinDocumentId;
+ }
+
return TruncateIndexResult(first_document_to_reindex,
index_needed_restoration,
integer_index_needed_restoration,
- qualified_id_join_index_needed_restoration);
+ qualified_id_join_index_needed_restoration,
+ embedding_index_needed_restoration);
}
libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles(
@@ -2750,7 +2861,7 @@ libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles(
if (schema_store_ != nullptr || document_store_ != nullptr ||
index_ != nullptr || integer_index_ != nullptr ||
- qualified_id_join_index_ != nullptr) {
+ qualified_id_join_index_ != nullptr || embedding_index_ != nullptr) {
return absl_ports::FailedPreconditionError(
"Cannot discard derived files while having valid instances");
}
@@ -2792,12 +2903,15 @@ libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles(
}
}
+ // TODO(b/326656531): Update version-util to consider embedding index.
+
return libtextclassifier3::Status::OK;
}
libtextclassifier3::Status IcingSearchEngine::ClearSearchIndices() {
ICING_RETURN_IF_ERROR(index_->Reset());
ICING_RETURN_IF_ERROR(integer_index_->Clear());
+ ICING_RETURN_IF_ERROR(embedding_index_->Clear());
return libtextclassifier3::Status::OK;
}
@@ -2870,9 +2984,9 @@ SuggestionResponse IcingSearchEngine::SearchSuggestions(
// Create the suggestion processor.
auto suggestion_processor_or = SuggestionProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get(),
- clock_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!suggestion_processor_or.ok()) {
TransformStatus(suggestion_processor_or.status(), response_status);
return response;
diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h
index b9df95d..57f0f28 100644
--- a/icing/icing-search-engine.h
+++ b/icing/icing-search-engine.h
@@ -28,6 +28,7 @@
#include "icing/file/filesystem.h"
#include "icing/file/version-util.h"
#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/jni/jni-cache.h"
@@ -483,6 +484,9 @@ class IcingSearchEngine {
std::unique_ptr<QualifiedIdJoinIndex> qualified_id_join_index_
ICING_GUARDED_BY(mutex_);
+ // Storage for all hits of embedding contents from the document store.
+ std::unique_ptr<EmbeddingIndex> embedding_index_ ICING_GUARDED_BY(mutex_);
+
// Pointer to JNI class references
const std::unique_ptr<const JniCache> jni_cache_;
@@ -701,6 +705,7 @@ class IcingSearchEngine {
bool index_needed_restoration;
bool integer_index_needed_restoration;
bool qualified_id_join_index_needed_restoration;
+ bool embedding_index_needed_restoration;
};
IndexRestorationResult RestoreIndexIfNeeded()
ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
@@ -741,17 +746,21 @@ class IcingSearchEngine {
bool index_needed_restoration;
bool integer_index_needed_restoration;
bool qualified_id_join_index_needed_restoration;
+ bool embedding_index_needed_restoration;
explicit TruncateIndexResult(
DocumentId first_document_to_reindex_in,
bool index_needed_restoration_in,
bool integer_index_needed_restoration_in,
- bool qualified_id_join_index_needed_restoration_in)
+ bool qualified_id_join_index_needed_restoration_in,
+ bool embedding_index_needed_restoration_in)
: first_document_to_reindex(first_document_to_reindex_in),
index_needed_restoration(index_needed_restoration_in),
integer_index_needed_restoration(integer_index_needed_restoration_in),
qualified_id_join_index_needed_restoration(
- qualified_id_join_index_needed_restoration_in) {}
+ qualified_id_join_index_needed_restoration_in),
+ embedding_index_needed_restoration(
+ embedding_index_needed_restoration_in) {}
};
libtextclassifier3::StatusOr<TruncateIndexResult> TruncateIndicesTo(
DocumentId last_stored_document_id)
diff --git a/icing/icing-search-engine_schema_test.cc b/icing/icing-search-engine_schema_test.cc
index 49c024e..34081e9 100644
--- a/icing/icing-search-engine_schema_test.cc
+++ b/icing/icing-search-engine_schema_test.cc
@@ -36,7 +36,6 @@
#include "icing/proto/persist.pb.h"
#include "icing/proto/reset.pb.h"
#include "icing/proto/schema.pb.h"
-#include "icing/proto/scoring.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/status.pb.h"
#include "icing/proto/storage.pb.h"
@@ -60,6 +59,7 @@ namespace {
using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::Eq;
using ::testing::HasSubstr;
+using ::testing::Not;
using ::testing::Return;
// For mocking purpose, we allow tests to provide a custom Filesystem.
@@ -3154,6 +3154,86 @@ TEST_F(IcingSearchEngineSchemaTest, IcingShouldReturnErrorForExtraSections) {
HasSubstr("Too many properties to be indexed"));
}
+TEST_F(IcingSearchEngineSchemaTest, UpdatedTypeDescriptionIsSaved) {
+ // Create a schema with more sections than allowed.
+ PropertyConfigProto old_property =
+ PropertyConfigBuilder()
+ .SetName("prop0")
+ .SetDescription("old property description")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+ SchemaTypeConfigProto old_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(old_property)
+ .Build();
+ SchemaProto old_schema =
+ SchemaBuilder().AddType(old_schema_type_config).Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk());
+
+ // Update the type description
+ SchemaTypeConfigProto new_schema_type_config =
+ SchemaTypeConfigBuilder(old_schema_type_config)
+ .SetDescription("new description")
+ .Build();
+ SchemaProto new_schema =
+ SchemaBuilder().AddType(new_schema_type_config).Build();
+ ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk());
+
+ GetSchemaResultProto get_result = icing.GetSchema();
+ ASSERT_THAT(get_result.status(), ProtoIsOk());
+ ASSERT_THAT(get_result.schema(), EqualsProto(new_schema));
+ ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema)));
+}
+
+TEST_F(IcingSearchEngineSchemaTest, UpdatedPropertyDescriptionIsSaved) {
+ // Create a schema with more sections than allowed.
+ PropertyConfigProto old_property =
+ PropertyConfigBuilder()
+ .SetName("prop0")
+ .SetDescription("old property description")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+ SchemaTypeConfigProto old_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(old_property)
+ .Build();
+ SchemaProto old_schema =
+ SchemaBuilder().AddType(old_schema_type_config).Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk());
+
+ // Update the property description
+ PropertyConfigProto new_property =
+ PropertyConfigBuilder(old_property)
+ .SetDescription("new property description")
+ .Build();
+ SchemaTypeConfigProto new_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(new_property)
+ .Build();
+ SchemaProto new_schema =
+ SchemaBuilder().AddType(new_schema_type_config).Build();
+ ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk());
+
+ GetSchemaResultProto get_result = icing.GetSchema();
+ ASSERT_THAT(get_result.status(), ProtoIsOk());
+ ASSERT_THAT(get_result.schema(), EqualsProto(new_schema));
+ ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema)));
+}
+
} // namespace
} // namespace lib
} // namespace icing
diff --git a/icing/icing-search-engine_search_test.cc b/icing/icing-search-engine_search_test.cc
index d815f61..a58dbc8 100644
--- a/icing/icing-search-engine_search_test.cc
+++ b/icing/icing-search-engine_search_test.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include <cstdint>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
+#include <vector>
-#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -27,7 +29,7 @@
#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/jni/jni-cache.h"
#include "icing/join/join-processor.h"
-#include "icing/portable/endian.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/equals-proto.h"
#include "icing/portable/platform.h"
#include "icing/proto/debug.pb.h"
@@ -49,11 +51,13 @@
#include "icing/result/result-state-manager.h"
#include "icing/schema-builder.h"
#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/clock.h"
#include "icing/util/snippet-helpers.h"
namespace icing {
@@ -63,6 +67,7 @@ namespace {
using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::DoubleEq;
+using ::testing::DoubleNear;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Gt;
@@ -120,6 +125,8 @@ class IcingSearchEngineSearchTest
// Non-zero value so we don't override it to be the current time
constexpr int64_t kDefaultCreationTimestampMs = 1575492852000;
+constexpr double kEps = 0.000001;
+
IcingSearchEngineOptions GetDefaultIcingOptions() {
IcingSearchEngineOptions icing_options;
icing_options.set_base_dir(GetTestBaseDir());
@@ -7306,6 +7313,208 @@ TEST_P(IcingSearchEngineSearchTest, HasPropertyQueryNestedDocument) {
EXPECT_THAT(results.results(), IsEmpty());
}
+TEST_P(IcingSearchEngineSearchTest, EmbeddingSearch) {
+ if (GetParam() !=
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ GTEST_SKIP() << "Embedding search is only supported in advanced query.";
+ }
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Email")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("body")
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding1")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding2")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED)))
+ .Build();
+ DocumentProto document0 =
+ DocumentBuilder()
+ .SetKey("icing", "uri0")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddStringProperty("body", "foo")
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {0.1, 0.2, 0.3, 0.4, 0.5}))
+ .AddVectorProperty(
+ "embedding2",
+ CreateVector("my_model_v1", {-0.1, -0.2, -0.3, 0.4, 0.5}),
+ CreateVector("my_model_v2", {0.6, 0.7, 0.8}))
+ .Build();
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("icing", "uri1")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {-0.1, 0.2, -0.3, -0.4, 0.5}))
+ .AddVectorProperty("embedding2",
+ CreateVector("my_model_v2", {0.6, 0.7, -0.8}))
+ .Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document0).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_term_match_type(TermMatchType::EXACT_ONLY);
+ search_spec.set_embedding_query_metric_type(
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT);
+ search_spec.add_enabled_features(
+ std::string(kListFilterQueryLanguageFeature));
+ search_spec.add_enabled_features(std::string(kEmbeddingSearchFeature));
+ search_spec.set_search_type(GetParam());
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v1", {1, -1, -1, 1, -1});
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding2)
+ // - document 1: -2.1 (embedding2)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v2", {-1, -1, 1});
+ ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
+ scoring_spec.set_rank_by(
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
+
+ // Match documents that have embeddings with a similarity closer to 0 that is
+ // greater than -1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({}) = -0.2
+ // - document 1: sum({-0.9}) + sum({}) = -0.9
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ SearchResultProto results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query the same as above but with a section restriction, which
+ // still matches document 0 and document 1 but the semantic score 0.3 should
+ // be removed from document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ // - document 1: sum({-0.9}) = -0.9
+ search_spec.set_query(
+ "embedding1:semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query that only matches document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -1.5)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+
+ // Create a query that only matches document 1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-2.1, kEps));
+
+ // Create a complex query that matches all hits from all documents.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2), -0.5 (embedding2)
+ // - document 1: -0.9 (embedding1), -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({-0.5}) = -0.7
+ // - document 1: sum({-0.9}) + sum({-2.1}) = -3
+ search_spec.set_query(
+ "semanticSearch(getSearchSpecEmbedding(0)) OR "
+ "semanticSearch(getSearchSpecEmbedding(1))");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3 - 0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9 - 2.1, kEps));
+
+ // Create a hybrid query that matches document 0 because of term-based search
+ // and document 1 because of embedding-based search.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({}) = 0
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query(
+ "foo OR semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ // Document 0 has no matched embedding hit, so its score is 0.
+ EXPECT_THAT(results.results(0).score(), DoubleNear(0, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-2.1, kEps));
+}
+
INSTANTIATE_TEST_SUITE_P(
IcingSearchEngineSearchTest, IcingSearchEngineSearchTest,
testing::Values(
diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.cc b/icing/index/embed/doc-hit-info-iterator-embedding.cc
new file mode 100644
index 0000000..8f2bc7d
--- /dev/null
+++ b/icing/index/embed/doc-hit-info-iterator-embedding.cc
@@ -0,0 +1,164 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/doc-hit-info-iterator-embedding.h"
+
+#include <memory>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/embed/embedding-scorer.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/iterator/section-restrict-data.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>>
+DocHitInfoIteratorEmbedding::Create(
+ const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ double score_low, double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index) {
+ ICING_RETURN_ERROR_IF_NULL(query);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
+ ICING_RETURN_ERROR_IF_NULL(score_map);
+
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or = embedding_index->GetAccessorForVector(*query);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ // A not-found error should be fine, since that means there is no matching
+ // embedding hits in the index.
+ pl_accessor = nullptr;
+ } else {
+ // Otherwise, return the error as is.
+ return pl_accessor_or.status();
+ }
+
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer,
+ EmbeddingScorer::Create(metric_type));
+
+ return std::unique_ptr<DocHitInfoIteratorEmbedding>(
+ new DocHitInfoIteratorEmbedding(query, std::move(section_restrict_data),
+ metric_type, std::move(embedding_scorer),
+ score_low, score_high, score_map,
+ embedding_index, std::move(pl_accessor)));
+}
+
+libtextclassifier3::StatusOr<const EmbeddingHit*>
+DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() {
+ if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) {
+ ICING_ASSIGN_OR_RETURN(cached_embedding_hits_,
+ posting_list_accessor_->GetNextHitsBatch());
+ cached_embedding_hits_idx_ = 0;
+ if (cached_embedding_hits_.empty()) {
+ no_more_hit_ = true;
+ return nullptr;
+ }
+ }
+ const EmbeddingHit& embedding_hit =
+ cached_embedding_hits_[cached_embedding_hits_idx_];
+ if (doc_hit_info_.document_id() == kInvalidDocumentId) {
+ doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id());
+ if (section_restrict_data_ != nullptr) {
+ current_allowed_sections_mask_ =
+ section_restrict_data_->ComputeAllowedSectionsMask(
+ doc_hit_info_.document_id());
+ }
+ } else if (doc_hit_info_.document_id() !=
+ embedding_hit.basic_hit().document_id()) {
+ return nullptr;
+ }
+ ++cached_embedding_hits_idx_;
+ return &embedding_hit;
+}
+
+libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() {
+ if (no_more_hit_ || posting_list_accessor_ == nullptr) {
+ return absl_ports::ResourceExhaustedError(
+ "No more DocHitInfos in iterator");
+ }
+
+ doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone);
+ std::vector<double>* matched_scores = nullptr;
+ current_allowed_sections_mask_ = kSectionIdMaskAll;
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit,
+ AdvanceToNextEmbeddingHit());
+ if (embedding_hit == nullptr) {
+ // No more hits for the current document.
+ break;
+ }
+
+ // Filter out the embedding hit according to the section restriction.
+ if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) &
+ current_allowed_sections_mask_) == 0) {
+ continue;
+ }
+
+ // Calculate the semantic score.
+ int dimension = query_.values_size();
+ ICING_ASSIGN_OR_RETURN(
+ const float* vector,
+ embedding_index_.GetEmbeddingVector(*embedding_hit, dimension));
+ double semantic_score =
+ embedding_scorer_->Score(dimension,
+ /*v1=*/query_.values().data(),
+ /*v2=*/vector);
+
+ // If the semantic score is within the desired score range, update
+ // doc_hit_info_ and score_map_.
+ if (score_low_ <= semantic_score && semantic_score <= score_high_) {
+ doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id());
+ if (matched_scores == nullptr) {
+ matched_scores = &(score_map_[doc_hit_info_.document_id()]);
+ }
+ matched_scores->push_back(semantic_score);
+ }
+ }
+
+ if (doc_hit_info_.document_id() == kInvalidDocumentId) {
+ return absl_ports::ResourceExhaustedError(
+ "No more DocHitInfos in iterator");
+ }
+
+ ++num_advance_calls_;
+
+ // Skip the current document if it has no vector matched.
+ if (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone) {
+ return Advance();
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.h b/icing/index/embed/doc-hit-info-iterator-embedding.h
new file mode 100644
index 0000000..21180db
--- /dev/null
+++ b/icing/index/embed/doc-hit-info-iterator-embedding.h
@@ -0,0 +1,161 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/embed/embedding-scorer.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+
+namespace icing {
+namespace lib {
+
+class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator {
+ public:
+ // Create a DocHitInfoIterator for iterating through all docs which have an
+ // embedding matched with the provided query with a score in the range of
+ // [score_low, score_high], using the provided metric_type.
+ //
+ // The iterator will store the matched embedding scores in score_map to
+ // prepare for scoring.
+ //
+ // The iterator will handle the section restriction logic internally by the
+ // provided section_restrict_data.
+ //
+ // Returns:
+ // - a DocHitInfoIteratorEmbedding instance on success.
+ // - Any error from posting lists.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<DocHitInfoIteratorEmbedding>>
+ Create(const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ double score_low, double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index);
+
+ libtextclassifier3::Status Advance() override;
+
+ // The iterator will internally handle the section restriction logic by itself
+ // to have better control, so that it is able to filter out embedding hits
+ // from unwanted sections to avoid retrieving unnecessary vectors and
+ // calculate scores for them.
+ bool full_section_restriction_applied() const override { return true; }
+
+ libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
+ return absl_ports::InvalidArgumentError(
+ "Query suggestions for the semanticSearch function are not supported");
+ }
+
+ CallStats GetCallStats() const override {
+ return CallStats(
+ /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
+ /*num_leaf_advance_calls_main_index_in=*/0,
+ /*num_leaf_advance_calls_integer_index_in=*/0,
+ /*num_leaf_advance_calls_no_index_in=*/0,
+ /*num_blocks_inspected_in=*/0);
+ }
+
+ std::string ToString() const override { return "embedding_iterator"; }
+
+ // PopulateMatchedTermsStats is not applicable to embedding search.
+ void PopulateMatchedTermsStats(
+ std::vector<TermMatchInfo>* matched_terms_stats,
+ SectionIdMask filtering_section_mask) const override {}
+
+ private:
+ explicit DocHitInfoIteratorEmbedding(
+ const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
+ double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index,
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)
+ : query_(*query),
+ section_restrict_data_(std::move(section_restrict_data)),
+ metric_type_(metric_type),
+ embedding_scorer_(std::move(embedding_scorer)),
+ score_low_(score_low),
+ score_high_(score_high),
+ score_map_(*score_map),
+ embedding_index_(*embedding_index),
+ posting_list_accessor_(std::move(posting_list_accessor)),
+ cached_embedding_hits_idx_(0),
+ current_allowed_sections_mask_(kSectionIdMaskAll),
+ no_more_hit_(false),
+ num_advance_calls_(0) {}
+
+ // Advance to the next embedding hit of the current document. If the current
+ // document id is kInvalidDocumentId, the method will advance to the first
+ // embedding hit of the next document and update doc_hit_info_.
+ //
+ // This method also properly updates cached_embedding_hits_,
+ // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
+ // no_more_hit_ to reflect the current state.
+ //
+ // Returns:
+ // - a const pointer to the next embedding hit on success.
+ // - nullptr, if there is no more hit for the current document, or no more
+ // hit in general if the current document id is kInvalidDocumentId.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
+
+ // Query information
+ const PropertyProto::VectorProto& query_; // Does not own
+ std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable.
+
+ // Scoring arguments
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ std::unique_ptr<EmbeddingScorer> embedding_scorer_;
+ double score_low_;
+ double score_high_;
+
+ // Score map
+ EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own
+
+ // Access to embeddings index data
+ const EmbeddingIndex& embedding_index_;
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
+
+ // Cached data from the embeddings index
+ std::vector<EmbeddingHit> cached_embedding_hits_;
+ int cached_embedding_hits_idx_;
+ SectionIdMask current_allowed_sections_mask_;
+ bool no_more_hit_;
+
+ int num_advance_calls_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
diff --git a/icing/index/embed/embedding-hit.h b/icing/index/embed/embedding-hit.h
new file mode 100644
index 0000000..165a123
--- /dev/null
+++ b/icing/index/embed/embedding-hit.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_HIT_H_
+#define ICING_INDEX_EMBED_EMBEDDING_HIT_H_
+
+#include <cstdint>
+
+#include "icing/index/hit/hit.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingHit {
+ public:
+ // Order: basic_hit, location
+ // Value bits layout: 32 basic_hit + 32 location
+ using Value = uint64_t;
+
+ // WARNING: Changing this value will invalidate any pre-existing posting lists
+ // on user devices.
+ //
+ // kInvalidValue contains:
+ // - BasicHit of value 0, which is invalid.
+ // - location of 0 (valid), which is ok because BasicHit is already invalid.
+ static_assert(BasicHit::kInvalidValue == 0);
+ static constexpr Value kInvalidValue = 0;
+
+ explicit EmbeddingHit(BasicHit basic_hit, uint32_t location) {
+ value_ = (static_cast<uint64_t>(basic_hit.value()) << 32) | location;
+ }
+
+ explicit EmbeddingHit(Value value) : value_(value) {}
+
+ // BasicHit contains the document id and the section id.
+ BasicHit basic_hit() const { return BasicHit(value_ >> 32); }
+ // The location of the referred embedding vector in the vector storage.
+ uint32_t location() const { return value_ & 0xFFFFFFFF; };
+
+ bool is_valid() const { return basic_hit().is_valid(); }
+ Value value() const { return value_; }
+
+ bool operator<(const EmbeddingHit& h2) const { return value_ < h2.value_; }
+ bool operator==(const EmbeddingHit& h2) const { return value_ == h2.value_; }
+
+ private:
+ Value value_;
+};
+
+static_assert(sizeof(BasicHit) == 4);
+static_assert(sizeof(EmbeddingHit) == 8);
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_HIT_H_
diff --git a/icing/index/embed/embedding-hit_test.cc b/icing/index/embed/embedding-hit_test.cc
new file mode 100644
index 0000000..31a0b6c
--- /dev/null
+++ b/icing/index/embed/embedding-hit_test.cc
@@ -0,0 +1,80 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-hit.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/index/hit/hit.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::IsFalse;
+
+static constexpr DocumentId kSomeDocumentId = 24;
+static constexpr SectionId kSomeSectionid = 5;
+static constexpr uint32_t kSomeLocation = 123;
+
+TEST(EmbeddingHitTest, Accessors) {
+ BasicHit basic_hit(kSomeSectionid, kSomeDocumentId);
+ EmbeddingHit embedding_hit(basic_hit, kSomeLocation);
+ EXPECT_THAT(embedding_hit.basic_hit(), Eq(basic_hit));
+ EXPECT_THAT(embedding_hit.location(), Eq(kSomeLocation));
+}
+
+TEST(EmbeddingHitTest, Invalid) {
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EXPECT_THAT(invalid_hit.is_valid(), IsFalse());
+
+ // Also make sure the invalid EmbeddingHit contains an invalid document id.
+ EXPECT_THAT(invalid_hit.basic_hit().document_id(), Eq(kInvalidDocumentId));
+ EXPECT_THAT(invalid_hit.basic_hit().section_id(), Eq(kMinSectionId));
+ EXPECT_THAT(invalid_hit.location(), Eq(0));
+}
+
+TEST(EmbeddingHitTest, Comparison) {
+ // Create basic hits with basic_hit1 < basic_hit2 < basic_hit3.
+ BasicHit basic_hit1(/*section_id=*/1, /*document_id=*/2409);
+ BasicHit basic_hit2(/*section_id=*/1, /*document_id=*/243);
+ BasicHit basic_hit3(/*section_id=*/15, /*document_id=*/243);
+
+ // Embedding hits are sorted by BasicHit first, and then by location.
+ // So embedding_hit3 < embedding_hit4 < embedding_hit2 < embedding_hit1.
+ EmbeddingHit embedding_hit1(basic_hit3, /*location=*/10);
+ EmbeddingHit embedding_hit2(basic_hit3, /*location=*/0);
+ EmbeddingHit embedding_hit3(basic_hit1, /*location=*/100);
+ EmbeddingHit embedding_hit4(basic_hit2, /*location=*/0);
+
+ std::vector<EmbeddingHit> embedding_hits{embedding_hit1, embedding_hit2,
+ embedding_hit3, embedding_hit4};
+ std::sort(embedding_hits.begin(), embedding_hits.end());
+ EXPECT_THAT(embedding_hits, ElementsAre(embedding_hit3, embedding_hit4,
+ embedding_hit2, embedding_hit1));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-index.cc b/icing/index/embed/embedding-index.cc
new file mode 100644
index 0000000..2e70f0e
--- /dev/null
+++ b/icing/index/embed/embedding-index.cc
@@ -0,0 +1,440 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-index.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/file/destructible-directory.h"
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/store/document-id.h"
+#include "icing/store/dynamic-trie-key-mapper.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+#include "icing/util/encode-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+constexpr uint32_t kEmbeddingHitListMapperMaxSize =
+ 128 * 1024 * 1024; // 128 MiB;
+
+// The maximum length returned by encode_util::EncodeIntToCString is 5 for
+// uint32_t.
+constexpr uint32_t kEncodedDimensionLength = 5;
+
+std::string GetMetadataFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/metadata");
+}
+
+std::string GetFlashIndexStorageFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/flash_index_storage");
+}
+
+std::string GetEmbeddingHitListMapperPath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper");
+}
+
+std::string GetEmbeddingVectorsFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_vectors");
+}
+
+// An injective function that maps the ordered pair (dimension, model_signature)
+// to a string, which is used to form a key for embedding_posting_list_mapper_.
+std::string GetPostingListKey(uint32_t dimension,
+ std::string_view model_signature) {
+ std::string encoded_dimension_str =
+ encode_util::EncodeIntToCString(dimension);
+ // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes.
+ while (encoded_dimension_str.size() < kEncodedDimensionLength) {
+ // C string cannot contain 0 bytes, so we append it using 1, just like what
+ // we do in encode_util::EncodeIntToCString.
+ //
+ // The reason that this works is because DecodeIntToString decodes a byte
+ // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded
+ // dimension that is less than 5 bytes, it means that the dimension contains
+ // unencoded leading 0x00. So here we're explicitly encoding those bytes as
+ // 0x01.
+ encoded_dimension_str.push_back(1);
+ }
+ return absl_ports::StrCat(encoded_dimension_str, model_signature);
+}
+
+std::string GetPostingListKey(const PropertyProto::VectorProto& vector) {
+ return GetPostingListKey(vector.values_size(), vector.model_signature());
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>>
+EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) {
+ ICING_RETURN_ERROR_IF_NULL(filesystem);
+
+ std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>(
+ new EmbeddingIndex(*filesystem, std::move(working_path)));
+ ICING_RETURN_IF_ERROR(index->Initialize());
+ return index;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Initialize() {
+ bool is_new = false;
+ if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) {
+ // Create working directory.
+ if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) {
+ return absl_ports::InternalError(
+ absl_ports::StrCat("Failed to create directory: ", working_path_));
+ }
+ is_new = true;
+ }
+
+ ICING_ASSIGN_OR_RETURN(
+ MemoryMappedFile metadata_mmapped_file,
+ MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_),
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
+ /*max_file_size=*/kMetadataFileSize,
+ /*pre_mapping_file_offset=*/0,
+ /*pre_mapping_mmap_size=*/kMetadataFileSize));
+ metadata_mmapped_file_ =
+ std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file));
+
+ ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(
+ GetFlashIndexStorageFilePath(working_path_),
+ &filesystem_, posting_list_hit_serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_posting_list_mapper_,
+ DynamicTrieKeyMapper<PostingListIdentifier>::Create(
+ filesystem_, GetEmbeddingHitListMapperPath(working_path_),
+ kEmbeddingHitListMapperMaxSize));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_vectors_,
+ FileBackedVector<float>::Create(
+ filesystem_, GetEmbeddingVectorsFilePath(working_path_),
+ MemoryMappedFile::READ_WRITE_AUTO_SYNC));
+
+ if (is_new) {
+ ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary(
+ /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize));
+ info().magic = Info::kMagic;
+ info().last_added_document_id = kInvalidDocumentId;
+ ICING_RETURN_IF_ERROR(InitializeNewStorage());
+ } else {
+ if (metadata_mmapped_file_->available_size() != kMetadataFileSize) {
+ return absl_ports::FailedPreconditionError(
+ "Incorrect metadata file size");
+ }
+ if (info().magic != Info::kMagic) {
+ return absl_ports::FailedPreconditionError("Incorrect magic value");
+ }
+ ICING_RETURN_IF_ERROR(InitializeExistingStorage());
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Clear() {
+ pending_embedding_hits_.clear();
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+ if (filesystem_.DirectoryExists(working_path_.c_str())) {
+ ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_));
+ }
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+EmbeddingIndex::GetAccessor(uint32_t dimension,
+ std::string_view model_signature) const {
+ if (dimension == 0) {
+ return absl_ports::InvalidArgumentError("Dimension is 0");
+ }
+ std::string key = GetPostingListKey(dimension, model_signature);
+ ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id,
+ embedding_posting_list_mapper_->Get(key));
+ return PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id);
+}
+
+libtextclassifier3::Status EmbeddingIndex::BufferEmbedding(
+ const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) {
+ if (vector.values_size() == 0) {
+ return absl_ports::InvalidArgumentError("Vector dimension is 0");
+ }
+
+ uint32_t location = embedding_vectors_->num_elements();
+ uint32_t dimension = vector.values_size();
+ std::string key = GetPostingListKey(vector);
+
+ // Buffer the embedding hit.
+ pending_embedding_hits_.push_back(
+ {std::move(key), EmbeddingHit(basic_hit, location)});
+
+ // Put vector
+ ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr,
+ embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension);
+
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() {
+ std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end());
+ auto iter_curr_key = pending_embedding_hits_.rbegin();
+ while (iter_curr_key != pending_embedding_hits_.rend()) {
+ // In order to batch putting embedding hits with the same key (dimension,
+ // model_signature) to the same posting list, we find the range
+ // [iter_curr_key, iter_next_key) of embedding hits with the same key and
+ // put them into their corresponding posting list together.
+ auto iter_next_key = iter_curr_key;
+ while (iter_next_key != pending_embedding_hits_.rend() &&
+ iter_next_key->first == iter_curr_key->first) {
+ iter_next_key++;
+ }
+
+ const std::string& key = iter_curr_key->first;
+ libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or =
+ embedding_posting_list_mapper_->Get(key);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (posting_list_id_or.ok()) {
+ // Existing posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id_or.ValueOrDie()));
+ } else if (absl_ports::IsNotFound(posting_list_id_or.status())) {
+ // New posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get()));
+ } else {
+ // Errors
+ return std::move(posting_list_id_or).status();
+ }
+
+ // Adding the embedding hits.
+ for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) {
+ ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second));
+ }
+
+ // Finalize this posting list and add the posting list id in
+ // embedding_posting_list_mapper_.
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id));
+
+ // Advance to the next key.
+ iter_curr_key = iter_next_key;
+ }
+ pending_embedding_hits_.clear();
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::TransferIndex(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ EmbeddingIndex* new_index) const {
+ std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr =
+ embedding_posting_list_mapper_->GetIterator();
+ while (itr->Advance()) {
+ std::string_view key = itr->GetKey();
+ // This should never happen unless there is an inconsistency, or the index
+ // is corrupted.
+ if (key.size() < kEncodedDimensionLength) {
+ return absl_ports::InternalError(
+ "Got invalid key from embedding posting list mapper.");
+ }
+ uint32_t dimension = encode_util::DecodeIntFromCString(
+ std::string_view(key.begin(), kEncodedDimensionLength));
+
+ // Transfer hits
+ std::vector<EmbeddingHit> new_hits;
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ /*existing_posting_list_id=*/itr->GetValue()));
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ old_pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ break;
+ }
+ for (EmbeddingHit& old_hit : batch) {
+ // Safety checks to add robustness to the codebase, so to make sure
+ // that we never access invalid memory, in case that hit from the
+ // posting list is corrupted.
+ ICING_ASSIGN_OR_RETURN(const float* old_vector,
+ GetEmbeddingVector(old_hit, dimension));
+ if (old_hit.basic_hit().document_id() < 0 ||
+ old_hit.basic_hit().document_id() >=
+ document_id_old_to_new.size()) {
+ return absl_ports::InternalError(
+ "Embedding hit document id is out of bound. The provided map is "
+ "too small, or the index may have been corrupted.");
+ }
+
+ // Construct transferred hit
+ DocumentId new_document_id =
+ document_id_old_to_new[old_hit.basic_hit().document_id()];
+ if (new_document_id == kInvalidDocumentId) {
+ continue;
+ }
+ uint32_t new_location = new_index->embedding_vectors_->num_elements();
+ new_hits.push_back(EmbeddingHit(
+ BasicHit(old_hit.basic_hit().section_id(), new_document_id),
+ new_location));
+
+ // Copy the embedding vector of the hit to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ FileBackedVector<float>::MutableArrayView mutable_arr,
+ new_index->embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, old_vector, dimension);
+ }
+ }
+ // No hit needs to be added to the new index.
+ if (new_hits.empty()) {
+ return libtextclassifier3::Status::OK;
+ }
+ // Add transferred hits to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum,
+ PostingListEmbeddingHitAccessor::Create(
+ new_index->flash_index_storage_.get(),
+ new_index->posting_list_hit_serializer_.get()));
+ for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend();
+ ++new_hit_itr) {
+ ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr));
+ }
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*hit_accum).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(
+ new_index->embedding_posting_list_mapper_->Put(key, result.id));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Optimize(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ DocumentId new_last_added_document_id) {
+ // This is just for completeness, but this should never be necessary, since we
+ // should never have pending hits at the time when Optimize is run.
+ ICING_RETURN_IF_ERROR(CommitBufferToIndex());
+
+ std::string temporary_index_working_path = working_path_ + "_temp";
+ if (!filesystem_.DeleteDirectoryRecursively(
+ temporary_index_working_path.c_str())) {
+ ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path;
+ return absl_ports::InternalError(
+ "Unable to delete temp directory to prepare to build new index.");
+ }
+
+ DestructibleDirectory temporary_index_dir(
+ &filesystem_, std::move(temporary_index_working_path));
+ if (!temporary_index_dir.is_valid()) {
+ return absl_ports::InternalError(
+ "Unable to create temp directory to build new index.");
+ }
+
+ {
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<EmbeddingIndex> new_index,
+ EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir()));
+ ICING_RETURN_IF_ERROR(
+ TransferIndex(document_id_old_to_new, new_index.get()));
+ new_index->set_last_added_document_id(new_last_added_document_id);
+ ICING_RETURN_IF_ERROR(new_index->PersistToDisk());
+ }
+
+ // Destruct current storage instances to safely swap directories.
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+
+ if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(),
+ working_path_.c_str())) {
+ return absl_ports::InternalError(
+ "Unable to apply new index due to failed swap!");
+ }
+
+ // Reinitialize the index.
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) {
+ return metadata_mmapped_file_->PersistToDisk();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) {
+ if (!flash_index_storage_->PersistToDisk()) {
+ return absl_ports::InternalError("Fail to persist flash index to disk");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk());
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum(
+ bool force) {
+ return info().ComputeChecksum();
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum(
+ bool force) {
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc,
+ embedding_posting_list_mapper_->ComputeChecksum());
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc,
+ embedding_vectors_->ComputeChecksum());
+ return Crc32(embedding_posting_list_mapper_crc.Get() ^
+ embedding_vectors_crc.Get());
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-index.h b/icing/index/embed/embedding-index.h
new file mode 100644
index 0000000..7318871
--- /dev/null
+++ b/icing/index/embed/embedding-index.h
@@ -0,0 +1,274 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
+#define ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/persistent-storage.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/index/hit/hit.h"
+#include "icing/store/document-id.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingIndex : public PersistentStorage {
+ public:
+ struct Info {
+ static constexpr int32_t kMagic = 0xfbe13cbb;
+
+ int32_t magic;
+ DocumentId last_added_document_id;
+
+ Crc32 ComputeChecksum() const {
+ return Crc32(
+ std::string_view(reinterpret_cast<const char*>(this), sizeof(Info)));
+ }
+ } __attribute__((packed));
+ static_assert(sizeof(Info) == 8, "");
+
+ // Metadata file layout: <Crcs><Info>
+ static constexpr int32_t kCrcsMetadataBufferOffset = 0;
+ static constexpr int32_t kInfoMetadataBufferOffset =
+ static_cast<int32_t>(sizeof(Crcs));
+ static constexpr int32_t kMetadataFileSize = sizeof(Crcs) + sizeof(Info);
+ static_assert(kMetadataFileSize == 20, "");
+
+ static constexpr WorkingPathType kWorkingPathType =
+ WorkingPathType::kDirectory;
+
+ EmbeddingIndex(const EmbeddingIndex&) = delete;
+ EmbeddingIndex& operator=(const EmbeddingIndex&) = delete;
+
+ // Creates a new EmbeddingIndex instance to index embeddings.
+ //
+ // Returns:
+ // - FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored
+ // checksum.
+ // - INTERNAL_ERROR on I/O errors.
+ // - Any error from MemoryMappedFile, FlashIndexStorage,
+ // DynamicTrieKeyMapper, or FileBackedVector.
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> Create(
+ const Filesystem* filesystem, std::string working_path);
+
+ static libtextclassifier3::Status Discard(const Filesystem& filesystem,
+ const std::string& working_path) {
+ return PersistentStorage::Discard(filesystem, working_path,
+ kWorkingPathType);
+ }
+
+ libtextclassifier3::Status Clear();
+
+ // Buffer an embedding pending to be added to the index. This is required
+ // since EmbeddingHits added in posting lists must be decreasing, which means
+ // that section ids and location indexes for a single document must be added
+ // decreasingly.
+ //
+ // Returns:
+ // - OK on success
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status BufferEmbedding(
+ const BasicHit& basic_hit, const PropertyProto::VectorProto& vector);
+
+ // Commit the embedding hits in the buffer to the index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ // - Any error from posting lists
+ libtextclassifier3::Status CommitBufferToIndex();
+
+ // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
+ // with the provided dimension and signature.
+ //
+ // Returns:
+ // - a PostingListEmbeddingHitAccessor instance on success.
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - NOT_FOUND error if there is no matching embedding hit.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ GetAccessor(uint32_t dimension, std::string_view model_signature) const;
+
+ // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
+ // with the provided vector's dimension and signature.
+ //
+ // Returns:
+ // - a PostingListEmbeddingHitAccessor instance on success.
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - NOT_FOUND error if there is no matching embedding hit.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ GetAccessorForVector(const PropertyProto::VectorProto& vector) const {
+ return GetAccessor(vector.values_size(), vector.model_signature());
+ }
+
+ // Reduces internal file sizes by reclaiming space of deleted documents.
+ // new_last_added_document_id will be used to update the last added document
+ // id in the lite index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on IO error, this indicates that the index may be in an
+ // invalid state and should be cleared.
+ libtextclassifier3::Status Optimize(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ DocumentId new_last_added_document_id);
+
+ libtextclassifier3::StatusOr<const float*> GetEmbeddingVector(
+ const EmbeddingHit& hit, uint32_t dimension) const {
+ if (static_cast<int64_t>(hit.location()) + dimension >
+ GetTotalVectorSize()) {
+ return absl_ports::InternalError(
+ "Got an embedding hit that refers to a vector out of range.");
+ }
+ return embedding_vectors_->array() + hit.location();
+ }
+
+ const float* GetRawEmbeddingData() const {
+ return embedding_vectors_->array();
+ }
+
+ int32_t GetTotalVectorSize() const {
+ return embedding_vectors_->num_elements();
+ }
+
+ DocumentId last_added_document_id() const {
+ return info().last_added_document_id;
+ }
+
+ void set_last_added_document_id(DocumentId document_id) {
+ Info& info_ref = info();
+ if (info_ref.last_added_document_id == kInvalidDocumentId ||
+ document_id > info_ref.last_added_document_id) {
+ info_ref.last_added_document_id = document_id;
+ }
+ }
+
+ private:
+ explicit EmbeddingIndex(const Filesystem& filesystem,
+ std::string working_path)
+ : PersistentStorage(filesystem, std::move(working_path),
+ kWorkingPathType) {}
+
+ libtextclassifier3::Status Initialize();
+
+ // Transfers embedding data and hits from the current index to new_index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error. This could potentially leave the storages
+ // in an invalid state and the caller should handle it properly (e.g.
+ // discard and rebuild)
+ libtextclassifier3::Status TransferIndex(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ EmbeddingIndex* new_index) const;
+
+ // Flushes contents of metadata file.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status PersistMetadataToDisk(bool force) override;
+
+ // Flushes contents of all storages to underlying files.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status PersistStoragesToDisk(bool force) override;
+
+ // Computes and returns Info checksum.
+ //
+ // Returns:
+ // - Crc of the Info on success
+ libtextclassifier3::StatusOr<Crc32> ComputeInfoChecksum(bool force) override;
+
+ // Computes and returns all storages checksum.
+ //
+ // Returns:
+ // - Crc of all storages on success
+ // - INTERNAL_ERROR if any data inconsistency
+ libtextclassifier3::StatusOr<Crc32> ComputeStoragesChecksum(
+ bool force) override;
+
+ Crcs& crcs() override {
+ return *reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() +
+ kCrcsMetadataBufferOffset);
+ }
+
+ const Crcs& crcs() const override {
+ return *reinterpret_cast<const Crcs*>(metadata_mmapped_file_->region() +
+ kCrcsMetadataBufferOffset);
+ }
+
+ Info& info() {
+ return *reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() +
+ kInfoMetadataBufferOffset);
+ }
+
+ const Info& info() const {
+ return *reinterpret_cast<const Info*>(metadata_mmapped_file_->region() +
+ kInfoMetadataBufferOffset);
+ }
+
+ // In memory data:
+ // Pending embedding hits with their embedding keys used for
+ // embedding_posting_list_mapper_.
+ std::vector<std::pair<std::string, EmbeddingHit>> pending_embedding_hits_;
+
+ // Metadata
+ std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_;
+
+ // Posting list storage
+ std::unique_ptr<PostingListEmbeddingHitSerializer>
+ posting_list_hit_serializer_ =
+ std::make_unique<PostingListEmbeddingHitSerializer>();
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+
+ // The mapper from embedding keys to the corresponding posting list identifier
+ // that stores all embedding hits with the same key.
+ //
+ // The key for an embedding hit is a one-to-one encoded string of the ordered
+ // pair (dimension, model_signature) corresponding to the embedding.
+ std::unique_ptr<KeyMapper<PostingListIdentifier>>
+ embedding_posting_list_mapper_;
+
+ // A single FileBackedVector that holds all embedding vectors.
+ std::unique_ptr<FileBackedVector<float>> embedding_vectors_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
diff --git a/icing/index/embed/embedding-index_test.cc b/icing/index/embed/embedding-index_test.cc
new file mode 100644
index 0000000..5980e82
--- /dev/null
+++ b/icing/index/embed/embedding-index_test.cc
@@ -0,0 +1,582 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-index.h"
+
+#include <unistd.h>
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/filesystem.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/proto/document.pb.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::Test;
+
+class EmbeddingIndexTest : public Test {
+ protected:
+ void SetUp() override {
+ embedding_index_dir_ = GetTestTempDir() + "/embedding_index_test";
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
+ }
+
+ void TearDown() override {
+ embedding_index_.reset();
+ filesystem_.DeleteDirectoryRecursively(embedding_index_dir_.c_str());
+ }
+
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ uint32_t dimension, std::string_view model_signature) {
+ std::vector<EmbeddingHit> hits;
+
+ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or =
+ embedding_index_->GetAccessor(dimension, model_signature);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ return hits;
+ } else {
+ return std::move(pl_accessor_or).status();
+ }
+
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ return hits;
+ }
+ hits.insert(hits.end(), batch.begin(), batch.end());
+ }
+ }
+
+ std::vector<float> GetRawEmbeddingData() {
+ return std::vector<float>(embedding_index_->GetRawEmbeddingData(),
+ embedding_index_->GetRawEmbeddingData() +
+ embedding_index_->GetTotalVectorSize());
+ }
+
+ Filesystem filesystem_;
+ std::string embedding_index_dir_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
+};
+
+TEST_F(EmbeddingIndexTest, AddSingleEmbedding) {
+ PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, AddMultipleEmbeddingsInTheSameSection) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, HitsWithLowerSectionIdReturnedFirst) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/5, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/2, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/5, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, HitsWithHigherDocumentIdReturnedFirst) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+}
+
+TEST_F(EmbeddingIndexTest, AddEmbeddingsFromDifferentModels) {
+ PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model2", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/5, /*model_signature=*/"non-existent-model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest,
+ AddEmbeddingsWithSameSignatureButDifferentDimension) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, ClearIndex) {
+ // Loop the same logic twice to make sure that clear works as expected, and
+ // the index is still valid after clearing.
+ for (int i = 0; i < 2; i++) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/2, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/1),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Check that clear works as expected.
+ ICING_ASSERT_OK(embedding_index_->Clear());
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), kInvalidDocumentId);
+ }
+}
+
+TEST_F(EmbeddingIndexTest, EmptyCommitIsOk) {
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexTest, MultipleCommits) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+}
+
+TEST_F(EmbeddingIndexTest,
+ InvalidCommit_SectionIdCanOnlyDecreaseForSingleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector2));
+ // Posting list with delta encoding can only allow decreasing values.
+ EXPECT_THAT(embedding_index_->CommitBufferToIndex(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(EmbeddingIndexTest, InvalidCommit_DocumentIdCanOnlyIncrease) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ // Posting list with delta encoding can only allow decreasing values, which
+ // means document ids must be committed increasingly, since document ids are
+ // inverted in hit values.
+ EXPECT_THAT(embedding_index_->CommitBufferToIndex(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(EmbeddingIndexTest, EmptyOptimizeIsOk) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{},
+ /*new_last_added_document_id=*/kInvalidDocumentId));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeSingleEmbeddingSingleDocument) {
+ PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(2);
+
+ // Before optimize
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize without deleting any documents, and check that the index is
+ // not changed.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1, 2},
+ /*new_last_added_document_id=*/2));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize to map document id 2 to 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize to delete the document.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsSingleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(2);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize without deleting any documents, and check that the index is
+ // not changed.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1, 2},
+ /*new_last_added_document_id=*/2));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize to map document id 2 to 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize to delete the document.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsMultipleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 = CreateVector("model", {1, 2, 3});
+ PropertyProto::VectorProto vector3 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector3));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/6),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, 1, 2, 3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize without deleting any documents. It is expected to see that the
+ // raw embedding data is rearranged, since during index transfer, embedding
+ // vectors from higher document ids are added first.
+ //
+ // Also keep in mind that once the raw data is rearranged, calling another
+ // Optimize subsequently will not change the raw data again.
+ for (int i = 0; i < 2; i++) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(-0.1, -0.2, -0.3, 0.1, 0.2, 0.3, 1, 2, 3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+ }
+
+ // Run optimize to delete document 0, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{kInvalidDocumentId, 0},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(-0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeEmbeddingsFromDifferentModels) {
+ PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 = CreateVector("model1", {1, 2});
+ PropertyProto::VectorProto vector3 =
+ CreateVector("model2", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/1), vector3));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/2),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1),
+ /*location=*/4))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 1, 2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize without deleting any documents. It is expected to see that the
+ // raw embedding data is rearranged, since during index transfer:
+ // - Embedding vectors with lower keys, which are the string encoded ordered
+ // pairs (dimension, model_signature), are iterated first.
+ // - Embedding vectors from higher document ids are added first.
+ //
+ // Also keep in mind that once the raw data is rearranged, calling another
+ // Optimize subsequently will not change the raw data again.
+ for (int i = 0; i < 2; i++) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1),
+ /*location=*/4))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(1, 2, 0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+ }
+
+ // Run optimize to delete document 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+} // namespace
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-query-results.h b/icing/index/embed/embedding-query-results.h
new file mode 100644
index 0000000..de85489
--- /dev/null
+++ b/icing/index/embed/embedding-query-results.h
@@ -0,0 +1,72 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+#define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "icing/proto/search.pb.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+// A class to store results generated from embedding queries.
+struct EmbeddingQueryResults {
+ // Maps from DocumentId to the list of matched embedding scores for that
+ // document, which will be used in the advanced scoring language to
+ // determine the results for the "this.matchedSemanticScores(...)" function.
+ using EmbeddingQueryScoreMap =
+ std::unordered_map<DocumentId, std::vector<double>>;
+
+ // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap.
+ std::unordered_map<
+ int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code,
+ EmbeddingQueryScoreMap>>
+ result_scores;
+
+ // Returns the matched scores for the given query_vector_index, metric_type,
+ // and doc_id. Returns nullptr if (query_vector_index, metric_type) does not
+ // exist in the result_scores map.
+ const std::vector<double>* GetMatchedScoresForDocument(
+ int query_vector_index,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ DocumentId doc_id) const {
+ // Check if a mapping exists for the query_vector_index
+ auto outer_it = result_scores.find(query_vector_index);
+ if (outer_it == result_scores.end()) {
+ return nullptr;
+ }
+ // Check if a mapping exists for the metric_type
+ auto inner_it = outer_it->second.find(metric_type);
+ if (inner_it == outer_it->second.end()) {
+ return nullptr;
+ }
+ const EmbeddingQueryScoreMap& score_map = inner_it->second;
+
+ // Check if the doc_id exists in the score_map
+ auto scores_it = score_map.find(doc_id);
+ if (scores_it == score_map.end()) {
+ return nullptr;
+ }
+ return &scores_it->second;
+ }
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
diff --git a/icing/index/embed/embedding-scorer.cc b/icing/index/embed/embedding-scorer.cc
new file mode 100644
index 0000000..0d84e01
--- /dev/null
+++ b/icing/index/embed/embedding-scorer.cc
@@ -0,0 +1,95 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-scorer.h"
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+float CalculateDotProduct(int dimension, const float* v1, const float* v2) {
+ float dot_product = 0.0;
+ for (int i = 0; i < dimension; ++i) {
+ dot_product += v1[i] * v2[i];
+ }
+ return dot_product;
+}
+
+float CalculateNorm2(int dimension, const float* v) {
+ return std::sqrt(CalculateDotProduct(dimension, v, v));
+}
+
+float CalculateCosine(int dimension, const float* v1, const float* v2) {
+ float divisor = CalculateNorm2(dimension, v1) * CalculateNorm2(dimension, v2);
+ if (divisor == 0.0) {
+ return 0.0;
+ }
+ return CalculateDotProduct(dimension, v1, v2) / divisor;
+}
+
+float CalculateEuclideanDistance(int dimension, const float* v1,
+ const float* v2) {
+ float result = 0.0;
+ for (int i = 0; i < dimension; ++i) {
+ float diff = v1[i] - v2[i];
+ result += diff * diff;
+ }
+ return std::sqrt(result);
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>>
+EmbeddingScorer::Create(
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) {
+ switch (metric_type) {
+ case SearchSpecProto::EmbeddingQueryMetricType::COSINE:
+ return std::make_unique<CosineEmbeddingScorer>();
+ case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT:
+ return std::make_unique<DotProductEmbeddingScorer>();
+ case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN:
+ return std::make_unique<EuclideanDistanceEmbeddingScorer>();
+ default:
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type)));
+ }
+}
+
+float CosineEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateCosine(dimension, v1, v2);
+}
+
+float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateDotProduct(dimension, v1, v2);
+}
+
+float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateEuclideanDistance(dimension, v1, v2);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-scorer.h b/icing/index/embed/embedding-scorer.h
new file mode 100644
index 0000000..8caf0bc
--- /dev/null
+++ b/icing/index/embed/embedding-scorer.h
@@ -0,0 +1,54 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
+#define ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingScorer {
+ public:
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> Create(
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type);
+ virtual float Score(int dimension, const float* v1,
+ const float* v2) const = 0;
+
+ virtual ~EmbeddingScorer() = default;
+};
+
+class CosineEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+class DotProductEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+class EuclideanDistanceEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.cc b/icing/index/embed/posting-list-embedding-hit-accessor.cc
new file mode 100644
index 0000000..e154165
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor.cc
@@ -0,0 +1,132 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+PostingListEmbeddingHitAccessor::Create(
+ FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer) {
+ uint32_t max_posting_list_bytes = storage->max_posting_list_bytes();
+ ICING_ASSIGN_OR_RETURN(PostingListUsed in_memory_posting_list,
+ PostingListUsed::CreateFromUnitializedRegion(
+ serializer, max_posting_list_bytes));
+ return std::unique_ptr<PostingListEmbeddingHitAccessor>(
+ new PostingListEmbeddingHitAccessor(storage, serializer,
+ std::move(in_memory_posting_list)));
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+PostingListEmbeddingHitAccessor::CreateFromExisting(
+ FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer,
+ PostingListIdentifier existing_posting_list_id) {
+ // Our in_memory_posting_list_ will start as empty.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ Create(storage, serializer));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage->GetPostingList(existing_posting_list_id));
+ pl_accessor->preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ return pl_accessor;
+}
+
+// Returns the next batch of hits for the provided posting list.
+libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
+PostingListEmbeddingHitAccessor::GetNextHitsBatch() {
+ if (preexisting_posting_list_ == nullptr) {
+ if (has_reached_posting_list_chain_end_) {
+ return std::vector<EmbeddingHit>();
+ }
+ return absl_ports::FailedPreconditionError(
+ "Cannot retrieve hits from a PostingListEmbeddingHitAccessor that was "
+ "not created from a preexisting posting list.");
+ }
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<EmbeddingHit> batch,
+ serializer_->GetHits(&preexisting_posting_list_->posting_list));
+ uint32_t next_block_index = kInvalidBlockIndex;
+ // Posting lists will only be chained when they are max-sized, in which case
+ // next_block_index will point to the next block for the next posting list.
+ // Otherwise, next_block_index can be kInvalidBlockIndex or be used to point
+ // to the next free list block, which is not relevant here.
+ if (preexisting_posting_list_->posting_list.size_in_bytes() ==
+ storage_->max_posting_list_bytes()) {
+ next_block_index = preexisting_posting_list_->next_block_index;
+ }
+
+ if (next_block_index != kInvalidBlockIndex) {
+ // Since we only have to deal with next block for max-sized posting list
+ // block, max_num_posting_lists is 1 and posting_list_index_bits is
+ // BitsToStore(1).
+ PostingListIdentifier next_posting_list_id(
+ next_block_index, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/BitsToStore(1));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage_->GetPostingList(next_posting_list_id));
+ preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ } else {
+ has_reached_posting_list_chain_end_ = true;
+ preexisting_posting_list_.reset();
+ }
+ return batch;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitAccessor::PrependHit(
+ const EmbeddingHit &hit) {
+ PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr)
+ ? preexisting_posting_list_->posting_list
+ : in_memory_posting_list_;
+ libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit);
+ if (!absl_ports::IsResourceExhausted(status)) {
+ return status;
+ }
+ // There is no more room to add hits to this current posting list! Therefore,
+ // we need to either move those hits to a larger posting list or flush this
+ // posting list and create another max-sized posting list in the chain.
+ if (preexisting_posting_list_ != nullptr) {
+ ICING_RETURN_IF_ERROR(FlushPreexistingPostingList());
+ } else {
+ ICING_RETURN_IF_ERROR(FlushInMemoryPostingList());
+ }
+
+ // Re-add hit. Should always fit since we just cleared
+ // in_memory_posting_list_. It's fine to explicitly reference
+ // in_memory_posting_list_ here because there's no way of reaching this line
+ // while preexisting_posting_list_ is still in use.
+ return serializer_->PrependHit(&in_memory_posting_list_, hit);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.h b/icing/index/embed/posting-list-embedding-hit-accessor.h
new file mode 100644
index 0000000..4acb9a3
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor.h
@@ -0,0 +1,106 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
+#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+namespace icing {
+namespace lib {
+
+// This class is used to provide a simple abstraction for adding hits to posting
+// lists. PostingListEmbeddingHitAccessor handles 1) selection of properly-sized
+// posting lists for the accumulated hits during Finalize() and 2) chaining of
+// max-sized posting lists.
+class PostingListEmbeddingHitAccessor : public PostingListAccessor {
+ public:
+ // Creates an empty PostingListEmbeddingHitAccessor.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of
+ // PostingListEmbeddingHitAccessor
+ // - INVALID_ARGUMENT error if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ Create(FlashIndexStorage* storage,
+ PostingListEmbeddingHitSerializer* serializer);
+
+ // Create a PostingListEmbeddingHitAccessor with an existing posting list
+ // identified by existing_posting_list_id.
+ //
+ // The PostingListEmbeddingHitAccessor will add hits to this posting list
+ // until it is necessary either to 1) chain the posting list (if it is
+ // max-sized) or 2) move its hits to a larger posting list.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of
+ // PostingListEmbeddingHitAccessor
+ // - INVALID_ARGUMENT if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ CreateFromExisting(FlashIndexStorage* storage,
+ PostingListEmbeddingHitSerializer* serializer,
+ PostingListIdentifier existing_posting_list_id);
+
+ PostingListSerializer* GetSerializer() override { return serializer_; }
+
+ // Retrieve the next batch of hits for the posting list chain
+ //
+ // RETURNS:
+ // - On success, a vector of hits in the posting list chain
+ // - INTERNAL if called on an instance of PostingListEmbeddingHitAccessor
+ // that was created via PostingListEmbeddingHitAccessor::Create, if unable
+ // to read the next posting list in the chain or if the posting list has
+ // been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetNextHitsBatch();
+
+ // Prepend one hit. This may result in flushing the posting list to disk (if
+ // the PostingListEmbeddingHitAccessor holds a max-sized posting list that is
+ // full) or freeing a pre-existing posting list if it is too small to fit all
+ // hits necessary.
+ //
+ // RETURNS:
+ // - OK, on success
+ // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
+ // previously added hit.
+ // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new
+ // posting list.
+ libtextclassifier3::Status PrependHit(const EmbeddingHit& hit);
+
+ private:
+ explicit PostingListEmbeddingHitAccessor(
+ FlashIndexStorage* storage, PostingListEmbeddingHitSerializer* serializer,
+ PostingListUsed in_memory_posting_list)
+ : PostingListAccessor(storage, std::move(in_memory_posting_list)),
+ serializer_(serializer) {}
+
+ PostingListEmbeddingHitSerializer* serializer_; // Does not own.
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor_test.cc b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc
new file mode 100644
index 0000000..b9ebe87
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc
@@ -0,0 +1,387 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/index/hit/hit.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/hit-test-utils.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::Lt;
+using ::testing::SizeIs;
+
+class PostingListEmbeddingHitAccessorTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ test_dir_ = GetTestTempDir() + "/test_dir";
+ file_name_ = test_dir_ + "/test_file.idx.index";
+
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str()));
+
+ serializer_ = std::make_unique<PostingListEmbeddingHitSerializer>();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+ }
+
+ void TearDown() override {
+ flash_index_storage_.reset();
+ serializer_.reset();
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ }
+
+ Filesystem filesystem_;
+ std::string test_dir_;
+ std::string file_name_;
+ std::unique_ptr<PostingListEmbeddingHitSerializer> serializer_;
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+};
+
+TEST_F(PostingListEmbeddingHitAccessorTest, HitsAddAndRetrieveProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit& hit : hits1) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
+ }
+ PostingListAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result.status);
+ EXPECT_THAT(result.id.block_index(), Eq(1));
+ EXPECT_THAT(result.id.posting_list_index(), Eq(0));
+
+ // Retrieve some hits.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(pl_holder.next_block_index, Eq(kInvalidBlockIndex));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLKeepOnSameBlock) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add a single hit. This will fit in a min-sized posting list.
+ EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/1);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ // Should have been allocated to the first block.
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more hit. The minimum size for a posting list must be able to fit
+ // at least two hits, so this should NOT cause the previous pl to be
+ // reallocated.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ EmbeddingHit hit2(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit2));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ // Should have been allocated to the same posting list as the first hit.
+ EXPECT_THAT(result2.id, Eq(result1.id));
+
+ // The posting list at result2.id should hold all of the hits that have been
+ // added.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result2.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAre(hit2, hit1)));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLReallocateToLargerPL) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/11, /*desired_byte_length=*/1);
+
+ // Add 8 hits with a small posting list of 24 bytes. The first 7 hits will
+ // be compressed to one byte each and will be able to fit in the 8 byte
+ // padded region. The last hit will fit in one of the special hits. The
+ // posting list will be ALMOST_FULL and can fit at most 2 more hits.
+ for (int i = 0; i < 8; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ // Should have been allocated to the first block.
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Now let's add some more hits!
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ // The current posting list can fit at most 2 more hits.
+ for (int i = 8; i < 10; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ // The 2 hits should still fit on the first block
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more hit
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result2.id));
+ // The current posting list should be FULL. Adding more hits should result in
+ // these hits being moved to a larger posting list.
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[10]));
+ PostingListAccessor::FinalizeResult result3 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result3.status);
+ // Should have been allocated to the second (new) block because the posting
+ // list should have grown beyond the size that the first block maintains.
+ EXPECT_THAT(result3.id.block_index(), Eq(2));
+ EXPECT_THAT(result3.id.posting_list_index(), Eq(0));
+
+ // The posting list at result3.id should hold all of the hits that have been
+ // added.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result3.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, MultiBlockChainsBlocksProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5000, /*desired_byte_length=*/1);
+ for (const EmbeddingHit& hit : hits1) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ PostingListIdentifier second_block_id = result1.id;
+ // Should have been allocated to the second block, which holds a max-sized
+ // posting list.
+ EXPECT_THAT(second_block_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // Now let's retrieve them!
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_block_id));
+ // This pl_holder will only hold a posting list with the hits that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits,
+ serializer_->GetHits(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_hits, SizeIs(Lt(hits1.size())));
+ auto first_block_hits_start = hits1.rbegin() + second_block_hits.size();
+ EXPECT_THAT(second_block_hits,
+ ElementsAreArray(hits1.rbegin(), first_block_hits_start));
+
+ // Now retrieve all of the hits that were on the first block.
+ uint32_t first_block_id = pl_holder.next_block_index;
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(
+ serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest,
+ PreexistingMultiBlockReusesBlocksProperly) {
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/5050, /*desired_byte_length=*/1);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ for (int i = 0; i < 5000; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ PostingListIdentifier first_add_id = result1.id;
+ EXPECT_THAT(first_add_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // Now add a couple more hits. These should fit on the existing, not full
+ // second block.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), first_add_id));
+ for (int i = 5000; i < hits.size(); ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ PostingListIdentifier second_add_id = result2.id;
+ EXPECT_THAT(second_add_id, Eq(first_add_id));
+
+ // We should be able to retrieve all 5050 hits.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_add_id));
+ // This pl_holder will only hold a posting list with the hits that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits,
+ serializer_->GetHits(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_hits, SizeIs(Lt(hits.size())));
+ auto first_block_hits_start = hits.rbegin() + second_block_hits.size();
+ EXPECT_THAT(second_block_hits,
+ ElementsAreArray(hits.rbegin(), first_block_hits_start));
+
+ // Now retrieve all of the hits that were on the first block.
+ uint32_t first_block_id = pl_holder.next_block_index;
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(
+ serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, InvalidHitReturnsInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EXPECT_THAT(pl_accessor->PrependHit(invalid_hit),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest,
+ HitsNotDecreasingReturnsInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/5);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+
+ EmbeddingHit hit2(BasicHit(/*section_id=*/6, /*document_id=*/1),
+ /*location=*/5);
+ EXPECT_THAT(pl_accessor->PrependHit(hit2),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EmbeddingHit hit3(BasicHit(/*section_id=*/2, /*document_id=*/0),
+ /*location=*/5);
+ EXPECT_THAT(pl_accessor->PrependHit(hit3),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EmbeddingHit hit4(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/6);
+ EXPECT_THAT(pl_accessor->PrependHit(hit4),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, NewPostingListNoHitsAdded) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ EXPECT_THAT(result1.status,
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPostingListNoHitsAdded) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/5);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor2,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor2).Finalize();
+ ICING_ASSERT_OK(result2.status);
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.cc b/icing/index/embed/posting-list-embedding-hit-serializer.cc
new file mode 100644
index 0000000..9247bcb
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer.cc
@@ -0,0 +1,647 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+#include <cinttypes>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+uint32_t PostingListEmbeddingHitSerializer::GetBytesUsed(
+ const PostingListUsed* posting_list_used) const {
+ // The special hits will be included if they represent actual hits. If they
+ // represent the hit offset or the invalid hit sentinel, they are not
+ // included.
+ return posting_list_used->size_in_bytes() -
+ GetStartByteOffset(posting_list_used);
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) {
+ // If in either the FULL state or ALMOST_FULL state, this posting list *is*
+ // the minimum size posting list that can fit these hits. So just return the
+ // size of the posting list.
+ return posting_list_used->size_in_bytes();
+ }
+
+ // - In NOT_FULL status, BytesUsed contains no special hits. For a posting
+ // list in the NOT_FULL state with n hits, we would have n-1 compressed hits
+ // and 1 uncompressed hit.
+ // - The minimum sized posting list that would be guaranteed to fit these hits
+ // would be FULL, but calculating the size required for the FULL posting
+ // list would require deserializing the last two added hits, so instead we
+ // will calculate the size of an ALMOST_FULL posting list to fit.
+ // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the
+ // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in
+ // the compressed region.
+ // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed
+ // hits.
+ // - Therefore, fitting these hits into a posting list would require
+ // BytesUsed + one extra full hit.
+ return GetBytesUsed(posting_list_used) + sizeof(EmbeddingHit);
+}
+
+void PostingListEmbeddingHitSerializer::Clear(
+ PostingListUsed* posting_list_used) const {
+ // Safe to ignore return value because posting_list_used->size_in_bytes() is
+ // a valid argument.
+ SetStartByteOffset(posting_list_used,
+ /*offset=*/posting_list_used->size_in_bytes());
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::MoveFrom(
+ PostingListUsed* dst, PostingListUsed* src) const {
+ ICING_RETURN_ERROR_IF_NULL(dst);
+ ICING_RETURN_ERROR_IF_NULL(src);
+ if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "src MinPostingListSizeToFit %d must be larger than size %d.",
+ GetMinPostingListSizeToFit(src), dst->size_in_bytes()));
+ }
+
+ if (!IsPostingListValid(dst)) {
+ return absl_ports::FailedPreconditionError(
+ "Dst posting list is in an invalid state and can't be used!");
+ }
+ if (!IsPostingListValid(src)) {
+ return absl_ports::InvalidArgumentError(
+ "Cannot MoveFrom an invalid src posting list!");
+ }
+
+ // Pop just enough hits that all of src's compressed hits fit in
+ // dst posting_list's compressed area. Then we can memcpy that area.
+ std::vector<EmbeddingHit> hits;
+ while (IsFull(src) || IsAlmostFull(src) ||
+ (dst->size_in_bytes() - kSpecialHitsSize < GetBytesUsed(src))) {
+ if (!GetHitsInternal(src, /*limit=*/1, /*pop=*/true, &hits).ok()) {
+ return absl_ports::AbortedError(
+ "Unable to retrieve hits from src posting list.");
+ }
+ }
+
+ // memcpy the area and set up start byte offset.
+ Clear(dst);
+ memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src),
+ src->posting_list_buffer() + GetStartByteOffset(src),
+ GetBytesUsed(src));
+ // Because we popped all hits from src outside of the compressed area and we
+ // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() -
+ // kSpecialHitSize. This is guaranteed to be a valid byte offset for the
+ // NOT_FULL state, so ignoring the value is safe.
+ SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src));
+
+ // Put back remaining hits.
+ for (size_t i = 0; i < hits.size(); i++) {
+ const EmbeddingHit& hit = hits[hits.size() - i - 1];
+ // PrependHit can return either INVALID_ARGUMENT - if hit is invalid or not
+ // less than the previous hit - or RESOURCE_EXHAUSTED. RESOURCE_EXHAUSTED
+ // should be impossible because we've already assured that there is enough
+ // room above.
+ ICING_RETURN_IF_ERROR(PrependHit(dst, hit));
+ }
+
+ Clear(src);
+ return libtextclassifier3::Status::OK;
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetPadEnd(
+ const PostingListUsed* posting_list_used, uint32_t offset) const {
+ EmbeddingHit::Value pad;
+ uint32_t pad_end = offset;
+ while (pad_end < posting_list_used->size_in_bytes()) {
+ size_t pad_len = VarInt::Decode(
+ posting_list_used->posting_list_buffer() + pad_end, &pad);
+ if (pad != 0) {
+ // No longer a pad.
+ break;
+ }
+ pad_end += pad_len;
+ }
+ return pad_end;
+}
+
+bool PostingListEmbeddingHitSerializer::PadToEnd(
+ PostingListUsed* posting_list_used, uint32_t start, uint32_t end) const {
+ if (end > posting_list_used->size_in_bytes()) {
+ ICING_LOG(ERROR) << "Cannot pad a region that ends after size!";
+ return false;
+ }
+ // In VarInt a value of 0 encodes to 0.
+ memset(posting_list_used->posting_list_buffer() + start, 0, end - start);
+ return true;
+}
+
+libtextclassifier3::Status
+PostingListEmbeddingHitSerializer::PrependHitToAlmostFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ // Get delta between first hit and the new hit. Try to fit delta
+ // in the padded area and put new hit at the special position 1.
+ // Calling ValueOrDie is safe here because 1 < kNumSpecialData.
+ EmbeddingHit cur = GetSpecialHit(posting_list_used, /*index=*/1);
+ if (cur.value() <= hit.value()) {
+ return absl_ports::InvalidArgumentError(
+ "Hit being prepended must be strictly less than the most recent Hit");
+ }
+ uint64_t delta = cur.value() - hit.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+
+ uint32_t pad_end = GetPadEnd(posting_list_used,
+ /*offset=*/kSpecialHitsSize);
+
+ if (pad_end >= kSpecialHitsSize + delta_len) {
+ // Pad area has enough space for delta of existing hit (cur). Write delta at
+ // pad_end - delta_len.
+ uint8_t* delta_offset =
+ posting_list_used->posting_list_buffer() + pad_end - delta_len;
+ memcpy(delta_offset, delta_buf, delta_len);
+
+ // Now first hit is the new hit, at special position 1. Safe to ignore the
+ // return value because 1 < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+ // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
+ // argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // No space for delta. We put the new hit at special position 0
+ // and go to the full state. Safe to ignore the return value because 1 <
+ // kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0, hit);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+void PostingListEmbeddingHitSerializer::PrependHitToEmpty(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ // First hit to be added. Just add verbatim, no compression.
+ if (posting_list_used->size_in_bytes() == kSpecialHitsSize) {
+ // Safe to ignore the return value because 1 < kNumSpecialData
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+ // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
+ // argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // Since this is the first hit, size != kSpecialHitsSize and
+ // size % sizeof(EmbeddingHit) == 0, we know that there is room to fit 'hit'
+ // into the compressed region, so ValueOrDie is safe.
+ uint32_t offset =
+ PrependHitUncompressed(posting_list_used, hit,
+ /*offset=*/posting_list_used->size_in_bytes())
+ .ValueOrDie();
+ // Safe to ignore the return value because PrependHitUncompressed is
+ // guaranteed to return a valid offset.
+ SetStartByteOffset(posting_list_used, offset);
+ }
+}
+
+libtextclassifier3::Status
+PostingListEmbeddingHitSerializer::PrependHitToNotFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const {
+ // First hit in compressed area. It is uncompressed. See if delta
+ // between the first hit and new hit will still fit in the
+ // compressed area.
+ if (offset + sizeof(EmbeddingHit::Value) >
+ posting_list_used->size_in_bytes()) {
+ // The first hit in the compressed region *should* be uncompressed, but
+ // somehow there isn't enough room between offset and the end of the
+ // compressed area to fit an uncompressed hit. This should NEVER happen.
+ return absl_ports::FailedPreconditionError(
+ "Posting list is in an invalid state.");
+ }
+ EmbeddingHit::Value cur_value;
+ memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset,
+ sizeof(EmbeddingHit::Value));
+ if (cur_value <= hit.value()) {
+ return absl_ports::InvalidArgumentError(
+ IcingStringUtil::StringPrintf("EmbeddingHit %" PRId64
+ " being prepended must be "
+ "strictly less than the most recent "
+ "EmbeddingHit %" PRId64,
+ hit.value(), cur_value));
+ }
+ uint64_t delta = cur_value - hit.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+
+ // offset now points to one past the end of the first hit.
+ offset += sizeof(EmbeddingHit::Value);
+ if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) + delta_len <= offset) {
+ // Enough space for delta in compressed area.
+
+ // Prepend delta.
+ offset -= delta_len;
+ memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
+ delta_len);
+
+ // Prepend new hit. We know that there is room for 'hit' because of the if
+ // statement above, so calling ValueOrDie is safe.
+ offset =
+ PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie();
+ // offset is guaranteed to be valid here. So it's safe to ignore the return
+ // value. The if above will guarantee that offset >= kSpecialHitSize and <
+ // posting_list_used->size_in_bytes() because the if ensures that there is
+ // enough room between offset and kSpecialHitSize to fit the delta of the
+ // previous hit and the uncompressed hit.
+ SetStartByteOffset(posting_list_used, offset);
+ } else if (kSpecialHitsSize + delta_len <= offset) {
+ // Only have space for delta. The new hit must be put in special
+ // position 1.
+
+ // Prepend delta.
+ offset -= delta_len;
+ memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
+ delta_len);
+
+ // Prepend pad. Safe to ignore the return value of PadToEnd because offset
+ // must be less than posting_list_used->size_in_bytes(). Otherwise, this
+ // function already would have returned FAILED_PRECONDITION.
+ PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+
+ // Put new hit in special position 1. Safe to ignore return value because 1
+ // < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+
+ // State almost_full. Safe to ignore the return value because
+ // sizeof(EmbeddingHit) is a valid argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // Very rare case where delta is larger than sizeof(EmbeddingHit::Value)
+ // (i.e. varint delta encoding expanded required storage). We
+ // move first hit to special position 1 and put new hit in
+ // special position 0.
+ EmbeddingHit cur(cur_value);
+ // Safe to ignore the return value of PadToEnd because offset must be less
+ // than posting_list_used->size_in_bytes(). Otherwise, this function
+ // already would have returned FAILED_PRECONDITION.
+ PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+ // Safe to ignore the return value here because 0 and 1 < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, cur);
+ SetSpecialHit(posting_list_used, /*index=*/0, hit);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::PrependHit(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ static_assert(
+ sizeof(EmbeddingHit::Value) <= sizeof(uint64_t),
+ "EmbeddingHit::Value cannot be larger than 8 bytes because the delta "
+ "must be able to fit in 8 bytes.");
+ if (!hit.is_valid()) {
+ return absl_ports::InvalidArgumentError("Cannot prepend an invalid hit!");
+ }
+ if (!IsPostingListValid(posting_list_used)) {
+ return absl_ports::FailedPreconditionError(
+ "This PostingListUsed is in an invalid state and can't add any hits!");
+ }
+
+ if (IsFull(posting_list_used)) {
+ // State full: no space left.
+ return absl_ports::ResourceExhaustedError("No more room for hits");
+ } else if (IsAlmostFull(posting_list_used)) {
+ return PrependHitToAlmostFull(posting_list_used, hit);
+ } else if (IsEmpty(posting_list_used)) {
+ PrependHitToEmpty(posting_list_used, hit);
+ return libtextclassifier3::Status::OK;
+ } else {
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ return PrependHitToNotFull(posting_list_used, hit, offset);
+ }
+}
+
+libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
+PostingListEmbeddingHitSerializer::GetHits(
+ const PostingListUsed* posting_list_used) const {
+ std::vector<EmbeddingHit> hits_out;
+ ICING_RETURN_IF_ERROR(GetHits(posting_list_used, &hits_out));
+ return hits_out;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHits(
+ const PostingListUsed* posting_list_used,
+ std::vector<EmbeddingHit>* hits_out) const {
+ return GetHitsInternal(posting_list_used,
+ /*limit=*/std::numeric_limits<uint32_t>::max(),
+ /*pop=*/false, hits_out);
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::PopFrontHits(
+ PostingListUsed* posting_list_used, uint32_t num_hits) const {
+ if (num_hits == 1 && IsFull(posting_list_used)) {
+ // The PL is in full status which means that we save 2 uncompressed hits in
+ // the 2 special postions. But full status may be reached by 2 different
+ // statuses.
+ // (1) In "almost full" status
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // When we prepend another hit, we can only put it at the special
+ // position 0. And we get a full PL
+ // +-----------------+----------------+-------+-----------------+
+ // |new 1st hit |original 1st hit|(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // (2) In "not full" status
+ // +-----------------+----------------+------+-------+------------------+
+ // |hits-start-offset|Hit::kInvalidVal|(pad) |1st hit|(compressed) hits |
+ // +-----------------+----------------+------+-------+------------------+
+ // When we prepend another hit, we can reach any of the 3 following
+ // scenarios:
+ // (2.1) not full
+ // if the space of pad and original 1st hit can accommodate the new 1st hit
+ // and the encoded delta value.
+ // +-----------------+----------------+------+-----------+-----------------+
+ // |hits-start-offset|Hit::kInvalidVal|(pad) |new 1st hit|(compressed) hits|
+ // +-----------------+----------------+------+-----------+-----------------+
+ // (2.2) almost full
+ // If the space of pad and original 1st hit cannot accommodate the new 1st
+ // hit and the encoded delta value but can accommodate the encoded delta
+ // value only. We can put the new 1st hit at special position 1.
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |new 1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // (2.3) full
+ // In very rare case, it cannot even accommodate only the encoded delta
+ // value. we can move the original 1st hit into special position 1 and the
+ // new 1st hit into special position 0. This may happen because we use
+ // VarInt encoding method which may make the encoded value longer (about
+ // 4/3 times of original)
+ // +-----------------+----------------+-------+-----------------+
+ // |new 1st hit |original 1st hit|(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // Suppose now the PL is full. But we don't know whether it arrived to
+ // this status from "not full" like (2.3) or from "almost full" like (1).
+ // We'll return to "almost full" status like (1) if we simply pop the new
+ // 1st hit but we want to make the prepending operation "reversible". So
+ // there should be some way to return to "not full" if possible. A simple
+ // way to do it is to pop 2 hits out of the PL to status "almost full" or
+ // "not full". And add the original 1st hit back. We can return to the
+ // correct original statuses of (2.1) or (1). This makes our prepending
+ // operation reversible.
+ std::vector<EmbeddingHit> out;
+
+ // Popping 2 hits should never fail because we've just ensured that the
+ // posting list is in the FULL state.
+ ICING_RETURN_IF_ERROR(
+ GetHitsInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out));
+
+ // PrependHit should never fail because out[1] is a valid hit less than
+ // previous hits in the posting list and because there's no way that the
+ // posting list could run out of room because it previously stored this hit
+ // AND another hit.
+ ICING_RETURN_IF_ERROR(PrependHit(posting_list_used, out[1]));
+ } else if (num_hits > 0) {
+ return GetHitsInternal(posting_list_used, /*limit=*/num_hits, /*pop=*/true,
+ nullptr);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHitsInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<EmbeddingHit>* out) const {
+ // Put current uncompressed val here.
+ EmbeddingHit::Value val = EmbeddingHit::kInvalidValue;
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ uint32_t count = 0;
+
+ // First traverse the first two special positions.
+ while (count < limit && offset < kSpecialHitsSize) {
+ // Calling ValueOrDie is safe here because offset / sizeof(EmbeddingHit) <
+ // kNumSpecialData because of the check above.
+ EmbeddingHit hit = GetSpecialHit(posting_list_used,
+ /*index=*/offset / sizeof(EmbeddingHit));
+ val = hit.value();
+ if (out != nullptr) {
+ out->push_back(hit);
+ }
+ offset += sizeof(EmbeddingHit);
+ count++;
+ }
+
+ // If special position 1 was set then we need to skip padding.
+ if (val != EmbeddingHit::kInvalidValue && offset == kSpecialHitsSize) {
+ offset = GetPadEnd(posting_list_used, offset);
+ }
+
+ while (count < limit && offset < posting_list_used->size_in_bytes()) {
+ if (val == EmbeddingHit::kInvalidValue) {
+ // First hit is in compressed area. Put that in val.
+ memcpy(&val, posting_list_used->posting_list_buffer() + offset,
+ sizeof(EmbeddingHit::Value));
+ offset += sizeof(EmbeddingHit::Value);
+ } else {
+ // Now we have delta encoded subsequent hits. Decode and push.
+ uint64_t delta;
+ offset += VarInt::Decode(
+ posting_list_used->posting_list_buffer() + offset, &delta);
+ val += delta;
+ }
+ EmbeddingHit hit(val);
+ if (out != nullptr) {
+ out->push_back(hit);
+ }
+ count++;
+ }
+
+ if (pop) {
+ PostingListUsed* mutable_posting_list_used =
+ const_cast<PostingListUsed*>(posting_list_used);
+ // Modify the posting list so that we pop all hits actually
+ // traversed.
+ if (offset >= kSpecialHitsSize &&
+ offset < posting_list_used->size_in_bytes()) {
+ // In the compressed area. Pop and reconstruct. offset/val is
+ // the last traversed hit, which we must discard. So move one
+ // more forward.
+ uint64_t delta;
+ offset += VarInt::Decode(
+ posting_list_used->posting_list_buffer() + offset, &delta);
+ val += delta;
+
+ // Now val is the first hit of the new posting list.
+ if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) <= offset) {
+ // val fits in compressed area. Simply copy.
+ offset -= sizeof(EmbeddingHit::Value);
+ memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val,
+ sizeof(EmbeddingHit::Value));
+ } else {
+ // val won't fit in compressed area.
+ EmbeddingHit hit(val);
+ // Okay to ignore the return value here because 1 < kNumSpecialData.
+ SetSpecialHit(mutable_posting_list_used, /*index=*/1, hit);
+
+ // Prepend pad. Safe to ignore the return value of PadToEnd because
+ // offset must be less than posting_list_used->size_in_bytes() thanks to
+ // the if above.
+ PadToEnd(mutable_posting_list_used,
+ /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+ offset = sizeof(EmbeddingHit);
+ }
+ }
+ // offset is guaranteed to be valid so ignoring the return value of
+ // set_start_byte_offset is safe. It falls into one of four scenarios:
+ // Scenario 1: the above if was false because offset is not <
+ // posting_list_used->size_in_bytes()
+ // In this case, offset must be == posting_list_used->size_in_bytes()
+ // because we reached offset by unwinding hits on the posting list.
+ // Scenario 2: offset is < kSpecialHitSize
+ // In this case, offset is guaranteed to be either 0 or
+ // sizeof(EmbeddingHit) because offset is incremented by
+ // sizeof(EmbeddingHit) within the first while loop.
+ // Scenario 3: offset is within the compressed region and the new first hit
+ // in the posting list (the value that 'val' holds) will fit as an
+ // uncompressed hit in the compressed region. The resulting offset from
+ // decompressing val must be >= kSpecialHitSize because otherwise we'd be
+ // in Scenario 4
+ // Scenario 4: offset is within the compressed region, but the new first hit
+ // in the posting list is too large to fit as an uncompressed hit in the
+ // in the compressed region. Therefore, it must be stored in a special hit
+ // and offset will be sizeof(EmbeddingHit).
+ SetStartByteOffset(mutable_posting_list_used, offset);
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+EmbeddingHit PostingListEmbeddingHitSerializer::GetSpecialHit(
+ const PostingListUsed* posting_list_used, uint32_t index) const {
+ static_assert(sizeof(EmbeddingHit::Value) >= sizeof(uint32_t), "HitTooSmall");
+ EmbeddingHit val(EmbeddingHit::kInvalidValue);
+ memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val),
+ sizeof(val));
+ return val;
+}
+
+void PostingListEmbeddingHitSerializer::SetSpecialHit(
+ PostingListUsed* posting_list_used, uint32_t index,
+ const EmbeddingHit& val) const {
+ memcpy(posting_list_used->posting_list_buffer() + index * sizeof(val), &val,
+ sizeof(val));
+}
+
+bool PostingListEmbeddingHitSerializer::IsPostingListValid(
+ const PostingListUsed* posting_list_used) const {
+ if (IsAlmostFull(posting_list_used)) {
+ // Special Hit 1 should hold a Hit. Calling ValueOrDie is safe because we
+ // know that 1 < kNumSpecialData.
+ if (!GetSpecialHit(posting_list_used, /*index=*/1).is_valid()) {
+ ICING_LOG(ERROR)
+ << "Both special hits cannot be invalid at the same time.";
+ return false;
+ }
+ } else if (!IsFull(posting_list_used)) {
+ // NOT_FULL. Special Hit 0 should hold a valid offset. Calling ValueOrDie is
+ // safe because we know that 0 < kNumSpecialData.
+ if (GetSpecialHit(posting_list_used, /*index=*/0).value() >
+ posting_list_used->size_in_bytes() ||
+ GetSpecialHit(posting_list_used, /*index=*/0).value() <
+ kSpecialHitsSize) {
+ ICING_LOG(ERROR) << "EmbeddingHit: "
+ << GetSpecialHit(posting_list_used, /*index=*/0).value()
+ << " size: " << posting_list_used->size_in_bytes()
+ << " sp size: " << kSpecialHitsSize;
+ return false;
+ }
+ }
+ return true;
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetStartByteOffset(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used)) {
+ return 0;
+ } else if (IsAlmostFull(posting_list_used)) {
+ return sizeof(EmbeddingHit);
+ } else {
+ // NOT_FULL, calling ValueOrDie is safe because we know that 0 <
+ // kNumSpecialData.
+ return GetSpecialHit(posting_list_used, /*index=*/0).value();
+ }
+}
+
+bool PostingListEmbeddingHitSerializer::SetStartByteOffset(
+ PostingListUsed* posting_list_used, uint32_t offset) const {
+ if (offset > posting_list_used->size_in_bytes()) {
+ ICING_LOG(ERROR) << "offset cannot be a value greater than size "
+ << posting_list_used->size_in_bytes() << ". offset is "
+ << offset << ".";
+ return false;
+ }
+ if (offset < kSpecialHitsSize && offset > sizeof(EmbeddingHit)) {
+ ICING_LOG(ERROR) << "offset cannot be a value between ("
+ << sizeof(EmbeddingHit) << ", " << kSpecialHitsSize
+ << "). offset is " << offset << ".";
+ return false;
+ }
+ if (offset < sizeof(EmbeddingHit) && offset != 0) {
+ ICING_LOG(ERROR) << "offset cannot be a value between (0, "
+ << sizeof(EmbeddingHit) << "). offset is " << offset
+ << ".";
+ return false;
+ }
+ if (offset >= kSpecialHitsSize) {
+ // not_full state. Safe to ignore the return value because 0 and 1 are both
+ // < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0, EmbeddingHit(offset));
+ SetSpecialHit(posting_list_used, /*index=*/1,
+ EmbeddingHit(EmbeddingHit::kInvalidValue));
+ } else if (offset == sizeof(EmbeddingHit)) {
+ // almost_full state. Safe to ignore the return value because 1 is both <
+ // kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0,
+ EmbeddingHit(EmbeddingHit::kInvalidValue));
+ }
+ // Nothing to do for the FULL state - the offset isn't actually stored
+ // anywhere and both special hits hold valid hits.
+ return true;
+}
+
+libtextclassifier3::StatusOr<uint32_t>
+PostingListEmbeddingHitSerializer::PrependHitUncompressed(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const {
+ if (offset < kSpecialHitsSize + sizeof(EmbeddingHit::Value)) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Not enough room to prepend EmbeddingHit::Value at offset %d.",
+ offset));
+ }
+ offset -= sizeof(EmbeddingHit::Value);
+ EmbeddingHit::Value val = hit.value();
+ memcpy(posting_list_used->posting_list_buffer() + offset, &val,
+ sizeof(EmbeddingHit::Value));
+ return offset;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.h b/icing/index/embed/posting-list-embedding-hit-serializer.h
new file mode 100644
index 0000000..76198c2
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer.h
@@ -0,0 +1,284 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
+#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
+
+#include <cstdint>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+// A serializer class to serialize hits to PostingListUsed. Layout described in
+// comments in posting-list-embedding-hit-serializer.cc.
+class PostingListEmbeddingHitSerializer : public PostingListSerializer {
+ public:
+ static constexpr uint32_t kSpecialHitsSize =
+ kNumSpecialData * sizeof(EmbeddingHit);
+
+ uint32_t GetDataTypeBytes() const override { return sizeof(EmbeddingHit); }
+
+ uint32_t GetMinPostingListSize() const override {
+ static constexpr uint32_t kMinPostingListSize = kSpecialHitsSize;
+ static_assert(sizeof(PostingListIndex) <= kMinPostingListSize,
+ "PostingListIndex must be small enough to fit in a "
+ "minimum-sized Posting List.");
+
+ return kMinPostingListSize;
+ }
+
+ uint32_t GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const override;
+
+ uint32_t GetBytesUsed(
+ const PostingListUsed* posting_list_used) const override;
+
+ void Clear(PostingListUsed* posting_list_used) const override;
+
+ libtextclassifier3::Status MoveFrom(PostingListUsed* dst,
+ PostingListUsed* src) const override;
+
+ // Prepend a hit to the posting list.
+ //
+ // RETURNS:
+ // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
+ // previously added hit.
+ // - RESOURCE_EXHAUSTED if there is no more room to add hit to the posting
+ // list.
+ libtextclassifier3::Status PrependHit(PostingListUsed* posting_list_used,
+ const EmbeddingHit& hit) const;
+
+ // Prepend hits to the posting list. Hits should be sorted in descending order
+ // (as defined by the less than operator for Hit)
+ //
+ // Returns the number of hits that could be prepended to the posting list. If
+ // keep_prepended is true, whatever could be prepended is kept, otherwise the
+ // posting list is left in its original state.
+ template <class T, EmbeddingHit (*GetHit)(const T&)>
+ libtextclassifier3::StatusOr<uint32_t> PrependHitArray(
+ PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
+ bool keep_prepended) const;
+
+ // Retrieves the hits stored in the posting list.
+ //
+ // RETURNS:
+ // - On success, a vector of hits sorted by the reverse order of prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ const PostingListUsed* posting_list_used) const;
+
+ // Same as GetHits but appends hits to hits_out.
+ //
+ // RETURNS:
+ // - On success, a vector of hits sorted by the reverse order of prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetHits(const PostingListUsed* posting_list_used,
+ std::vector<EmbeddingHit>* hits_out) const;
+
+ // Undo the last num_hits hits prepended. If num_hits > number of
+ // hits we clear all hits.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status PopFrontHits(PostingListUsed* posting_list_used,
+ uint32_t num_hits) const;
+
+ private:
+ // Posting list layout formats:
+ //
+ // not_full
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |hits-start-offset|Hit::kInvalidVal|xxxxxxx|(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // almost_full
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // full()
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |1st hit |2nd hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // The first two uncompressed hits also implicitly encode information about
+ // the size of the compressed hits region.
+ //
+ // 1. If the posting list is NOT_FULL, then
+ // posting_list_buffer_[0] contains the byte offset of the start of the
+ // compressed hits - and, thus, the size of the compressed hits region is
+ // size_in_bytes - posting_list_buffer_[0].
+ //
+ // 2. If posting list is ALMOST_FULL or FULL, then the compressed hits region
+ // starts somewhere between [kSpecialHitsSize, kSpecialHitsSize +
+ // sizeof(EmbeddingHit) - 1] and ends at size_in_bytes - 1.
+
+ // Helpers to determine what state the posting list is in.
+ bool IsFull(const PostingListUsed* posting_list_used) const {
+ return GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
+ GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ bool IsAlmostFull(const PostingListUsed* posting_list_used) const {
+ return !GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
+ GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ bool IsEmpty(const PostingListUsed* posting_list_used) const {
+ return GetSpecialHit(posting_list_used, /*index=*/0).value() ==
+ posting_list_used->size_in_bytes() &&
+ !GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ // Returns false if both special hits are invalid or if the offset value
+ // stored in the special hit is less than kSpecialHitsSize or greater than
+ // posting_list_used->size_in_bytes(). Returns true, otherwise.
+ bool IsPostingListValid(const PostingListUsed* posting_list_used) const;
+
+ // Prepend hit to a posting list that is in the ALMOST_FULL state.
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if hit is not less than the previously added hit.
+ libtextclassifier3::Status PrependHitToAlmostFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const;
+
+ // Prepend hit to a posting list that is in the EMPTY state. This will always
+ // succeed because there are no pre-existing hits and no validly constructed
+ // posting list could fail to fit one hit.
+ void PrependHitToEmpty(PostingListUsed* posting_list_used,
+ const EmbeddingHit& hit) const;
+
+ // Prepend hit to a posting list that is in the NOT_FULL state.
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if hit is not less than the previously added hit.
+ libtextclassifier3::Status PrependHitToNotFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const;
+
+ // Returns either 0 (full state), sizeof(EmbeddingHit) (almost_full state) or
+ // a byte offset between kSpecialHitsSize and
+ // posting_list_used->size_in_bytes() (inclusive) (not_full state).
+ uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const;
+
+ // Sets the special hits to properly reflect what offset is (see layout
+ // comment for further details).
+ //
+ // Returns false if offset > posting_list_used->size_in_bytes() or offset is
+ // (kSpecialHitsSize, sizeof(EmbeddingHit)) or offset is
+ // (sizeof(EmbeddingHit), 0). True, otherwise.
+ bool SetStartByteOffset(PostingListUsed* posting_list_used,
+ uint32_t offset) const;
+
+ // Manipulate padded areas. We never store the same hit value twice
+ // so a delta of 0 is a pad byte.
+
+ // Returns offset of first non-pad byte.
+ uint32_t GetPadEnd(const PostingListUsed* posting_list_used,
+ uint32_t offset) const;
+
+ // Fill padding between offset start and offset end with 0s.
+ // Returns false if end > posting_list_used->size_in_bytes(). True,
+ // otherwise.
+ bool PadToEnd(PostingListUsed* posting_list_used, uint32_t start,
+ uint32_t end) const;
+
+ // Helper for AppendHits/PopFrontHits. Adds limit number of hits to out or all
+ // hits in the posting list if the posting list contains less than limit
+ // number of hits. out can be NULL.
+ //
+ // NOTE: If called with limit=1, pop=true on a posting list that transitioned
+ // from NOT_FULL directly to FULL, GetHitsInternal will not return the posting
+ // list to NOT_FULL. Instead it will leave it in a valid state, but it will be
+ // ALMOST_FULL.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetHitsInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<EmbeddingHit>* out) const;
+
+ // Retrieves the value stored in the index-th special hit.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ //
+ // RETURNS:
+ // - A valid SpecialData<EmbeddingHit>.
+ EmbeddingHit GetSpecialHit(const PostingListUsed* posting_list_used,
+ uint32_t index) const;
+
+ // Sets the value stored in the index-th special hit to val.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ void SetSpecialHit(PostingListUsed* posting_list_used, uint32_t index,
+ const EmbeddingHit& val) const;
+
+ // Prepends hit to the memory region [offset - sizeof(EmbeddingHit), offset]
+ // and returns the new beginning of the padded region.
+ //
+ // RETURNS:
+ // - The new beginning of the padded region, if successful.
+ // - INVALID_ARGUMENT if hit will not fit (uncompressed) between offset and
+ // kSpecialHitsSize
+ libtextclassifier3::StatusOr<uint32_t> PrependHitUncompressed(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const;
+};
+
+// Inlined functions. Implementation details below. Avert eyes!
+template <class T, EmbeddingHit (*GetHit)(const T&)>
+libtextclassifier3::StatusOr<uint32_t>
+PostingListEmbeddingHitSerializer::PrependHitArray(
+ PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
+ bool keep_prepended) const {
+ if (!IsPostingListValid(posting_list_used)) {
+ return 0;
+ }
+
+ // Prepend hits working backwards from array[num_hits - 1].
+ uint32_t i;
+ for (i = 0; i < num_hits; ++i) {
+ if (!PrependHit(posting_list_used, GetHit(array[num_hits - i - 1])).ok()) {
+ break;
+ }
+ }
+ if (i != num_hits && !keep_prepended) {
+ // Didn't fit. Undo everything and check that we have the same offset as
+ // before. PopFrontHits guarantees that it will remove all 'i' hits so long
+ // as there are at least 'i' hits in the posting list, which we know there
+ // are.
+ ICING_RETURN_IF_ERROR(PopFrontHits(posting_list_used, /*num_hits=*/i));
+ }
+ return i;
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer_test.cc b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc
new file mode 100644
index 0000000..f829634
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc
@@ -0,0 +1,864 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <deque>
+#include <iterator>
+#include <limits>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/hit/hit.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/hit-test-utils.h"
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+using testing::Eq;
+using testing::IsEmpty;
+using testing::Le;
+using testing::Lt;
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+struct HitElt {
+ HitElt() = default;
+ explicit HitElt(const EmbeddingHit &hit_in) : hit(hit_in) {}
+
+ static EmbeddingHit get_hit(const HitElt &hit_elt) { return hit_elt.hit; }
+
+ EmbeddingHit hit;
+};
+
+TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedPrependHitNotFull) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ static const int kNumHits = 2551;
+ static const size_t kHitsSize = kNumHits * sizeof(EmbeddingHit);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize));
+
+ // Make used.
+ EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0));
+ int expected_size = sizeof(EmbeddingHit::Value);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
+
+ EmbeddingHit hit1(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/1);
+ uint64_t delta = hit0.value() - hit1.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+
+ EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/2);
+ delta = hit1.value() - hit2.value();
+ delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
+
+ EmbeddingHit hit3(BasicHit(/*section_id=*/0, /*document_id=*/3),
+ /*location=*/3);
+ delta = hit2.value() - hit3.value();
+ delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListUsedPrependHitAlmostFull) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 32
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ // Fill up the compressed region.
+ // Transitions:
+ // Adding hit0: EMPTY -> NOT_FULL
+ // Adding hit1: NOT_FULL -> NOT_FULL
+ // Adding hit2: NOT_FULL -> NOT_FULL
+ EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/1);
+ EmbeddingHit hit1 = CreateEmbeddingHit(hit0, /*desired_byte_length=*/3);
+ EmbeddingHit hit2 = CreateEmbeddingHit(hit1, /*desired_byte_length=*/3);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2));
+ // Size used will be 8 (hit2) + 3 (hit1-hit2) + 3 (hit0-hit1) = 14 bytes
+ int expected_size = sizeof(EmbeddingHit) + 3 + 3;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
+
+ // Add one more hit to transition NOT_FULL -> ALMOST_FULL
+ EmbeddingHit hit3 = CreateEmbeddingHit(hit2, /*desired_byte_length=*/3);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3));
+ // Storing them in the compressed region requires 8 (hit) + 3 (hit2-hit3) +
+ // 3 (hit1-hit2) + 3 (hit0-hit1) = 17 bytes, but there are only 16 bytes in
+ // the compressed region. So instead, the posting list will transition to
+ // ALMOST_FULL. The in-use compressed region will actually shrink from 14
+ // bytes to 9 bytes because the uncompressed version of hit2 will be
+ // overwritten with the compressed delta of hit2. hit3 will be written to one
+ // of the special hits. Because we're in ALMOST_FULL, the expected size is the
+ // size of the pl minus the one hit used to mark the posting list as
+ // ALMOST_FULL.
+ expected_size = pl_size - sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL
+ EmbeddingHit hit4 = CreateEmbeddingHit(hit3, /*desired_byte_length=*/6);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4));
+ // There are currently 9 bytes in use in the compressed region. Hit3 will
+ // have a 6-byte delta, which fits in the compressed region. Hit3 will be
+ // moved from the special hit to the compressed region (which will have 15
+ // bytes in use after adding hit3). Hit4 will be placed in one of the special
+ // hits and the posting list will remain in ALMOST_FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> FULL
+ EmbeddingHit hit5 = CreateEmbeddingHit(hit4, /*desired_byte_length=*/2);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5));
+ // There are currently 15 bytes in use in the compressed region. Hit4 will
+ // have a 2-byte delta which will not fit in the compressed region. So hit4
+ // will remain in one of the special hits and hit5 will occupy the other,
+ // making the posting list FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0)));
+
+ // The posting list is FULL. Adding another hit should fail.
+ EmbeddingHit hit6 = CreateEmbeddingHit(hit5, /*desired_byte_length=*/1);
+ EXPECT_THAT(serializer.PrependHit(&pl_used, hit6),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedMinSize) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Min size = 16
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ // PL State: EMPTY
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+
+ // Add a hit, PL should shift to ALMOST_FULL state
+ EmbeddingHit hit0(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/1);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ // Size = sizeof(uncompressed hit0)
+ int expected_size = sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
+
+ // Add the smallest hit possible with a delta of 0b1. PL should shift to FULL
+ // state.
+ EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0)
+ expected_size += sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+
+ // Try to add the smallest hit possible. Should fail
+ EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0);
+ EXPECT_THAT(serializer.PrependHit(&pl_used, hit2),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayMinSizePostingList) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Min Size = 16
+ int pl_size = serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ std::vector<HitElt> hits_in;
+ hits_in.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ std::reverse(hits_in.begin(), hits_in.end());
+
+ // Add five hits. The PL is in the empty state and an empty min size PL can
+ // only fit two hits. So PrependHitArray should fail.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_can_prepend,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_can_prepend, Eq(2));
+
+ int can_fit_hits = num_can_prepend;
+ // The PL has room for 2 hits. We should be able to add them without any
+ // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL
+ const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_can_prepend,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false)));
+ EXPECT_THAT(num_can_prepend, Eq(can_fit_hits));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ std::deque<EmbeddingHit> hits_pushed;
+ std::transform(hits_in.rbegin(),
+ hits_in.rend() - hits_in.size() + can_fit_hits,
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayPostingList) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 48
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ std::vector<HitElt> hits_in;
+ hits_in.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // The last hit is uncompressed and the four before it should only take one
+ // byte. Total use = 8 bytes.
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43-36 EmbeddingHit #4
+ // 35-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=36
+ // ----------------------
+ int byte_size = sizeof(EmbeddingHit::Value) + hits_in.size() - 1;
+
+ // Add five hits. The PL is in the empty state and should be able to fit all
+ // five hits without issue, transitioning the PL from EMPTY -> NOT_FULL.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ std::deque<EmbeddingHit> hits_pushed;
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ EmbeddingHit first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/1);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 EmbeddingHit #11
+ // 21-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=22
+ // ----------------------
+ byte_size += 14;
+
+ // Add these 7 hits. The PL is currently in the NOT_FULL state and should
+ // remain in the NOT_FULL state.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/8);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 delta(EmbeddingHit #11)
+ // 21-16 <unused>
+ // 15-8 EmbeddingHit #12
+ // 7-0 kSpecialHit
+ // ----------------------
+ byte_size = 40; // 48 - 8
+
+ // Add this 1 hit. The PL is currently in the NOT_FULL state and should
+ // transition to the ALMOST_FULL state - even though there is still some
+ // unused space.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/5);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 delta(EmbeddingHit #11)
+ // 21-17 delta(EmbeddingHit #12)
+ // 16 <unused>
+ // 15-8 EmbeddingHit #13
+ // 7-0 EmbeddingHit #14
+ // ----------------------
+
+ // Add these 2 hits.
+ // - The PL is currently in the ALMOST_FULL state. Adding the first hit should
+ // keep the PL in ALMOST_FULL because the delta between
+ // EmbeddingHit #12 and EmbeddingHit #13 (5 byte) can fit in the unused area
+ // (6 bytes).
+ // - Adding the second hit should transition to the FULL state because the
+ // delta between EmbeddingHit #13 and EmbeddingHit #14 (3 bytes) is larger
+ // than the remaining unused area (1 byte).
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayTooManyHits) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ static constexpr int kNumHits = 130;
+ static constexpr int kDeltaSize = 1;
+ static constexpr size_t kHitsSize =
+ ((kNumHits - 2) * kDeltaSize + (2 * sizeof(EmbeddingHit)));
+
+ // Create an array with one too many hits
+ std::vector<HitElt> hit_elts_in_too_many;
+ hit_elts_in_too_many.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0));
+ for (int i = 0; i < kNumHits; ++i) {
+ hit_elts_in_too_many.emplace_back(CreateEmbeddingHit(
+ hit_elts_in_too_many.back().hit, /*desired_byte_length=*/1));
+ }
+ // Reverse so that hits are inserted in descending order
+ std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ // PrependHitArray should fail because hit_elts_in_too_many is far too large
+ // for the minimum size pl.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(2));
+ ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size()));
+ ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize));
+ // PrependHitArray should fail because hit_elts_in_too_many is one hit too
+ // large for this pl.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1));
+ ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListStatusJumpFromNotFullToFullAndBack) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 24
+ const uint32_t pl_size = 3 * sizeof(EmbeddingHit);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ EmbeddingHit max_valued_hit(
+ BasicHit(/*section_id=*/kMaxSectionId, /*document_id=*/kMinDocumentId),
+ /*location=*/std::numeric_limits<uint32_t>::max());
+ ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit));
+ uint32_t bytes_used = serializer.GetBytesUsed(&pl);
+ ASSERT_THAT(bytes_used, sizeof(EmbeddingHit));
+ // Status not full.
+ ASSERT_THAT(
+ bytes_used,
+ Le(pl_size - PostingListEmbeddingHitSerializer::kSpecialHitsSize));
+
+ EmbeddingHit min_valued_hit(
+ BasicHit(/*section_id=*/kMinSectionId, /*document_id=*/kMaxDocumentId),
+ /*location=*/0);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl, min_valued_hit));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAre(min_valued_hit, max_valued_hit)));
+ // Status should jump to full directly.
+ ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size));
+ ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAre(max_valued_hit)));
+ // Status should return to not full as before.
+ ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(bytes_used));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, DeltaOverflow) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ const uint32_t pl_size = 4 * sizeof(EmbeddingHit);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ static const EmbeddingHit::Value kMaxHitValue =
+ std::numeric_limits<EmbeddingHit::Value>::max();
+ static const EmbeddingHit::Value kOverflow[4] = {
+ kMaxHitValue >> 2,
+ (kMaxHitValue >> 2) * 2,
+ (kMaxHitValue >> 2) * 3,
+ kMaxHitValue - 1,
+ };
+
+ // Fit at least 4 ordinary values.
+ std::deque<EmbeddingHit> hits_pushed;
+ for (EmbeddingHit::Value v = 0; v < 4; v++) {
+ hits_pushed.push_front(
+ EmbeddingHit(BasicHit(kMaxSectionId, kMaxDocumentId), 4 - v));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front()));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+ }
+
+ // Cannot fit 4 overflow values.
+ hits_pushed.clear();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ for (int i = 3; i >= 1; i--) {
+ hits_pushed.push_front(EmbeddingHit(/*value=*/kOverflow[i]));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front()));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+ }
+ EXPECT_THAT(serializer.PrependHit(&pl, EmbeddingHit(/*value=*/kOverflow[0])),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ GetMinPostingListToFitForNotFullPL) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 64
+ int pl_size = 4 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used NOT_FULL
+ std::vector<EmbeddingHit> hits_in =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 63-62 delta(EmbeddingHit #0)
+ // 61-60 delta(EmbeddingHit #1)
+ // 59-58 delta(EmbeddingHit #2)
+ // 57-56 delta(EmbeddingHit #3)
+ // 55-48 EmbeddingHit #5
+ // 47-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=48
+ // ----------------------
+ int bytes_used = 16;
+
+ // Check that all hits have been inserted
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used
+ // into a posting list with this min size should make it ALMOST_FULL, which we
+ // can see should have size = 24.
+ // ----------------------
+ // 23-22 delta(EmbeddingHit #0)
+ // 21-20 delta(EmbeddingHit #1)
+ // 19-18 delta(EmbeddingHit #2)
+ // 17-16 delta(EmbeddingHit #3)
+ // 15-8 EmbeddingHit #4
+ // 7-0 kSpecialHit
+ // ----------------------
+ int expected_min_size = 24;
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(expected_min_size));
+
+ // Also check that this min size to fit posting list actually does fit all the
+ // hits and can only hit one more hit in the ALMOST_FULL state.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, min_size_to_fit));
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ }
+
+ // Adding another hit to the min-size-to-fit posting list should succeed
+ EmbeddingHit hit =
+ CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ // Adding any other hits should fail with RESOURCE_EXHAUSTED error.
+ EXPECT_THAT(
+ serializer.PrependHit(&min_size_to_fit_pl,
+ CreateEmbeddingHit(hit, /*desired_byte_length=*/1)),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // Check that all hits have been inserted and the min-fit posting list is now
+ // FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl),
+ Eq(min_size_to_fit));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 24;
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used ALMOST_FULL
+ std::vector<EmbeddingHit> hits_in =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 23-22 delta(EmbeddingHit #0)
+ // 21-20 delta(EmbeddingHit #1)
+ // 19-18 delta(EmbeddingHit #2)
+ // 17-16 delta(EmbeddingHit #3)
+ // 15-8 EmbeddingHit #4
+ // 7-0 kSpecialHit
+ // ----------------------
+ int bytes_used = 16;
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should return the same size as pl_used.
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
+
+ // Add another hit to make the posting list FULL
+ EmbeddingHit hit =
+ CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should still be the same size as pl_used.
+ min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, MoveFrom) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ ICING_ASSERT_OK(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveFromNullArgumentReturnsInvalidArgument) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveFromInvalidPostingListReturnsInvalidArgument) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ // Write invalid hits to the beginning of pl_used1 to make it invalid.
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EmbeddingHit *first_hit =
+ reinterpret_cast<EmbeddingHit *>(pl_used1.posting_list_buffer());
+ *first_hit = invalid_hit;
+ ++first_hit;
+ *first_hit = invalid_hit;
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveToInvalidPostingListReturnsFailedPrecondition) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ // Write invalid hits to the beginning of pl_used2 to make it invalid.
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EmbeddingHit *first_hit =
+ reinterpret_cast<EmbeddingHit *>(pl_used2.posting_list_buffer());
+ *first_hit = invalid_hit;
+ ++first_hit;
+ *first_hit = invalid_hit;
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, MoveToPostingListTooSmall) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/1, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend())));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embedding-indexing-handler.cc b/icing/index/embedding-indexing-handler.cc
new file mode 100644
index 0000000..049e307
--- /dev/null
+++ b/icing/index/embedding-indexing-handler.cc
@@ -0,0 +1,85 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embedding-indexing-handler.h"
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/hit/hit.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/status-macros.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>>
+EmbeddingIndexingHandler::Create(const Clock* clock,
+ EmbeddingIndex* embedding_index) {
+ ICING_RETURN_ERROR_IF_NULL(clock);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
+
+ return std::unique_ptr<EmbeddingIndexingHandler>(
+ new EmbeddingIndexingHandler(clock, embedding_index));
+}
+
+libtextclassifier3::Status EmbeddingIndexingHandler::Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ bool recovery_mode, PutDocumentStatsProto* put_document_stats) {
+ std::unique_ptr<Timer> index_timer = clock_.GetNewTimer();
+
+ if (!IsDocumentIdValid(document_id)) {
+ return absl_ports::InvalidArgumentError(
+ IcingStringUtil::StringPrintf("Invalid DocumentId %d", document_id));
+ }
+
+ if (embedding_index_.last_added_document_id() != kInvalidDocumentId &&
+ document_id <= embedding_index_.last_added_document_id()) {
+ if (recovery_mode) {
+ // Skip the document if document_id <= last_added_document_id in
+ // recovery mode without returning an error.
+ return libtextclassifier3::Status::OK;
+ }
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "DocumentId %d must be greater than last added document_id %d",
+ document_id, embedding_index_.last_added_document_id()));
+ }
+ embedding_index_.set_last_added_document_id(document_id);
+
+ for (const Section<PropertyProto::VectorProto>& vector_section :
+ tokenized_document.vector_sections()) {
+ BasicHit hit(/*section_id=*/vector_section.metadata.id, document_id);
+ for (const PropertyProto::VectorProto& vector : vector_section.content) {
+ ICING_RETURN_IF_ERROR(embedding_index_.BufferEmbedding(hit, vector));
+ }
+ }
+ ICING_RETURN_IF_ERROR(embedding_index_.CommitBufferToIndex());
+
+ if (put_document_stats != nullptr) {
+ put_document_stats->set_embedding_index_latency_ms(
+ index_timer->GetElapsedMilliseconds());
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embedding-indexing-handler.h b/icing/index/embedding-indexing-handler.h
new file mode 100644
index 0000000..f3adf6a
--- /dev/null
+++ b/icing/index/embedding-indexing-handler.h
@@ -0,0 +1,70 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
+#define ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingIndexingHandler : public DataIndexingHandler {
+ public:
+ ~EmbeddingIndexingHandler() override = default;
+
+ // Creates an EmbeddingIndexingHandler instance which does not take
+ // ownership of any input components. All pointers must refer to valid objects
+ // that outlive the created EmbeddingIndexingHandler instance.
+ //
+ // Returns:
+ // - An EmbeddingIndexingHandler instance on success
+ // - FAILED_PRECONDITION_ERROR if any of the input pointer is null
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>>
+ Create(const Clock* clock, EmbeddingIndex* embedding_index);
+
+ // Handles the embedding indexing process: add hits into the embedding index
+ // for all contents in tokenized_document.vector_sections.
+ //
+ // Returns:
+ // - OK on success.
+ // - INVALID_ARGUMENT_ERROR if document_id is invalid OR document_id is less
+ // than or equal to the document_id of a previously indexed document in
+ // non recovery mode.
+ // - INTERNAL_ERROR if any other errors occur.
+ // - Any embedding index errors.
+ libtextclassifier3::Status Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ bool recovery_mode, PutDocumentStatsProto* put_document_stats) override;
+
+ private:
+ explicit EmbeddingIndexingHandler(const Clock* clock,
+ EmbeddingIndex* embedding_index)
+ : DataIndexingHandler(clock), embedding_index_(*embedding_index) {}
+
+ EmbeddingIndex& embedding_index_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
diff --git a/icing/index/embedding-indexing-handler_test.cc b/icing/index/embedding-indexing-handler_test.cc
new file mode 100644
index 0000000..ed711c1
--- /dev/null
+++ b/icing/index/embedding-indexing-handler_test.cc
@@ -0,0 +1,620 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embedding-indexing-handler.h"
+
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/portable/platform.h"
+#include "icing/proto/document_wrapper.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/icu-data-file-helper.h"
+#include "icing/testing/test-data.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
+#include "icing/util/status-macros.h"
+#include "icing/util/tokenized-document.h"
+#include "unicode/uloc.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::IsEmpty;
+using ::testing::IsTrue;
+
+// Indexable properties (section) and section id. Section id is determined by
+// the lexicographical order of indexable property paths.
+// Schema type with indexable properties: FakeType
+// Section id = 0: "body"
+// Section id = 1: "bodyEmbedding"
+// Section id = 2: "title"
+// Section id = 3: "titleEmbedding"
+static constexpr std::string_view kFakeType = "FakeType";
+static constexpr std::string_view kPropertyBody = "body";
+static constexpr std::string_view kPropertyBodyEmbedding = "bodyEmbedding";
+static constexpr std::string_view kPropertyTitle = "title";
+static constexpr std::string_view kPropertyTitleEmbedding = "titleEmbedding";
+static constexpr std::string_view kPropertyNonIndexableEmbedding =
+ "nonIndexableEmbedding";
+
+static constexpr SectionId kSectionIdBodyEmbedding = 1;
+static constexpr SectionId kSectionIdTitleEmbedding = 3;
+
+// Schema type with nested indexable properties: FakeCollectionType
+// Section id = 0: "collection.body"
+// Section id = 1: "collection.bodyEmbedding"
+// Section id = 2: "collection.title"
+// Section id = 3: "collection.titleEmbedding"
+// Section id = 4: "fullDocEmbedding"
+static constexpr std::string_view kFakeCollectionType = "FakeCollectionType";
+static constexpr std::string_view kPropertyCollection = "collection";
+static constexpr std::string_view kPropertyFullDocEmbedding =
+ "fullDocEmbedding";
+
+static constexpr SectionId kSectionIdNestedBodyEmbedding = 1;
+static constexpr SectionId kSectionIdNestedTitleEmbedding = 3;
+static constexpr SectionId kSectionIdFullDocEmbedding = 4;
+
+class EmbeddingIndexingHandlerTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ if (!IsCfStringTokenization() && !IsReverseJniTokenization()) {
+ ICING_ASSERT_OK(
+ // File generated via icu_data_file rule in //icing/BUILD.
+ icu_data_file_helper::SetUpICUDataFile(
+ GetTestFilePath("icing/icu.dat")));
+ }
+
+ base_dir_ = GetTestTempDir() + "/icing_test";
+ ASSERT_THAT(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()),
+ IsTrue());
+
+ embedding_index_working_path_ = base_dir_ + "/embedding_index";
+ schema_store_dir_ = base_dir_ + "/schema_store";
+ document_store_dir_ = base_dir_ + "/document_store";
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_working_path_));
+
+ language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ lang_segmenter_,
+ language_segmenter_factory::Create(std::move(segmenter_options)));
+
+ ASSERT_THAT(
+ filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()),
+ IsTrue());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(
+ SchemaTypeConfigBuilder()
+ .SetType(kFakeType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyTitle)
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyBody)
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyTitleEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::
+ LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyBodyEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::
+ LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyNonIndexableEmbedding)
+ .SetDataType(TYPE_VECTOR)
+ .SetCardinality(CARDINALITY_REPEATED)))
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kFakeCollectionType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyCollection)
+ .SetDataTypeDocument(
+ kFakeType,
+ /*index_nested_properties=*/true)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyFullDocEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::
+ EmbeddingIndexingType::LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ICING_ASSERT_OK(schema_store_->SetSchema(
+ schema, /*ignore_errors_and_delete_documents=*/false,
+ /*allow_circular_schema_definitions=*/false));
+
+ ASSERT_TRUE(
+ filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str()));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult doc_store_create_result,
+ DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_,
+ schema_store_.get(),
+ /*force_recovery_and_revalidate_documents=*/false,
+ /*namespace_id_fingerprint=*/false,
+ /*pre_mapping_fbv=*/false,
+ /*use_persistent_hash_map=*/false,
+ PortableFileBackedProtoLog<
+ DocumentWrapper>::kDeflateCompressionLevel,
+ /*initialize_stats=*/nullptr));
+ document_store_ = std::move(doc_store_create_result.document_store);
+ }
+
+ void TearDown() override {
+ document_store_.reset();
+ schema_store_.reset();
+ lang_segmenter_.reset();
+ embedding_index_.reset();
+
+ filesystem_.DeleteDirectoryRecursively(base_dir_.c_str());
+ }
+
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ uint32_t dimension, std::string_view model_signature) {
+ std::vector<EmbeddingHit> hits;
+
+ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or =
+ embedding_index_->GetAccessor(dimension, model_signature);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ return hits;
+ } else {
+ return std::move(pl_accessor_or).status();
+ }
+
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ return hits;
+ }
+ hits.insert(hits.end(), batch.begin(), batch.end());
+ }
+ }
+
+ std::vector<float> GetRawEmbeddingData() {
+ return std::vector<float>(embedding_index_->GetRawEmbeddingData(),
+ embedding_index_->GetRawEmbeddingData() +
+ embedding_index_->GetTotalVectorSize());
+ }
+
+ Filesystem filesystem_;
+ FakeClock fake_clock_;
+ std::string base_dir_;
+ std::string embedding_index_working_path_;
+ std::string schema_store_dir_;
+ std::string document_store_dir_;
+
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
+ std::unique_ptr<LanguageSegmenter> lang_segmenter_;
+ std::unique_ptr<SchemaStore> schema_store_;
+ std::unique_ptr<DocumentStore> document_store_;
+};
+
+} // namespace
+
+TEST_F(EmbeddingIndexingHandlerTest, CreationWithNullPointerShouldFail) {
+ EXPECT_THAT(EmbeddingIndexingHandler::Create(/*clock=*/nullptr,
+ embedding_index_.get()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+
+ EXPECT_THAT(EmbeddingIndexingHandler::Create(&fake_clock_,
+ /*embedding_index=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest, HandleEmbeddingSection) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kInvalidDocumentId));
+ // Handle document.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest, HandleNestedEmbeddingSection) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_collection_type/1")
+ .SetSchema(std::string(kFakeCollectionType))
+ .AddDocumentProperty(
+ std::string(kPropertyCollection),
+ DocumentBuilder()
+ .SetKey("icing", "nested_fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(
+ std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build())
+ .AddVectorProperty(std::string(kPropertyFullDocEmbedding),
+ CreateVector("model", {2.1, 2.2, 2.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kInvalidDocumentId));
+ // Handle document.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedTitleEmbedding, /*document_id=*/0),
+ /*location=*/6),
+ EmbeddingHit(BasicHit(kSectionIdFullDocEmbedding, /*document_id=*/0),
+ /*location=*/9))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
+ 0.1, 0.2, 0.3, 2.1, 2.2, 2.3));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleInvalidDocumentIdShouldReturnInvalidArgumentError) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK(document_store_->Put(tokenized_document.document()));
+
+ static constexpr DocumentId kCurrentDocumentId = 3;
+ embedding_index_->set_last_added_document_id(kCurrentDocumentId);
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handling document with kInvalidDocumentId should cause a failure, and both
+ // index data and last_added_document_id should remain unchanged.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, kInvalidDocumentId,
+ /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+
+ // Recovery mode should get the same result.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, kInvalidDocumentId,
+ /*recovery_mode=*/true, /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handling document with document_id == last_added_document_id should cause a
+ // failure, and both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+
+ // Handling document with document_id < last_added_document_id should cause a
+ // failure, and both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id + 1);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
+
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId) {
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title one")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body one")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ DocumentProto document2 =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/2")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title two")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {10.1, 10.2, 10.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body two")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {10.4, 10.5, 10.6}),
+ CreateVector("model", {10.7, 10.8, 10.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {11.1, 11.2, 11.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document1,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document1)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document2,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document2)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store_->Put(tokenized_document1.document()));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id2,
+ document_store_->Put(tokenized_document2.document()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handle document with document_id > last_added_document_id in recovery mode.
+ // The handler should index this document and update last_added_document_id.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document1, document_id1, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id1));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+
+ // Handle document with document_id == last_added_document_id in recovery
+ // mode. We should not get any error, but the handler should ignore the
+ // document, so both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id2);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+
+ // Handle document with document_id < last_added_document_id in recovery mode.
+ // We should not get any error, but the handler should ignore the document, so
+ // both index data and last_added_document_id should remain unchanged.
+ embedding_index_->set_last_added_document_id(document_id2 + 1);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h
index 83fff12..e971016 100644
--- a/icing/index/hit/hit.h
+++ b/icing/index/hit/hit.h
@@ -60,6 +60,7 @@ class BasicHit {
explicit BasicHit(SectionId section_id, DocumentId document_id);
explicit BasicHit() : value_(kInvalidValue) {}
+ explicit BasicHit(Value value) : value_(value) {}
bool is_valid() const { return value_ != kInvalidValue; }
Value value() const { return value_; }
diff --git a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
index c16a1c4..d766712 100644
--- a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
+++ b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
@@ -17,13 +17,17 @@
#include <cstdint>
#include <memory>
+#include <set>
#include <string>
-#include <string_view>
#include <utility>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
namespace icing {
diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
index 682b752..735adaa 100644
--- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
@@ -134,7 +134,9 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions(
ChildrenMapper mapper;
mapper = [&data, &mapper](std::unique_ptr<DocHitInfoIterator> iterator)
-> std::unique_ptr<DocHitInfoIterator> {
- if (iterator->is_leaf()) {
+ if (iterator->full_section_restriction_applied()) {
+ return iterator;
+ } else if (iterator->is_leaf()) {
return std::make_unique<DocHitInfoIteratorSectionRestrict>(
std::move(iterator), data);
} else {
@@ -149,24 +151,8 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() {
doc_hit_info_ = DocHitInfo(kInvalidDocumentId);
while (delegate_->Advance().ok()) {
DocumentId document_id = delegate_->doc_hit_info().document_id();
-
- auto data_optional = data_->document_store().GetAliveDocumentFilterData(
- document_id, data_->current_time_ms());
- if (!data_optional) {
- // Ran into some error retrieving information on this hit, skip
- continue;
- }
-
- // Guaranteed that the DocumentFilterData exists at this point
- SchemaTypeId schema_type_id = data_optional.value().schema_type_id();
- auto schema_type_or = data_->schema_store().GetSchemaType(schema_type_id);
- if (!schema_type_or.ok()) {
- // Ran into error retrieving schema type, skip
- continue;
- }
- const std::string* schema_type = std::move(schema_type_or).ValueOrDie();
SectionIdMask allowed_sections_mask =
- data_->ComputeAllowedSectionsMask(*schema_type);
+ data_->ComputeAllowedSectionsMask(document_id);
// A hit can be in multiple sections at once, need to check which of the
// section ids match the sections allowed by type_property_masks_. This can
diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h
index 728f957..921e4d4 100644
--- a/icing/index/iterator/doc-hit-info-iterator.h
+++ b/icing/index/iterator/doc-hit-info-iterator.h
@@ -214,6 +214,12 @@ class DocHitInfoIterator {
virtual bool is_leaf() { return false; }
+ // Whether the iterator has already been applied all the required section
+ // restrictions.
+ // If true, calling DocHitInfoIteratorSectionRestrict::ApplyRestrictions on
+ // this iterator will have no effect.
+ virtual bool full_section_restriction_applied() const { return false; }
+
virtual ~DocHitInfoIterator() = default;
// Returns:
diff --git a/icing/index/iterator/section-restrict-data.cc b/icing/index/iterator/section-restrict-data.cc
index 085437d..581ac8e 100644
--- a/icing/index/iterator/section-restrict-data.cc
+++ b/icing/index/iterator/section-restrict-data.cc
@@ -22,6 +22,8 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
namespace icing {
namespace lib {
@@ -78,5 +80,24 @@ SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask(
return new_section_id_mask;
}
+SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask(
+ DocumentId document_id) {
+ auto data_optional =
+ document_store_.GetAliveDocumentFilterData(document_id, current_time_ms_);
+ if (!data_optional) {
+ // Ran into some error retrieving information on this document, skip
+ return kSectionIdMaskNone;
+ }
+ // Guaranteed that the DocumentFilterData exists at this point
+ SchemaTypeId schema_type_id = data_optional.value().schema_type_id();
+ auto schema_type_or = schema_store_.GetSchemaType(schema_type_id);
+ if (!schema_type_or.ok()) {
+ // Ran into error retrieving schema type, skip
+ return kSectionIdMaskNone;
+ }
+ const std::string* schema_type = std::move(schema_type_or).ValueOrDie();
+ return ComputeAllowedSectionsMask(*schema_type);
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/index/iterator/section-restrict-data.h b/icing/index/iterator/section-restrict-data.h
index 26ca597..64d5087 100644
--- a/icing/index/iterator/section-restrict-data.h
+++ b/icing/index/iterator/section-restrict-data.h
@@ -23,6 +23,7 @@
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
namespace icing {
@@ -48,10 +49,21 @@ class SectionRestrictData {
// Returns:
// - If type_property_filters_ has an entry for the given schema type or
// wildcard(*), return a bitwise or of section IDs in the schema type
- // that that are also present in the relevant filter list.
+ // that are also present in the relevant filter list.
// - Otherwise, return kSectionIdMaskAll.
SectionIdMask ComputeAllowedSectionsMask(const std::string& schema_type);
+ // Calculates the section mask of allowed sections(determined by the
+ // property filters map) for the given document id, by retrieving its schema
+ // type name and calling the above method.
+ //
+ // Returns:
+ // - If type_property_filters_ has an entry for the given document's schema
+ // type or wildcard(*), return a bitwise or of section IDs in the schema
+ // type that are also present in the relevant filter list.
+ // - Otherwise, return kSectionIdMaskAll.
+ SectionIdMask ComputeAllowedSectionsMask(DocumentId document_id);
+
const DocumentStore& document_store() const { return document_store_; }
const SchemaStore& schema_store() const { return schema_store_; }
diff --git a/icing/query/advanced_query_parser/function.cc b/icing/query/advanced_query_parser/function.cc
index e7938db..0ff160b 100644
--- a/icing/query/advanced_query_parser/function.cc
+++ b/icing/query/advanced_query_parser/function.cc
@@ -13,8 +13,15 @@
// limitations under the License.
#include "icing/query/advanced_query_parser/function.h"
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/query/advanced_query_parser/param.h"
+#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -73,5 +80,20 @@ libtextclassifier3::StatusOr<PendingValue> Function::Eval(
return eval_(std::move(args));
}
+libtextclassifier3::StatusOr<DataType> Function::get_param_type(int i) const {
+ if (i < 0 || params_.empty()) {
+ return absl_ports::OutOfRangeError("Invalid argument index.");
+ }
+ const Param* parm;
+ if (i < params_.size()) {
+ parm = &params_.at(i);
+ } else if (params_.back().cardinality == Cardinality::kVariable) {
+ parm = &params_.back();
+ } else {
+ return absl_ports::OutOfRangeError("Invalid argument index.");
+ }
+ return parm->data_type;
+}
+
} // namespace lib
-} // namespace icing \ No newline at end of file
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/function.h b/icing/query/advanced_query_parser/function.h
index 3514878..08cc7e8 100644
--- a/icing/query/advanced_query_parser/function.h
+++ b/icing/query/advanced_query_parser/function.h
@@ -46,6 +46,8 @@ class Function {
libtextclassifier3::StatusOr<PendingValue> Eval(
std::vector<PendingValue>&& args) const;
+ libtextclassifier3::StatusOr<DataType> get_param_type(int i) const;
+
private:
Function(DataType return_type, std::string name, std::vector<Param> params,
EvalFunction eval)
diff --git a/icing/query/advanced_query_parser/param.h b/icing/query/advanced_query_parser/param.h
index 69c46be..9ea1915 100644
--- a/icing/query/advanced_query_parser/param.h
+++ b/icing/query/advanced_query_parser/param.h
@@ -35,13 +35,19 @@ struct Param {
libtextclassifier3::Status Matches(PendingValue& arg) const {
bool matches = arg.data_type() == data_type;
- // Values of type kText could also potentially be valid kLong values. If
- // we're expecting a kLong and we have a kText, try to parse it as a kLong.
+ // Values of type kText could also potentially be valid kLong or kDouble
+ // values. If we're expecting a kLong or kDouble and we have a kText, try to
+ // parse it as what we expect.
if (!matches && data_type == DataType::kLong &&
arg.data_type() == DataType::kText) {
ICING_RETURN_IF_ERROR(arg.ParseInt());
matches = true;
}
+ if (!matches && data_type == DataType::kDouble &&
+ arg.data_type() == DataType::kText) {
+ ICING_RETURN_IF_ERROR(arg.ParseDouble());
+ matches = true;
+ }
return matches ? libtextclassifier3::Status::OK
: absl_ports::InvalidArgumentError(
"Provided arg doesn't match required param type.");
diff --git a/icing/query/advanced_query_parser/pending-value.cc b/icing/query/advanced_query_parser/pending-value.cc
index 67bdc3a..a3f95d9 100644
--- a/icing/query/advanced_query_parser/pending-value.cc
+++ b/icing/query/advanced_query_parser/pending-value.cc
@@ -13,7 +13,11 @@
// limitations under the License.
#include "icing/query/advanced_query_parser/pending-value.h"
+#include <cstdlib>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
namespace icing {
namespace lib {
@@ -40,5 +44,27 @@ libtextclassifier3::Status PendingValue::ParseInt() {
return libtextclassifier3::Status::OK;
}
+libtextclassifier3::Status PendingValue::ParseDouble() {
+ if (data_type_ == DataType::kDouble) {
+ return libtextclassifier3::Status::OK;
+ } else if (data_type_ != DataType::kText) {
+ return absl_ports::InvalidArgumentError("Cannot parse value as double");
+ }
+ if (query_term_.is_prefix_val) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Cannot use prefix operator '*' with numeric value: ",
+ query_term_.term));
+ }
+ char* value_end;
+ double_val_ = std::strtod(query_term_.term.c_str(), &value_end);
+ if (value_end != query_term_.term.c_str() + query_term_.term.length()) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Unable to parse \"", query_term_.term, "\" as double."));
+ }
+ data_type_ = DataType::kDouble;
+ query_term_ = {/*term=*/"", /*raw_term=*/"", /*is_prefix_val=*/false};
+ return libtextclassifier3::Status::OK;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/query/advanced_query_parser/pending-value.h b/icing/query/advanced_query_parser/pending-value.h
index 1a6717e..34912f3 100644
--- a/icing/query/advanced_query_parser/pending-value.h
+++ b/icing/query/advanced_query_parser/pending-value.h
@@ -14,12 +14,16 @@
#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_
#define ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_
+#include <cstdint>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/util/status-macros.h"
@@ -30,10 +34,14 @@ namespace lib {
enum class DataType {
kNone,
kLong,
+ kDouble,
kText,
kString,
kStringList,
kDocumentIterator,
+ // TODO(b/326656531): Instead of creating a vector index type, consider
+ // changing it to vector type so that the data is the vector directly.
+ kVectorIndex,
};
struct QueryTerm {
@@ -52,6 +60,10 @@ struct PendingValue {
return PendingValue(std::move(text), DataType::kText);
}
+ static PendingValue CreateVectorIndexPendingValue(int64_t vector_index) {
+ return PendingValue(vector_index, DataType::kVectorIndex);
+ }
+
PendingValue() : data_type_(DataType::kNone) {}
explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator)
@@ -111,6 +123,16 @@ struct PendingValue {
return long_val_;
}
+ libtextclassifier3::StatusOr<double> double_val() {
+ ICING_RETURN_IF_ERROR(ParseDouble());
+ return double_val_;
+ }
+
+ libtextclassifier3::StatusOr<int64_t> vector_index_val() const {
+ ICING_RETURN_IF_ERROR(CheckDataType(DataType::kVectorIndex));
+ return long_val_;
+ }
+
// Attempts to interpret the value as an int. A pending value can be parsed as
// an int under two circumstances:
// 1. It holds a kText value which can be parsed to an int
@@ -122,12 +144,26 @@ struct PendingValue {
// - INVALID_ARGUMENT if the value could not be parsed as a long
libtextclassifier3::Status ParseInt();
+ // Attempts to interpret the value as a double. A pending value can be parsed
+ // as a double under two circumstances:
+ // 1. It holds a kText value which can be parsed to a double
+ // 2. It holds a kDouble value
+ // If #1 is true, then the parsed value will be stored in double_val_ and
+ // data_type will be updated to kDouble.
+ // RETURNS:
+ // - OK, if able to successfully parse the value into a double
+ // - INVALID_ARGUMENT if the value could not be parsed as a double
+ libtextclassifier3::Status ParseDouble();
+
DataType data_type() const { return data_type_; }
private:
explicit PendingValue(QueryTerm query_term, DataType data_type)
: query_term_(std::move(query_term)), data_type_(data_type) {}
+ explicit PendingValue(int64_t long_val, DataType data_type)
+ : long_val_(long_val), data_type_(data_type) {}
+
libtextclassifier3::Status CheckDataType(DataType required_data_type) const {
if (data_type_ == required_data_type) {
return libtextclassifier3::Status::OK;
@@ -151,6 +187,9 @@ struct PendingValue {
// long_val_ will be populated when data_type_ is kLong - after a successful
// call to ParseInt.
int64_t long_val_;
+ // double_val_ will be populated when data_type_ is kDouble - after a
+ // successful call to ParseDouble.
+ double double_val_;
DataType data_type_;
};
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc
index 31da959..1ac52c5 100644
--- a/icing/query/advanced_query_parser/query-visitor.cc
+++ b/icing/query/advanced_query_parser/query-visitor.cc
@@ -16,20 +16,26 @@
#include <algorithm>
#include <cstdint>
-#include <cstdlib>
#include <iterator>
#include <limits>
#include <memory>
#include <set>
#include <string>
+#include <string_view>
+#include <unordered_map>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/absl_ports/str_join.h"
+#include "icing/index/embed/doc-hit-info-iterator-embedding.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h"
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
+#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator-none.h"
#include "icing/index/iterator/doc-hit-info-iterator-not.h"
#include "icing/index/iterator/doc-hit-info-iterator-or.h"
@@ -37,17 +43,23 @@
#include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h"
#include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
#include "icing/index/property-existence-indexing-handler.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/function.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/param.h"
#include "icing/query/advanced_query_parser/parser.h"
#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/query/advanced_query_parser/util/string-util.h"
#include "icing/query/query-features.h"
+#include "icing/query/query-results.h"
#include "icing/schema/property-util.h"
+#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer.h"
+#include "icing/util/embedding-util.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() {
.ValueOrDie();
registered_functions_.insert(
{has_property_function.name(), std::move(has_property_function)});
+
+ // vector_index getSearchSpecEmbedding(long);
+ auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) {
+ return PendingValue::CreateVectorIndexPendingValue(
+ args.at(0).long_val().ValueOrDie());
+ };
+ Function get_search_spec_embedding_function =
+ Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding",
+ {Param(DataType::kLong)},
+ std::move(get_search_spec_embedding))
+ .ValueOrDie();
+ registered_functions_.insert({get_search_spec_embedding_function.name(),
+ std::move(get_search_spec_embedding_function)});
+
+ // DocHitInfoIterator semanticSearch(vector_index, double, double, string);
+ auto semantic_search = [this](std::vector<PendingValue>&& args) {
+ return this->SemanticSearchFunction(std::move(args));
+ };
+ Function semantic_search_function =
+ Function::Create(DataType::kDocumentIterator, "semanticSearch",
+ {Param(DataType::kVectorIndex),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kString, Cardinality::kOptional)},
+ std::move(semantic_search))
+ .ValueOrDie();
+ registered_functions_.insert(
+ {semantic_search_function.name(), std::move(semantic_search_function)});
}
libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
@@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
document_store_.last_added_document_id());
} else {
QueryVisitor query_visitor(
- &index_, &numeric_index_, &document_store_, &schema_store_,
- &normalizer_, &tokenizer_, query->raw_term, filter_options_,
- match_type_, needs_term_frequency_info_, pending_property_restricts_,
- processing_not_, current_time_ms_);
+ &index_, &numeric_index_, &embedding_index_, &document_store_,
+ &schema_store_, &normalizer_, &tokenizer_, query->raw_term,
+ embedding_query_vectors_, filter_options_, match_type_,
+ embedding_query_metric_type_, needs_term_frequency_info_,
+ pending_property_restricts_, processing_not_, current_time_ms_);
tree_root->Accept(&query_visitor);
ICING_ASSIGN_OR_RETURN(query_result,
std::move(query_visitor).ConsumeResults());
@@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction(
return PendingValue(std::move(property_in_document_iterator));
}
+libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction(
+ std::vector<PendingValue>&& args) {
+ features_.insert(kEmbeddingSearchFeature);
+
+ int64_t vector_index = args.at(0).vector_index_val().ValueOrDie();
+ if (embedding_query_vectors_ == nullptr || vector_index < 0 ||
+ vector_index >= embedding_query_vectors_->size()) {
+ return absl_ports::InvalidArgumentError("Got invalid vector search index!");
+ }
+
+ // Handle default values for the optional arguments.
+ double low = -std::numeric_limits<double>::infinity();
+ double high = std::numeric_limits<double>::infinity();
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ embedding_query_metric_type_;
+ if (args.size() >= 2) {
+ low = args.at(1).double_val().ValueOrDie();
+ }
+ if (args.size() >= 3) {
+ high = args.at(2).double_val().ValueOrDie();
+ }
+ if (args.size() >= 4) {
+ const std::string& metric = args.at(3).string_val().ValueOrDie()->term;
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+
+ // Create SectionRestrictData for section restriction.
+ std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr;
+ if (pending_property_restricts_.has_active_property_restricts()) {
+ std::unordered_map<std::string, std::set<std::string>>
+ type_property_filters;
+ type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] =
+ pending_property_restricts_.active_property_restricts();
+ section_restrict_data = std::make_unique<SectionRestrictData>(
+ &document_store_, &schema_store_, current_time_ms_,
+ type_property_filters);
+ }
+
+ // Create and return iterator.
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map =
+ &embedding_query_results_.result_scores[vector_index][metric_type];
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator,
+ DocHitInfoIteratorEmbedding::Create(
+ &embedding_query_vectors_->at(vector_index),
+ std::move(section_restrict_data), metric_type, low,
+ high, score_map, &embedding_index_));
+ return PendingValue(std::move(iterator));
+}
+
libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() {
if (pending_values_.empty()) {
return absl_ports::InvalidArgumentError("Unable to retrieve int value.");
@@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() {
// raw_text, then all of raw_text must correspond to this token.
raw_token = raw_text;
} else {
- ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken(
- raw_text, token.text));
+ ICING_ASSIGN_OR_RETURN(
+ raw_token, string_util::FindEscapedToken(raw_text, token.text));
}
normalized_term = normalizer_.NormalizeTerm(token.text);
QueryTerm term_value{std::move(normalized_term), raw_token,
@@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator(
"Visit unary operator child didn't correctly add pending values.");
}
- // 3. We want to preserve the original text of the integer value, append our
- // minus and *then* parse as an int.
- ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue());
- int_text_val.term = absl_ports::StrCat("-", int_text_val.term);
+ // 3. We want to preserve the original text of the numeric value, append our
+ // minus to the text. It will be parsed as either an int or a double later.
+ ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue());
+ numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term);
PendingValue pending_value =
- PendingValue::CreateTextPendingValue(std::move(int_text_val));
- ICING_RETURN_IF_ERROR(pending_value.long_val());
+ PendingValue::CreateTextPendingValue(std::move(numeric_text_val));
- // We've parsed our integer value successfully. Pop our placeholder, push it
+ // We've parsed our numeric value successfully. Pop our placeholder, push it
// on to the stack and return successfully.
if (!pending_values_.top().is_placeholder()) {
return absl_ports::InvalidArgumentError(
@@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) {
end = text_val.raw_term.data() + text_val.raw_term.length();
} else {
start = std::min(start, text_val.raw_term.data());
- end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length());
+ end = std::max(end,
+ text_val.raw_term.data() + text_val.raw_term.length());
}
members.push_back(std::move(text_val.term));
}
@@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
"Function ", node->function_name()->value(), " is not supported."));
return;
}
+ const Function& function = itr->second;
// 2. Put in a placeholder PendingValue
pending_values_.push(PendingValue());
// 3. Visit the children.
- for (const std::unique_ptr<Node>& arg : node->args()) {
+ expecting_numeric_arg_ = true;
+ for (int i = 0; i < node->args().size(); ++i) {
+ const std::unique_ptr<Node>& arg = node->args()[i];
+ libtextclassifier3::StatusOr<DataType> arg_type_or =
+ function.get_param_type(i);
+ bool current_level_expecting_numeric_arg = expecting_numeric_arg_;
+ // If arg_type_or has an error, we should ignore it for now, since
+ // function.Eval should do the type check and return better error messages.
+ if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong ||
+ arg_type_or.ValueOrDie() == DataType::kDouble)) {
+ expecting_numeric_arg_ = true;
+ }
arg->Accept(this);
+ expecting_numeric_arg_ = current_level_expecting_numeric_arg;
if (has_pending_error()) {
return;
}
@@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
pending_values_.pop();
}
std::reverse(args.begin(), args.end());
- const Function& function = itr->second;
auto eval_result = function.Eval(std::move(args));
if (!eval_result.ok()) {
pending_error_ = std::move(eval_result).status();
@@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && {
results.root_iterator = std::move(iterator_or).ValueOrDie();
results.query_term_iterators = std::move(query_term_iterators_);
results.query_terms = std::move(property_query_terms_map_);
+ results.embedding_query_results = std::move(embedding_query_results_);
results.features_in_use = std::move(features_);
return results;
}
diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h
index d090b3c..17149f5 100644
--- a/icing/query/advanced_query_parser/query-visitor.h
+++ b/icing/query/advanced_query_parser/query-visitor.h
@@ -17,13 +17,19 @@
#include <cstdint>
#include <memory>
+#include <set>
#include <stack>
#include <string>
+#include <string_view>
+#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -33,10 +39,12 @@
#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/query/query-features.h"
#include "icing/query/query-results.h"
+#include "icing/query/query-terms.h"
#include "icing/schema/schema-store.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/tokenizer.h"
#include "icing/transform/normalizer.h"
+#include <google/protobuf/repeated_field.h>
namespace icing {
namespace lib {
@@ -45,19 +53,23 @@ namespace lib {
// the parser.
class QueryVisitor : public AbstractSyntaxTreeVisitor {
public:
- explicit QueryVisitor(Index* index,
- const NumericIndex<int64_t>* numeric_index,
- const DocumentStore* document_store,
- const SchemaStore* schema_store,
- const Normalizer* normalizer,
- const Tokenizer* tokenizer,
- std::string_view raw_query_text,
- DocHitInfoIteratorFilter::Options filter_options,
- TermMatchType::Code match_type,
- bool needs_term_frequency_info, int64_t current_time_ms)
- : QueryVisitor(index, numeric_index, document_store, schema_store,
- normalizer, tokenizer, raw_query_text, filter_options,
- match_type, needs_term_frequency_info,
+ explicit QueryVisitor(
+ Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ const Normalizer* normalizer, const Tokenizer* tokenizer,
+ std::string_view raw_query_text,
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options filter_options,
+ TermMatchType::Code match_type,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ embedding_query_metric_type,
+ bool needs_term_frequency_info, int64_t current_time_ms)
+ : QueryVisitor(index, numeric_index, embedding_index, document_store,
+ schema_store, normalizer, tokenizer, raw_query_text,
+ embedding_query_vectors, filter_options, match_type,
+ embedding_query_metric_type, needs_term_frequency_info,
PendingPropertyRestricts(),
/*processing_not=*/false, current_time_ms) {}
@@ -106,22 +118,31 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
explicit QueryVisitor(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const DocumentStore* document_store, const SchemaStore* schema_store,
const Normalizer* normalizer, const Tokenizer* tokenizer,
std::string_view raw_query_text,
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors,
DocHitInfoIteratorFilter::Options filter_options,
- TermMatchType::Code match_type, bool needs_term_frequency_info,
+ TermMatchType::Code match_type,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ embedding_query_metric_type,
+ bool needs_term_frequency_info,
PendingPropertyRestricts pending_property_restricts, bool processing_not,
int64_t current_time_ms)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
document_store_(*document_store),
schema_store_(*schema_store),
normalizer_(*normalizer),
tokenizer_(*tokenizer),
raw_query_text_(raw_query_text),
+ embedding_query_vectors_(embedding_query_vectors),
filter_options_(std::move(filter_options)),
match_type_(match_type),
+ embedding_query_metric_type_(embedding_query_metric_type),
needs_term_frequency_info_(needs_term_frequency_info),
pending_property_restricts_(std::move(pending_property_restricts)),
processing_not_(processing_not),
@@ -264,6 +285,22 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
libtextclassifier3::StatusOr<PendingValue> HasPropertyFunction(
std::vector<PendingValue>&& args);
+ // Implementation of the semanticSearch(vector, low, high, metric) custom
+ // function. This function is used for supporting vector search with a
+ // syntax like `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`.
+ //
+ // low, high, metric are optional parameters:
+ // - low is default to negative infinity
+ // - high is default to positive infinity
+ // - metric is default to the metric specified in SearchSpec
+ //
+ // Returns:
+ // - a Pending Value of type DocHitIterator that returns all documents with
+ // an embedding vector that has a score within [low, high].
+ // - any errors returned by Lexer::ExtractTokens
+ libtextclassifier3::StatusOr<PendingValue> SemanticSearchFunction(
+ std::vector<PendingValue>&& args);
+
// Handles a NaryOperatorNode where the operator is HAS (':') and pushes an
// iterator with the proper section filter applied. If the current property
// restriction represented by pending_property_restricts and the first child
@@ -292,19 +329,26 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
SectionRestrictQueryTermsMap property_query_terms_map_;
QueryTermIteratorsMap query_term_iterators_;
+
+ EmbeddingQueryResults embedding_query_results_;
+
// Set of features invoked in the query.
std::unordered_set<Feature> features_;
Index& index_; // Does not own!
const NumericIndex<int64_t>& numeric_index_; // Does not own!
+ const EmbeddingIndex& embedding_index_; // Does not own!
const DocumentStore& document_store_; // Does not own!
const SchemaStore& schema_store_; // Does not own!
const Normalizer& normalizer_; // Does not own!
const Tokenizer& tokenizer_; // Does not own!
std::string_view raw_query_text_;
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors_; // Nullable, does not own!
DocHitInfoIteratorFilter::Options filter_options_;
TermMatchType::Code match_type_;
+ SearchSpecProto::EmbeddingQueryMetricType::Code embedding_query_metric_type_;
// Whether or not term_frequency information is needed. This affects:
// - how DocHitInfoIteratorTerms are constructed
// - whether the QueryTermIteratorsMap is populated in the QueryResults.
diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc
index 9455baa..c5ba866 100644
--- a/icing/query/advanced_query_parser/query-visitor_test.cc
+++ b/icing/query/advanced_query_parser/query-visitor_test.cc
@@ -15,6 +15,7 @@
#include "icing/query/advanced_query_parser/query-visitor.h"
#include <cstdint>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <string>
@@ -31,6 +32,7 @@
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/hit/hit.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
@@ -42,6 +44,7 @@
#include "icing/jni/jni-cache.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/platform.h"
+#include "icing/proto/search.pb.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
@@ -54,6 +57,7 @@
#include "icing/store/document-store.h"
#include "icing/store/namespace-id.h"
#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/test-data.h"
@@ -67,16 +71,22 @@
#include "icing/util/clock.h"
#include "icing/util/status-macros.h"
#include "unicode/uloc.h"
+#include <google/protobuf/repeated_field.h>
namespace icing {
namespace lib {
namespace {
+using ::testing::DoubleNear;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
+using ::testing::IsNull;
+using ::testing::Pointee;
using ::testing::UnorderedElementsAre;
+constexpr float kEps = 0.000001;
+
constexpr DocumentId kDocumentId0 = 0;
constexpr DocumentId kDocumentId1 = 1;
constexpr DocumentId kDocumentId2 = 2;
@@ -85,6 +95,18 @@ constexpr SectionId kSectionId0 = 0;
constexpr SectionId kSectionId1 = 1;
constexpr SectionId kSectionId2 = 2;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_UNKNOWN =
+ SearchSpecProto::EmbeddingQueryMetricType::UNKNOWN;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_COSINE = SearchSpecProto::EmbeddingQueryMetricType::COSINE;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_DOT_PRODUCT =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_EUCLIDEAN =
+ SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN;
+
template <typename T, typename U>
std::vector<T> ExtractKeys(const std::unordered_map<T, U>& map) {
std::vector<T> keys;
@@ -106,6 +128,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
test_dir_ = GetTestTempDir() + "/icing";
index_dir_ = test_dir_ + "/index";
numeric_index_dir_ = test_dir_ + "/numeric_index";
+ embedding_index_dir_ = test_dir_ + "/embedding_index";
store_dir_ = test_dir_ + "/store";
schema_store_dir_ = test_dir_ + "/schema_store";
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -154,6 +177,10 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
+
ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create(
/*max_term_byte_size=*/1000));
@@ -219,6 +246,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
std::string test_dir_;
std::string index_dir_;
std::string numeric_index_dir_;
+ std::string embedding_index_dir_;
std::string schema_store_dir_;
std::string store_dir_;
Clock clock_;
@@ -226,6 +254,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
std::unique_ptr<DocumentStore> document_store_;
std::unique_ptr<Index> index_;
std::unique_ptr<DummyNumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<Normalizer> normalizer_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Tokenizer> tokenizer_;
@@ -252,9 +281,11 @@ TEST_P(QueryVisitorTest, SimpleLessThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -295,9 +326,11 @@ TEST_P(QueryVisitorTest, SimpleLessThanEq) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -338,9 +371,11 @@ TEST_P(QueryVisitorTest, SimpleEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -381,9 +416,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThanEq) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -424,9 +461,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -468,9 +507,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -512,9 +553,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -557,9 +600,11 @@ TEST_P(QueryVisitorTest, NestedPropertyLessThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -585,9 +630,11 @@ TEST_P(QueryVisitorTest, IntParsingError) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -599,9 +646,11 @@ TEST_P(QueryVisitorTest, NotEqualsUnsupported) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -647,9 +696,11 @@ TEST_P(QueryVisitorTest, LessThanTooManyOperandsInvalid) {
args.push_back(std::move(extra_value_node));
auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -674,9 +725,11 @@ TEST_P(QueryVisitorTest, LessThanTooFewOperandsInvalid) {
args.push_back(std::move(member_node));
auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -705,9 +758,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -727,9 +782,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
TEST_P(QueryVisitorTest, NeverVisitedReturnsInvalid) {
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), "",
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), "", /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
@@ -756,9 +813,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -786,9 +845,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -801,9 +862,11 @@ TEST_P(QueryVisitorTest, NumericComparisonPropertyStringIsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -865,9 +928,11 @@ TEST_P(QueryVisitorTest, NumericComparatorDoesntAffectLaterTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -908,9 +973,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyEnabled) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -960,9 +1027,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyDisabled) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/false, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1012,9 +1081,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1028,9 +1099,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) {
query = CreateQuery("fo*");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1048,9 +1121,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1062,9 +1137,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterNumericValueReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1076,9 +1153,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyRestrictReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1114,9 +1193,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1137,9 +1218,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) {
query = CreateQuery("ba?fo*");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1174,9 +1257,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1221,9 +1306,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTermPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1274,9 +1361,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingQuote) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1326,9 +1415,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingEscape) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1380,9 +1471,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1407,9 +1500,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
query = CreateQuery(R"(("foobar\\y"))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1462,9 +1557,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1488,9 +1585,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) {
query = CreateQuery(R"(("foobar\\n"))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1537,9 +1636,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingComplex) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1596,9 +1697,11 @@ TEST_P(QueryVisitorTest, SingleMinusTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1650,9 +1753,11 @@ TEST_P(QueryVisitorTest, SingleNotTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1705,9 +1810,11 @@ TEST_P(QueryVisitorTest, NestedNotTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1774,9 +1881,11 @@ TEST_P(QueryVisitorTest, DeeplyNestedNotTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1813,9 +1922,11 @@ TEST_P(QueryVisitorTest, ImplicitAndTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1856,9 +1967,11 @@ TEST_P(QueryVisitorTest, ExplicitAndTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1899,9 +2012,11 @@ TEST_P(QueryVisitorTest, OrTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1944,9 +2059,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1969,9 +2086,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
query = CreateQuery("bar OR baz foo");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1993,9 +2112,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
query = CreateQuery("(bar OR baz) foo");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2055,9 +2176,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2075,9 +2198,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) {
query = CreateQuery("foo NOT (bar OR baz)");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2140,9 +2265,11 @@ TEST_P(QueryVisitorTest, PropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2216,9 +2343,11 @@ TEST_F(QueryVisitorTest, MultiPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2259,9 +2388,11 @@ TEST_P(QueryVisitorTest, PropertyFilterStringIsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -2315,9 +2446,11 @@ TEST_P(QueryVisitorTest, PropertyFilterNonNormalized) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2386,9 +2519,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithGrouping) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2453,9 +2588,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2474,9 +2611,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) {
/*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2540,9 +2679,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2561,9 +2702,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) {
/*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2627,9 +2770,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2648,9 +2793,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) {
"NOT ", CreateQuery("(foo OR bar)", /*property_restrict=*/"prop1"));
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2728,9 +2875,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2755,9 +2904,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) {
query = CreateQuery("(NOT foo OR bar)", /*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2837,9 +2988,11 @@ TEST_P(QueryVisitorTest, SegmentationTest) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2957,9 +3110,11 @@ TEST_P(QueryVisitorTest, PropertyRestrictsPopCorrectly) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3074,9 +3229,11 @@ TEST_P(QueryVisitorTest, UnsatisfiablePropertyRestrictsPopCorrectly) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3098,9 +3255,11 @@ TEST_F(QueryVisitorTest, UnsupportedFunctionReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3112,9 +3271,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooFewArgumentsReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3126,9 +3287,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooManyArgumentsReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3142,9 +3305,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3154,9 +3319,11 @@ TEST_F(QueryVisitorTest,
query = R"(search(createList("subject")))";
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(),
@@ -3170,9 +3337,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3182,9 +3351,11 @@ TEST_F(QueryVisitorTest,
query = R"(search("foo", 7))";
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(),
@@ -3197,9 +3368,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3260,9 +3433,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_two_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3284,9 +3459,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
R"(", createList("prop1")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3308,9 +3485,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
R"(", createList("prop1")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_four_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_four_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_four_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3430,9 +3609,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_one_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3462,9 +3643,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
R"(", createList("prop6", "prop0", "prop4", "prop2")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3488,9 +3671,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
R"(", createList("prop0", "prop6")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3610,9 +3795,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_one_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3634,9 +3821,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
R"(", createList("prop6", "prop0", "prop4", "prop2")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3659,9 +3848,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
R"( "prop0", "prop6", "prop4", "prop7")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3685,9 +3876,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3701,9 +3894,11 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3717,9 +3912,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3732,9 +3929,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3786,9 +3985,11 @@ TEST_P(QueryVisitorTest, PropertyDefinedFunctionReturnsMatchingDocuments) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3839,9 +4040,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3890,9 +4093,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3910,9 +4115,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3925,9 +4132,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3941,9 +4150,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3956,9 +4167,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -4015,9 +4228,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor1(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor1);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -4033,9 +4248,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) {
query = CreateQuery("bar OR NOT hasProperty(\"price\")");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor2(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor2);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -4088,9 +4305,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -4102,6 +4321,890 @@ TEST_P(QueryVisitorTest,
EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
}
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithNoArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query = "semanticSearch()";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithIncorrectArgumentTypeReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query = "semanticSearch(0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithExtraArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, \"COSINE\", 0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ GetSearchSpecEmbeddingFunctionWithExtraArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query index is invalid, since there are only 2 queries.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0, 1))";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithInvalidIndexReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query index is invalid, since there are only 2 queries.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(10))";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithInvalidMetricReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query metric is invalid.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10, \"UNKNOWN\")";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ // Passing an unknown default metric type without overriding it in the query
+ // expression is also considered invalid.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -10, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ EXPECT_THAT(std::move(query_visitor2).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleLowerBound) {
+ // Index two embedding vectors.
+ PropertyProto::VectorProto vector0 =
+ CreateVector("my_model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector1 =
+ CreateVector("my_model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector0));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has a semantic score of 1 with vector0 and
+ // -1 with vector1.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // The query should match vector0 only.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+
+ // The query should match both vector0 and vector1.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -1.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match nothing, since there is no vector with a
+ // score >= 1.01.
+ query = "semanticSearch(getSearchSpecEmbedding(0), 1.01)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleUpperBound) {
+ // Index two embedding vectors.
+ PropertyProto::VectorProto vector0 =
+ CreateVector("my_model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector1 =
+ CreateVector("my_model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector0));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has a semantic score of 1 with vector0 and
+ // -1 with vector1.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // The query should match vector1 only.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), -100, 0.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match both vector0 and vector1.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -100, 1.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match nothing, since there is no vector with a
+ // score <= -1.01.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -100, -1.01)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionMetricOverride) {
+ // Index a embedding vector.
+ PropertyProto::VectorProto vector = CreateVector("my_model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has:
+ // - a cosine semantic score of 1
+ // - a dot product semantic score of 0.14
+ // - a euclidean semantic score of 0
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // Create a query that overrides the metric to COSINE.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.95, 1.05, \"COSINE\")";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+
+ // Create a query that overrides the metric to DOT_PRODUCT.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.1, 0.2, \"DOT_PRODUCT\")";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(0.14, kEps))));
+
+ // Create a query that overrides the metric to EUCLIDEAN.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), -0.05, 0.05, \"EUCLIDEAN\")";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_UNKNOWN,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_EUCLIDEAN, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(0, kEps))));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionMultipleQueries) {
+ // Index 3 embedding vectors for document 0.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId2, kDocumentId0),
+ CreateVector("my_model2", {-1, 2, 3, -4})));
+ // Index 2 embedding vectors for document 1.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId1),
+ CreateVector("my_model2", {1, -2, 3, -4})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ // Semantic scores for this query:
+ // - document 0: -2 (section 0), 0 (section 1)
+ // - document 1: 6 (section 0)
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+ // Semantic scores for this query:
+ // - document 0: 4 (section 2)
+ // - document 1: -2 (section 1)
+ embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model2", {-1, 1, -1, -1});
+
+ // The query can only match document 0:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part should match
+ // semantic scores {-2, 0}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0)" part should match
+ // semantic scores {4}.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -5) AND "
+ "semanticSearch(getSearchSpecEmbedding(1), 0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0,
+ std::vector<SectionId>{kSectionId0, kSectionId1,
+ kSectionId2}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0)));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(4)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // The query can match both document 0 and document 1:
+ // For document 0:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return
+ // semantic scores {}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return
+ // semantic scores {4}.
+ // For document 1:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return
+ // semantic scores {6}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return
+ // semantic scores {}.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), 1) OR "
+ "semanticSearch(getSearchSpecEmbedding(1), 0.1)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ IsNull());
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId2}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ IsNull());
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(4)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionMultipleQueriesScoresMergedRepeat) {
+ // Index 3 embedding vectors for document 0.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, -3})));
+ // Index 2 embedding vectors for document 1.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ // Semantic scores for this query:
+ // - document 0: -2 (section 0), 0 (section 1)
+ // - document 1: 6 (section 0)
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // The query should match both document 0 and document 1, since the overall
+ // range is [-10, 10]. The scores in the results should be merged.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 0) OR "
+ "semanticSearch(getSearchSpecEmbedding(0), 0.0001, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // The same query appears twice, in which case all the scores in the results
+ // should repeat twice.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10) OR "
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6, 6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0, -2, 0)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionHybridQueries) {
+ // Index terms
+ Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1,
+ TERM_MATCH_PREFIX, /*namespace_id=*/0);
+ ICING_ASSERT_OK(editor.BufferTerm("foo"));
+ ICING_ASSERT_OK(editor.IndexAllBufferedTerms());
+ editor = index_->Edit(kDocumentId1, kSectionId1, TERM_MATCH_PREFIX,
+ /*namespace_id=*/0);
+ ICING_ASSERT_OK(editor.BufferTerm("bar"));
+ ICING_ASSERT_OK(editor.IndexAllBufferedTerms());
+
+ // Index embedding vectors
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query with semantic scores:
+ // - document 0: -2
+ // - document 1: 6
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // Perform a hybrid search:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 0)" part only matches
+ // document 1.
+ // - The "foo" part only matches document 0.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0) OR foo";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators),
+ UnorderedElementsAre("foo"));
+ EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre(""));
+ EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo"));
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ IsNull());
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // Perform another hybrid search:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part matches both
+ // document 0 and 1.
+ // - The "foo" part only matches document 0.
+ // As a result, only document 0 will be returned.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -5) AND foo";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators),
+ UnorderedElementsAre("foo"));
+ EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre(""));
+ EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo"));
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSectionRestriction) {
+ ICING_ASSERT_OK(schema_store_->SetSchema(
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("type")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("prop1")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("prop2")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build(),
+ /*ignore_errors_and_delete_documents=*/false,
+ /*allow_circular_schema_definitions=*/false));
+
+ // Create two documents.
+ ICING_ASSERT_OK(document_store_->Put(
+ DocumentBuilder().SetKey("ns", "uri0").SetSchema("type").Build()));
+ ICING_ASSERT_OK(document_store_->Put(
+ DocumentBuilder().SetKey("ns", "uri1").SetSchema("type").Build()));
+ // Add embedding vectors into different sections for the two documents.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(
+ embedding_index_->BufferEmbedding(BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, 2, 3})));
+ ICING_ASSERT_OK(
+ embedding_index_->BufferEmbedding(BasicHit(kSectionId1, kDocumentId1),
+ CreateVector("my_model1", {1, 2, -3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query with semantic scores:
+ // - document 0: -2 (section 0), 6 (section 1)
+ // - document 1: 2 (section 0), -6 (section 1)
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // An embedding query with section restriction. The scores returned should
+ // only be limited to the section restricted.
+ std::string query = "prop1:semanticSearch(getSearchSpecEmbedding(0), -100)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(2)));
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
INSTANTIATE_TEST_SUITE_P(QueryVisitorTest, QueryVisitorTest,
testing::Values(QueryType::kPlain,
QueryType::kSearch));
diff --git a/icing/query/query-features.h b/icing/query/query-features.h
index d829cd7..bc3602f 100644
--- a/icing/query/query-features.h
+++ b/icing/query/query-features.h
@@ -52,9 +52,15 @@ constexpr Feature kListFilterQueryLanguageFeature =
constexpr Feature kHasPropertyFunctionFeature =
"HAS_PROPERTY_FUNCTION"; // Features#HAS_PROPERTY_FUNCTION
+// This feature relates to the use of embedding searches in the advanced query
+// language. Ex. `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`.
+constexpr Feature kEmbeddingSearchFeature =
+ "EMBEDDING_SEARCH"; // Features#EMBEDDING_SEARCH
+
inline std::unordered_set<Feature> GetQueryFeaturesSet() {
return {kNumericSearchFeature, kVerbatimSearchFeature,
- kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature};
+ kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature,
+ kEmbeddingSearchFeature};
}
} // namespace lib
diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc
index bc0917b..6e13001 100644
--- a/icing/query/query-processor.cc
+++ b/icing/query/query-processor.cc
@@ -27,6 +27,7 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h"
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
@@ -111,25 +112,28 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame(
libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>>
QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
const SchemaStore* schema_store, const Clock* clock) {
ICING_RETURN_ERROR_IF_NULL(index);
ICING_RETURN_ERROR_IF_NULL(numeric_index);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
ICING_RETURN_ERROR_IF_NULL(language_segmenter);
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
ICING_RETURN_ERROR_IF_NULL(clock);
- return std::unique_ptr<QueryProcessor>(
- new QueryProcessor(index, numeric_index, language_segmenter, normalizer,
- document_store, schema_store, clock));
+ return std::unique_ptr<QueryProcessor>(new QueryProcessor(
+ index, numeric_index, embedding_index, language_segmenter, normalizer,
+ document_store, schema_store, clock));
}
QueryProcessor::QueryProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
@@ -137,6 +141,7 @@ QueryProcessor::QueryProcessor(Index* index,
const Clock* clock)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
language_segmenter_(*language_segmenter),
normalizer_(*normalizer),
document_store_(*document_store),
@@ -229,11 +234,12 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery(
ranking_strategy == ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE;
std::unique_ptr<Timer> query_visitor_timer = clock_.GetNewTimer();
- QueryVisitor query_visitor(&index_, &numeric_index_, &document_store_,
- &schema_store_, &normalizer_,
- plain_tokenizer.get(), search_spec.query(),
- std::move(options), search_spec.term_match_type(),
- needs_term_frequency_info, current_time_ms);
+ QueryVisitor query_visitor(
+ &index_, &numeric_index_, &embedding_index_, &document_store_,
+ &schema_store_, &normalizer_, plain_tokenizer.get(), search_spec.query(),
+ &search_spec.embedding_query_vectors(), std::move(options),
+ search_spec.term_match_type(), search_spec.embedding_query_metric_type(),
+ needs_term_frequency_info, current_time_ms);
tree_root->Accept(&query_visitor);
ICING_ASSIGN_OR_RETURN(QueryResults results,
std::move(query_visitor).ConsumeResults());
diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h
index de256ee..d90b5f6 100644
--- a/icing/query/query-processor.h
+++ b/icing/query/query-processor.h
@@ -19,6 +19,7 @@
#include <memory>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/proto/logging.pb.h"
@@ -47,6 +48,7 @@ class QueryProcessor {
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
const DocumentStore* document_store, const SchemaStore* schema_store,
const Clock* clock);
@@ -74,6 +76,7 @@ class QueryProcessor {
private:
explicit QueryProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
@@ -110,6 +113,7 @@ class QueryProcessor {
// query time.
Index& index_; // Does not own.
const NumericIndex<int64_t>& numeric_index_; // Does not own.
+ const EmbeddingIndex& embedding_index_; // Does not own.
const LanguageSegmenter& language_segmenter_; // Does not own.
const Normalizer& normalizer_; // Does not own.
const DocumentStore& document_store_; // Does not own.
diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc
index 3be74fd..5b3bf99 100644
--- a/icing/query/query-processor_benchmark.cc
+++ b/icing/query/query-processor_benchmark.cc
@@ -12,26 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "testing/base/public/benchmark.h"
-#include "gmock/gmock.h"
#include "third_party/absl/flags/flag.h"
+#include "icing/absl_ports/str_cat.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/dummy-numeric-index.h"
-#include "icing/index/numeric/numeric-index.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/term.pb.h"
#include "icing/query/query-processor.h"
+#include "icing/query/query-results.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
+#include "icing/transform/normalizer.h"
#include "icing/util/clock.h"
#include "icing/util/logging.h"
#include "unicode/uloc.h"
@@ -118,6 +132,7 @@ void BM_QueryOneTerm(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -134,6 +149,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -172,8 +190,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get(), &clock));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -247,6 +266,7 @@ void BM_QueryFiveTerms(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -263,6 +283,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -315,8 +338,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get(), &clock));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
const std::string query_string = absl_ports::StrCat(
input_string_a, " ", input_string_b, " ", input_string_c, " ",
@@ -394,6 +418,7 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -410,6 +435,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -451,8 +479,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get(), &clock));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -526,6 +555,7 @@ void BM_QueryHiragana(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -542,6 +572,9 @@ void BM_QueryHiragana(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -583,8 +616,9 @@ void BM_QueryHiragana(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get(), &clock));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc
index 43c9629..288c7fb 100644
--- a/icing/query/query-processor_test.cc
+++ b/icing/query/query-processor_test.cc
@@ -14,17 +14,24 @@
#include "icing/query/query-processor.h"
+#include <array>
#include <cstdint>
#include <memory>
#include <string>
+#include <unordered_map>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -54,6 +61,8 @@
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
+#include "icing/util/status-macros.h"
#include "unicode/uloc.h"
namespace icing {
@@ -87,7 +96,8 @@ class QueryProcessorTest
store_dir_(test_dir_ + "/store"),
schema_store_dir_(test_dir_ + "/schema_store"),
index_dir_(test_dir_ + "/index"),
- numeric_index_dir_(test_dir_ + "/numeric_index") {}
+ numeric_index_dir_(test_dir_ + "/numeric_index"),
+ embedding_index_dir_(test_dir_ + "/embedding_index") {}
void SetUp() override {
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -126,6 +136,9 @@ class QueryProcessorTest
ICING_ASSERT_OK_AND_ASSIGN(
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
language_segmenter_factory::SegmenterOptions segmenter_options(
ULOC_US, jni_cache_.get());
@@ -138,10 +151,10 @@ class QueryProcessorTest
ICING_ASSERT_OK_AND_ASSIGN(
query_processor_,
- QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get(),
- &fake_clock_));
+ QueryProcessor::Create(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
}
libtextclassifier3::Status AddTokenToIndex(
@@ -177,10 +190,12 @@ class QueryProcessorTest
IcingFilesystem icing_filesystem_;
const std::string index_dir_;
const std::string numeric_index_dir_;
+ const std::string embedding_index_dir_;
protected:
std::unique_ptr<Index> index_;
std::unique_ptr<NumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
FakeClock fake_clock_;
@@ -191,42 +206,53 @@ class QueryProcessorTest
};
TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) {
- EXPECT_THAT(QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(),
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get(), &fake_clock_),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr,
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get(), &fake_clock_),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr,
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
- /*language_segmenter=*/nullptr,
+ /*embedding_index=*/nullptr,
+ language_segmenter_.get(),
normalizer_.get(), document_store_.get(),
schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(QueryProcessor::Create(
- index_.get(), numeric_index_.get(), language_segmenter_.get(),
- /*normalizer=*/nullptr, document_store_.get(),
- schema_store_.get(), &fake_clock_),
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ /*language_segmenter=*/nullptr, normalizer_.get(),
+ document_store_.get(), schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- /*document_store=*/nullptr, schema_store_.get(),
- &fake_clock_),
+ embedding_index_.get(), language_segmenter_.get(),
+ /*normalizer=*/nullptr, document_store_.get(),
+ schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ /*document_store=*/nullptr, schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ /*schema_store=*/nullptr, &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), /*clock=*/nullptr),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- /*schema_store=*/nullptr, &fake_clock_),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get(), /*clock=*/nullptr),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) {
@@ -2956,9 +2982,9 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> local_query_processor,
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get(),
- &fake_clock_));
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
SearchSpecProto search_spec;
search_spec.set_query("hello");
@@ -3019,9 +3045,9 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> local_query_processor,
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get(),
- &fake_clock_));
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
SearchSpecProto search_spec;
search_spec.set_query("hello");
diff --git a/icing/query/query-results.h b/icing/query/query-results.h
index 52cdd71..983ab35 100644
--- a/icing/query/query-results.h
+++ b/icing/query/query-results.h
@@ -18,9 +18,10 @@
#include <memory>
#include <unordered_set>
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
-#include "icing/query/query-terms.h"
#include "icing/query/query-features.h"
+#include "icing/query/query-terms.h"
namespace icing {
namespace lib {
@@ -35,6 +36,10 @@ struct QueryResults {
// beginning with root_iterator.
// This will only be populated when ranking_strategy == RELEVANCE_SCORE.
QueryTermIteratorsMap query_term_iterators;
+ // Contains similarity scores from embedding based queries, which will be used
+ // in the advanced scoring language to determine the results for the
+ // "this.matchedSemanticScores(...)" function.
+ EmbeddingQueryResults embedding_query_results;
// Features that are invoked during query execution.
// The list of possible features is defined in query_features.h.
std::unordered_set<Feature> features_in_use;
diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc
index 2d73c6b..dfebb98 100644
--- a/icing/query/suggestion-processor.cc
+++ b/icing/query/suggestion-processor.cc
@@ -26,11 +26,14 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/index/term-metadata.h"
#include "icing/proto/search.pb.h"
#include "icing/query/query-processor.h"
+#include "icing/query/query-results.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/store/document-filter-data.h"
@@ -49,6 +52,7 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
SuggestionProcessor::Create(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
@@ -56,15 +60,16 @@ SuggestionProcessor::Create(Index* index,
const Clock* clock) {
ICING_RETURN_ERROR_IF_NULL(index);
ICING_RETURN_ERROR_IF_NULL(numeric_index);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
ICING_RETURN_ERROR_IF_NULL(language_segmenter);
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
ICING_RETURN_ERROR_IF_NULL(clock);
- return std::unique_ptr<SuggestionProcessor>(
- new SuggestionProcessor(index, numeric_index, language_segmenter,
- normalizer, document_store, schema_store, clock));
+ return std::unique_ptr<SuggestionProcessor>(new SuggestionProcessor(
+ index, numeric_index, embedding_index, language_segmenter, normalizer,
+ document_store, schema_store, clock));
}
libtextclassifier3::StatusOr<
@@ -246,9 +251,9 @@ SuggestionProcessor::QuerySuggestions(
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(&index_, &numeric_index_, &language_segmenter_,
- &normalizer_, &document_store_, &schema_store_,
- &clock_));
+ QueryProcessor::Create(&index_, &numeric_index_, &embedding_index_,
+ &language_segmenter_, &normalizer_,
+ &document_store_, &schema_store_, &clock_));
SearchSpecProto search_spec;
search_spec.set_query(suggestion_spec.prefix());
@@ -321,11 +326,13 @@ SuggestionProcessor::QuerySuggestions(
SuggestionProcessor::SuggestionProcessor(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
const DocumentStore* document_store, const SchemaStore* schema_store,
const Clock* clock)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
language_segmenter_(*language_segmenter),
normalizer_(*normalizer),
document_store_(*document_store),
diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h
index fe1437c..cf393b4 100644
--- a/icing/query/suggestion-processor.h
+++ b/icing/query/suggestion-processor.h
@@ -20,6 +20,7 @@
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/proto/search.pb.h"
@@ -46,6 +47,7 @@ class SuggestionProcessor {
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
Create(Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer, const DocumentStore* document_store,
const SchemaStore* schema_store, const Clock* clock);
@@ -62,6 +64,7 @@ class SuggestionProcessor {
private:
explicit SuggestionProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
@@ -72,6 +75,7 @@ class SuggestionProcessor {
// index.
Index& index_; // Does not own.
const NumericIndex<int64_t>& numeric_index_; // Does not own.
+ const EmbeddingIndex& embedding_index_; // Does not own.
const LanguageSegmenter& language_segmenter_; // Does not own.
const Normalizer& normalizer_; // Does not own.
const DocumentStore& document_store_; // Does not own.
diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc
index 4c5e4ac..c9bdd8a 100644
--- a/icing/query/suggestion-processor_test.cc
+++ b/icing/query/suggestion-processor_test.cc
@@ -14,14 +14,30 @@
#include "icing/query/suggestion-processor.h"
+#include <cstdint>
+#include <memory>
#include <string>
+#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
+#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/index.h"
#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/index/term-metadata.h"
+#include "icing/jni/jni-cache.h"
+#include "icing/legacy/index/icing-filesystem.h"
+#include "icing/portable/platform.h"
#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
@@ -30,7 +46,9 @@
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
+#include "icing/transform/normalizer.h"
#include "unicode/uloc.h"
namespace icing {
@@ -59,7 +77,8 @@ class SuggestionProcessorTest : public Test {
store_dir_(test_dir_ + "/store"),
schema_store_dir_(test_dir_ + "/schema_store"),
index_dir_(test_dir_ + "/index"),
- numeric_index_dir_(test_dir_ + "/numeric_index") {}
+ numeric_index_dir_(test_dir_ + "/numeric_index"),
+ embedding_index_dir_(test_dir_ + "/embedding_index") {}
void SetUp() override {
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -105,6 +124,9 @@ class SuggestionProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
language_segmenter_factory::SegmenterOptions segmenter_options(
ULOC_US, jni_cache_.get());
@@ -117,10 +139,10 @@ class SuggestionProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
suggestion_processor_,
- SuggestionProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get(), &fake_clock_));
+ SuggestionProcessor::Create(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
}
libtextclassifier3::Status AddTokenToIndex(
@@ -147,10 +169,12 @@ class SuggestionProcessorTest : public Test {
IcingFilesystem icing_filesystem_;
const std::string index_dir_;
const std::string numeric_index_dir_;
+ const std::string embedding_index_dir_;
protected:
std::unique_ptr<Index> index_;
std::unique_ptr<NumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
FakeClock fake_clock_;
@@ -674,6 +698,16 @@ TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) {
EXPECT_THAT(terms_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
+
+ if (SearchSpecProto::default_instance().search_type() ==
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ suggestion_spec.set_prefix(
+ "bar OR semanticSearch(getSearchSpecEmbedding(0), 0.5, 1)");
+ terms_or = suggestion_processor_->QuerySuggestions(
+ suggestion_spec, fake_clock_.GetSystemTimeMilliseconds());
+ EXPECT_THAT(terms_or,
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ }
}
TEST_F(SuggestionProcessorTest, InvalidPrefixTest) {
diff --git a/icing/schema-builder.h b/icing/schema-builder.h
index c74505e..c702afe 100644
--- a/icing/schema-builder.h
+++ b/icing/schema-builder.h
@@ -56,6 +56,13 @@ constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_UNKNOWN =
constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_RANGE =
IntegerIndexingConfig::NumericMatchType::RANGE;
+constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ EMBEDDING_INDEXING_UNKNOWN =
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN;
+constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ EMBEDDING_INDEXING_LINEAR_SEARCH =
+ EmbeddingIndexingConfig::EmbeddingIndexingType::LINEAR_SEARCH;
+
constexpr PropertyConfigProto::DataType::Code TYPE_UNKNOWN =
PropertyConfigProto::DataType::UNKNOWN;
constexpr PropertyConfigProto::DataType::Code TYPE_STRING =
@@ -70,6 +77,8 @@ constexpr PropertyConfigProto::DataType::Code TYPE_BYTES =
PropertyConfigProto::DataType::BYTES;
constexpr PropertyConfigProto::DataType::Code TYPE_DOCUMENT =
PropertyConfigProto::DataType::DOCUMENT;
+constexpr PropertyConfigProto::DataType::Code TYPE_VECTOR =
+ PropertyConfigProto::DataType::VECTOR;
constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_NONE =
JoinableConfig::ValueType::NONE;
@@ -146,6 +155,15 @@ class PropertyConfigBuilder {
return *this;
}
+ PropertyConfigBuilder& SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ embedding_indexing_type) {
+ property_.set_data_type(PropertyConfigProto::DataType::VECTOR);
+ property_.mutable_embedding_indexing_config()->set_embedding_indexing_type(
+ embedding_indexing_type);
+ return *this;
+ }
+
PropertyConfigBuilder& SetJoinable(
JoinableConfig::ValueType::Code join_value_type, bool propagate_delete) {
property_.mutable_joinable_config()->set_value_type(join_value_type);
@@ -159,6 +177,11 @@ class PropertyConfigBuilder {
return *this;
}
+ PropertyConfigBuilder& SetDescription(std::string description) {
+ property_.set_description(std::move(description));
+ return *this;
+ }
+
PropertyConfigProto Build() const { return std::move(property_); }
private:
@@ -186,6 +209,11 @@ class SchemaTypeConfigBuilder {
return *this;
}
+ SchemaTypeConfigBuilder& SetDescription(std::string description) {
+ type_config_.set_description(std::move(description));
+ return *this;
+ }
+
SchemaTypeConfigBuilder& AddProperty(PropertyConfigProto property) {
*type_config_.add_properties() = std::move(property);
return *this;
diff --git a/icing/schema/property-util.cc b/icing/schema/property-util.cc
index 67ff748..7c86122 100644
--- a/icing/schema/property-util.cc
+++ b/icing/schema/property-util.cc
@@ -14,6 +14,8 @@
#include "icing/schema/property-util.h"
+#include <cstddef>
+#include <cstdint>
#include <string>
#include <string_view>
#include <vector>
@@ -131,6 +133,14 @@ ExtractPropertyValues<int64_t>(const PropertyProto& property) {
property.int64_values().end());
}
+template <>
+libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>>
+ExtractPropertyValues<PropertyProto::VectorProto>(
+ const PropertyProto& property) {
+ return std::vector<PropertyProto::VectorProto>(
+ property.vector_values().begin(), property.vector_values().end());
+}
+
} // namespace property_util
} // namespace lib
diff --git a/icing/schema/property-util.h b/icing/schema/property-util.h
index 7557879..c409a9d 100644
--- a/icing/schema/property-util.h
+++ b/icing/schema/property-util.h
@@ -15,8 +15,11 @@
#ifndef ICING_SCHEMA_PROPERTY_UTIL_H_
#define ICING_SCHEMA_PROPERTY_UTIL_H_
+#include <cstddef>
+#include <cstdint>
#include <string>
#include <string_view>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
@@ -163,6 +166,11 @@ template <>
libtextclassifier3::StatusOr<std::vector<int64_t>>
ExtractPropertyValues<int64_t>(const PropertyProto& property);
+template <>
+libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>>
+ExtractPropertyValues<PropertyProto::VectorProto>(
+ const PropertyProto& property);
+
template <typename T>
libtextclassifier3::StatusOr<std::vector<T>> ExtractPropertyValuesFromDocument(
const DocumentProto& document, std::string_view property_path) {
diff --git a/icing/schema/property-util_test.cc b/icing/schema/property-util_test.cc
index eddcc84..5e8a430 100644
--- a/icing/schema/property-util_test.cc
+++ b/icing/schema/property-util_test.cc
@@ -22,6 +22,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/portable/equals-proto.h"
#include "icing/proto/document.pb.h"
#include "icing/testing/common-matchers.h"
@@ -30,6 +31,7 @@ namespace lib {
namespace {
+using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
@@ -38,6 +40,8 @@ static constexpr std::string_view kPropertySingleString = "singleString";
static constexpr std::string_view kPropertyRepeatedString = "repeatedString";
static constexpr std::string_view kPropertySingleInteger = "singleInteger";
static constexpr std::string_view kPropertyRepeatedInteger = "repeatedInteger";
+static constexpr std::string_view kPropertySingleVector = "singleVector";
+static constexpr std::string_view kPropertyRepeatedVector = "repeatedVector";
static constexpr std::string_view kTypeNestedTest = "NestedTest";
static constexpr std::string_view kPropertyStr = "str";
@@ -83,6 +87,28 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeInteger) {
IsOkAndHolds(ElementsAre(123, -456, 0)));
}
+TEST(PropertyUtilTest, ExtractPropertyValuesTypeVector) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
+ PropertyProto property;
+ *property.mutable_vector_values()->Add() = vector1;
+ *property.mutable_vector_values()->Add() = vector2;
+
+ EXPECT_THAT(
+ property_util::ExtractPropertyValues<PropertyProto::VectorProto>(
+ property),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
+}
+
TEST(PropertyUtilTest, ExtractPropertyValuesMismatchedType) {
PropertyProto property;
property.mutable_int64_values()->Add(123);
@@ -110,6 +136,16 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeUnimplemented) {
}
TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
DocumentProto document =
DocumentBuilder()
.SetKey("icing", "test/1")
@@ -119,6 +155,9 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
"repeated2", "repeated3")
.AddInt64Property(std::string(kPropertySingleInteger), 123)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3)
+ .AddVectorProperty(std::string(kPropertySingleVector), vector1)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector), vector1,
+ vector2)
.Build();
// Single string
@@ -139,9 +178,33 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<int64_t>(
document, /*property_path=*/kPropertyRepeatedInteger),
IsOkAndHolds(ElementsAre(1, 2, 3)));
+ // Single vector
+ EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ document, /*property_path=*/kPropertySingleVector),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1))));
+ // Repeated vector
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ document, /*property_path=*/kPropertyRepeatedVector),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
}
TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+ PropertyProto::VectorProto vector3;
+ vector3.set_model_signature("my_model3");
+ vector3.add_values(1.0f);
+
DocumentProto nested_document =
DocumentBuilder()
.SetKey("icing", "nested/1")
@@ -158,6 +221,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
.AddInt64Property(std::string(kPropertySingleInteger), 123)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2,
3)
+ .AddVectorProperty(std::string(kPropertySingleVector),
+ vector1)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector),
+ vector1, vector2)
.Build(),
DocumentBuilder()
.SetSchema(std::string(kTypeTest))
@@ -168,6 +235,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
.AddInt64Property(std::string(kPropertySingleInteger), 456)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 4, 5,
6)
+ .AddVectorProperty(std::string(kPropertySingleVector),
+ vector2)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector),
+ vector2, vector3)
.Build())
.Build();
@@ -189,6 +260,17 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
property_util::ExtractPropertyValuesFromDocument<int64_t>(
nested_document, /*property_path=*/"nestedDocument.repeatedInteger"),
IsOkAndHolds(ElementsAre(1, 2, 3, 4, 5, 6)));
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ nested_document, /*property_path=*/"nestedDocument.singleVector"),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ nested_document, /*property_path=*/"nestedDocument.repeatedVector"),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2),
+ EqualsProto(vector2), EqualsProto(vector3))));
// Test the property at first level
EXPECT_THAT(
diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc
index 8cc7008..ca5cdd3 100644
--- a/icing/schema/schema-store_test.cc
+++ b/icing/schema/schema-store_test.cc
@@ -14,8 +14,10 @@
#include "icing/schema/schema-store.h"
+#include <cstdint>
#include <memory>
#include <string>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
@@ -23,6 +25,7 @@
#include "gtest/gtest.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/document-builder.h"
+#include "icing/file/file-backed-proto.h"
#include "icing/file/filesystem.h"
#include "icing/file/mock-filesystem.h"
#include "icing/file/version-util.h"
@@ -32,10 +35,7 @@
#include "icing/proto/logging.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/storage.pb.h"
-#include "icing/proto/term.pb.h"
#include "icing/schema-builder.h"
-#include "icing/schema/schema-util.h"
-#include "icing/schema/section-manager.h"
#include "icing/schema/section.h"
#include "icing/store/document-filter-data.h"
#include "icing/testing/common-matchers.h"
@@ -142,7 +142,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveConstructible) {
IsOkAndHolds(Eq(expected_checksum)));
SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN,
TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN,
- "prop1");
+ EMBEDDING_INDEXING_UNKNOWN, "prop1");
EXPECT_THAT(move_constructed_schema_store.GetSectionMetadata("type_a"),
IsOkAndHolds(Pointee(ElementsAre(expected_metadata))));
}
@@ -193,7 +193,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveAssignment) {
IsOkAndHolds(Eq(expected_checksum)));
SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN,
TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN,
- "prop1");
+ EMBEDDING_INDEXING_UNKNOWN, "prop1");
EXPECT_THAT(move_assigned_schema_store->GetSectionMetadata("type_a"),
IsOkAndHolds(Pointee(ElementsAre(expected_metadata))));
}
@@ -1527,10 +1527,10 @@ TEST_F(SchemaStoreTest, SetSchemaRegenerateDerivedFilesFailure) {
.Build();
SectionMetadata expected_int_prop1_metadata(
/*id_in=*/0, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, "intProp1");
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, "intProp1");
SectionMetadata expected_string_prop1_metadata(
/*id_in=*/1, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT,
- NUMERIC_MATCH_UNKNOWN, "stringProp1");
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, "stringProp1");
ICING_ASSERT_OK_AND_ASSIGN(SectionGroup section_group,
schema_store->ExtractSections(document));
ASSERT_THAT(section_group.string_sections, SizeIs(1));
diff --git a/icing/schema/schema-util.cc b/icing/schema/schema-util.cc
index 72287a8..976d1b7 100644
--- a/icing/schema/schema-util.cc
+++ b/icing/schema/schema-util.cc
@@ -115,6 +115,12 @@ bool IsIntegerNumericMatchTypeCompatible(
return old_indexed.numeric_match_type() == new_indexed.numeric_match_type();
}
+bool IsEmbeddingIndexingCompatible(const EmbeddingIndexingConfig& old_indexed,
+ const EmbeddingIndexingConfig& new_indexed) {
+ return old_indexed.embedding_indexing_type() ==
+ new_indexed.embedding_indexing_type();
+}
+
bool IsDocumentIndexingCompatible(const DocumentIndexingConfig& old_indexed,
const DocumentIndexingConfig& new_indexed) {
// TODO(b/265304217): This could mark the new schema as incompatible and
@@ -824,6 +830,10 @@ libtextclassifier3::Status SchemaUtil::ValidateDocumentIndexingConfig(
!property_config.document_indexing_config()
.indexable_nested_properties_list()
.empty();
+ case PropertyConfigProto::DataType::VECTOR:
+ return property_config.embedding_indexing_config()
+ .embedding_indexing_type() !=
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN;
case PropertyConfigProto::DataType::UNKNOWN:
case PropertyConfigProto::DataType::DOUBLE:
case PropertyConfigProto::DataType::BOOLEAN:
@@ -1087,7 +1097,10 @@ const SchemaUtil::SchemaDelta SchemaUtil::ComputeCompatibilityDelta(
new_property_config->integer_indexing_config()) ||
!IsDocumentIndexingCompatible(
old_property_config.document_indexing_config(),
- new_property_config->document_indexing_config())) {
+ new_property_config->document_indexing_config()) ||
+ !IsEmbeddingIndexingCompatible(
+ old_property_config.embedding_indexing_config(),
+ new_property_config->embedding_indexing_config())) {
is_index_incompatible = true;
}
diff --git a/icing/schema/schema-util_test.cc b/icing/schema/schema-util_test.cc
index 82683ba..e77ffab 100644
--- a/icing/schema/schema-util_test.cc
+++ b/icing/schema/schema-util_test.cc
@@ -2900,6 +2900,118 @@ TEST_P(SchemaUtilTest,
IsEmpty());
}
+TEST_P(SchemaUtilTest, ChangingIndexedVectorPropertiesMakesIndexIncompatible) {
+ SchemaProto schema_with_indexed_property =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaProto schema_with_unindexed_property =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::SchemaDelta schema_delta;
+ schema_delta.schema_types_index_incompatible.insert(kPersonType);
+
+ // New schema gained a new indexed vector property.
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(
+ schema_with_unindexed_property, schema_with_indexed_property,
+ no_dependents_map),
+ Eq(schema_delta));
+
+ // New schema lost an indexed vector property.
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(
+ schema_with_indexed_property, schema_with_unindexed_property,
+ no_dependents_map),
+ Eq(schema_delta));
+}
+
+TEST_P(SchemaUtilTest, AddingNewIndexedVectorPropertyMakesIndexIncompatible) {
+ // Configure old schema
+ SchemaProto old_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ // Configure new schema
+ SchemaProto new_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("NewEmbeddingProperty")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::SchemaDelta schema_delta;
+ schema_delta.schema_types_index_incompatible.insert(kPersonType);
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema,
+ no_dependents_map),
+ Eq(schema_delta));
+}
+
+TEST_P(SchemaUtilTest,
+ AddingNewNonIndexedVectorPropertyShouldRemainIndexCompatible) {
+ // Configure old schema
+ SchemaProto old_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ // Configure new schema
+ SchemaProto new_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("NewProperty")
+ .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema,
+ no_dependents_map)
+ .schema_types_index_incompatible,
+ IsEmpty());
+}
+
TEST_P(SchemaUtilTest,
AddingNewIndexedDocumentPropertyMakesIndexAndJoinIncompatible) {
SchemaTypeConfigProto nested_schema =
diff --git a/icing/schema/section-manager.cc b/icing/schema/section-manager.cc
index 3d540d6..8689bf2 100644
--- a/icing/schema/section-manager.cc
+++ b/icing/schema/section-manager.cc
@@ -61,6 +61,7 @@ libtextclassifier3::Status AppendNewSectionMetadata(
property_config.string_indexing_config().tokenizer_type(),
property_config.string_indexing_config().term_match_type(),
property_config.integer_indexing_config().numeric_match_type(),
+ property_config.embedding_indexing_config().embedding_indexing_type(),
std::move(concatenated_path)));
return libtextclassifier3::Status::OK;
}
@@ -162,6 +163,19 @@ libtextclassifier3::StatusOr<SectionGroup> SectionManager::ExtractSections(
section_group.integer_sections);
break;
}
+ case PropertyConfigProto::DataType::VECTOR: {
+ if (section_metadata.embedding_indexing_type ==
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN) {
+ // Skip if embedding indexing type is UNKNOWN.
+ break;
+ }
+ AppendSection(
+ section_metadata,
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(document, section_metadata.path),
+ section_group.vector_sections);
+ break;
+ }
default: {
// Skip other data types.
break;
diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc
index eee78e9..b735fb1 100644
--- a/icing/schema/section-manager_test.cc
+++ b/icing/schema/section-manager_test.cc
@@ -14,10 +14,12 @@
#include "icing/schema/section-manager.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -27,6 +29,8 @@
#include "icing/schema-builder.h"
#include "icing/schema/schema-type-manager.h"
#include "icing/schema/schema-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-filter-data.h"
#include "icing/store/dynamic-trie-key-mapper.h"
#include "icing/store/key-mapper.h"
#include "icing/testing/common-matchers.h"
@@ -37,6 +41,7 @@ namespace lib {
namespace {
+using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
using ::testing::Pointee;
@@ -48,6 +53,8 @@ static constexpr std::string_view kTypeEmail = "Email";
static constexpr std::string_view kPropertyRecipientIds = "recipientIds";
static constexpr std::string_view kPropertyRecipients = "recipients";
static constexpr std::string_view kPropertySubject = "subject";
+static constexpr std::string_view kPropertySubjectEmbedding =
+ "subjectEmbedding";
static constexpr std::string_view kPropertyTimestamp = "timestamp";
// non-indexable
static constexpr std::string_view kPropertyAttachment = "attachment";
@@ -109,6 +116,14 @@ PropertyConfigProto CreateSubjectPropertyConfig() {
.Build();
}
+PropertyConfigProto CreateSubjectEmbeddingPropertyConfig() {
+ return PropertyConfigBuilder()
+ .SetName(kPropertySubjectEmbedding)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+}
+
PropertyConfigProto CreateTimestampPropertyConfig() {
return PropertyConfigBuilder()
.SetName(kPropertyTimestamp)
@@ -145,6 +160,7 @@ SchemaTypeConfigProto CreateEmailTypeConfig() {
return SchemaTypeConfigBuilder()
.SetType(kTypeEmail)
.AddProperty(CreateSubjectPropertyConfig())
+ .AddProperty(CreateSubjectEmbeddingPropertyConfig())
.AddProperty(PropertyConfigBuilder()
.SetName(kPropertyText)
.SetDataTypeString(TERM_MATCH_UNKNOWN, TOKENIZER_NONE)
@@ -220,11 +236,19 @@ class SectionManagerTest : public ::testing::Test {
ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeConversation, 1));
ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeGroup, 2));
+ email_subject_embedding_ = PropertyProto::VectorProto();
+ email_subject_embedding_.add_values(1.0);
+ email_subject_embedding_.add_values(2.0);
+ email_subject_embedding_.add_values(3.0);
+ email_subject_embedding_.set_model_signature("my_model");
+
email_document_ =
DocumentBuilder()
.SetKey("icing", "email/1")
.SetSchema(std::string(kTypeEmail))
.AddStringProperty(std::string(kPropertySubject), "the subject")
+ .AddVectorProperty(std::string(kPropertySubjectEmbedding),
+ email_subject_embedding_)
.AddStringProperty(std::string(kPropertyText), "the text")
.AddBytesProperty(std::string(kPropertyAttachment),
"attachment bytes")
@@ -265,6 +289,7 @@ class SectionManagerTest : public ::testing::Test {
SchemaUtil::TypeConfigMap type_config_map_;
std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper_;
+ PropertyProto::VectorProto email_subject_embedding_;
DocumentProto email_document_;
DocumentProto conversation_document_;
DocumentProto group_document_;
@@ -308,11 +333,21 @@ TEST_F(SectionManagerTest, ExtractSections) {
EXPECT_THAT(section_group.integer_sections[0].content, ElementsAre(1, 2, 3));
EXPECT_THAT(section_group.integer_sections[1].metadata,
- EqualsSectionMetadata(/*expected_id=*/3,
+ EqualsSectionMetadata(/*expected_id=*/4,
/*expected_property_path=*/"timestamp",
CreateTimestampPropertyConfig()));
EXPECT_THAT(section_group.integer_sections[1].content,
ElementsAre(kDefaultTimestamp));
+
+ // Vector sections
+ EXPECT_THAT(section_group.vector_sections, SizeIs(1));
+ EXPECT_THAT(
+ section_group.vector_sections[0].metadata,
+ EqualsSectionMetadata(/*expected_id=*/3,
+ /*expected_property_path=*/"subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()));
+ EXPECT_THAT(section_group.vector_sections[0].content,
+ ElementsAre(EqualsProto(email_subject_embedding_)));
}
TEST_F(SectionManagerTest, ExtractSectionsNested) {
@@ -359,11 +394,22 @@ TEST_F(SectionManagerTest, ExtractSectionsNested) {
EXPECT_THAT(
section_group.integer_sections[1].metadata,
- EqualsSectionMetadata(/*expected_id=*/3,
+ EqualsSectionMetadata(/*expected_id=*/4,
/*expected_property_path=*/"emails.timestamp",
CreateTimestampPropertyConfig()));
EXPECT_THAT(section_group.integer_sections[1].content,
ElementsAre(kDefaultTimestamp, kDefaultTimestamp));
+
+ // Vector sections
+ EXPECT_THAT(section_group.vector_sections, SizeIs(1));
+ EXPECT_THAT(section_group.vector_sections[0].metadata,
+ EqualsSectionMetadata(
+ /*expected_id=*/3,
+ /*expected_property_path=*/"emails.subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()));
+ EXPECT_THAT(section_group.vector_sections[0].content,
+ ElementsAre(EqualsProto(email_subject_embedding_),
+ EqualsProto(email_subject_embedding_)));
}
TEST_F(SectionManagerTest, ExtractSectionsIndexableNestedPropertiesList) {
@@ -461,7 +507,8 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
// 0 -> recipientIds
// 1 -> recipients
// 2 -> subject
- // 3 -> timestamp
+ // 3 -> subjectEmbedding
+ // 4 -> timestamp
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/0, /*section_id=*/0),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
@@ -472,13 +519,30 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
IsOkAndHolds(Pointee(EqualsSectionMetadata(
/*expected_id=*/1, /*expected_property_path=*/"recipients",
CreateRecipientsPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/2),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/2, /*expected_property_path=*/"subject",
+ CreateSubjectPropertyConfig()))));
+ EXPECT_THAT(
+ schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/3),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/3, /*expected_property_path=*/"subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/4),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/4, /*expected_property_path=*/"timestamp",
+ CreateTimestampPropertyConfig()))));
// Conversation (section id -> section property path):
// 0 -> emails.recipientIds
// 1 -> emails.recipients
// 2 -> emails.subject
- // 3 -> emails.timestamp
- // 4 -> name
+ // 3 -> emails.subjectEmbedding
+ // 4 -> emails.timestamp
+ // 5 -> name
EXPECT_THAT(
schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/1, /*section_id=*/0),
@@ -497,16 +561,22 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
IsOkAndHolds(Pointee(EqualsSectionMetadata(
/*expected_id=*/2, /*expected_property_path=*/"emails.subject",
CreateSubjectPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/1, /*section_id=*/3),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/3,
+ /*expected_property_path=*/"emails.subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()))));
EXPECT_THAT(
schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/3),
+ /*schema_type_id=*/1, /*section_id=*/4),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
- /*expected_id=*/3, /*expected_property_path=*/"emails.timestamp",
+ /*expected_id=*/4, /*expected_property_path=*/"emails.timestamp",
CreateTimestampPropertyConfig()))));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/4),
+ /*schema_type_id=*/1, /*section_id=*/5),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
- /*expected_id=*/4, /*expected_property_path=*/"name",
+ /*expected_id=*/5, /*expected_property_path=*/"name",
CreateNamePropertyConfig()))));
// Group (section id -> section property path):
@@ -615,25 +685,27 @@ TEST_F(SectionManagerTest, GetSectionMetadataInvalidSectionId) {
// 0 -> recipientIds
// 1 -> recipients
// 2 -> subject
- // 3 -> timestamp
+ // 3 -> subjectEmbedding
+ // 4 -> timestamp
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/0, /*section_id=*/-1),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/0, /*section_id=*/4),
+ /*schema_type_id=*/0, /*section_id=*/5),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// Conversation (section id -> section property path):
// 0 -> emails.recipientIds
// 1 -> emails.recipients
// 2 -> emails.subject
- // 3 -> emails.timestamp
- // 4 -> name
+ // 3 -> emails.subjectEmbedding
+ // 4 -> emails.timestamp
+ // 5 -> name
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/1, /*section_id=*/-1),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/5),
+ /*schema_type_id=*/1, /*section_id=*/6),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
diff --git a/icing/schema/section.h b/icing/schema/section.h
index 3685a29..76e9e57 100644
--- a/icing/schema/section.h
+++ b/icing/schema/section.h
@@ -21,6 +21,7 @@
#include <utility>
#include <vector>
+#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/term.pb.h"
@@ -89,18 +90,31 @@ struct SectionMetadata {
// Contents will be matched by a range query.
IntegerIndexingConfig::NumericMatchType::Code numeric_match_type;
+ // How vectors in a vector section should be indexed.
+ //
+ // EmbeddingIndexingType::UNKNOWN:
+ // Contents will not be indexed. It is invalid for a vector section
+ // (data_type == 'VECTOR') to have embedding_indexing_type == 'UNKNOWN'.
+ //
+ // EmbeddingIndexingType::LINEAR_SEARCH:
+ // Contents will be indexed for linear search.
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code embedding_indexing_type;
+
explicit SectionMetadata(
SectionId id_in, PropertyConfigProto::DataType::Code data_type_in,
StringIndexingConfig::TokenizerType::Code tokenizer,
TermMatchType::Code term_match_type_in,
IntegerIndexingConfig::NumericMatchType::Code numeric_match_type_in,
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ embedding_indexing_type_in,
std::string&& path_in)
: path(std::move(path_in)),
id(id_in),
data_type(data_type_in),
tokenizer(tokenizer),
term_match_type(term_match_type_in),
- numeric_match_type(numeric_match_type_in) {}
+ numeric_match_type(numeric_match_type_in),
+ embedding_indexing_type(embedding_indexing_type_in) {}
SectionMetadata(const SectionMetadata& other) = default;
SectionMetadata& operator=(const SectionMetadata& other) = default;
@@ -144,6 +158,7 @@ struct Section {
struct SectionGroup {
std::vector<Section<std::string_view>> string_sections;
std::vector<Section<int64_t>> integer_sections;
+ std::vector<Section<PropertyProto::VectorProto>> vector_sections;
};
} // namespace lib
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc
index 83c1519..e375a8e 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer.cc
@@ -14,14 +14,25 @@
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+#include <cstdint>
#include <memory>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/section-weights.h"
+#include "icing/store/document-store.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -29,11 +40,15 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>>
AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store, int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher) {
+ const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
Lexer lexer(scoring_spec.advanced_scoring_expression(),
Lexer::Language::SCORING);
@@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
std::unique_ptr<Bm25fCalculator> bm25f_calculator =
std::make_unique<Bm25fCalculator>(document_store, section_weights.get(),
current_time_ms);
- ScoringVisitor visitor(default_score, document_store, schema_store,
- section_weights.get(), bm25f_calculator.get(),
- join_children_fetcher, current_time_ms);
+ ScoringVisitor visitor(default_score, default_semantic_metric_type,
+ document_store, schema_store, section_weights.get(),
+ bm25f_calculator.get(), join_children_fetcher,
+ embedding_query_results, current_time_ms);
tree_root->Accept(&visitor);
ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression,
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h
index d69abad..00477a3 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.h
+++ b/icing/scoring/advanced_scoring/advanced-scorer.h
@@ -15,17 +15,26 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
#define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
+#include <cstdint>
#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/scorer.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-store.h"
+#include "icing/util/logging.h"
namespace icing {
namespace lib {
@@ -38,9 +47,11 @@ class AdvancedScorer : public Scorer {
// INVALID_ARGUMENT if fails to create an instance
static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
double GetScore(const DocHitInfo& hit_info,
const DocHitInfoIterator* query_it) override {
@@ -63,7 +74,7 @@ class AdvancedScorer : public Scorer {
bm25f_calculator_->PrepareToScore(query_term_iterators);
}
- bool is_constant() const { return score_expression_->is_constant_double(); }
+ bool is_constant() const { return score_expression_->is_constant(); }
private:
explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
index 3612359..ca081cd 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
@@ -12,11 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <cstddef>
#include <cstdint>
#include <memory>
+#include <string>
#include <string_view>
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+#include "icing/store/document-store.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/tmp-directory.h"
@@ -29,6 +36,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
const std::string test_dir = GetTestTempDir() + "/icing";
const std::string doc_store_dir = test_dir + "/doc_store";
const std::string schema_store_dir = test_dir + "/schema_store";
+ EmbeddingQueryResults empty_embedding_query_results_;
filesystem.DeleteDirectoryRecursively(test_dir.c_str());
filesystem.CreateDirectoryRecursively(doc_store_dir.c_str());
filesystem.CreateDirectoryRecursively(schema_store_dir.c_str());
@@ -54,9 +62,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
scoring_spec.set_advanced_scoring_expression(text);
AdvancedScorer::Create(scoring_spec,
- /*default_score=*/10, document_store.get(),
- schema_store.get(),
- fake_clock.GetSystemTimeMilliseconds());
+ /*default_score=*/10,
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ fake_clock.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_);
// Not able to test the GetScore method of AdvancedScorer, since it will only
// be available after AdvancedScorer is successfully created. However, the
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
index cc1d413..4b9a46c 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
@@ -15,14 +15,22 @@
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
#include <cmath>
+#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/proto/document.pb.h"
@@ -31,6 +39,8 @@
#include "icing/proto/usage.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer.h"
#include "icing/store/document-id.h"
@@ -45,6 +55,7 @@ namespace lib {
namespace {
using ::testing::DoubleNear;
using ::testing::Eq;
+using ::testing::HasSubstr;
class AdvancedScorerTest : public testing::Test {
protected:
@@ -124,15 +135,19 @@ class AdvancedScorerTest : public testing::Test {
const std::string test_dir_;
const std::string doc_store_dir_;
const std::string schema_store_dir_;
+ EmbeddingQueryResults empty_embedding_query_results_;
Filesystem filesystem_;
std::unique_ptr<SchemaStore> schema_store_;
std::unique_ptr<DocumentStore> document_store_;
FakeClock fake_clock_;
};
-constexpr double kEps = 0.0000000001;
+constexpr double kEps = 0.0000001;
constexpr int kDefaultScore = 0;
constexpr int64_t kDefaultCreationTimestampMs = 1571100001111;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ kDefaultSemanticMetricType =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
DocumentProto CreateDocument(
const std::string& name_space, const std::string& uri,
@@ -193,8 +208,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) {
scoring_spec.set_rank_by(
ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ kDefaultSemanticMetricType,
document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// Non-empty scoring expression for normal scoring
@@ -202,8 +220,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) {
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
scoring_spec.set_advanced_scoring_expression("1");
EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ kDefaultSemanticMetricType,
document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -215,9 +236,11 @@ TEST_F(AdvancedScorerTest, SimpleExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("123"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -233,44 +256,61 @@ TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5));
}
@@ -283,103 +323,131 @@ TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("log(2.718281828459045)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sum(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10 + 11 + 12 + 13 + 14));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("avg(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq((10 + 11 + 12 + 13 + 14) / 5.));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sin(3.141592653589793)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("cos(3.141592653589793)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
}
@@ -394,17 +462,21 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("this.documentScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.creationTimestamp()"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(kDefaultCreationTimestampMs));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -412,8 +484,10 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) {
AdvancedScorer::Create(
CreateAdvancedScoringSpec(
"this.documentScore() + this.creationTimestamp()"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo),
Eq(123 + kDefaultCreationTimestampMs));
}
@@ -429,8 +503,10 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) {
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageCount(1) + this.usageCount(2) "
"+ this.usageLastUsedTimestamp(3)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
ICING_ASSERT_OK(document_store_->ReportUsage(
CreateUsageReport("namespace", "uri", 100000, UsageReport::USAGE_TYPE1)));
@@ -446,22 +522,28 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(1)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(100000));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(2)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(200000));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(3)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300000));
}
@@ -478,24 +560,29 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionOutOfRange) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("this.usageCount(4)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(4)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount(0)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(0)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount(1.5)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(1.5)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
}
@@ -516,9 +603,11 @@ TEST_F(AdvancedScorerTest, RelevanceScoreFunctionScoreExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("this.relevanceScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
scorer->PrepareToScore(/*query_term_iterators=*/{});
// Should get the default score.
@@ -566,8 +655,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(2));
// document_id_2 has one child.
@@ -579,8 +669,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sum(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3));
// document_id_2 has one child with score 4.
@@ -592,8 +683,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("avg(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.));
// document_id_2 has one child with score 4.
@@ -604,13 +696,15 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec(
- // Equivalent to "avg(this.childrenRankingSignals())"
- "sum(this.childrenRankingSignals()) / "
- "len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ // Equivalent to "avg(this.childrenRankingSignals())"
+ "sum(this.childrenRankingSignals()) / "
+ "len(this.childrenRankingSignals())"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.));
// document_id_2 has one child with score 4.
@@ -670,9 +764,11 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(spec_proto,
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// min([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// min([0.5, 0.8]) = 0.5
@@ -682,10 +778,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// max([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// max([0.5, 0.8]) = 0.8
@@ -695,10 +794,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// sum([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// sum([0.5, 0.8]) = 1.3
@@ -747,9 +849,11 @@ TEST_F(AdvancedScorerTest,
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(spec_proto,
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// min([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// min([0.5, 1, 0.5]) = 0.5
@@ -757,10 +861,13 @@ TEST_F(AdvancedScorerTest,
spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// max([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// max([0.5, 1, 0.5]) = 1
@@ -768,10 +875,13 @@ TEST_F(AdvancedScorerTest,
spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// sum([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// sum([0.5, 1, 0.5]) = 2
@@ -786,20 +896,22 @@ TEST_F(AdvancedScorerTest, InvalidChildrenScoresFunctionScoreExpression) {
EXPECT_THAT(
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(),
- /*join_children_fetcher=*/nullptr),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// The root expression can only be of double type, but here it is of list
// type.
JoinChildrenFetcher fake_fetcher(JoinSpecProto::default_instance(),
/*map_joinable_qualified_id=*/{});
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.childrenRankingSignals()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fake_fetcher),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.childrenRankingSignals()"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fake_fetcher, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, ComplexExpression) {
@@ -821,9 +933,11 @@ TEST_F(AdvancedScorerTest, ComplexExpression) {
"+ 10 * (2 + 10 + this.creationTimestamp()))"
// This should evaluate to default score.
"+ this.relevanceScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_FALSE(scorer->is_constant());
scorer->PrepareToScore(/*query_term_iterators=*/{});
@@ -847,19 +961,24 @@ TEST_F(AdvancedScorerTest, ConstantExpression) {
"pow(sin(2), 2)"
"+ log(2, 122) / 12.34"
"* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_TRUE(scorer->is_constant());
}
// Should be a parsing Error
TEST_F(AdvancedScorerTest, EmptyExpression) {
- EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec(""),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec(""),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
@@ -872,30 +991,38 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("log(0)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
}
@@ -904,133 +1031,402 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
TEST_F(AdvancedScorerTest, MathTypeError) {
const double default_score = 0;
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("test"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, DocumentFunctionTypeError) {
const double default_score = 0;
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("documentScore(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.creationTimestamp(1)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("documentScore(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.creationTimestamp(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount()"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("relevanceScore(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("documentScore(this)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("that.documentScore()"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.this.creationTimestamp()"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount()"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+}
+
+TEST_F(AdvancedScorerTest,
+ MatchedSemanticScoresFunctionScoreExpressionTypeError) {
+ libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> scorer_or =
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(matchedSemanticScores(getSearchSpecEmbedding(0)))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("not called with \"this\""));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("relevanceScore(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("got invalid argument type for embedding vector"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0), 0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("documentScore(this)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("Embedding metric can only be given as a string"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSINE\", 0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("that.documentScore()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("got invalid number of arguments"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSIGN\"))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.this.creationTimestamp()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("Unknown metric type: COSIGN"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(\"0\")))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("getSearchSpecEmbedding got invalid argument type"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding()))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("getSearchSpecEmbedding must have 1 argument"));
+}
+
+void AddEntryToEmbeddingQueryScoreMap(
+ EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map,
+ double semantic_score, DocumentId document_id) {
+ score_map[document_id].push_back(semantic_score);
+}
+
+TEST_F(AdvancedScorerTest, MatchedSemanticScoresFunctionScoreExpression) {
+ DocumentId document_id_0 = 0;
+ DocumentId document_id_1 = 1;
+ DocHitInfo doc_hit_info_0(document_id_0);
+ DocHitInfo doc_hit_info_1(document_id_1);
+ EmbeddingQueryResults embedding_query_results;
+
+ // Let the first query assign the following semantic scores:
+ // COSINE:
+ // Document 0: 0.1, 0.2
+ // Document 1: 0.3, 0.4
+ // DOT_PRODUCT:
+ // Document 0: 0.5
+ // Document 1: 0.6
+ // EUCLIDEAN:
+ // Document 0: 0.7
+ // Document 1: 0.8
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map =
+ &embedding_query_results
+ .result_scores[0][SearchSpecProto::EmbeddingQueryMetricType::COSINE];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.1, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.2, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.3, document_id_1);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.4, document_id_1);
+ score_map = &embedding_query_results.result_scores
+ [0][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.5, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.6, document_id_1);
+ score_map =
+ &embedding_query_results
+ .result_scores[0]
+ [SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.7, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.8, document_id_1);
+
+ // Let the second query only assign DOT_PRODUCT scores:
+ // DOT_PRODUCT:
+ // Document 0: 0.1
+ // Document 1: 0.2
+ score_map = &embedding_query_results.result_scores
+ [1][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.1, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.2, document_id_1);
+
+ // Get semantic scores for default metric (DOT_PRODUCT) for the first query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.5, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.6, kEps));
+
+ // Get semantic scores for a metric overriding the default one for the first
+ // query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSINE\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1 + 0.2, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.3 + 0.4, kEps));
+
+ // Get semantic scores for multiple metrics for the first query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer, AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"COSINE\")) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"DOT_PRODUCT\")) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"EUCLIDEAN\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0),
+ DoubleNear(0.1 + 0.2 + 0.5 + 0.7, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1),
+ DoubleNear(0.3 + 0.4 + 0.6 + 0.8, kEps));
+
+ // Get semantic scores for the second query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.2, kEps));
+
+ // The second query does not contain cosine scores.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(1), \"COSINE\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0, kEps));
}
} // namespace
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc
index e8a2a89..687180a 100644
--- a/icing/scoring/advanced_scoring/score-expression.cc
+++ b/icing/scoring/advanced_scoring/score-expression.cc
@@ -14,10 +14,39 @@
#include "icing/scoring/advanced_scoring/score-expression.h"
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
#include <numeric>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-associated-score-data.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/util/embedding-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -49,7 +78,7 @@ OperatorScoreExpression::Create(
return absl_ports::InvalidArgumentError(
"Operators are only supported for double type.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
children_all_constant_double = false;
}
}
@@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create(
"Got an invalid type for the math function. Should expect a double "
"type argument.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
args_all_constant_double = false;
}
}
@@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId(
return filter_data_optional.value().schema_type_id();
}
+libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
+GetSearchSpecEmbeddingFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args) {
+ if (args.size() != 1) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " must have 1 argument."));
+ }
+ if (args[0]->type() != ScoreExpressionType::kDouble) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid argument type."));
+ }
+ bool is_constant = args[0]->is_constant();
+ std::unique_ptr<ScoreExpression> expression =
+ std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>(
+ new GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::move(args[0])));
+ if (is_constant) {
+ return ConstantScoreExpression::Create(
+ expression->eval(DocHitInfo(), /*query_it=*/nullptr),
+ expression->type());
+ }
+ return expression;
+}
+
+libtextclassifier3::StatusOr<double>
+GetSearchSpecEmbeddingFunctionScoreExpression::eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ arg_->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ if (query_index != raw_query_index) {
+ return absl_ports::InvalidArgumentError(
+ "The index of an embedding query must be an integer.");
+ }
+ return query_index;
+}
+
+libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results) {
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
+ ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
+
+ if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
+ }
+ if (args.size() != 2 && args.size() != 3) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
+ }
+ if (args[1]->type() != ScoreExpressionType::kVectorIndex) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ kFunctionName, " got invalid argument type for embedding vector."));
+ }
+ if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a string.");
+ }
+
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ default_metric_type;
+ if (args.size() == 3) {
+ if (!args[2]->is_constant()) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a constant string.");
+ }
+ ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string());
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+ return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
+ new MatchedSemanticScoresFunctionScoreExpression(
+ std::move(args), metric_type, *embedding_query_results));
+}
+
+libtextclassifier3::StatusOr<std::vector<double>>
+MatchedSemanticScoresFunctionScoreExpression::eval_list(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ args_[1]->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ const std::vector<double>* scores =
+ embedding_query_results_.GetMatchedScoresForDocument(
+ query_index, metric_type_, hit_info.document_id());
+ if (scores == nullptr) {
+ return std::vector<double>();
+ }
+ return *scores;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h
index 08d7997..e28fcd7 100644
--- a/icing/scoring/advanced_scoring/score-expression.h
+++ b/icing/scoring/advanced_scoring/score-expression.h
@@ -15,20 +15,26 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
#define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
-#include <algorithm>
-#include <cmath>
+#include <cstdint>
#include <memory>
+#include <string>
+#include <string_view>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
-#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -36,7 +42,11 @@ namespace lib {
enum class ScoreExpressionType {
kDouble,
kDoubleList,
- kDocument // Only "this" is considered as document type.
+ kDocument, // Only "this" is considered as document type.
+ // TODO(b/326656531): Instead of creating a vector index type, consider
+ // changing it to vector type so that the data is the vector directly.
+ kVectorIndex,
+ kString,
};
class ScoreExpression {
@@ -75,12 +85,24 @@ class ScoreExpression {
"checking.");
}
+ virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const {
+ if (type() == ScoreExpressionType::kString) {
+ return absl_ports::UnimplementedError(
+ "All ScoreExpressions of type string must provide their own "
+ "implementation of eval_string!");
+ }
+ return absl_ports::InternalError(
+ "Runtime type error: the expression should never be evaluated to a "
+ "string. There must be inconsistencies in the static type checking.");
+ }
+
// Indicate the type to which the current expression will be evaluated.
virtual ScoreExpressionType type() const = 0;
- // Indicate whether the current expression is a constant double.
- // Returns true if and only if the object is of ConstantScoreExpression type.
- virtual bool is_constant_double() const { return false; }
+ // Indicate whether the current expression is a constant.
+ // Returns true if and only if the object is of ConstantScoreExpression or
+ // StringExpression type.
+ virtual bool is_constant() const { return false; }
};
class ThisExpression : public ScoreExpression {
@@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression {
class ConstantScoreExpression : public ScoreExpression {
public:
static std::unique_ptr<ConstantScoreExpression> Create(
- libtextclassifier3::StatusOr<double> c) {
+ libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type = ScoreExpressionType::kDouble) {
return std::unique_ptr<ConstantScoreExpression>(
- new ConstantScoreExpression(c));
+ new ConstantScoreExpression(c, type));
}
libtextclassifier3::StatusOr<double> eval(
@@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression {
return c_;
}
- ScoreExpressionType type() const override {
- return ScoreExpressionType::kDouble;
- }
+ ScoreExpressionType type() const override { return type_; }
- bool is_constant_double() const override { return true; }
+ bool is_constant() const override { return true; }
private:
- explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c)
- : c_(c) {}
+ explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type)
+ : c_(c), type_(type) {}
libtextclassifier3::StatusOr<double> c_;
+ ScoreExpressionType type_;
+};
+
+class StringExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<StringExpression> Create(std::string str) {
+ return std::unique_ptr<StringExpression>(
+ new StringExpression(std::move(str)));
+ }
+
+ libtextclassifier3::StatusOr<std::string_view> eval_string() const override {
+ return str_;
+ }
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kString;
+ }
+
+ bool is_constant() const override { return true; }
+
+ private:
+ explicit StringExpression(std::string str) : str_(std::move(str)) {}
+ std::string str_;
};
class OperatorScoreExpression : public ScoreExpression {
@@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression {
int64_t current_time_ms_;
};
+class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding";
+
+ // RETURNS:
+ // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if
+ // not simplifiable.
+ // - A ConstantScoreExpression instance on success if simplifiable.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args);
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kVectorIndex;
+ }
+
+ private:
+ explicit GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::unique_ptr<ScoreExpression> arg)
+ : arg_(std::move(arg)) {}
+ std::unique_ptr<ScoreExpression> arg_;
+};
+
+class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "matchedSemanticScores";
+
+ // RETURNS:
+ // - A MatchedSemanticScoresFunctionScoreExpression instance on success.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+ Create(std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results);
+
+ libtextclassifier3::StatusOr<std::vector<double>> eval_list(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kDoubleList;
+ }
+
+ private:
+ explicit MatchedSemanticScoresFunctionScoreExpression(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ const EmbeddingQueryResults& embedding_query_results)
+ : args_(std::move(args)),
+ metric_type_(metric_type),
+ embedding_query_results_(embedding_query_results) {}
+
+ std::vector<std::unique_ptr<ScoreExpression>> args_;
+ const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ const EmbeddingQueryResults& embedding_query_results_;
+};
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc
index 588090d..cd58366 100644
--- a/icing/scoring/advanced_scoring/score-expression_test.cc
+++ b/icing/scoring/advanced_scoring/score-expression_test.cc
@@ -14,15 +14,16 @@
#include "icing/scoring/advanced_scoring/score-expression.h"
-#include <cmath>
#include <memory>
-#include <string>
-#include <string_view>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/testing/common-matchers.h"
namespace icing {
@@ -47,7 +48,7 @@ class NonConstantScoreExpression : public ScoreExpression {
return ScoreExpressionType::kDouble;
}
- bool is_constant_double() const override { return false; }
+ bool is_constant() const override { return false; }
};
class ListScoreExpression : public ScoreExpression {
@@ -87,7 +88,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
OperatorScoreExpression::OperatorType::kPlus,
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(1))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
// 1 - 2 - 3 = -4
@@ -97,7 +98,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(3))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4)));
// 1 * 2 * 3 * 4 = 24
@@ -108,7 +109,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(3),
ConstantScoreExpression::Create(4))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24)));
// 1 / 2 / 4 = 0.125
@@ -118,7 +119,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(4))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125)));
// -(2) = -2
@@ -126,7 +127,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
expression, OperatorScoreExpression::Create(
OperatorScoreExpression::OperatorType::kNegative,
MakeChildren(ConstantScoreExpression::Create(2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2)));
}
@@ -138,7 +139,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
MathFunctionScoreExpression::FunctionType::kPow,
MakeChildren(ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4)));
// abs(-2) = 2
@@ -146,7 +147,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kAbs,
MakeChildren(ConstantScoreExpression::Create(-2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
// log(e) = 1
@@ -154,7 +155,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kLog,
MakeChildren(ConstantScoreExpression::Create(M_E))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1)));
}
@@ -166,7 +167,7 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
OperatorScoreExpression::OperatorType::kPlus,
MakeChildren(ConstantScoreExpression::Create(1),
NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// non_constant * non_constant = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
@@ -174,14 +175,14 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
OperatorScoreExpression::OperatorType::kTimes,
MakeChildren(NonConstantScoreExpression::Create(),
NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// -(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, OperatorScoreExpression::Create(
OperatorScoreExpression::OperatorType::kNegative,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// pow(non_constant, 2) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
@@ -189,21 +190,21 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
MathFunctionScoreExpression::FunctionType::kPow,
MakeChildren(NonConstantScoreExpression::Create(),
ConstantScoreExpression::Create(2))));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// abs(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kAbs,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// log(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kLog,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
}
TEST(ScoreExpressionTest, MathFunctionsWithListTypeArgument) {
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc
index e2b24a2..05240c0 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.cc
+++ b/icing/scoring/advanced_scoring/scoring-visitor.cc
@@ -14,7 +14,17 @@
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
namespace icing {
namespace lib {
@@ -25,8 +35,7 @@ void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) {
}
void ScoringVisitor::VisitString(const StringNode* node) {
- pending_error_ =
- absl_ports::InvalidArgumentError("Scoring does not support String!");
+ stack_.push_back(StringExpression::Create(node->value()));
}
void ScoringVisitor::VisitText(const TextNode* node) {
@@ -120,6 +129,17 @@ void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node,
expression = MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::kFunctionNames.at(function_name),
std::move(args));
+ } else if (function_name ==
+ GetSearchSpecEmbeddingFunctionScoreExpression::kFunctionName) {
+ // getSearchSpecEmbedding function
+ expression =
+ GetSearchSpecEmbeddingFunctionScoreExpression::Create(std::move(args));
+ } else if (function_name ==
+ MatchedSemanticScoresFunctionScoreExpression::kFunctionName) {
+ // matchedSemanticScores function
+ expression = MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::move(args), default_semantic_metric_type_,
+ &embedding_query_results_);
}
if (!expression.ok()) {
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h
index cfee25b..bb5e6ba 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.h
+++ b/icing/scoring/advanced_scoring/scoring-visitor.h
@@ -15,14 +15,22 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
#define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/legacy/core/icing-string-util.h"
-#include "icing/proto/scoring.pb.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-store.h"
namespace icing {
@@ -31,18 +39,23 @@ namespace lib {
class ScoringVisitor : public AbstractSyntaxTreeVisitor {
public:
explicit ScoringVisitor(double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store,
SectionWeights* section_weights,
Bm25fCalculator* bm25f_calculator,
const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results,
int64_t current_time_ms)
: default_score_(default_score),
+ default_semantic_metric_type_(default_semantic_metric_type),
document_store_(*document_store),
schema_store_(*schema_store),
section_weights_(*section_weights),
bm25f_calculator_(*bm25f_calculator),
join_children_fetcher_(join_children_fetcher),
+ embedding_query_results_(*embedding_query_results),
current_time_ms_(current_time_ms) {}
void VisitFunctionName(const FunctionNameNode* node) override;
@@ -90,12 +103,15 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor {
}
double default_score_;
+ const SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type_;
const DocumentStore& document_store_;
const SchemaStore& schema_store_;
SectionWeights& section_weights_;
Bm25fCalculator& bm25f_calculator_;
// A non-null join_children_fetcher_ indicates scoring in a join.
const JoinChildrenFetcher* join_children_fetcher_; // Does not own.
+ const EmbeddingQueryResults& embedding_query_results_;
libtextclassifier3::Status pending_error_;
std::vector<std::unique_ptr<ScoreExpression>> stack_;
diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc
index 7cb5a95..4da8d1f 100644
--- a/icing/scoring/score-and-rank_benchmark.cc
+++ b/icing/scoring/score-and-rank_benchmark.cc
@@ -12,18 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <algorithm>
#include <cstdint>
#include <limits>
#include <memory>
#include <random>
#include <string>
+#include <unordered_map>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "testing/base/public/benchmark.h"
+#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -31,6 +35,7 @@
#include "icing/proto/schema.pb.h"
#include "icing/proto/scoring.pb.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
#include "icing/scoring/ranker.h"
#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scoring-processor.h"
@@ -131,11 +136,15 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -237,11 +246,15 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(
ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -344,11 +357,15 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -446,11 +463,15 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc
index e56f10c..1d66d7f 100644
--- a/icing/scoring/scorer-factory.cc
+++ b/icing/scoring/scorer-factory.cc
@@ -14,19 +14,26 @@
#include "icing/scoring/scorer-factory.h"
+#include <cstdint>
#include <memory>
+#include <optional>
+#include <string>
#include <unordered_map>
+#include <utility>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
#include "icing/proto/scoring.pb.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/scorer.h"
#include "icing/scoring/section-weights.h"
-#include "icing/store/document-id.h"
+#include "icing/store/document-associated-score-data.h"
#include "icing/store/document-store.h"
#include "icing/util/status-macros.h"
@@ -173,10 +180,14 @@ namespace scorer_factory {
libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) {
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
if (!scoring_spec.advanced_scoring_expression().empty() &&
scoring_spec.rank_by() !=
@@ -223,9 +234,10 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
return absl_ports::InvalidArgumentError(
"Advanced scoring is enabled, but the expression is empty!");
}
- return AdvancedScorer::Create(scoring_spec, default_score, document_store,
- schema_store, current_time_ms,
- join_children_fetcher);
+ return AdvancedScorer::Create(
+ scoring_spec, default_score, default_semantic_metric_type,
+ document_store, schema_store, current_time_ms, join_children_fetcher,
+ embedding_query_results);
case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
// Use join aggregate score to rank. Since the aggregation score is
// calculated by child documents after joining (in JoinProcessor), we can
diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h
index 659bebd..f5766b3 100644
--- a/icing/scoring/scorer-factory.h
+++ b/icing/scoring/scorer-factory.h
@@ -15,8 +15,13 @@
#ifndef ICING_SCORING_SCORER_FACTORY_H_
#define ICING_SCORING_SCORER_FACTORY_H_
+#include <cstdint>
+#include <memory>
+
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/join/join-children-fetcher.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/scorer.h"
#include "icing/store/document-store.h"
@@ -37,9 +42,11 @@ namespace scorer_factory {
// INVALID_ARGUMENT if fails to create an instance
libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
} // namespace scorer_factory
diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc
index 5194c7f..e22d5f4 100644
--- a/icing/scoring/scorer_test.cc
+++ b/icing/scoring/scorer_test.cc
@@ -14,13 +14,19 @@
#include "icing/scoring/scorer.h"
+#include <cstdint>
+#include <limits>
#include <memory>
#include <string>
+#include <utility>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
@@ -30,7 +36,6 @@
#include "icing/schema/schema-store.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer-test-utils.h"
-#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
@@ -107,6 +112,10 @@ class ScorerTest : public ::testing::TestWithParam<ScorerTestingMode> {
fake_clock1_.SetSystemTimeMilliseconds(new_time);
}
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ EmbeddingQueryResults empty_embedding_query_results;
+
private:
const std::string test_dir_;
const std::string doc_store_dir_;
@@ -134,8 +143,10 @@ TEST_P(ScorerTest, CreationWithNullDocumentStoreShouldFail) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, /*document_store=*/nullptr, schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()),
+ /*default_score=*/0, default_semantic_metric_type,
+ /*document_store=*/nullptr, schema_store(),
+ fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
@@ -144,8 +155,9 @@ TEST_P(ScorerTest, CreationWithNullSchemaStoreShouldFail) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, document_store(),
- /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
@@ -155,8 +167,9 @@ TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
// Non existent document id
DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1);
@@ -181,8 +194,9 @@ TEST_P(ScorerTest, ShouldGetDefaultDocumentScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
@@ -206,8 +220,9 @@ TEST_P(ScorerTest, ShouldGetCorrectDocumentScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
@@ -233,8 +248,9 @@ TEST_P(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
@@ -268,8 +284,9 @@ TEST_P(ScorerTest, ShouldGetCorrectCreationTimestampScore) {
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP,
GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo1 = DocHitInfo(document_id1);
DocHitInfo docHitInfo2 = DocHitInfo(document_id2);
@@ -297,22 +314,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -347,22 +367,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -397,22 +420,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -444,31 +470,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -516,31 +545,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -588,31 +620,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -651,8 +686,9 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::NONE, GetParam()),
- /*default_score=*/3, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/3, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0);
DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1);
@@ -662,11 +698,13 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, scorer_factory::Create(
- CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::NONE, GetParam()),
- /*default_score=*/111, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer,
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE, GetParam()),
+ /*default_score=*/111, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
docHitInfo1 = DocHitInfo(/*document_id_in=*/4);
docHitInfo2 = DocHitInfo(/*document_id_in=*/5);
@@ -690,13 +728,14 @@ TEST_P(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
// Create usage report for the maximum allowable timestamp.
diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc
index b827bd8..bb62db9 100644
--- a/icing/scoring/scoring-processor.cc
+++ b/icing/scoring/scoring-processor.cc
@@ -14,6 +14,7 @@
#include "icing/scoring/scoring-processor.h"
+#include <cstdint>
#include <limits>
#include <memory>
#include <string>
@@ -23,10 +24,12 @@
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
#include "icing/proto/scoring.pb.h"
-#include "icing/scoring/ranker.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer.h"
@@ -44,24 +47,28 @@ constexpr double kDefaultScoreInAscendingOrder =
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
ScoringProcessor::Create(const ScoringSpecProto& scoring_spec,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store,
int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher) {
+ const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
bool is_descending_order =
scoring_spec.order_by() == ScoringSpecProto::Order::DESC;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<Scorer> scorer,
- scorer_factory::Create(scoring_spec,
- is_descending_order
- ? kDefaultScoreInDescendingOrder
- : kDefaultScoreInAscendingOrder,
- document_store, schema_store, current_time_ms,
- join_children_fetcher));
+ scorer_factory::Create(
+ scoring_spec,
+ is_descending_order ? kDefaultScoreInDescendingOrder
+ : kDefaultScoreInAscendingOrder,
+ default_semantic_metric_type, document_store, schema_store,
+ current_time_ms, join_children_fetcher, embedding_query_results));
// Using `new` to access a non-public constructor.
return std::unique_ptr<ScoringProcessor>(
new ScoringProcessor(std::move(scorer)));
diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h
index 8634a22..de0b95c 100644
--- a/icing/scoring/scoring-processor.h
+++ b/icing/scoring/scoring-processor.h
@@ -23,6 +23,7 @@
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/proto/logging.pb.h"
@@ -46,9 +47,12 @@ class ScoringProcessor {
// A ScoringProcessor on success
// FAILED_PRECONDITION on any null pointer input
static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create(
- const ScoringSpecProto& scoring_spec, const DocumentStore* document_store,
- const SchemaStore* schema_store, int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ const ScoringSpecProto& scoring_spec,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
// Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns
// a vector of ScoredDocumentHits. The size of results is no more than
diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc
index deddff8..e3b70a6 100644
--- a/icing/scoring/scoring-processor_test.cc
+++ b/icing/scoring/scoring-processor_test.cc
@@ -15,22 +15,39 @@
#include "icing/scoring/scoring-processor.h"
#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/scoring.pb.h"
#include "icing/proto/term.pb.h"
#include "icing/proto/usage.pb.h"
#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-test-utils.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -111,6 +128,10 @@ class ScoringProcessorTest
const FakeClock& fake_clock() const { return fake_clock_; }
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ EmbeddingQueryResults empty_embedding_query_results;
+
private:
const std::string test_dir_;
const std::string doc_store_dir_;
@@ -187,27 +208,31 @@ PropertyWeight CreatePropertyWeight(std::string path, double weight) {
TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) {
ScoringSpecProto spec_proto;
- EXPECT_THAT(ScoringProcessor::Create(
- spec_proto, /*document_store=*/nullptr, schema_store(),
- fake_clock().GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, /*document_store=*/nullptr,
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) {
ScoringSpecProto spec_proto;
EXPECT_THAT(
- ScoringProcessor::Create(spec_proto, document_store(),
- /*schema_store=*/nullptr,
- fake_clock().GetSystemTimeMilliseconds()),
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ /*schema_store=*/nullptr, fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_P(ScoringProcessorTest, ShouldCreateInstance) {
ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam());
- ICING_EXPECT_OK(
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ICING_EXPECT_OK(ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
}
TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
@@ -222,8 +247,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/5),
@@ -249,8 +276,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/-1),
@@ -280,8 +309,10 @@ TEST_P(ScoringProcessorTest, ShouldRespectNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/2),
@@ -313,8 +344,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByDocumentScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -369,8 +402,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -439,8 +474,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -513,8 +550,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -563,8 +602,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -629,8 +670,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -700,8 +743,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -762,8 +807,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor with no explicit weights set.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ScoringSpecProto spec_proto_with_weights =
CreateScoringSpecForRankingStrategy(
@@ -778,9 +825,11 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor with default weight set for "body" property.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor_with_weights,
- ScoringProcessor::Create(spec_proto_with_weights, document_store(),
- schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto_with_weights, default_semantic_metric_type,
+ document_store(), schema_store(),
+ fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -866,8 +915,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -929,8 +980,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByCreationTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -990,8 +1043,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageCount) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -1051,8 +1106,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -1088,8 +1145,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNoScores) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/4),
ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default),
@@ -1138,8 +1197,10 @@ TEST_P(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
diff --git a/icing/testing/embedding-test-utils.h b/icing/testing/embedding-test-utils.h
new file mode 100644
index 0000000..931953e
--- /dev/null
+++ b/icing/testing/embedding-test-utils.h
@@ -0,0 +1,45 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_TESTING_EMBEDDING_TEST_UTILS_H_
+#define ICING_TESTING_EMBEDDING_TEST_UTILS_H_
+
+#include <initializer_list>
+#include <string>
+
+#include "icing/proto/document.pb.h"
+
+namespace icing {
+namespace lib {
+
+inline PropertyProto::VectorProto CreateVector(
+ const std::string& model_signature, std::initializer_list<float> values) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature(model_signature);
+ for (float value : values) {
+ vector.add_values(value);
+ }
+ return vector;
+}
+
+template <typename... V>
+inline PropertyProto::VectorProto CreateVector(
+ const std::string& model_signature, V&&... values) {
+ return CreateVector(model_signature, values...);
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_TESTING_EMBEDDING_TEST_UTILS_H_
diff --git a/icing/testing/hit-test-utils.cc b/icing/testing/hit-test-utils.cc
index 2fd3ac8..c235e23 100644
--- a/icing/testing/hit-test-utils.cc
+++ b/icing/testing/hit-test-utils.cc
@@ -17,6 +17,7 @@
#include <cstdint>
#include <vector>
+#include "icing/index/embed/embedding-hit.h"
#include "icing/index/hit/hit.h"
#include "icing/index/main/posting-list-hit-serializer.h"
#include "icing/schema/section.h"
@@ -87,5 +88,28 @@ std::vector<Hit> CreateHits(int num_hits, int desired_byte_length) {
return CreateHits(/*start_docid=*/0, num_hits, desired_byte_length);
}
+EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit,
+ uint32_t desired_byte_length) {
+ // Create a delta that has (desired_byte_length - 1) * 7 + 1 bits, so that it
+ // can be encoded in desired_byte_length bytes.
+ uint64_t delta = UINT64_C(1) << ((desired_byte_length - 1) * 7);
+ return EmbeddingHit(last_hit.value() - delta);
+}
+
+std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits,
+ int desired_byte_length) {
+ std::vector<EmbeddingHit> hits;
+ if (num_hits == 0) {
+ return hits;
+ }
+ hits.reserve(num_hits);
+ hits.push_back(EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0));
+ for (int i = 1; i < num_hits; ++i) {
+ hits.push_back(CreateEmbeddingHit(hits.back(), desired_byte_length));
+ }
+ return hits;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/testing/hit-test-utils.h b/icing/testing/hit-test-utils.h
index 2953c5c..e041c22 100644
--- a/icing/testing/hit-test-utils.h
+++ b/icing/testing/hit-test-utils.h
@@ -15,8 +15,10 @@
#ifndef ICING_TESTING_HIT_TEST_UTILS_H_
#define ICING_TESTING_HIT_TEST_UTILS_H_
+#include <cstdint>
#include <vector>
+#include "icing/index/embed/embedding-hit.h"
#include "icing/index/hit/hit.h"
#include "icing/store/document-id.h"
@@ -46,6 +48,18 @@ std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits,
// with desired_byte_length deltas.
std::vector<Hit> CreateHits(int num_hits, int desired_byte_length);
+// Returns a hit that has a delta of desired_byte_length from last_hit after
+// VarInt encoding.
+// Requires that 0 < desired_byte_length <= VarInt::kMaxEncodedLen64.
+EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit,
+ uint32_t desired_byte_length);
+
+// Returns a vector of num_hits Hits with the first hit starting at document 0
+// and with a delta of desired_byte_length between each subsequent hit after
+// VarInt encoding.
+std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits,
+ int desired_byte_length);
+
} // namespace lib
} // namespace icing
diff --git a/icing/util/document-validator.cc b/icing/util/document-validator.cc
index e0880ea..bc75334 100644
--- a/icing/util/document-validator.cc
+++ b/icing/util/document-validator.cc
@@ -15,13 +15,23 @@
#include "icing/util/document-validator.h"
#include <cstdint>
+#include <string>
+#include <string_view>
+#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/legacy/core/icing-string-util.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
+#include "icing/schema/schema-store.h"
#include "icing/schema/schema-util.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/util/logging.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -123,6 +133,18 @@ libtextclassifier3::Status DocumentValidator::Validate(
} else if (property_config.data_type() ==
PropertyConfigProto::DataType::DOCUMENT) {
value_size = property.document_values_size();
+ } else if (property_config.data_type() ==
+ PropertyConfigProto::DataType::VECTOR) {
+ value_size = property.vector_values_size();
+ for (const PropertyProto::VectorProto& vector_value :
+ property.vector_values()) {
+ if (vector_value.values_size() == 0) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Property '%s' contains empty vectors for key: (%s, %s).",
+ property.name().c_str(), document.namespace_().c_str(),
+ document.uri().c_str()));
+ }
+ }
}
if (property_config.cardinality() ==
diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc
index 9d10b36..2c366fd 100644
--- a/icing/util/document-validator_test.cc
+++ b/icing/util/document-validator_test.cc
@@ -15,7 +15,11 @@
#include "icing/util/document-validator.h"
#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -42,6 +46,7 @@ constexpr char kPropertySubject[] = "subject";
constexpr char kPropertyText[] = "text";
constexpr char kPropertyRecipients[] = "recipients";
constexpr char kPropertyNote[] = "note";
+constexpr char kPropertyNoteEmbedding[] = "noteEmbedding";
// type and property names of Conversation
constexpr char kTypeConversation[] = "Conversation";
constexpr char kTypeConversationWithEmailNote[] = "ConversationWithEmailNote";
@@ -92,6 +97,10 @@ class DocumentValidatorTest : public ::testing::Test {
.AddProperty(PropertyConfigBuilder()
.SetName(kPropertyNote)
.SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyNoteEmbedding)
+ .SetDataType(TYPE_VECTOR)
.SetCardinality(CARDINALITY_OPTIONAL)))
.AddType(
SchemaTypeConfigBuilder()
@@ -146,6 +155,12 @@ class DocumentValidatorTest : public ::testing::Test {
}
DocumentBuilder SimpleEmailWithNoteBuilder() {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ vector.set_model_signature("my_model");
+
return DocumentBuilder()
.SetKey(kDefaultNamespace, "email_with_note/1")
.SetSchema(kTypeEmailWithNote)
@@ -153,7 +168,8 @@ class DocumentValidatorTest : public ::testing::Test {
.AddStringProperty(kPropertyText, kDefaultString)
.AddStringProperty(kPropertyRecipients, kDefaultString, kDefaultString,
kDefaultString)
- .AddStringProperty(kPropertyNote, kDefaultString);
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector);
}
DocumentBuilder SimpleConversationBuilder() {
@@ -597,6 +613,64 @@ TEST_F(DocumentValidatorTest, NegativeDocumentTtlMsInvalid) {
HasSubstr("is negative")));
}
+TEST_F(DocumentValidatorTest, ValidateEmbeddingZeroDimensionInvalid) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature("my_model");
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ EXPECT_THAT(document_validator_->Validate(email),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT,
+ HasSubstr("contains empty vectors")));
+}
+
+TEST_F(DocumentValidatorTest, ValidateEmbeddingEmptySignatureOk) {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ vector.set_model_signature("");
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ ICING_EXPECT_OK(document_validator_->Validate(email));
+}
+
+TEST_F(DocumentValidatorTest, ValidateEmbeddingNoSignatureOk) {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ ICING_EXPECT_OK(document_validator_->Validate(email));
+}
+
} // namespace
} // namespace lib
diff --git a/icing/util/embedding-util.h b/icing/util/embedding-util.h
new file mode 100644
index 0000000..5026051
--- /dev/null
+++ b/icing/util/embedding-util.h
@@ -0,0 +1,49 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_UTIL_EMBEDDING_UTIL_H_
+#define ICING_UTIL_EMBEDDING_UTIL_H_
+
+#include <string_view>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+namespace embedding_util {
+
+inline libtextclassifier3::StatusOr<
+ SearchSpecProto::EmbeddingQueryMetricType::Code>
+GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name) {
+ if (metric_name == "COSINE") {
+ return SearchSpecProto::EmbeddingQueryMetricType::COSINE;
+ } else if (metric_name == "DOT_PRODUCT") {
+ return SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ } else if (metric_name == "EUCLIDEAN") {
+ return SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN;
+ }
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unknown metric type: ", metric_name));
+}
+
+} // namespace embedding_util
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_UTIL_EMBEDDING_UTIL_H_
diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc
index 19aaddf..e10fe25 100644
--- a/icing/util/tokenized-document.cc
+++ b/icing/util/tokenized-document.cc
@@ -14,16 +14,18 @@
#include "icing/util/tokenized-document.h"
-#include <string>
+#include <memory>
#include <string_view>
+#include <utility>
#include <vector>
-#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/proto/document.pb.h"
#include "icing/schema/joinable-property.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/tokenization/language-segmenter.h"
+#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer-factory.h"
#include "icing/tokenization/tokenizer.h"
#include "icing/util/document-validator.h"
@@ -85,6 +87,7 @@ TokenizedDocument::Create(const SchemaStore* schema_store,
return TokenizedDocument(std::move(document),
std::move(tokenized_string_sections),
std::move(section_group.integer_sections),
+ std::move(section_group.vector_sections),
std::move(joinable_property_group));
}
diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h
index 7cc34e3..0337083 100644
--- a/icing/util/tokenized-document.h
+++ b/icing/util/tokenized-document.h
@@ -16,7 +16,8 @@
#define ICING_STORE_TOKENIZED_DOCUMENT_H_
#include <cstdint>
-#include <string>
+#include <string_view>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
@@ -63,6 +64,11 @@ class TokenizedDocument {
return integer_sections_;
}
+ const std::vector<Section<PropertyProto::VectorProto>>& vector_sections()
+ const {
+ return vector_sections_;
+ }
+
const std::vector<JoinableProperty<std::string_view>>&
qualified_id_join_properties() const {
return joinable_property_group_.qualified_id_properties;
@@ -74,15 +80,18 @@ class TokenizedDocument {
DocumentProto&& document,
std::vector<TokenizedSection>&& tokenized_string_sections,
std::vector<Section<int64_t>>&& integer_sections,
+ std::vector<Section<PropertyProto::VectorProto>>&& vector_sections,
JoinablePropertyGroup&& joinable_property_group)
: document_(std::move(document)),
tokenized_string_sections_(std::move(tokenized_string_sections)),
integer_sections_(std::move(integer_sections)),
+ vector_sections_(std::move(vector_sections)),
joinable_property_group_(std::move(joinable_property_group)) {}
DocumentProto document_;
std::vector<TokenizedSection> tokenized_string_sections_;
std::vector<Section<int64_t>> integer_sections_;
+ std::vector<Section<PropertyProto::VectorProto>> vector_sections_;
JoinablePropertyGroup joinable_property_group_;
};
diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc
index 7c97776..ab7f4b9 100644
--- a/icing/util/tokenized-document_test.cc
+++ b/icing/util/tokenized-document_test.cc
@@ -16,6 +16,8 @@
#include <memory>
#include <string>
+#include <string_view>
+#include <utility>
#include <vector>
#include "gmock/gmock.h"
@@ -25,7 +27,6 @@
#include "icing/portable/platform.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
-#include "icing/proto/term.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/joinable-property.h"
#include "icing/schema/schema-store.h"
@@ -59,13 +60,19 @@ static constexpr std::string_view kIndexableIntegerProperty1 =
"indexableInteger1";
static constexpr std::string_view kIndexableIntegerProperty2 =
"indexableInteger2";
+static constexpr std::string_view kIndexableVectorProperty1 =
+ "indexableVector1";
+static constexpr std::string_view kIndexableVectorProperty2 =
+ "indexableVector2";
static constexpr std::string_view kStringExactProperty = "stringExact";
static constexpr std::string_view kStringPrefixProperty = "stringPrefix";
static constexpr SectionId kIndexableInteger1SectionId = 0;
static constexpr SectionId kIndexableInteger2SectionId = 1;
-static constexpr SectionId kStringExactSectionId = 2;
-static constexpr SectionId kStringPrefixSectionId = 3;
+static constexpr SectionId kIndexableVector1SectionId = 2;
+static constexpr SectionId kIndexableVector2SectionId = 3;
+static constexpr SectionId kStringExactSectionId = 4;
+static constexpr SectionId kStringPrefixSectionId = 5;
// Joinable properties and joinable property id. Joinable property id is
// determined by the lexicographical order of joinable property path.
@@ -77,19 +84,33 @@ static constexpr JoinablePropertyId kQualifiedId2JoinablePropertyId = 1;
const SectionMetadata kIndexableInteger1SectionMetadata(
kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1));
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kIndexableIntegerProperty1));
const SectionMetadata kIndexableInteger2SectionMetadata(
kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2));
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kIndexableIntegerProperty2));
+
+const SectionMetadata kIndexableVector1SectionMetadata(
+ kIndexableVector1SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH,
+ std::string(kIndexableVectorProperty1));
+
+const SectionMetadata kIndexableVector2SectionMetadata(
+ kIndexableVector2SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH,
+ std::string(kIndexableVectorProperty2));
const SectionMetadata kStringExactSectionMetadata(
kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT,
- NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty));
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kStringExactProperty));
const SectionMetadata kStringPrefixSectionMetadata(
kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX,
- NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty));
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kStringPrefixProperty));
const JoinablePropertyMetadata kQualifiedId1JoinablePropertyMetadata(
kQualifiedId1JoinablePropertyId, TYPE_STRING,
@@ -102,6 +123,7 @@ const JoinablePropertyMetadata kQualifiedId2JoinablePropertyMetadata(
// Other non-indexable/joinable properties.
constexpr std::string_view kUnindexedStringProperty = "unindexedString";
constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger";
+constexpr std::string_view kUnindexedVectorProperty = "unindexedVector";
class TokenizedDocumentTest : public ::testing::Test {
protected:
@@ -140,6 +162,10 @@ class TokenizedDocumentTest : public ::testing::Test {
.SetDataType(TYPE_INT64)
.SetCardinality(CARDINALITY_OPTIONAL))
.AddProperty(PropertyConfigBuilder()
+ .SetName(kUnindexedVectorProperty)
+ .SetDataType(TYPE_VECTOR)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
.SetName(kIndexableIntegerProperty1)
.SetDataTypeInt64(NUMERIC_MATCH_RANGE)
.SetCardinality(CARDINALITY_REPEATED))
@@ -147,6 +173,16 @@ class TokenizedDocumentTest : public ::testing::Test {
.SetName(kIndexableIntegerProperty2)
.SetDataTypeInt64(NUMERIC_MATCH_RANGE)
.SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kIndexableVectorProperty1)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kIndexableVectorProperty2)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
.AddProperty(PropertyConfigBuilder()
.SetName(kStringExactProperty)
.SetDataTypeString(TERM_MATCH_EXACT,
@@ -196,6 +232,16 @@ class TokenizedDocumentTest : public ::testing::Test {
};
TEST_F(TokenizedDocumentTest, CreateAll) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
DocumentProto document =
DocumentBuilder()
.SetKey("icing", "fake_type/1")
@@ -208,6 +254,10 @@ TEST_F(TokenizedDocumentTest, CreateAll) {
.AddInt64Property(std::string(kUnindexedIntegerProperty), 789)
.AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3)
.AddInt64Property(std::string(kIndexableIntegerProperty2), 456)
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1)
+ .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1,
+ vector2)
+ .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1)
.AddStringProperty(std::string(kQualifiedId1), "pkg$db/ns#uri1")
.AddStringProperty(std::string(kQualifiedId2), "pkg$db/ns#uri2")
.Build();
@@ -244,6 +294,17 @@ TEST_F(TokenizedDocumentTest, CreateAll) {
EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
ElementsAre(456));
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata,
+ Eq(kIndexableVector1SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).content,
+ ElementsAre(EqualsProto(vector1), EqualsProto(vector2)));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata,
+ Eq(kIndexableVector2SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).content,
+ ElementsAre(EqualsProto(vector1)));
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2));
EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata,
@@ -278,6 +339,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -314,6 +378,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) {
EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
ElementsAre(456));
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -341,6 +408,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -381,6 +451,92 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
+ // Qualified id join properties
+ EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateNoIndexableVectorProperties) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature("my_model");
+ vector.add_values(1.0f);
+
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
+ // Qualified id join properties
+ EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateMultipleIndexableVectorProperties) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1)
+ .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1,
+ vector2)
+ .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata,
+ Eq(kIndexableVector1SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).content,
+ ElementsAre(EqualsProto(vector1), EqualsProto(vector2)));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata,
+ Eq(kIndexableVector2SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).content,
+ ElementsAre(EqualsProto(vector1)));
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -408,6 +564,9 @@ TEST_F(TokenizedDocumentTest, CreateNoJoinQualifiedIdProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -437,6 +596,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleJoinQualifiedIdProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2));
EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata,
diff --git a/proto/icing/proto/document.proto b/proto/icing/proto/document.proto
index 1a501e7..919769e 100644
--- a/proto/icing/proto/document.proto
+++ b/proto/icing/proto/document.proto
@@ -78,7 +78,7 @@ message DocumentProto {
}
// Holds a property field of the Document.
-// Next tag: 8
+// Next tag: 9
message PropertyProto {
// Name of the property.
// See icing.lib.PropertyConfigProto.property_name for details.
@@ -92,6 +92,17 @@ message PropertyProto {
repeated bool boolean_values = 5;
repeated bytes bytes_values = 6;
repeated DocumentProto document_values = 7;
+
+ message VectorProto {
+ // The values of the vector.
+ repeated float values = 1 [packed = true];
+ // The model signature of the vector, which can be any string used to
+ // identify the model, so that embedding searches can be restricted only to
+ // the vectors with the matching target signature.
+ // Eg: "universal-sentence-encoder_v0"
+ optional string model_signature = 2;
+ }
+ repeated VectorProto vector_values = 8;
}
// Result of a call to IcingSearchEngine.Put
diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto
index 4854521..46e988e 100644
--- a/proto/icing/proto/logging.proto
+++ b/proto/icing/proto/logging.proto
@@ -23,7 +23,7 @@ option java_multiple_files = true;
option objc_class_prefix = "ICNG";
// Stats of the top-level function IcingSearchEngine::Initialize().
-// Next tag: 14
+// Next tag: 15
message InitializeStatsProto {
// Overall time used for the function call.
optional int32 latency_ms = 1;
@@ -121,10 +121,16 @@ message InitializeStatsProto {
// - SCHEMA_CHANGES_OUT_OF_SYNC
// - IO_ERROR
optional RecoveryCause qualified_id_join_index_restoration_cause = 13;
+
+ // Possible recovery causes for embedding index:
+ // - INCONSISTENT_WITH_GROUND_TRUTH
+ // - SCHEMA_CHANGES_OUT_OF_SYNC
+ // - IO_ERROR
+ optional RecoveryCause embedding_index_restoration_cause = 14;
}
// Stats of the top-level function IcingSearchEngine::Put().
-// Next tag: 12
+// Next tag: 13
message PutDocumentStatsProto {
// Overall time used for the function call.
optional int32 latency_ms = 1;
@@ -170,6 +176,9 @@ message PutDocumentStatsProto {
// Time used to index all metadata terms in the document, which can only be
// added by PropertyExistenceIndexingHandler currently.
optional int32 metadata_term_index_latency_ms = 11;
+
+ // Time used to index all embeddings in the document.
+ optional int32 embedding_index_latency_ms = 12;
}
// Stats of the top-level function IcingSearchEngine::Search() and
diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto
index 78e1588..99439bb 100644
--- a/proto/icing/proto/schema.proto
+++ b/proto/icing/proto/schema.proto
@@ -184,6 +184,26 @@ message IntegerIndexingConfig {
optional NumericMatchType.Code numeric_match_type = 1;
}
+// Describes how a vector property should be indexed.
+// Next tag: 3
+message EmbeddingIndexingConfig {
+ // OPTIONAL: Indicates how the vector contents of this property should be
+ // matched.
+ //
+ // The default value is UNKNOWN.
+ message EmbeddingIndexingType {
+ enum Code {
+ // Contents in this property will not be indexed. Useful if the vector
+ // property type is not indexable.
+ UNKNOWN = 0;
+
+ // Contents in this property will be indexed for linear search.
+ LINEAR_SEARCH = 1;
+ }
+ }
+ optional EmbeddingIndexingType.Code embedding_indexing_type = 1;
+}
+
// Describes how a property can be used to join this document with another
// document. See JoinSpecProto (in search.proto) for more details.
// Next tag: 3
@@ -215,7 +235,7 @@ message JoinableConfig {
// Describes the schema of a single property of Documents that belong to a
// specific SchemaTypeConfigProto. These can be considered as a rich, structured
// type for each property of Documents accepted by IcingSearchEngine.
-// Next tag: 10
+// Next tag: 11
message PropertyConfigProto {
// REQUIRED: Name that uniquely identifies a property within an Document of
// a specific SchemaTypeConfigProto.
@@ -252,6 +272,9 @@ message PropertyConfigProto {
// a hierarchical Document schema. Any property using this data_type
// MUST have a valid 'schema_type'.
DOCUMENT = 6;
+
+ // A list of floats. Vector type is used for embedding searches.
+ VECTOR = 7;
}
}
optional DataType.Code data_type = 2;
@@ -313,6 +336,10 @@ message PropertyConfigProto {
// - The property itself and any upper-level (nested doc) property should
// contain at most one element (i.e. Cardinality is OPTIONAL or REQUIRED).
optional JoinableConfig joinable_config = 8;
+
+ // OPTIONAL: Describes how vector properties should be indexed. Vector
+ // properties that do not set the indexing config will not be indexed.
+ optional EmbeddingIndexingConfig embedding_indexing_config = 10;
}
// List of all supported types constitutes the schema used by Icing.
diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto
index 7f4fb3e..7e145ce 100644
--- a/proto/icing/proto/search.proto
+++ b/proto/icing/proto/search.proto
@@ -27,7 +27,7 @@ option java_multiple_files = true;
option objc_class_prefix = "ICNG";
// Client-supplied specifications on what documents to retrieve.
-// Next tag: 11
+// Next tag: 13
message SearchSpecProto {
// REQUIRED: The "raw" query string that users may type. For example, "cat"
// will search for documents with the term cat in it.
@@ -112,6 +112,22 @@ message SearchSpecProto {
// (TypePropertyMask for the given schema type has empty paths field), no
// properties of that schema type will be searched.
repeated TypePropertyMask type_property_filters = 10;
+
+ // The vectors to be used in embedding queries.
+ repeated PropertyProto.VectorProto embedding_query_vectors = 11;
+
+ message EmbeddingQueryMetricType {
+ enum Code {
+ UNKNOWN = 0;
+ COSINE = 1;
+ DOT_PRODUCT = 2;
+ EUCLIDEAN = 3;
+ }
+ }
+
+ // The default metric type used to calculate the scores for embedding
+ // queries.
+ optional EmbeddingQueryMetricType.Code embedding_query_metric_type = 12;
}
// Client-supplied specifications on what to include/how to format the search
diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt
index c22a1e8..55c4647 100644
--- a/synced_AOSP_CL_number.txt
+++ b/synced_AOSP_CL_number.txt
@@ -1 +1 @@
-set(synced_AOSP_CL_number=613977851)
+set(synced_AOSP_CL_number=616925123)