aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKrzysztof KosiƄski <krzysio@google.com>2024-03-21 15:33:55 +0000
committerAlexander Dorokhine <adorokhine@google.com>2024-03-21 16:14:28 +0000
commit15170523d0b603a3fc2729695ce4d9740ce5a85e (patch)
tree693704d89a2b49ffda814b0d81b3e9541f48a740
parent42996c97b96f0da75543f0fee670f9e8cc595744 (diff)
parent555cb6e3295cf525baf46358235389ff52c9dcc2 (diff)
downloadicing-15170523d0b603a3fc2729695ce4d9740ce5a85e.tar.gz
Merge remote-tracking branch 'aosp/upstream-master'
Change descriptions: ======================================================================== Integration test for embedding search ======================================================================== Support embedding search in the advanced scoring language ======================================================================== Support embedding search in the advanced query language ======================================================================== Add missing header inclusions. ======================================================================== Assign default values to MemoryMappedFile members to avoid temporary object error during std::move ======================================================================== Use GetEmbeddingVector in EmbeddingIndex::TransferIndex ======================================================================== Refactor SectionRestrictData ======================================================================== Add EmbeddingIndex to IcingSearchEngine and create EmbeddingIndexingHandler ======================================================================== Support optimize for embedding index ======================================================================== Implement embedding search index ======================================================================== Branch posting list accessor and serializer to store embedding hits ======================================================================== Introduce embedding hit ======================================================================== Update schema and document definition ======================================================================== Add tests verifying that Icing will correctly save updated schema description fields. ======================================================================== Allow adding duplicate hits with the same value into the posting list. ======================================================================== Support an extension to the encoded Hit in Icing's posting list. ======================================================================== Adds description fields to SchemaTypeConfigProto and PropertyConfigProto ======================================================================== Support polymorphism in type property filters ======================================================================== Fix posting list GetMinPostingListSizeToFit size calculation bug ======================================================================== Add instructions to error message for advanced query backward compat. ======================================================================== [Trunk Stable Flags and Files Rebuild #1] Implement v2 version file in version-util ======================================================================== [Trunk Stable Flags and Files Rebuild #2] Integrate v2 version file with IcingSearchEngine ======================================================================== Remove dead code in index initialization for HasPropertyOperator after introducing v2 version file ======================================================================== [Icing version 4] Bump Icing kVersion to 4 for Android V ======================================================================== [Icing][Expand QueryStats][4/x] Add lite index hit buffer info into QueryStats ======================================================================== [Icing][Expand QueryStats][5/x] Add query processor advanced query components latencies into QueryStats::SearchStats ======================================================================== Add instructions to error message for AppSearch advanced query features backward compatibility ======================================================================== BUG: 326987971 BUG: 326656531 BUG: 329747255 BUG: 294274922 BUG: 321107391 BUG: 324908653 BUG: 326987971 Bug: 314816301 Bug: 309826655 Bug: 305098009 Bug: 329747255 Test: unit test Change-Id: I6f35b52114181d9dfded41cfb3949f062876a2e2
-rw-r--r--METADATA2
-rw-r--r--icing/document-builder.h18
-rw-r--r--icing/file/memory-mapped-file.h15
-rw-r--r--icing/file/posting_list/flash-index-storage-header.h2
-rw-r--r--icing/file/posting_list/flash-index-storage_test.cc139
-rw-r--r--icing/file/posting_list/index-block_test.cc115
-rw-r--r--icing/file/version-util.cc263
-rw-r--r--icing/file/version-util.h188
-rw-r--r--icing/file/version-util_test.cc498
-rw-r--r--icing/icing-search-engine.cc320
-rw-r--r--icing/icing-search-engine.h30
-rw-r--r--icing/icing-search-engine_initialization_test.cc134
-rw-r--r--icing/icing-search-engine_schema_test.cc82
-rw-r--r--icing/icing-search-engine_search_test.cc357
-rw-r--r--icing/index/embed/doc-hit-info-iterator-embedding.cc164
-rw-r--r--icing/index/embed/doc-hit-info-iterator-embedding.h161
-rw-r--r--icing/index/embed/embedding-hit.h67
-rw-r--r--icing/index/embed/embedding-hit_test.cc80
-rw-r--r--icing/index/embed/embedding-index.cc440
-rw-r--r--icing/index/embed/embedding-index.h274
-rw-r--r--icing/index/embed/embedding-index_test.cc582
-rw-r--r--icing/index/embed/embedding-query-results.h72
-rw-r--r--icing/index/embed/embedding-scorer.cc95
-rw-r--r--icing/index/embed/embedding-scorer.h54
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor.cc132
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor.h106
-rw-r--r--icing/index/embed/posting-list-embedding-hit-accessor_test.cc387
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer.cc647
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer.h284
-rw-r--r--icing/index/embed/posting-list-embedding-hit-serializer_test.cc864
-rw-r--r--icing/index/embedding-indexing-handler.cc85
-rw-r--r--icing/index/embedding-indexing-handler.h70
-rw-r--r--icing/index/embedding-indexing-handler_test.cc620
-rw-r--r--icing/index/hit/hit.cc96
-rw-r--r--icing/index/hit/hit.h154
-rw-r--r--icing/index/hit/hit_test.cc156
-rw-r--r--icing/index/index.cc6
-rw-r--r--icing/index/index.h20
-rw-r--r--icing/index/index_test.cc40
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-property-in-schema.h6
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-section-restrict.cc34
-rw-r--r--icing/index/iterator/doc-hit-info-iterator.h6
-rw-r--r--icing/index/iterator/section-restrict-data.cc21
-rw-r--r--icing/index/iterator/section-restrict-data.h14
-rw-r--r--icing/index/lite/lite-index-header.h22
-rw-r--r--icing/index/lite/lite-index-options.cc12
-rw-r--r--icing/index/lite/lite-index-options.h4
-rw-r--r--icing/index/lite/lite-index.cc21
-rw-r--r--icing/index/lite/lite-index.h28
-rw-r--r--icing/index/lite/lite-index_test.cc313
-rw-r--r--icing/index/lite/lite-index_thread-safety_test.cc42
-rw-r--r--icing/index/lite/term-id-hit-pair.h64
-rw-r--r--icing/index/lite/term-id-hit-pair_test.cc95
-rw-r--r--icing/index/main/main-index-merger.cc10
-rw-r--r--icing/index/main/main-index-merger_test.cc38
-rw-r--r--icing/index/main/main-index.cc2
-rw-r--r--icing/index/main/main-index_test.cc60
-rw-r--r--icing/index/main/posting-list-hit-accessor_test.cc67
-rw-r--r--icing/index/main/posting-list-hit-serializer.cc254
-rw-r--r--icing/index/main/posting-list-hit-serializer.h64
-rw-r--r--icing/index/main/posting-list-hit-serializer_test.cc828
-rw-r--r--icing/query/advanced_query_parser/function.cc24
-rw-r--r--icing/query/advanced_query_parser/function.h2
-rw-r--r--icing/query/advanced_query_parser/param.h10
-rw-r--r--icing/query/advanced_query_parser/pending-value.cc26
-rw-r--r--icing/query/advanced_query_parser/pending-value.h39
-rw-r--r--icing/query/advanced_query_parser/query-visitor.cc139
-rw-r--r--icing/query/advanced_query_parser/query-visitor.h72
-rw-r--r--icing/query/advanced_query_parser/query-visitor_test.cc1491
-rw-r--r--icing/query/query-features.h8
-rw-r--r--icing/query/query-processor.cc76
-rw-r--r--icing/query/query-processor.h33
-rw-r--r--icing/query/query-processor_benchmark.cc54
-rw-r--r--icing/query/query-processor_test.cc142
-rw-r--r--icing/query/query-results.h7
-rw-r--r--icing/query/suggestion-processor.cc54
-rw-r--r--icing/query/suggestion-processor.h27
-rw-r--r--icing/query/suggestion-processor_test.cc41
-rw-r--r--icing/schema-builder.h28
-rw-r--r--icing/schema/property-util.cc10
-rw-r--r--icing/schema/property-util.h8
-rw-r--r--icing/schema/property-util_test.cc82
-rw-r--r--icing/schema/schema-store.cc2
-rw-r--r--icing/schema/schema-store_test.cc14
-rw-r--r--icing/schema/schema-util.cc15
-rw-r--r--icing/schema/schema-util_test.cc112
-rw-r--r--icing/schema/section-manager.cc14
-rw-r--r--icing/schema/section-manager_test.cc100
-rw-r--r--icing/schema/section.h17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.cc24
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.h17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc17
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_test.cc884
-rw-r--r--icing/scoring/advanced_scoring/score-expression.cc128
-rw-r--r--icing/scoring/advanced_scoring/score-expression.h139
-rw-r--r--icing/scoring/advanced_scoring/score-expression_test.cc37
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.cc24
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.h18
-rw-r--r--icing/scoring/score-and-rank_benchmark.cc47
-rw-r--r--icing/scoring/scorer-factory.cc22
-rw-r--r--icing/scoring/scorer-factory.h11
-rw-r--r--icing/scoring/scorer_test.cc259
-rw-r--r--icing/scoring/scoring-processor.cc23
-rw-r--r--icing/scoring/scoring-processor.h10
-rw-r--r--icing/scoring/scoring-processor_test.cc155
-rw-r--r--icing/testing/embedding-test-utils.h45
-rw-r--r--icing/testing/hit-test-utils.cc84
-rw-r--r--icing/testing/hit-test-utils.h29
-rw-r--r--icing/text_classifier/lib3/utils/base/logging.h5
-rw-r--r--icing/text_classifier/lib3/utils/java/jni-helper.cc10
-rw-r--r--icing/text_classifier/lib3/utils/java/jni-helper.h9
-rw-r--r--icing/util/document-validator.cc22
-rw-r--r--icing/util/document-validator_test.cc76
-rw-r--r--icing/util/embedding-util.h49
-rw-r--r--icing/util/tokenized-document.cc7
-rw-r--r--icing/util/tokenized-document.h11
-rw-r--r--icing/util/tokenized-document_test.cc176
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngine.java7
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngineImpl.java7
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngineInterface.java2
-rw-r--r--proto/icing/proto/document.proto13
-rw-r--r--proto/icing/proto/initialize.proto53
-rw-r--r--proto/icing/proto/logging.proto36
-rw-r--r--proto/icing/proto/schema.proto45
-rw-r--r--proto/icing/proto/search.proto18
-rw-r--r--proto/icing/proto/storage.proto2
-rw-r--r--synced_AOSP_CL_number.txt2
127 files changed, 13123 insertions, 1871 deletions
diff --git a/METADATA b/METADATA
index d350608..f6ad8fa 100644
--- a/METADATA
+++ b/METADATA
@@ -12,6 +12,6 @@ third_party {
type: PIPER
value: "http://google3/third_party/icing/"
}
- last_upgrade_date { year: 2019 month: 12 day: 20 }
+ last_upgrade_date { year: 2024 month: 3 day: 18 }
license_type: NOTICE
}
diff --git a/icing/document-builder.h b/icing/document-builder.h
index 44500f9..5d6f14d 100644
--- a/icing/document-builder.h
+++ b/icing/document-builder.h
@@ -126,6 +126,13 @@ class DocumentBuilder {
return AddDocumentProperty(std::move(property_name), {document_values...});
}
+ // Takes a property name and any number of vector values.
+ template <typename... V>
+ DocumentBuilder& AddVectorProperty(std::string property_name,
+ V... vector_values) {
+ return AddVectorProperty(std::move(property_name), {vector_values...});
+ }
+
DocumentProto Build() const { return document_; }
private:
@@ -183,6 +190,17 @@ class DocumentBuilder {
}
return *this;
}
+
+ DocumentBuilder& AddVectorProperty(
+ std::string property_name,
+ std::initializer_list<PropertyProto::VectorProto> vector_values) {
+ auto property = document_.add_properties();
+ property->set_name(std::move(property_name));
+ for (PropertyProto::VectorProto vector_value : vector_values) {
+ property->mutable_vector_values()->Add(std::move(vector_value));
+ }
+ return *this;
+ }
};
} // namespace lib
diff --git a/icing/file/memory-mapped-file.h b/icing/file/memory-mapped-file.h
index 54507af..185d940 100644
--- a/icing/file/memory-mapped-file.h
+++ b/icing/file/memory-mapped-file.h
@@ -79,7 +79,6 @@
#include <algorithm>
#include <cstdint>
-#include <memory>
#include <string>
#include <string_view>
@@ -314,7 +313,7 @@ class MemoryMappedFile {
// Cached constructor params.
const Filesystem* filesystem_;
std::string file_path_;
- Strategy strategy_;
+ Strategy strategy_ = Strategy::READ_WRITE_AUTO_SYNC;
// Raw file related fields:
// - max_file_size_
@@ -327,7 +326,7 @@ class MemoryMappedFile {
//
// Note: max_file_size_ will be specified in runtime and the caller should
// make sure its value is correct and reasonable.
- int64_t max_file_size_;
+ int64_t max_file_size_ = 0;
// Cached file size to avoid calling system call too frequently. It is only
// used in GrowAndRemapIfNecessary(), the new API that handles underlying file
@@ -336,7 +335,7 @@ class MemoryMappedFile {
// Note: it is guaranteed that file_size_ is smaller or equal to the actual
// file size as long as the underlying file hasn't been truncated or deleted
// externally. See GrowFileSize() for more details.
- int64_t file_size_;
+ int64_t file_size_ = 0;
// Memory mapped related fields:
// - mmap_result_
@@ -345,22 +344,22 @@ class MemoryMappedFile {
// - mmap_size_
// Raw pointer (or error) returned by calls to mmap().
- void* mmap_result_;
+ void* mmap_result_ = nullptr;
// Offset within the file at which the current memory-mapped region starts.
- int64_t file_offset_;
+ int64_t file_offset_ = 0;
// Size that is currently memory-mapped.
// Note that the mmapped size can be larger than the underlying file size. We
// can reduce remapping by pre-mmapping a large memory and grow the file size
// later. See GrowAndRemapIfNecessary().
- int64_t mmap_size_;
+ int64_t mmap_size_ = 0;
// The difference between file_offset_ and the actual adjusted (aligned)
// offset.
// Since mmap requires the offset to be a multiple of system page size, we
// have to align file_offset_ to the last multiple of system page size.
- int64_t alignment_adjustment_;
+ int64_t alignment_adjustment_ = 0;
// E.g. system_page_size = 5, RemapImpl(/*new_file_offset=*/8, mmap_size)
//
diff --git a/icing/file/posting_list/flash-index-storage-header.h b/icing/file/posting_list/flash-index-storage-header.h
index 6bbf1ba..f7b331c 100644
--- a/icing/file/posting_list/flash-index-storage-header.h
+++ b/icing/file/posting_list/flash-index-storage-header.h
@@ -33,7 +33,7 @@ class HeaderBlock {
// The class used to access the actual header.
struct Header {
// A magic used to mark the beginning of a valid header.
- static constexpr int kMagic = 0xb0780cf4;
+ static constexpr int kMagic = 0xd1b7b293;
int magic;
int block_size;
int last_indexed_docid;
diff --git a/icing/file/posting_list/flash-index-storage_test.cc b/icing/file/posting_list/flash-index-storage_test.cc
index ef60037..203041e 100644
--- a/icing/file/posting_list/flash-index-storage_test.cc
+++ b/icing/file/posting_list/flash-index-storage_test.cc
@@ -16,9 +16,9 @@
#include <unistd.h>
-#include <algorithm>
-#include <cstdlib>
-#include <limits>
+#include <cstdint>
+#include <memory>
+#include <string>
#include <utility>
#include <vector>
@@ -27,6 +27,7 @@
#include "gtest/gtest.h"
#include "icing/file/filesystem.h"
#include "icing/file/posting_list/flash-index-storage-header.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
#include "icing/index/hit/hit.h"
#include "icing/index/main/posting-list-hit-serializer.h"
#include "icing/store/document-id.h"
@@ -213,10 +214,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits1 = {
- Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19),
- Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100),
- Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)};
+ Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits1) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder1.posting_list, hit));
@@ -237,10 +242,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits2 = {
- Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19),
- Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100),
- Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)};
+ Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits2) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder2.posting_list, hit));
@@ -273,10 +282,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) {
EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list),
IsOkAndHolds(IsEmpty()));
std::vector<Hit> hits3 = {
- Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62),
- Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45),
- Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12),
- Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)};
+ Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits3) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder3.posting_list, hit));
@@ -314,10 +327,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits1 = {
- Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19),
- Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100),
- Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)};
+ Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits1) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder1.posting_list, hit));
@@ -338,10 +355,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits2 = {
- Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19),
- Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100),
- Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)};
+ Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits2) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder2.posting_list, hit));
@@ -374,10 +395,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) {
EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list),
IsOkAndHolds(IsEmpty()));
std::vector<Hit> hits3 = {
- Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62),
- Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45),
- Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12),
- Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)};
+ Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits3) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder3.posting_list, hit));
@@ -417,10 +442,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits1 = {
- Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19),
- Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100),
- Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)};
+ Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits1) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder1.posting_list, hit));
@@ -441,10 +470,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits2 = {
- Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19),
- Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100),
- Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)};
+ Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits2) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder2.posting_list, hit));
@@ -492,10 +525,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) {
EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list),
IsOkAndHolds(IsEmpty()));
std::vector<Hit> hits3 = {
- Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62),
- Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45),
- Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12),
- Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)};
+ Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits3) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder3.posting_list, hit));
@@ -534,10 +571,14 @@ TEST_F(FlashIndexStorageTest, DifferentSizedPostingLists) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits1 = {
- Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19),
- Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100),
- Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)};
+ Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits1) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder1.posting_list, hit));
@@ -561,10 +602,14 @@ TEST_F(FlashIndexStorageTest, DifferentSizedPostingLists) {
EXPECT_THAT(flash_index_storage.empty(), IsFalse());
std::vector<Hit> hits2 = {
- Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12),
- Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19),
- Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100),
- Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)};
+ Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)};
for (const Hit& hit : hits2) {
ICING_ASSERT_OK(
serializer_->PrependHit(&posting_list_holder2.posting_list, hit));
diff --git a/icing/file/posting_list/index-block_test.cc b/icing/file/posting_list/index-block_test.cc
index ebc9ba4..d841e79 100644
--- a/icing/file/posting_list/index-block_test.cc
+++ b/icing/file/posting_list/index-block_test.cc
@@ -14,11 +14,16 @@
#include "icing/file/posting_list/index-block.h"
+#include <memory>
+#include <string>
+#include <vector>
+
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/file/filesystem.h"
-#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/main/posting-list-hit-serializer.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
@@ -67,7 +72,7 @@ class IndexBlockTest : public ::testing::Test {
};
TEST_F(IndexBlockTest, CreateFromUninitializedRegionProducesEmptyBlock) {
- constexpr int kPostingListBytes = 20;
+ constexpr int kPostingListBytes = 24;
{
// Create an IndexBlock from this newly allocated file block.
@@ -80,7 +85,7 @@ TEST_F(IndexBlockTest, CreateFromUninitializedRegionProducesEmptyBlock) {
}
TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) {
- constexpr int kPostingListBytes1 = 20;
+ constexpr int kPostingListBytes1 = 24;
// Create an IndexBlock from this newly allocated file block.
ICING_ASSERT_OK_AND_ASSIGN(IndexBlock block,
@@ -88,13 +93,13 @@ TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) {
&filesystem_, serializer_.get(), sfd_->get(),
/*offset=*/0, kBlockSize, kPostingListBytes1));
EXPECT_THAT(block.posting_list_bytes(), Eq(kPostingListBytes1));
- // There should be (4096 - 12) / 20 = 204 posting lists
- // (sizeof(BlockHeader)==12). We can store a PostingListIndex of 203 in only 8
+ // There should be (4096 - 12) / 24 = 170 posting lists
+ // (sizeof(BlockHeader)==12). We can store a PostingListIndex of 170 in only 8
// bits.
- EXPECT_THAT(block.max_num_posting_lists(), Eq(204));
+ EXPECT_THAT(block.max_num_posting_lists(), Eq(170));
EXPECT_THAT(block.posting_list_index_bits(), Eq(8));
- constexpr int kPostingListBytes2 = 200;
+ constexpr int kPostingListBytes2 = 240;
// Create an IndexBlock from this newly allocated file block.
ICING_ASSERT_OK_AND_ASSIGN(
@@ -102,22 +107,27 @@ TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) {
&filesystem_, serializer_.get(), sfd_->get(), /*offset=*/0,
kBlockSize, kPostingListBytes2));
EXPECT_THAT(block.posting_list_bytes(), Eq(kPostingListBytes2));
- // There should be (4096 - 12) / 200 = 20 posting lists
+ // There should be (4096 - 12) / 240 = 17 posting lists
// (sizeof(BlockHeader)==12). We can store a PostingListIndex of 19 in only 5
// bits.
- EXPECT_THAT(block.max_num_posting_lists(), Eq(20));
+ EXPECT_THAT(block.max_num_posting_lists(), Eq(17));
EXPECT_THAT(block.posting_list_index_bits(), Eq(5));
}
TEST_F(IndexBlockTest, IndexBlockChangesPersistAcrossInstances) {
- constexpr int kPostingListBytes = 2000;
+ constexpr int kPostingListBytes = 2004;
std::vector<Hit> test_hits{
- Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99),
- Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17),
- Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency),
+ Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
PostingListIndex allocated_index;
{
@@ -158,21 +168,31 @@ TEST_F(IndexBlockTest, IndexBlockChangesPersistAcrossInstances) {
}
TEST_F(IndexBlockTest, IndexBlockMultiplePostingLists) {
- constexpr int kPostingListBytes = 2000;
+ constexpr int kPostingListBytes = 2004;
std::vector<Hit> hits_in_posting_list1{
- Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99),
- Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17),
- Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency),
+ Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
std::vector<Hit> hits_in_posting_list2{
- Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88),
- Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2),
- Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12),
- Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency),
+ Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
PostingListIndex allocated_index_1;
PostingListIndex allocated_index_2;
@@ -242,7 +262,7 @@ TEST_F(IndexBlockTest, IndexBlockMultiplePostingLists) {
}
TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) {
- constexpr int kPostingListBytes = 2000;
+ constexpr int kPostingListBytes = 2004;
// Create an IndexBlock from this newly allocated file block.
ICING_ASSERT_OK_AND_ASSIGN(IndexBlock block,
@@ -252,11 +272,16 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) {
// Add hits to the first posting list.
std::vector<Hit> hits_in_posting_list1{
- Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99),
- Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17),
- Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency),
+ Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_1,
block.AllocatePostingList());
@@ -270,11 +295,16 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) {
// Add hits to the second posting list.
std::vector<Hit> hits_in_posting_list2{
- Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88),
- Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2),
- Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12),
- Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency),
+ Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_2,
block.AllocatePostingList());
@@ -296,9 +326,12 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) {
EXPECT_THAT(block.HasFreePostingLists(), IsOkAndHolds(IsTrue()));
std::vector<Hit> hits_in_posting_list3{
- Hit(/*section_id=*/12, /*document_id=*/0, /*term_frequency=*/88),
- Hit(/*section_id=*/17, /*document_id=*/1, Hit::kDefaultTermFrequency),
- Hit(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/2),
+ Hit(/*section_id=*/12, /*document_id=*/0, /*term_frequency=*/88,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/17, /*document_id=*/1, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
+ Hit(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/2,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false),
};
ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_3,
block.AllocatePostingList());
@@ -317,7 +350,7 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) {
}
TEST_F(IndexBlockTest, IndexBlockNextBlockIndex) {
- constexpr int kPostingListBytes = 2000;
+ constexpr int kPostingListBytes = 2004;
constexpr int kSomeBlockIndex = 22;
{
diff --git a/icing/file/version-util.cc b/icing/file/version-util.cc
index dd233e0..e750b0c 100644
--- a/icing/file/version-util.cc
+++ b/icing/file/version-util.cc
@@ -15,36 +15,47 @@
#include "icing/file/version-util.h"
#include <cstdint>
+#include <memory>
#include <string>
+#include <string_view>
+#include <unordered_set>
#include <utility>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/file/file-backed-proto.h"
#include "icing/file/filesystem.h"
#include "icing/index/index.h"
+#include "icing/proto/initialize.pb.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
namespace version_util {
-libtextclassifier3::StatusOr<VersionInfo> ReadVersion(
- const Filesystem& filesystem, const std::string& version_file_path,
+namespace {
+
+libtextclassifier3::StatusOr<VersionInfo> ReadV1VersionInfo(
+ const Filesystem& filesystem, const std::string& version_file_dir,
const std::string& index_base_dir) {
// 1. Read the version info.
+ const std::string v1_version_filepath =
+ MakeVersionFilePath(version_file_dir, kVersionFilenameV1);
VersionInfo existing_version_info(-1, -1);
- if (filesystem.FileExists(version_file_path.c_str()) &&
- !filesystem.PRead(version_file_path.c_str(), &existing_version_info,
+ if (filesystem.FileExists(v1_version_filepath.c_str()) &&
+ !filesystem.PRead(v1_version_filepath.c_str(), &existing_version_info,
sizeof(VersionInfo), /*offset=*/0)) {
- return absl_ports::InternalError("Fail to read version");
+ return absl_ports::InternalError("Failed to read v1 version file");
}
// 2. Check the Index magic to see if we're actually on version 0.
- libtextclassifier3::StatusOr<int> existing_flash_index_magic_or =
+ libtextclassifier3::StatusOr<int> existing_flash_index_magic =
Index::ReadFlashIndexMagic(&filesystem, index_base_dir);
- if (!existing_flash_index_magic_or.ok()) {
- if (absl_ports::IsNotFound(existing_flash_index_magic_or.status())) {
+ if (!existing_flash_index_magic.ok()) {
+ if (absl_ports::IsNotFound(existing_flash_index_magic.status())) {
// Flash index magic doesn't exist. In this case, we're unable to
// determine the version change state correctly (regardless of the
// existence of the version file), so invalidate VersionInfo by setting
@@ -53,9 +64,9 @@ libtextclassifier3::StatusOr<VersionInfo> ReadVersion(
return existing_version_info;
}
// Real error.
- return std::move(existing_flash_index_magic_or).status();
+ return std::move(existing_flash_index_magic).status();
}
- if (existing_flash_index_magic_or.ValueOrDie() ==
+ if (existing_flash_index_magic.ValueOrDie() ==
kVersionZeroFlashIndexMagic) {
existing_version_info.version = 0;
if (existing_version_info.max_version == -1) {
@@ -66,15 +77,124 @@ libtextclassifier3::StatusOr<VersionInfo> ReadVersion(
return existing_version_info;
}
-libtextclassifier3::Status WriteVersion(const Filesystem& filesystem,
- const std::string& version_file_path,
- const VersionInfo& version_info) {
- ScopedFd scoped_fd(filesystem.OpenForWrite(version_file_path.c_str()));
+libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadV2VersionInfo(
+ const Filesystem& filesystem, const std::string& version_file_dir) {
+ // Read the v2 version file. V2 version file stores the
+ // IcingSearchEngineVersionProto as a file-backed proto.
+ const std::string v2_version_filepath =
+ MakeVersionFilePath(version_file_dir, kVersionFilenameV2);
+ FileBackedProto<IcingSearchEngineVersionProto> v2_version_file(
+ filesystem, v2_version_filepath);
+ ICING_ASSIGN_OR_RETURN(const IcingSearchEngineVersionProto* v2_version_proto,
+ v2_version_file.Read());
+
+ return *v2_version_proto;
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadVersion(
+ const Filesystem& filesystem, const std::string& version_file_dir,
+ const std::string& index_base_dir) {
+ // 1. Read the v1 version file
+ ICING_ASSIGN_OR_RETURN(
+ VersionInfo v1_version_info,
+ ReadV1VersionInfo(filesystem, version_file_dir, index_base_dir));
+ if (!v1_version_info.IsValid()) {
+ // This happens if IcingLib's state is invalid (e.g. flash index header file
+ // is missing). Return the invalid version numbers in this case.
+ IcingSearchEngineVersionProto version_proto;
+ version_proto.set_version(v1_version_info.version);
+ version_proto.set_max_version(v1_version_info.max_version);
+ return version_proto;
+ }
+
+ // 2. Read the v2 version file
+ auto v2_version_proto = ReadV2VersionInfo(filesystem, version_file_dir);
+ if (!v2_version_proto.ok()) {
+ if (!absl_ports::IsNotFound(v2_version_proto.status())) {
+ // Real error.
+ return std::move(v2_version_proto).status();
+ }
+ // The v2 version file has not been written
+ IcingSearchEngineVersionProto version_proto;
+ if (v1_version_info.version < kFirstV2Version) {
+ // There are two scenarios for this case:
+ // 1. It's the first time that we're upgrading from a lower version to a
+ // version >= kFirstV2Version.
+ // - It's expected that the v2 version file has not been written yet in
+ // this case and we return the v1 version numbers instead.
+ // 2. We're rolling forward from a version < kFirstV2Version, after
+ // rolling back from a previous version >= kFirstV2Version, and for
+ // some unknown reason we lost the v2 version file in the previous
+ // version.
+ // - e.g. version #4 -> version #1 -> version #4, but we lost the v2
+ // file during version #1.
+ // - This is a rollforward case, but it's still fine to return the v1
+ // version number here as ShouldRebuildDerivedFiles can handle
+ // rollforwards correctly.
+ version_proto.set_version(v1_version_info.version);
+ version_proto.set_max_version(v1_version_info.max_version);
+ } else {
+ // Something weird has happened. During last initialization we were
+ // already on a version >= kFirstV2Version, so the v2 version file
+ // should have been written.
+ // Return an invalid version number in this case and trigger rebuilding
+ // everything.
+ version_proto.set_version(-1);
+ version_proto.set_max_version(v1_version_info.max_version);
+ }
+ return version_proto;
+ }
+
+ // 3. Check if versions match. If not, it means that we're rolling forward
+ // from a version < kFirstV2Version. In order to trigger rebuilding
+ // everything, we return an invalid version number in this case.
+ IcingSearchEngineVersionProto v2_version_proto_value =
+ std::move(v2_version_proto).ValueOrDie();
+ if (v1_version_info.version != v2_version_proto_value.version()) {
+ v2_version_proto_value.set_version(-1);
+ v2_version_proto_value.mutable_enabled_features()->Clear();
+ }
+
+ return v2_version_proto_value;
+}
+
+libtextclassifier3::Status WriteV1Version(const Filesystem& filesystem,
+ const std::string& version_file_dir,
+ const VersionInfo& version_info) {
+ ScopedFd scoped_fd(filesystem.OpenForWrite(
+ MakeVersionFilePath(version_file_dir, kVersionFilenameV1).c_str()));
if (!scoped_fd.is_valid() ||
!filesystem.PWrite(scoped_fd.get(), /*offset=*/0, &version_info,
sizeof(VersionInfo)) ||
!filesystem.DataSync(scoped_fd.get())) {
- return absl_ports::InternalError("Fail to write version");
+ return absl_ports::InternalError("Failed to write v1 version file");
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status WriteV2Version(
+ const Filesystem& filesystem, const std::string& version_file_dir,
+ std::unique_ptr<IcingSearchEngineVersionProto> version_proto) {
+ FileBackedProto<IcingSearchEngineVersionProto> v2_version_file(
+ filesystem, MakeVersionFilePath(version_file_dir, kVersionFilenameV2));
+ libtextclassifier3::Status v2_write_status =
+ v2_version_file.Write(std::move(version_proto));
+ if (!v2_write_status.ok()) {
+ return absl_ports::InternalError(absl_ports::StrCat(
+ "Failed to write v2 version file: ", v2_write_status.error_message()));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status DiscardVersionFiles(
+ const Filesystem& filesystem, std::string_view version_file_dir) {
+ if (!filesystem.DeleteFile(
+ MakeVersionFilePath(version_file_dir, kVersionFilenameV1).c_str()) ||
+ !filesystem.DeleteFile(
+ MakeVersionFilePath(version_file_dir, kVersionFilenameV2).c_str())) {
+ return absl_ports::InternalError("Failed to discard version files");
}
return libtextclassifier3::Status::OK;
}
@@ -102,6 +222,65 @@ StateChange GetVersionStateChange(const VersionInfo& existing_version_info,
}
}
+DerivedFilesRebuildResult CalculateRequiredDerivedFilesRebuild(
+ const IcingSearchEngineVersionProto& prev_version_proto,
+ const IcingSearchEngineVersionProto& curr_version_proto) {
+ // 1. Do version check using version and max_version numbers
+ if (ShouldRebuildDerivedFiles(GetVersionInfoFromProto(prev_version_proto),
+ curr_version_proto.version())) {
+ return DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true);
+ }
+
+ // 2. Compare the previous enabled features with the current enabled features
+ // and rebuild if there are differences.
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ prev_features;
+ for (const auto& feature : prev_version_proto.enabled_features()) {
+ prev_features.insert(feature.feature_type());
+ }
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ curr_features;
+ for (const auto& feature : curr_version_proto.enabled_features()) {
+ curr_features.insert(feature.feature_type());
+ }
+ DerivedFilesRebuildResult result;
+ for (const auto& prev_feature : prev_features) {
+ // If there is an UNKNOWN feature in the previous feature set (note that we
+ // never use UNKNOWN when writing the version proto), it means that:
+ // - The previous version proto contains a feature enum that is only defined
+ // in a newer version.
+ // - We've now rolled back to an old version that doesn't understand this
+ // new enum value, and proto serialization defaults it to 0 (UNKNOWN).
+ // - In this case we need to rebuild everything.
+ if (prev_feature == IcingSearchEngineFeatureInfoProto::UNKNOWN) {
+ return DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true);
+ }
+ if (curr_features.find(prev_feature) == curr_features.end()) {
+ DerivedFilesRebuildResult required_rebuilds =
+ GetFeatureDerivedFilesRebuildResult(prev_feature);
+ result.CombineWithOtherRebuildResultOr(required_rebuilds);
+ }
+ }
+ for (const auto& curr_feature : curr_features) {
+ if (prev_features.find(curr_feature) == prev_features.end()) {
+ DerivedFilesRebuildResult required_rebuilds =
+ GetFeatureDerivedFilesRebuildResult(curr_feature);
+ result.CombineWithOtherRebuildResultOr(required_rebuilds);
+ }
+ }
+ return result;
+}
+
bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info,
int32_t curr_version) {
StateChange state_change =
@@ -135,6 +314,10 @@ bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info,
// version 2 -> version 3 upgrade, no need to rebuild
break;
}
+ case 3: {
+ // version 3 -> version 4 upgrade, no need to rebuild
+ break;
+ }
default:
// This should not happen. Rebuild anyway if unsure.
should_rebuild |= true;
@@ -144,6 +327,56 @@ bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info,
return should_rebuild;
}
+DerivedFilesRebuildResult GetFeatureDerivedFilesRebuildResult(
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature) {
+ switch (feature) {
+ case IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR: {
+ return DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false);
+ }
+ case IcingSearchEngineFeatureInfoProto::UNKNOWN:
+ return DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true);
+ }
+}
+
+IcingSearchEngineFeatureInfoProto GetFeatureInfoProto(
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature) {
+ IcingSearchEngineFeatureInfoProto info;
+ info.set_feature_type(feature);
+
+ DerivedFilesRebuildResult result =
+ GetFeatureDerivedFilesRebuildResult(feature);
+ info.set_needs_document_store_rebuild(
+ result.needs_document_store_derived_files_rebuild);
+ info.set_needs_schema_store_rebuild(
+ result.needs_schema_store_derived_files_rebuild);
+ info.set_needs_term_index_rebuild(result.needs_term_index_rebuild);
+ info.set_needs_integer_index_rebuild(result.needs_integer_index_rebuild);
+ info.set_needs_qualified_id_join_index_rebuild(
+ result.needs_qualified_id_join_index_rebuild);
+
+ return info;
+}
+
+void AddEnabledFeatures(const IcingSearchEngineOptions& options,
+ IcingSearchEngineVersionProto* version_proto) {
+ auto* enabled_features = version_proto->mutable_enabled_features();
+ // HasPropertyOperator feature
+ if (options.build_property_existence_metadata_hits()) {
+ enabled_features->Add(GetFeatureInfoProto(
+ IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR));
+ }
+}
+
} // namespace version_util
} // namespace lib
diff --git a/icing/file/version-util.h b/icing/file/version-util.h
index b2d51df..feadaf6 100644
--- a/icing/file/version-util.h
+++ b/icing/file/version-util.h
@@ -16,11 +16,15 @@
#define ICING_FILE_VERSION_UTIL_H_
#include <cstdint>
+#include <memory>
#include <string>
+#include <string_view>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/str_cat.h"
#include "icing/file/filesystem.h"
+#include "icing/proto/initialize.pb.h"
namespace icing {
namespace lib {
@@ -32,16 +36,26 @@ namespace version_util {
// - Version 2: M-2023-09, M-2023-11, M-2024-01. Schema is compatible with v1.
// (There were no M-2023-10, M-2023-12).
// - Version 3: M-2024-02. Schema is compatible with v1 and v2.
+// - Version 4: Android V base. Schema is compatible with v1, v2 and v3.
//
+// TODO(b/314816301): Bump kVersion to 4 for Android V rollout with v2 version
+// detection
// LINT.IfChange(kVersion)
-inline static constexpr int32_t kVersion = 3;
+inline static constexpr int32_t kVersion = 4;
// LINT.ThenChange(//depot/google3/icing/schema/schema-store.cc:min_overlay_version_compatibility)
inline static constexpr int32_t kVersionOne = 1;
inline static constexpr int32_t kVersionTwo = 2;
inline static constexpr int32_t kVersionThree = 3;
+inline static constexpr int32_t kVersionFour = 4;
+
+// Version at which v2 version file is introduced.
+inline static constexpr int32_t kFirstV2Version = kVersionFour;
inline static constexpr int kVersionZeroFlashIndexMagic = 0x6dfba6ae;
+inline static constexpr std::string_view kVersionFilenameV1 = "version";
+inline static constexpr std::string_view kVersionFilenameV2 = "version2";
+
struct VersionInfo {
int32_t version;
int32_t max_version;
@@ -67,28 +81,144 @@ enum class StateChange {
kVersionZeroRollForward,
};
-// Helper method to read version info (using version file and flash index header
-// magic) from the existing data. If the state is invalid (e.g. flash index
-// header file is missing), then return an invalid VersionInfo.
+// Contains information about which derived files need to be rebuild.
+//
+// These flags only reflect whether each component should be rebuilt, but do not
+// handle any dependencies. The caller should handle the dependencies by
+// themselves.
+// e.g. - qualified id join index depends on document store derived files, but
+// it's possible to have needs_document_store_derived_files_rebuild =
+// true and needs_qualified_id_join_index_rebuild = false.
+// - The caller should know that join index should also be rebuilt in this
+// case even though needs_qualified_id_join_index_rebuild = false.
+struct DerivedFilesRebuildResult {
+ bool needs_document_store_derived_files_rebuild = false;
+ bool needs_schema_store_derived_files_rebuild = false;
+ bool needs_term_index_rebuild = false;
+ bool needs_integer_index_rebuild = false;
+ bool needs_qualified_id_join_index_rebuild = false;
+
+ DerivedFilesRebuildResult() = default;
+
+ explicit DerivedFilesRebuildResult(
+ bool needs_document_store_derived_files_rebuild_in,
+ bool needs_schema_store_derived_files_rebuild_in,
+ bool needs_term_index_rebuild_in, bool needs_integer_index_rebuild_in,
+ bool needs_qualified_id_join_index_rebuild_in)
+ : needs_document_store_derived_files_rebuild(
+ needs_document_store_derived_files_rebuild_in),
+ needs_schema_store_derived_files_rebuild(
+ needs_schema_store_derived_files_rebuild_in),
+ needs_term_index_rebuild(needs_term_index_rebuild_in),
+ needs_integer_index_rebuild(needs_integer_index_rebuild_in),
+ needs_qualified_id_join_index_rebuild(
+ needs_qualified_id_join_index_rebuild_in) {}
+
+ bool IsRebuildNeeded() const {
+ return needs_document_store_derived_files_rebuild ||
+ needs_schema_store_derived_files_rebuild ||
+ needs_term_index_rebuild || needs_integer_index_rebuild ||
+ needs_qualified_id_join_index_rebuild;
+ }
+
+ bool operator==(const DerivedFilesRebuildResult& other) const {
+ return needs_document_store_derived_files_rebuild ==
+ other.needs_document_store_derived_files_rebuild &&
+ needs_schema_store_derived_files_rebuild ==
+ other.needs_schema_store_derived_files_rebuild &&
+ needs_term_index_rebuild == other.needs_term_index_rebuild &&
+ needs_integer_index_rebuild == other.needs_integer_index_rebuild &&
+ needs_qualified_id_join_index_rebuild ==
+ other.needs_qualified_id_join_index_rebuild;
+ }
+
+ void CombineWithOtherRebuildResultOr(const DerivedFilesRebuildResult& other) {
+ needs_document_store_derived_files_rebuild =
+ needs_document_store_derived_files_rebuild ||
+ other.needs_document_store_derived_files_rebuild;
+ needs_schema_store_derived_files_rebuild =
+ needs_schema_store_derived_files_rebuild ||
+ other.needs_schema_store_derived_files_rebuild;
+ needs_term_index_rebuild =
+ needs_term_index_rebuild || other.needs_term_index_rebuild;
+ needs_integer_index_rebuild =
+ needs_integer_index_rebuild || other.needs_integer_index_rebuild;
+ needs_qualified_id_join_index_rebuild =
+ needs_qualified_id_join_index_rebuild ||
+ other.needs_qualified_id_join_index_rebuild;
+ }
+};
+
+// There are two icing version files:
+// 1. V1 version file contains version and max_version info of the existing
+// data.
+// 2. V2 version file writes the version information using
+// FileBackedProto<IcingSearchEngineVersionProto>. This contains information
+// about the version's enabled trunk stable features in addition to the
+// version numbers written for V1.
+//
+// Both version files must be written to maintain backwards compatibility.
+inline std::string MakeVersionFilePath(std::string_view version_file_dir,
+ std::string_view version_file_name) {
+ return absl_ports::StrCat(version_file_dir, "/", version_file_name);
+}
+
+// Returns a VersionInfo from a given IcingSearchEngineVersionProto.
+inline VersionInfo GetVersionInfoFromProto(
+ const IcingSearchEngineVersionProto& version_proto) {
+ return VersionInfo(version_proto.version(), version_proto.max_version());
+}
+
+
+// Reads the IcingSearchEngineVersionProto from the version files of the
+// existing data.
+//
+// This method reads both the v1 and v2 version files, and returns the v1
+// version numbers in the absence of the v2 version file. If there is a mismatch
+// between the v1 and v2 version numbers, or if the state is invalid (e.g. flash
+// index header file is missing), then an invalid VersionInfo is returned.
//
// RETURNS:
-// - Existing data's VersionInfo on success
+// - Existing data's IcingSearchEngineVersionProto on success
// - INTERNAL_ERROR on I/O errors
-libtextclassifier3::StatusOr<VersionInfo> ReadVersion(
- const Filesystem& filesystem, const std::string& version_file_path,
+libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadVersion(
+ const Filesystem& filesystem, const std::string& version_file_dir,
const std::string& index_base_dir);
-// Helper method to write version file.
+// Writes the v1 version file. V1 version file is written for all versions and
+// contains only Icing's VersionInfo (version number and max_version)
//
// RETURNS:
// - OK on success
// - INTERNAL_ERROR on I/O errors
-libtextclassifier3::Status WriteVersion(const Filesystem& filesystem,
- const std::string& version_file_path,
- const VersionInfo& version_info);
+libtextclassifier3::Status WriteV1Version(const Filesystem& filesystem,
+ const std::string& version_file_dir,
+ const VersionInfo& version_info);
-// Helper method to determine the change state between the existing data version
-// and the current code version.
+// Writes the v2 version file. V2 version file writes the version information
+// using FileBackedProto<IcingSearchEngineVersionProto>.
+//
+// REQUIRES: version_proto.version >= kFirstV2Version. We implement v2 version
+// checking in kFirstV2Version, so callers will always use a version # greater
+// than this.
+//
+// RETURNS:
+// - OK on success
+// - INTERNAL_ERROR on I/O errors
+libtextclassifier3::Status WriteV2Version(
+ const Filesystem& filesystem, const std::string& version_file_dir,
+ std::unique_ptr<IcingSearchEngineVersionProto> version_proto);
+
+// Deletes Icing's version files from version_file_dir.
+//
+// Returns:
+// - OK on success
+// - INTERNAL_ERROR on I/O error
+libtextclassifier3::Status DiscardVersionFiles(
+ const Filesystem& filesystem, std::string_view version_file_dir);
+
+// Determines the change state between the existing data version and the current
+// code version.
//
// REQUIRES: curr_version > 0. We implement version checking in version 1, so
// the callers (except unit tests) will always use a version # greater than 0.
@@ -97,7 +227,19 @@ libtextclassifier3::Status WriteVersion(const Filesystem& filesystem,
StateChange GetVersionStateChange(const VersionInfo& existing_version_info,
int32_t curr_version = kVersion);
-// Helper method to determine whether Icing should rebuild all derived files.
+// Determines the derived files that need to be rebuilt between Icing's existing
+// data based on previous data version and enabled features, and the current
+// code version and enabled features.
+//
+// REQUIRES: curr_version >= kFirstV2Version. We implement v2 version checking
+// in kFirstV2Version, so callers will always use a version # greater than this.
+//
+// RETURNS: DerivedFilesRebuildResult
+DerivedFilesRebuildResult CalculateRequiredDerivedFilesRebuild(
+ const IcingSearchEngineVersionProto& prev_version_proto,
+ const IcingSearchEngineVersionProto& curr_version_proto);
+
+// Determines whether Icing should rebuild all derived files.
// Sometimes it is not required to rebuild derived files when
// roll-forward/upgrading. This function "encodes" upgrade paths and checks if
// the roll-forward/upgrading requires derived files to be rebuilt or not.
@@ -107,6 +249,24 @@ StateChange GetVersionStateChange(const VersionInfo& existing_version_info,
bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info,
int32_t curr_version = kVersion);
+// Returns the derived files rebuilds required for a given feature.
+DerivedFilesRebuildResult GetFeatureDerivedFilesRebuildResult(
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature);
+
+// Constructs the IcingSearchEngineFeatureInfoProto for a given feature.
+IcingSearchEngineFeatureInfoProto GetFeatureInfoProto(
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature);
+
+// Populates the enabled_features field for an IcingSearchEngineFeatureInfoProto
+// based on icing's initialization flag options.
+//
+// All enabled features are converted into an IcingSearchEngineFeatureInfoProto
+// and returned in IcingSearchEngineVersionProto::enabled_features. A conversion
+// should be added for each trunk stable feature flag defined in
+// IcingSearchEngineOptions.
+void AddEnabledFeatures(const IcingSearchEngineOptions& options,
+ IcingSearchEngineVersionProto* version_proto);
+
} // namespace version_util
} // namespace lib
diff --git a/icing/file/version-util_test.cc b/icing/file/version-util_test.cc
index 9dedb1d..f2922e2 100644
--- a/icing/file/version-util_test.cc
+++ b/icing/file/version-util_test.cc
@@ -14,14 +14,19 @@
#include "icing/file/version-util.h"
+#include <cstdint>
+#include <memory>
#include <optional>
#include <string>
+#include <unordered_set>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/file/filesystem.h"
#include "icing/file/posting_list/flash-index-storage-header.h"
+#include "icing/portable/equals-proto.h"
+#include "icing/proto/initialize.pb.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
@@ -32,19 +37,38 @@ namespace version_util {
namespace {
using ::testing::Eq;
+using ::testing::IsEmpty;
using ::testing::IsFalse;
using ::testing::IsTrue;
+IcingSearchEngineVersionProto MakeTestVersionProto(
+ const VersionInfo& version_info,
+ const std::unordered_set<
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>& features_set) {
+ IcingSearchEngineVersionProto version_proto;
+ version_proto.set_version(version_info.version);
+ version_proto.set_max_version(version_info.max_version);
+
+ auto* enabled_features = version_proto.mutable_enabled_features();
+ for (const auto& feature : features_set) {
+ enabled_features->Add(GetFeatureInfoProto(feature));
+ }
+ return version_proto;
+}
+
struct VersionUtilReadVersionTestParam {
- std::optional<VersionInfo> existing_version_info;
+ std::optional<VersionInfo> existing_v1_version_info;
+ std::optional<VersionInfo> existing_v2_version_info;
std::optional<int> existing_flash_index_magic;
VersionInfo expected_version_info;
explicit VersionUtilReadVersionTestParam(
- std::optional<VersionInfo> existing_version_info_in,
+ std::optional<VersionInfo> existing_v1_version_info_in,
+ std::optional<VersionInfo> existing_v2_version_info_in,
std::optional<int> existing_flash_index_magic_in,
VersionInfo expected_version_info_in)
- : existing_version_info(std::move(existing_version_info_in)),
+ : existing_v1_version_info(std::move(existing_v1_version_info_in)),
+ existing_v2_version_info(std::move(existing_v2_version_info_in)),
existing_flash_index_magic(std::move(existing_flash_index_magic_in)),
expected_version_info(std::move(expected_version_info_in)) {}
};
@@ -54,7 +78,6 @@ class VersionUtilReadVersionTest
protected:
void SetUp() override {
base_dir_ = GetTestTempDir() + "/version_util_test";
- version_file_path_ = base_dir_ + "/version";
index_path_ = base_dir_ + "/index";
ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()));
@@ -68,19 +91,28 @@ class VersionUtilReadVersionTest
Filesystem filesystem_;
std::string base_dir_;
- std::string version_file_path_;
std::string index_path_;
};
TEST_P(VersionUtilReadVersionTest, ReadVersion) {
const VersionUtilReadVersionTestParam& param = GetParam();
+ IcingSearchEngineVersionProto dummy_version_proto;
- // Prepare version file and flash index file.
- if (param.existing_version_info.has_value()) {
- ICING_ASSERT_OK(WriteVersion(filesystem_, version_file_path_,
- param.existing_version_info.value()));
+ if (param.existing_v1_version_info.has_value()) {
+ ICING_ASSERT_OK(WriteV1Version(filesystem_, base_dir_,
+ param.existing_v1_version_info.value()));
+ }
+ if (param.existing_v2_version_info.has_value()) {
+ dummy_version_proto = MakeTestVersionProto(
+ param.existing_v2_version_info.value(),
+ {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR,
+ IcingSearchEngineFeatureInfoProto::UNKNOWN});
+ ICING_ASSERT_OK(WriteV2Version(
+ filesystem_, base_dir_,
+ std::make_unique<IcingSearchEngineVersionProto>(dummy_version_proto)));
}
+ // Prepare flash index file.
if (param.existing_flash_index_magic.has_value()) {
HeaderBlock header_block(&filesystem_, /*block_size=*/4096);
header_block.header()->magic = param.existing_flash_index_magic.value();
@@ -94,10 +126,22 @@ TEST_P(VersionUtilReadVersionTest, ReadVersion) {
ASSERT_TRUE(header_block.Write(sfd.get()));
}
- ICING_ASSERT_OK_AND_ASSIGN(
- VersionInfo version_info,
- ReadVersion(filesystem_, version_file_path_, index_path_));
- EXPECT_THAT(version_info, Eq(param.expected_version_info));
+ ICING_ASSERT_OK_AND_ASSIGN(IcingSearchEngineVersionProto version_proto,
+ ReadVersion(filesystem_, base_dir_, index_path_));
+ if (param.existing_v2_version_info.has_value() &&
+ param.expected_version_info.version ==
+ param.existing_v2_version_info.value().version) {
+ EXPECT_THAT(version_proto,
+ portable_equals_proto::EqualsProto(dummy_version_proto));
+ } else {
+ // We're returning the version from v1 version file, or an invalid version.
+ // version_proto.enabled_features should be empty in this case.
+ EXPECT_THAT(version_proto.version(),
+ Eq(param.expected_version_info.version));
+ EXPECT_THAT(version_proto.max_version(),
+ Eq(param.expected_version_info.max_version));
+ EXPECT_THAT(version_proto.enabled_features(), IsEmpty());
+ }
}
INSTANTIATE_TEST_SUITE_P(
@@ -107,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P(
// - Flash index doesn't exist
// - Result: version -1, max_version -1 (invalid)
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::nullopt,
+ /*existing_v1_version_info_in=*/std::nullopt,
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/std::nullopt,
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/-1, /*max_version=*/-1)),
@@ -116,7 +161,8 @@ INSTANTIATE_TEST_SUITE_P(
// - Flash index exists with version 0 magic
// - Result: version 0, max_version 0
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::nullopt,
+ /*existing_v1_version_info_in=*/std::nullopt,
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic),
/*expected_version_info_in=*/
@@ -126,65 +172,157 @@ INSTANTIATE_TEST_SUITE_P(
// - Flash index exists with non version 0 magic
// - Result: version -1, max_version -1 (invalid)
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::nullopt,
+ /*existing_v1_version_info_in=*/std::nullopt,
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/-1, /*max_version=*/-1)),
- // - Version file exists
+ // - Version file v1 exists
// - Flash index doesn't exist
// - Result: version -1, max_version 1 (invalid)
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::make_optional<VersionInfo>(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
/*version_in=*/1, /*max_version=*/1),
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/std::nullopt,
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/-1, /*max_version=*/1)),
- // - Version file exists: version 1, max_version 1
+ // - Version file v1 exists: version 1, max_version 1
// - Flash index exists with version 0 magic
// - Result: version 0, max_version 1
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::make_optional<VersionInfo>(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
/*version_in=*/1, /*max_version=*/1),
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic),
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/0, /*max_version=*/1)),
- // - Version file exists: version 2, max_version 3
+ // - Version file v1 exists: version 2, max_version 3
// - Flash index exists with version 0 magic
// - Result: version 0, max_version 3
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::make_optional<VersionInfo>(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
/*version_in=*/2, /*max_version=*/3),
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic),
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/0, /*max_version=*/3)),
- // - Version file exists: version 1, max_version 1
+ // - Version file v1 exists: version 1, max_version 1
// - Flash index exists with non version 0 magic
// - Result: version 1, max_version 1
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::make_optional<VersionInfo>(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
/*version_in=*/1, /*max_version=*/1),
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
/*expected_version_info_in=*/
VersionInfo(/*version_in=*/1, /*max_version=*/1)),
- // - Version file exists: version 2, max_version 3
+ // - Version file v1 exists: version 2, max_version 3
// - Flash index exists with non version 0 magic
// - Result: version 2, max_version 3
VersionUtilReadVersionTestParam(
- /*existing_version_info_in=*/std::make_optional<VersionInfo>(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
/*version_in=*/2, /*max_version=*/3),
+ /*existing_v2_version_info_in=*/std::nullopt,
+ /*existing_flash_index_magic_in=*/
+ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/2, /*max_version=*/3)),
+
+ // - Version file v1 exists: version 2, max_version 4
+ // - Version file v2 exists: version 4, max_version 4
+ // - Flash index exists with non version 0 magic
+ // - Result: version -1, max_version 4
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
+ /*version_in=*/2, /*max_version=*/4),
+ /*existing_v2_version_info_in=*/
+ std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_flash_index_magic_in=*/
+ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/-1, /*max_version=*/4)),
+
+ // - Version file v1 exists: version 4, max_version 4
+ // - Version file v2 exists: version 4, max_version 4
+ // - Flash index exists with version 0 magic
+ // - Result: version -1, max_version 4
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_v2_version_info_in=*/
+ std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_flash_index_magic_in=*/
+ std::make_optional<int>(kVersionZeroFlashIndexMagic),
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/-1, /*max_version=*/4)),
+
+ // - Version file v1 exists: version 4, max_version 4
+ // - Version file v2 exists: version 4, max_version 4
+ // - Flash index exists with non version 0 magic
+ // - Result: version 4, max_version 4
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_v2_version_info_in=*/
+ std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_flash_index_magic_in=*/
+ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/4, /*max_version=*/4)),
+
+ // - Version file v1 exists: version 4, max_version 4
+ // - Version file v2 does not exist
+ // - Flash index exists with non version 0 magic
+ // - Result: version -1, max_version 4
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_v2_version_info_in=*/std::nullopt,
/*existing_flash_index_magic_in=*/
std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
/*expected_version_info_in=*/
- VersionInfo(/*version_in=*/2, /*max_version=*/3))));
+ VersionInfo(/*version_in=*/-1, /*max_version=*/4)),
+
+ // - Version file v1 does not exist
+ // - Version file v2 exists: version 4, max_version 4
+ // - Flash index exists with non version 0 magic
+ // - Result: version -1, max_version -1
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::nullopt,
+ /*existing_v2_version_info_in=*/
+ std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_flash_index_magic_in=*/
+ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1),
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/-1, /*max_version=*/-1)),
+
+ // - Version file v1 doesn't exist
+ // - Version file v2 exists: version 4, max_version 4
+ // - Flash index doesn't exist
+ // - Result: version -1, max_version -1 (invalid since flash index
+ // doesn't exist)
+ VersionUtilReadVersionTestParam(
+ /*existing_v1_version_info_in=*/std::nullopt,
+ /*existing_v2_version_info_in=*/
+ std::make_optional<VersionInfo>(
+ /*version_in=*/4, /*max_version=*/4),
+ /*existing_flash_index_magic_in=*/std::nullopt,
+ /*expected_version_info_in=*/
+ VersionInfo(/*version_in=*/-1, /*max_version=*/-1))));
struct VersionUtilStateChangeTestParam {
VersionInfo existing_version_info;
@@ -389,6 +527,240 @@ INSTANTIATE_TEST_SUITE_P(
/*curr_version_in=*/2,
/*expected_state_change_in=*/StateChange::kRollBack)));
+struct VersionUtilDerivedFilesRebuildTestParam {
+ int32_t existing_version;
+ int32_t max_version;
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ existing_enabled_features;
+ int32_t curr_version;
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ curr_enabled_features;
+ DerivedFilesRebuildResult expected_derived_files_rebuild_result;
+
+ explicit VersionUtilDerivedFilesRebuildTestParam(
+ int32_t existing_version_in, int32_t max_version_in,
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ existing_enabled_features_in,
+ int32_t curr_version_in,
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ curr_enabled_features_in,
+ DerivedFilesRebuildResult expected_derived_files_rebuild_result_in)
+ : existing_version(existing_version_in),
+ max_version(max_version_in),
+ existing_enabled_features(std::move(existing_enabled_features_in)),
+ curr_version(curr_version_in),
+ curr_enabled_features(std::move(curr_enabled_features_in)),
+ expected_derived_files_rebuild_result(
+ std::move(expected_derived_files_rebuild_result_in)) {}
+};
+
+class VersionUtilDerivedFilesRebuildTest
+ : public ::testing::TestWithParam<VersionUtilDerivedFilesRebuildTestParam> {
+};
+
+TEST_P(VersionUtilDerivedFilesRebuildTest,
+ CalculateRequiredDerivedFilesRebuild) {
+ const VersionUtilDerivedFilesRebuildTestParam& param = GetParam();
+
+ EXPECT_THAT(CalculateRequiredDerivedFilesRebuild(
+ /*prev_version_proto=*/MakeTestVersionProto(
+ VersionInfo(param.existing_version, param.max_version),
+ param.existing_enabled_features),
+ /*curr_version_proto=*/
+ MakeTestVersionProto(
+ VersionInfo(param.curr_version, param.max_version),
+ param.curr_enabled_features)),
+ Eq(param.expected_derived_files_rebuild_result));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ VersionUtilDerivedFilesRebuildTest, VersionUtilDerivedFilesRebuildTest,
+ testing::Values(
+ // - Existing version -1, max_version -1 (invalid)
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: rebuild everything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/-1, /*max_version_in=*/-1,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true)),
+
+ // - Existing version -1, max_version 2 (invalid)
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: rebuild everything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/-1, /*max_version_in=*/-1,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true)),
+
+ // - Existing version 3, max_version 3 (pre v2 version check)
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: don't rebuild anything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/3, /*max_version_in=*/3,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/false,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)),
+
+ // - Existing version 3, max_version 3 (pre v2 version check)
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR}
+ //
+ // - Result: rebuild term index
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/3, /*max_version_in=*/3,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/
+ {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)),
+
+ // - Existing version 4, max_version 4
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: don't rebuild anything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/4, /*max_version_in=*/4,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/false,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)),
+
+ // - Existing version 4, max_version 5
+ // - Existing enabled features = {}
+ // - Current version = 5
+ // - Current enabled features = {}
+ //
+ // - Result: Rollforward -- rebuild everything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/4, /*max_version_in=*/5,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/5,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true)),
+
+ // - Existing version 5, max_version 5
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: Rollback -- rebuild everything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/5, /*max_version_in=*/5,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true)),
+
+ // - Existing version 4, max_version 4
+ // - Existing enabled features = {}
+ // - Current version = 4
+ // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR}
+ //
+ // - Result: rebuild term index
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/4, /*max_version_in=*/4,
+ /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/
+ {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)),
+
+ // - Existing version 4, max_version 4
+ // - Existing enabled features = {FEATURE_HAS_PROPERTY_OPERATOR}
+ // - Current version = 4
+ // - Current enabled features = {}
+ //
+ // - Result: rebuild term index
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/4, /*max_version_in=*/4,
+ /*existing_enabled_features_in=*/
+ {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR},
+ /*curr_version_in=*/4, /*curr_enabled_features_in=*/{},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)),
+
+ // - Existing version 4, max_version 4
+ // - Existing enabled features = {UNKNOWN}
+ // - Current version = 4
+ // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR}
+ //
+ // - Result: rebuild everything
+ VersionUtilDerivedFilesRebuildTestParam(
+ /*existing_version_in=*/4, /*max_version_in=*/4,
+ /*existing_enabled_features_in=*/
+ {IcingSearchEngineFeatureInfoProto::UNKNOWN}, /*curr_version_in=*/4,
+ /*curr_enabled_features_in=*/
+ {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR},
+ /*expected_derived_files_rebuild_result_in=*/
+ DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true))));
+
TEST(VersionUtilTest, ShouldRebuildDerivedFilesUndeterminedVersion) {
EXPECT_THAT(
ShouldRebuildDerivedFiles(VersionInfo(-1, -1), /*curr_version=*/1),
@@ -475,8 +847,80 @@ TEST(VersionUtilTest, Upgrade) {
EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionOne, kVersionOne),
/*curr_version=*/kVersionThree),
IsFalse());
+
+ // kVersionThree -> kVersionFour.
+ EXPECT_THAT(
+ ShouldRebuildDerivedFiles(VersionInfo(kVersionThree, kVersionThree),
+ /*curr_version=*/kVersionFour),
+ IsFalse());
+
+ // kVersionTwo -> kVersionFour
+ EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionTwo, kVersionTwo),
+ /*curr_version=*/kVersionFour),
+ IsFalse());
+
+ // kVersionOne -> kVersionFour.
+ EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionOne, kVersionOne),
+ /*curr_version=*/kVersionFour),
+ IsFalse());
+}
+
+TEST(VersionUtilTest, GetFeatureDerivedFilesRebuildResult_unknown) {
+ EXPECT_THAT(GetFeatureDerivedFilesRebuildResult(
+ IcingSearchEngineFeatureInfoProto::UNKNOWN),
+ Eq(DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/true,
+ /*needs_schema_store_derived_files_rebuild=*/true,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/true,
+ /*needs_qualified_id_join_index_rebuild=*/true)));
}
+TEST(VersionUtilTest,
+ GetFeatureDerivedFilesRebuildResult_featureHasPropertyOperator) {
+ EXPECT_THAT(
+ GetFeatureDerivedFilesRebuildResult(
+ IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR),
+ Eq(DerivedFilesRebuildResult(
+ /*needs_document_store_derived_files_rebuild=*/false,
+ /*needs_schema_store_derived_files_rebuild=*/false,
+ /*needs_term_index_rebuild=*/true,
+ /*needs_integer_index_rebuild=*/false,
+ /*needs_qualified_id_join_index_rebuild=*/false)));
+}
+
+class VersionUtilFeatureProtoTest
+ : public ::testing::TestWithParam<
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> {};
+
+TEST_P(VersionUtilFeatureProtoTest, GetFeatureInfoProto) {
+ IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature_type =
+ GetParam();
+ DerivedFilesRebuildResult rebuild_result =
+ GetFeatureDerivedFilesRebuildResult(feature_type);
+
+ IcingSearchEngineFeatureInfoProto feature_info =
+ GetFeatureInfoProto(feature_type);
+ EXPECT_THAT(feature_info.feature_type(), Eq(feature_type));
+
+ EXPECT_THAT(feature_info.needs_document_store_rebuild(),
+ Eq(rebuild_result.needs_document_store_derived_files_rebuild));
+ EXPECT_THAT(feature_info.needs_schema_store_rebuild(),
+ Eq(rebuild_result.needs_schema_store_derived_files_rebuild));
+ EXPECT_THAT(feature_info.needs_term_index_rebuild(),
+ Eq(rebuild_result.needs_term_index_rebuild));
+ EXPECT_THAT(feature_info.needs_integer_index_rebuild(),
+ Eq(rebuild_result.needs_integer_index_rebuild));
+ EXPECT_THAT(feature_info.needs_qualified_id_join_index_rebuild(),
+ Eq(rebuild_result.needs_qualified_id_join_index_rebuild));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ VersionUtilFeatureProtoTest, VersionUtilFeatureProtoTest,
+ testing::Values(
+ IcingSearchEngineFeatureInfoProto::UNKNOWN,
+ IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR));
+
} // namespace
} // namespace version_util
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc
index 72be4e9..89caaf1 100644
--- a/icing/icing-search-engine.cc
+++ b/icing/icing-search-engine.cc
@@ -14,7 +14,10 @@
#include "icing/icing-search-engine.h"
+#include <algorithm>
+#include <cstddef>
#include <cstdint>
+#include <functional>
#include <memory>
#include <string>
#include <string_view>
@@ -34,6 +37,8 @@
#include "icing/file/filesystem.h"
#include "icing/file/version-util.h"
#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embedding-indexing-handler.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/index-processor.h"
#include "icing/index/index.h"
@@ -41,12 +46,16 @@
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/numeric/integer-index.h"
#include "icing/index/term-indexing-handler.h"
+#include "icing/index/term-metadata.h"
+#include "icing/jni/jni-cache.h"
+#include "icing/join/join-children-fetcher.h"
#include "icing/join/join-processor.h"
#include "icing/join/qualified-id-join-index-impl-v1.h"
#include "icing/join/qualified-id-join-index-impl-v2.h"
#include "icing/join/qualified-id-join-index.h"
#include "icing/join/qualified-id-join-indexing-handler.h"
#include "icing/legacy/index/icing-filesystem.h"
+#include "icing/performance-configuration.h"
#include "icing/portable/endian.h"
#include "icing/proto/debug.pb.h"
#include "icing/proto/document.pb.h"
@@ -73,9 +82,8 @@
#include "icing/result/projector.h"
#include "icing/result/result-adjustment-info.h"
#include "icing/result/result-retriever-v2.h"
+#include "icing/result/result-state-manager.h"
#include "icing/schema/schema-store.h"
-#include "icing/schema/schema-util.h"
-#include "icing/schema/section.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/priority-queue-scored-document-hits-ranker.h"
#include "icing/scoring/scored-document-hit.h"
@@ -84,11 +92,8 @@
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter-factory.h"
-#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
-#include "icing/transform/normalizer.h"
#include "icing/util/clock.h"
-#include "icing/util/crc32.h"
#include "icing/util/data-loss.h"
#include "icing/util/logging.h"
#include "icing/util/status-macros.h"
@@ -100,12 +105,12 @@ namespace lib {
namespace {
-constexpr std::string_view kVersionFilename = "version";
constexpr std::string_view kDocumentSubfolderName = "document_dir";
constexpr std::string_view kIndexSubfolderName = "index_dir";
constexpr std::string_view kIntegerIndexSubfolderName = "integer_index_dir";
constexpr std::string_view kQualifiedIdJoinIndexSubfolderName =
"qualified_id_join_index_dir";
+constexpr std::string_view kEmbeddingIndexSubfolderName = "embedding_index_dir";
constexpr std::string_view kSchemaSubfolderName = "schema_dir";
constexpr std::string_view kSetSchemaMarkerFilename = "set_schema_marker";
constexpr std::string_view kInitMarkerFilename = "init_marker";
@@ -253,12 +258,6 @@ CreateQualifiedIdJoinIndex(const Filesystem& filesystem,
}
}
-// Version file is a single file under base_dir containing version info of the
-// existing data.
-std::string MakeVersionFilePath(const std::string& base_dir) {
- return absl_ports::StrCat(base_dir, "/", kVersionFilename);
-}
-
// Document store files are in a standalone subfolder for easier file
// management. We can delete and recreate the subfolder and not touch/affect
// anything else.
@@ -296,6 +295,11 @@ std::string MakeQualifiedIdJoinIndexWorkingPath(const std::string& base_dir) {
return absl_ports::StrCat(base_dir, "/", kQualifiedIdJoinIndexSubfolderName);
}
+// Working path for embedding index.
+std::string MakeEmbeddingIndexWorkingPath(const std::string& base_dir) {
+ return absl_ports::StrCat(base_dir, "/", kEmbeddingIndexSubfolderName);
+}
+
// SchemaStore files are in a standalone subfolder for easier file management.
// We can delete and recreate the subfolder and not touch/affect anything
// else.
@@ -482,6 +486,7 @@ void IcingSearchEngine::ResetMembers() {
index_.reset();
integer_index_.reset();
qualified_id_join_index_.reset();
+ embedding_index_.reset();
}
libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile(
@@ -604,32 +609,49 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
return status;
}
- // Read version file and determine the state change.
- const std::string version_filepath = MakeVersionFilePath(options_.base_dir());
+ // Do version and flags compatibility check
+ // Read version file, determine the state change and rebuild derived files if
+ // needed.
const std::string index_dir = MakeIndexDirectoryPath(options_.base_dir());
ICING_ASSIGN_OR_RETURN(
- version_util::VersionInfo version_info,
- version_util::ReadVersion(*filesystem_, version_filepath, index_dir));
+ IcingSearchEngineVersionProto stored_version_proto,
+ version_util::ReadVersion(
+ *filesystem_, /*version_file_dir=*/options_.base_dir(), index_dir));
+ version_util::VersionInfo stored_version_info =
+ version_util::GetVersionInfoFromProto(stored_version_proto);
version_util::StateChange version_state_change =
- version_util::GetVersionStateChange(version_info);
+ version_util::GetVersionStateChange(stored_version_info);
+
+ // Construct icing's current version proto based on the current code version
+ IcingSearchEngineVersionProto current_version_proto;
+ current_version_proto.set_version(version_util::kVersion);
+ current_version_proto.set_max_version(
+ std::max(stored_version_info.max_version, version_util::kVersion));
+ version_util::AddEnabledFeatures(options_, &current_version_proto);
+
+ // Step 1: If versions are incompatible, migrate schema according to the
+ // version state change.
if (version_state_change != version_util::StateChange::kCompatible) {
- // Step 1: migrate schema according to the version state change.
ICING_RETURN_IF_ERROR(SchemaStore::MigrateSchema(
filesystem_.get(), MakeSchemaDirectoryPath(options_.base_dir()),
version_state_change, version_util::kVersion));
+ }
- // Step 2: discard all derived data if needed rebuild.
- if (version_util::ShouldRebuildDerivedFiles(version_info)) {
- ICING_RETURN_IF_ERROR(DiscardDerivedFiles());
- }
+ // Step 2: Discard derived files that need to be rebuilt
+ version_util::DerivedFilesRebuildResult required_derived_files_rebuild =
+ version_util::CalculateRequiredDerivedFilesRebuild(stored_version_proto,
+ current_version_proto);
+ ICING_RETURN_IF_ERROR(DiscardDerivedFiles(required_derived_files_rebuild));
- // Step 3: update version file
- version_util::VersionInfo new_version_info(
- version_util::kVersion,
- std::max(version_info.max_version, version_util::kVersion));
- ICING_RETURN_IF_ERROR(version_util::WriteVersion(
- *filesystem_, version_filepath, new_version_info));
- }
+ // Step 3: update version files. We need to update both the V1 and V2
+ // version files.
+ ICING_RETURN_IF_ERROR(version_util::WriteV1Version(
+ *filesystem_, /*version_file_dir=*/options_.base_dir(),
+ version_util::GetVersionInfoFromProto(current_version_proto)));
+ ICING_RETURN_IF_ERROR(version_util::WriteV2Version(
+ *filesystem_, /*version_file_dir=*/options_.base_dir(),
+ std::make_unique<IcingSearchEngineVersionProto>(
+ std::move(current_version_proto))));
ICING_RETURN_IF_ERROR(InitializeSchemaStore(initialize_stats));
@@ -655,15 +677,19 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
MakeIntegerIndexWorkingPath(options_.base_dir());
const std::string qualified_id_join_index_dir =
MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir());
+ const std::string embedding_index_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
if (!filesystem_->DeleteDirectoryRecursively(doc_store_dir.c_str()) ||
!filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) ||
!IntegerIndex::Discard(*filesystem_, integer_index_dir).ok() ||
!QualifiedIdJoinIndex::Discard(*filesystem_,
qualified_id_join_index_dir)
- .ok()) {
+ .ok() ||
+ !EmbeddingIndex::Discard(*filesystem_, embedding_index_dir).ok()) {
return absl_ports::InternalError(absl_ports::StrCat(
"Could not delete directories: ", index_dir, ", ", integer_index_dir,
- ", ", qualified_id_join_index_dir, " and ", doc_store_dir));
+ ", ", qualified_id_join_index_dir, ", ", embedding_index_dir, " and ",
+ doc_store_dir));
}
ICING_ASSIGN_OR_RETURN(
bool document_store_derived_files_regenerated,
@@ -688,10 +714,9 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
// We're going to need to build the index from scratch. So just delete its
// directory now.
// Discard index directory and instantiate a new one.
- Index::Options index_options(
- index_dir, options_.index_merge_size(),
- options_.lite_index_sort_at_indexing(), options_.lite_index_sort_size(),
- options_.build_property_existence_metadata_hits());
+ Index::Options index_options(index_dir, options_.index_merge_size(),
+ options_.lite_index_sort_at_indexing(),
+ options_.lite_index_sort_size());
if (!filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) ||
!filesystem_->CreateDirectoryRecursively(index_dir.c_str())) {
return absl_ports::InternalError(
@@ -722,6 +747,15 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
CreateQualifiedIdJoinIndex(
*filesystem_, std::move(qualified_id_join_index_dir), options_));
+ // Discard embedding index directory and instantiate a new one.
+ std::string embedding_index_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
+ ICING_RETURN_IF_ERROR(
+ EmbeddingIndex::Discard(*filesystem_, embedding_index_dir));
+ ICING_ASSIGN_OR_RETURN(
+ embedding_index_,
+ EmbeddingIndex::Create(filesystem_.get(), embedding_index_dir));
+
std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer();
IndexRestorationResult restore_result = RestoreIndexIfNeeded();
index_init_status = std::move(restore_result.status);
@@ -744,6 +778,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
initialize_stats->set_qualified_id_join_index_restoration_cause(
InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
+ initialize_stats->set_embedding_index_restoration_cause(
+ InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC);
} else if (version_state_change != version_util::StateChange::kCompatible) {
ICING_ASSIGN_OR_RETURN(bool document_store_derived_files_regenerated,
InitializeDocumentStore(
@@ -765,6 +801,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
InitializeStatsProto::VERSION_CHANGED);
initialize_stats->set_qualified_id_join_index_restoration_cause(
InitializeStatsProto::VERSION_CHANGED);
+ initialize_stats->set_embedding_index_restoration_cause(
+ InitializeStatsProto::VERSION_CHANGED);
} else {
ICING_ASSIGN_OR_RETURN(
bool document_store_derived_files_regenerated,
@@ -776,6 +814,32 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
if (!index_init_status.ok() && !absl_ports::IsDataLoss(index_init_status)) {
return index_init_status;
}
+
+ // Set recovery cause to FEATURE_FLAG_CHANGED according to the calculated
+ // required_derived_files_rebuild
+ if (required_derived_files_rebuild
+ .needs_document_store_derived_files_rebuild) {
+ initialize_stats->set_document_store_recovery_cause(
+ InitializeStatsProto::FEATURE_FLAG_CHANGED);
+ }
+ if (required_derived_files_rebuild
+ .needs_schema_store_derived_files_rebuild) {
+ initialize_stats->set_schema_store_recovery_cause(
+ InitializeStatsProto::FEATURE_FLAG_CHANGED);
+ }
+ if (required_derived_files_rebuild.needs_term_index_rebuild) {
+ initialize_stats->set_index_restoration_cause(
+ InitializeStatsProto::FEATURE_FLAG_CHANGED);
+ }
+ if (required_derived_files_rebuild.needs_integer_index_rebuild) {
+ initialize_stats->set_integer_index_restoration_cause(
+ InitializeStatsProto::FEATURE_FLAG_CHANGED);
+ }
+ if (required_derived_files_rebuild.needs_qualified_id_join_index_rebuild) {
+ initialize_stats->set_qualified_id_join_index_restoration_cause(
+ InitializeStatsProto::FEATURE_FLAG_CHANGED);
+ }
+ // TODO(b/326656531): Update version-util to consider embedding index.
}
if (status.ok()) {
@@ -842,10 +906,9 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex(
return absl_ports::InternalError(
absl_ports::StrCat("Could not create directory: ", index_dir));
}
- Index::Options index_options(
- index_dir, options_.index_merge_size(),
- options_.lite_index_sort_at_indexing(), options_.lite_index_sort_size(),
- options_.build_property_existence_metadata_hits());
+ Index::Options index_options(index_dir, options_.index_merge_size(),
+ options_.lite_index_sort_at_indexing(),
+ options_.lite_index_sort_size());
// Term index
InitializeStatsProto::RecoveryCause index_recovery_cause;
@@ -945,6 +1008,30 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex(
}
}
+ // Embedding index
+ const std::string embedding_dir =
+ MakeEmbeddingIndexWorkingPath(options_.base_dir());
+ InitializeStatsProto::RecoveryCause embedding_index_recovery_cause;
+ auto embedding_index_or =
+ EmbeddingIndex::Create(filesystem_.get(), embedding_dir);
+ if (!embedding_index_or.ok()) {
+ ICING_RETURN_IF_ERROR(EmbeddingIndex::Discard(*filesystem_, embedding_dir));
+
+ embedding_index_recovery_cause = InitializeStatsProto::IO_ERROR;
+
+ // Try recreating it from scratch and re-indexing everything.
+ ICING_ASSIGN_OR_RETURN(
+ embedding_index_,
+ EmbeddingIndex::Create(filesystem_.get(), embedding_dir));
+ } else {
+ // Embedding index was created fine.
+ embedding_index_ = std::move(embedding_index_or).ValueOrDie();
+ // If a recover does have to happen, then it must be because the index is
+ // out of sync with the document store.
+ embedding_index_recovery_cause =
+ InitializeStatsProto::INCONSISTENT_WITH_GROUND_TRUTH;
+ }
+
std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer();
IndexRestorationResult restore_result = RestoreIndexIfNeeded();
if (restore_result.index_needed_restoration ||
@@ -964,6 +1051,10 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex(
initialize_stats->set_qualified_id_join_index_restoration_cause(
qualified_id_join_index_recovery_cause);
}
+ if (restore_result.embedding_index_needed_restoration) {
+ initialize_stats->set_embedding_index_restoration_cause(
+ embedding_index_recovery_cause);
+ }
}
return restore_result.status;
}
@@ -1479,8 +1570,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery(
std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
// Gets unordered results from query processor
auto query_processor_or = QueryProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!query_processor_or.ok()) {
TransformStatus(query_processor_or.status(), result_status);
delete_stats->set_parse_query_latency_ms(
@@ -1676,6 +1768,15 @@ OptimizeResultProto IcingSearchEngine::Optimize() {
<< qualified_id_join_index_optimize_status.error_message();
should_rebuild_index = true;
}
+
+ libtextclassifier3::Status embedding_index_optimize_status =
+ embedding_index_->Optimize(optimize_result.document_id_old_to_new,
+ document_store_->last_added_document_id());
+ if (!embedding_index_optimize_status.ok()) {
+ ICING_LOG(WARNING) << "Failed to optimize embedding index. Error: "
+ << embedding_index_optimize_status.error_message();
+ should_rebuild_index = true;
+ }
}
// If we received a DATA_LOSS error from OptimizeDocumentStore, we have a
// valid document store, but it might be the old one or the new one. So throw
@@ -1902,6 +2003,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk(
ICING_RETURN_IF_ERROR(index_->PersistToDisk());
ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk());
ICING_RETURN_IF_ERROR(qualified_id_join_index_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(embedding_index_->PersistToDisk());
return libtextclassifier3::Status::OK;
}
@@ -1979,6 +2081,7 @@ SearchResultProto IcingSearchEngine::InternalSearch(
result_status->set_message("IcingSearchEngine has not been initialized!");
return result_proto;
}
+ index_->PublishQueryStats(query_stats);
libtextclassifier3::Status status =
ValidateResultSpec(document_store_.get(), result_spec);
@@ -2182,8 +2285,9 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
// Gets unordered results from query processor
auto query_processor_or = QueryProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!query_processor_or.ok()) {
search_stats->set_parse_query_latency_ms(
component_timer->GetElapsedMilliseconds());
@@ -2198,7 +2302,8 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
libtextclassifier3::StatusOr<QueryResults> query_results_or;
if (ranking_strategy_or.ok()) {
query_results_or = query_processor->ParseSearch(
- search_spec, ranking_strategy_or.ValueOrDie(), current_time_ms);
+ search_spec, ranking_strategy_or.ValueOrDie(), current_time_ms,
+ search_stats);
} else {
query_results_or = ranking_strategy_or.status();
}
@@ -2226,8 +2331,10 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
// Scores but does not rank the results.
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
scoring_processor_or = ScoringProcessor::Create(
- scoring_spec, document_store_.get(), schema_store_.get(),
- current_time_ms, join_children_fetcher);
+ scoring_spec, /*default_semantic_metric_type=*/
+ search_spec.embedding_query_metric_type(), document_store_.get(),
+ schema_store_.get(), current_time_ms, join_children_fetcher,
+ &query_results.embedding_query_results);
if (!scoring_processor_or.ok()) {
return QueryScoringResults(std::move(scoring_processor_or).status(),
std::move(query_results.query_terms),
@@ -2464,27 +2571,28 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
if (last_stored_document_id == index_->last_added_document_id() &&
last_stored_document_id == integer_index_->last_added_document_id() &&
last_stored_document_id ==
- qualified_id_join_index_->last_added_document_id()) {
+ qualified_id_join_index_->last_added_document_id() &&
+ last_stored_document_id == embedding_index_->last_added_document_id()) {
// No need to recover.
- return {libtextclassifier3::Status::OK, false, false, false};
+ return {libtextclassifier3::Status::OK, false, false, false, false};
}
if (last_stored_document_id == kInvalidDocumentId) {
// Document store is empty but index is not. Clear the index.
- return {ClearAllIndices(), false, false, false};
+ return {ClearAllIndices(), false, false, false, false};
}
// Truncate indices first.
auto truncate_result_or = TruncateIndicesTo(last_stored_document_id);
if (!truncate_result_or.ok()) {
- return {std::move(truncate_result_or).status(), false, false, false};
+ return {std::move(truncate_result_or).status(), false, false, false, false};
}
TruncateIndexResult truncate_result =
std::move(truncate_result_or).ValueOrDie();
if (truncate_result.first_document_to_reindex > last_stored_document_id) {
// Nothing to restore. Just return.
- return {libtextclassifier3::Status::OK, false, false, false};
+ return {libtextclassifier3::Status::OK, false, false, false, false};
}
auto data_indexing_handlers_or = CreateDataIndexingHandlers();
@@ -2492,7 +2600,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {data_indexing_handlers_or.status(),
truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
// By using recovery_mode for IndexProcessor, we're able to replay documents
// from smaller document id and it will skip documents that are already been
@@ -2519,7 +2628,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
// Returns other errors
return {document_or.status(), truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
}
DocumentProto document(std::move(document_or).ValueOrDie());
@@ -2532,7 +2642,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {tokenized_document_or.status(),
truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
TokenizedDocument tokenized_document(
std::move(tokenized_document_or).ValueOrDie());
@@ -2544,7 +2655,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
// Real error. Stop recovering and pass it up.
return {status, truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
// FIXME: why can we skip data loss error here?
// Just a data loss. Keep trying to add the remaining docs, but report the
@@ -2555,7 +2667,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {overall_status, truncate_result.index_needed_restoration,
truncate_result.integer_index_needed_restoration,
- truncate_result.qualified_id_join_index_needed_restoration};
+ truncate_result.qualified_id_join_index_needed_restoration,
+ truncate_result.embedding_index_needed_restoration};
}
libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() {
@@ -2608,6 +2721,11 @@ IcingSearchEngine::CreateDataIndexingHandlers() {
clock_.get(), document_store_.get(), qualified_id_join_index_.get()));
handlers.push_back(std::move(qualified_id_join_indexing_handler));
+ // Embedding index handler
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<EmbeddingIndexingHandler> embedding_indexing_handler,
+ EmbeddingIndexingHandler::Create(clock_.get(), embedding_index_.get()));
+ handlers.push_back(std::move(embedding_indexing_handler));
return handlers;
}
@@ -2696,53 +2814,104 @@ IcingSearchEngine::TruncateIndicesTo(DocumentId last_stored_document_id) {
first_document_to_reindex = kMinDocumentId;
}
+ // Attempt to truncate embedding index
+ bool embedding_index_needed_restoration = false;
+ DocumentId embedding_index_last_added_document_id =
+ embedding_index_->last_added_document_id();
+ if (embedding_index_last_added_document_id == kInvalidDocumentId ||
+ last_stored_document_id > embedding_index_last_added_document_id) {
+ // If last_stored_document_id is greater than
+ // embedding_index_last_added_document_id, then we only have to replay docs
+ // starting from (embedding_index_last_added_document_id + 1). Also use
+ // std::min since we might need to replay even smaller doc ids for other
+ // components.
+ embedding_index_needed_restoration = true;
+ if (embedding_index_last_added_document_id != kInvalidDocumentId) {
+ first_document_to_reindex =
+ std::min(first_document_to_reindex,
+ embedding_index_last_added_document_id + 1);
+ } else {
+ first_document_to_reindex = kMinDocumentId;
+ }
+ } else if (last_stored_document_id < embedding_index_last_added_document_id) {
+ // Clear the entire embedding index if last_stored_document_id is
+ // smaller than embedding_index_last_added_document_id, because
+ // there is no way to remove data with doc_id > last_stored_document_id from
+ // embedding index efficiently and we have to rebuild.
+ ICING_RETURN_IF_ERROR(embedding_index_->Clear());
+
+ // Since the entire embedding index is discarded, we start to
+ // rebuild it by setting first_document_to_reindex to kMinDocumentId.
+ embedding_index_needed_restoration = true;
+ first_document_to_reindex = kMinDocumentId;
+ }
+
return TruncateIndexResult(first_document_to_reindex,
index_needed_restoration,
integer_index_needed_restoration,
- qualified_id_join_index_needed_restoration);
+ qualified_id_join_index_needed_restoration,
+ embedding_index_needed_restoration);
}
-libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles() {
+libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles(
+ const version_util::DerivedFilesRebuildResult& rebuild_result) {
+ if (!rebuild_result.IsRebuildNeeded()) {
+ return libtextclassifier3::Status::OK;
+ }
+
if (schema_store_ != nullptr || document_store_ != nullptr ||
index_ != nullptr || integer_index_ != nullptr ||
- qualified_id_join_index_ != nullptr) {
+ qualified_id_join_index_ != nullptr || embedding_index_ != nullptr) {
return absl_ports::FailedPreconditionError(
"Cannot discard derived files while having valid instances");
}
// Schema store
- ICING_RETURN_IF_ERROR(
- SchemaStore::DiscardDerivedFiles(filesystem_.get(), options_.base_dir()));
+ if (rebuild_result.needs_schema_store_derived_files_rebuild) {
+ ICING_RETURN_IF_ERROR(SchemaStore::DiscardDerivedFiles(
+ filesystem_.get(), options_.base_dir()));
+ }
// Document store
- ICING_RETURN_IF_ERROR(DocumentStore::DiscardDerivedFiles(
- filesystem_.get(), options_.base_dir()));
+ if (rebuild_result.needs_document_store_derived_files_rebuild) {
+ ICING_RETURN_IF_ERROR(DocumentStore::DiscardDerivedFiles(
+ filesystem_.get(), options_.base_dir()));
+ }
// Term index
- if (!filesystem_->DeleteDirectoryRecursively(
- MakeIndexDirectoryPath(options_.base_dir()).c_str())) {
- return absl_ports::InternalError("Failed to discard index");
+ if (rebuild_result.needs_term_index_rebuild) {
+ if (!filesystem_->DeleteDirectoryRecursively(
+ MakeIndexDirectoryPath(options_.base_dir()).c_str())) {
+ return absl_ports::InternalError("Failed to discard index");
+ }
}
// Integer index
- if (!filesystem_->DeleteDirectoryRecursively(
- MakeIntegerIndexWorkingPath(options_.base_dir()).c_str())) {
- return absl_ports::InternalError("Failed to discard integer index");
+ if (rebuild_result.needs_integer_index_rebuild) {
+ if (!filesystem_->DeleteDirectoryRecursively(
+ MakeIntegerIndexWorkingPath(options_.base_dir()).c_str())) {
+ return absl_ports::InternalError("Failed to discard integer index");
+ }
}
// Qualified id join index
- if (!filesystem_->DeleteDirectoryRecursively(
- MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()).c_str())) {
- return absl_ports::InternalError(
- "Failed to discard qualified id join index");
+ if (rebuild_result.needs_qualified_id_join_index_rebuild) {
+ if (!filesystem_->DeleteDirectoryRecursively(
+ MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()).c_str())) {
+ return absl_ports::InternalError(
+ "Failed to discard qualified id join index");
+ }
}
+ // TODO(b/326656531): Update version-util to consider embedding index.
+
return libtextclassifier3::Status::OK;
}
libtextclassifier3::Status IcingSearchEngine::ClearSearchIndices() {
ICING_RETURN_IF_ERROR(index_->Reset());
ICING_RETURN_IF_ERROR(integer_index_->Clear());
+ ICING_RETURN_IF_ERROR(embedding_index_->Clear());
return libtextclassifier3::Status::OK;
}
@@ -2815,8 +2984,9 @@ SuggestionResponse IcingSearchEngine::SearchSuggestions(
// Create the suggestion processor.
auto suggestion_processor_or = SuggestionProcessor::Create(
- index_.get(), integer_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get());
+ index_.get(), integer_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), clock_.get());
if (!suggestion_processor_or.ok()) {
TransformStatus(suggestion_processor_or.status(), response_status);
return response;
diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h
index d316350..57f0f28 100644
--- a/icing/icing-search-engine.h
+++ b/icing/icing-search-engine.h
@@ -17,7 +17,6 @@
#include <cstdint>
#include <memory>
-#include <string>
#include <string_view>
#include <utility>
#include <vector>
@@ -27,7 +26,9 @@
#include "icing/absl_ports/mutex.h"
#include "icing/absl_ports/thread_annotations.h"
#include "icing/file/filesystem.h"
+#include "icing/file/version-util.h"
#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/jni/jni-cache.h"
@@ -51,11 +52,11 @@
#include "icing/result/result-state-manager.h"
#include "icing/schema/schema-store.h"
#include "icing/scoring/scored-document-hit.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer.h"
#include "icing/util/clock.h"
-#include "icing/util/crc32.h"
namespace icing {
namespace lib {
@@ -483,6 +484,9 @@ class IcingSearchEngine {
std::unique_ptr<QualifiedIdJoinIndex> qualified_id_join_index_
ICING_GUARDED_BY(mutex_);
+ // Storage for all hits of embedding contents from the document store.
+ std::unique_ptr<EmbeddingIndex> embedding_index_ ICING_GUARDED_BY(mutex_);
+
// Pointer to JNI class references
const std::unique_ptr<const JniCache> jni_cache_;
@@ -578,8 +582,8 @@ class IcingSearchEngine {
// read-lock, allowing for parallel non-exclusive operations.
// This implementation is used if search_spec.use_read_only_search is true.
SearchResultProto SearchLockedShared(const SearchSpecProto& search_spec,
- const ScoringSpecProto& scoring_spec,
- const ResultSpecProto& result_spec)
+ const ScoringSpecProto& scoring_spec,
+ const ResultSpecProto& result_spec)
ICING_LOCKS_EXCLUDED(mutex_);
// Implementation of IcingSearchEngine::Search that requires the overall
@@ -587,8 +591,8 @@ class IcingSearchEngine {
// this version is used.
// This implementation is used if search_spec.use_read_only_search is false.
SearchResultProto SearchLockedExclusive(const SearchSpecProto& search_spec,
- const ScoringSpecProto& scoring_spec,
- const ResultSpecProto& result_spec)
+ const ScoringSpecProto& scoring_spec,
+ const ResultSpecProto& result_spec)
ICING_LOCKS_EXCLUDED(mutex_);
// Helper method for the actual work to Search. We need this separate
@@ -641,13 +645,14 @@ class IcingSearchEngine {
libtextclassifier3::Status CheckConsistency()
ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- // Discards all derived data.
+ // Discards derived data that requires rebuild based on rebuild_result.
//
// Returns:
// OK on success
// FAILED_PRECONDITION_ERROR if those instances are valid (non nullptr)
// INTERNAL_ERROR on any I/O errors
- libtextclassifier3::Status DiscardDerivedFiles()
+ libtextclassifier3::Status DiscardDerivedFiles(
+ const version_util::DerivedFilesRebuildResult& rebuild_result)
ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Repopulates derived data off our ground truths.
@@ -700,6 +705,7 @@ class IcingSearchEngine {
bool index_needed_restoration;
bool integer_index_needed_restoration;
bool qualified_id_join_index_needed_restoration;
+ bool embedding_index_needed_restoration;
};
IndexRestorationResult RestoreIndexIfNeeded()
ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
@@ -740,17 +746,21 @@ class IcingSearchEngine {
bool index_needed_restoration;
bool integer_index_needed_restoration;
bool qualified_id_join_index_needed_restoration;
+ bool embedding_index_needed_restoration;
explicit TruncateIndexResult(
DocumentId first_document_to_reindex_in,
bool index_needed_restoration_in,
bool integer_index_needed_restoration_in,
- bool qualified_id_join_index_needed_restoration_in)
+ bool qualified_id_join_index_needed_restoration_in,
+ bool embedding_index_needed_restoration_in)
: first_document_to_reindex(first_document_to_reindex_in),
index_needed_restoration(index_needed_restoration_in),
integer_index_needed_restoration(integer_index_needed_restoration_in),
qualified_id_join_index_needed_restoration(
- qualified_id_join_index_needed_restoration_in) {}
+ qualified_id_join_index_needed_restoration_in),
+ embedding_index_needed_restoration(
+ embedding_index_needed_restoration_in) {}
};
libtextclassifier3::StatusOr<TruncateIndexResult> TruncateIndicesTo(
DocumentId last_stored_document_id)
diff --git a/icing/icing-search-engine_initialization_test.cc b/icing/icing-search-engine_initialization_test.cc
index 122e4af..d6316d4 100644
--- a/icing/icing-search-engine_initialization_test.cc
+++ b/icing/icing-search-engine_initialization_test.cc
@@ -19,6 +19,7 @@
#include <string>
#include <string_view>
#include <tuple>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -204,7 +205,7 @@ class IcingSearchEngineInitializationTest : public testing::Test {
// Non-zero value so we don't override it to be the current time
constexpr int64_t kDefaultCreationTimestampMs = 1575492852000;
-std::string GetVersionFilename() { return GetTestBaseDir() + "/version"; }
+std::string GetVersionFileDir() { return GetTestBaseDir(); }
std::string GetDocumentDir() { return GetTestBaseDir() + "/document_dir"; }
@@ -5566,12 +5567,26 @@ INSTANTIATE_TEST_SUITE_P(IcingSearchEngineInitializationSwitchJoinIndexTest,
IcingSearchEngineInitializationSwitchJoinIndexTest,
testing::Values(true, false));
+struct IcingSearchEngineInitializationVersionChangeTestParam {
+ version_util::VersionInfo existing_version_info;
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ existing_enabled_features;
+
+ explicit IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo version_info_in,
+ std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>
+ existing_enabled_features_in)
+ : existing_version_info(std::move(version_info_in)),
+ existing_enabled_features(std::move(existing_enabled_features_in)) {}
+};
+
class IcingSearchEngineInitializationVersionChangeTest
: public IcingSearchEngineInitializationTest,
- public ::testing::WithParamInterface<version_util::VersionInfo> {};
+ public ::testing::WithParamInterface<
+ IcingSearchEngineInitializationVersionChangeTestParam> {};
TEST_P(IcingSearchEngineInitializationVersionChangeTest,
- RecoverFromVersionChange) {
+ RecoverFromVersionChangeOrUnknownFlagChange) {
// TODO(b/280697513): test backup schema migration
// Test the following scenario: version change. All derived data should be
// rebuilt. We test this by manually adding some invalid derived data and
@@ -5725,10 +5740,27 @@ TEST_P(IcingSearchEngineInitializationVersionChangeTest,
std::move(incorrect_message)));
ICING_ASSERT_OK(index_processor.IndexDocument(tokenized_document, doc_id));
- // Change existing data's version file
- const version_util::VersionInfo& existing_version_info = GetParam();
- ICING_ASSERT_OK(version_util::WriteVersion(
- *filesystem(), GetVersionFilename(), existing_version_info));
+ // Rewrite existing data's version files
+ ICING_ASSERT_OK(
+ version_util::DiscardVersionFiles(*filesystem(), GetVersionFileDir()));
+ const version_util::VersionInfo& existing_version_info =
+ GetParam().existing_version_info;
+ ICING_ASSERT_OK(version_util::WriteV1Version(
+ *filesystem(), GetVersionFileDir(), existing_version_info));
+
+ if (existing_version_info.version >= version_util::kFirstV2Version) {
+ IcingSearchEngineVersionProto version_proto;
+ version_proto.set_version(existing_version_info.version);
+ version_proto.set_max_version(existing_version_info.max_version);
+ auto* enabled_features = version_proto.mutable_enabled_features();
+ for (const auto& feature : GetParam().existing_enabled_features) {
+ enabled_features->Add(version_util::GetFeatureInfoProto(feature));
+ }
+ version_util::WriteV2Version(
+ *filesystem(), GetVersionFileDir(),
+ std::make_unique<IcingSearchEngineVersionProto>(
+ std::move(version_proto)));
+ }
}
// Mock filesystem to observe and check the behavior of all indices.
@@ -5738,28 +5770,48 @@ TEST_P(IcingSearchEngineInitializationVersionChangeTest,
std::make_unique<FakeClock>(), GetTestJniCache());
InitializeResultProto initialize_result = icing.Initialize();
EXPECT_THAT(initialize_result.status(), ProtoIsOk());
- // Index Restoration should be triggered here. Incorrect data should be
- // deleted and correct data of message should be indexed.
+
+ // Derived files restoration should be triggered here. Incorrect data should
+ // be deleted and correct data of message should be indexed.
+ // Here we're recovering from a version change or a flag change that requires
+ // rebuilding all derived files.
+ //
+ // TODO(b/314816301): test individual derived files rebuilds due to change
+ // in trunk stable feature flags.
+ // i.e. Test individual rebuilding for each of:
+ // - document store
+ // - schema store
+ // - term index
+ // - numeric index
+ // - qualified id join index
+ InitializeStatsProto::RecoveryCause expected_recovery_cause =
+ GetParam().existing_version_info.version != version_util::kVersion
+ ? InitializeStatsProto::VERSION_CHANGED
+ : InitializeStatsProto::FEATURE_FLAG_CHANGED;
EXPECT_THAT(
initialize_result.initialize_stats().document_store_recovery_cause(),
- Eq(InitializeStatsProto::VERSION_CHANGED));
+ Eq(expected_recovery_cause));
+ EXPECT_THAT(
+ initialize_result.initialize_stats().schema_store_recovery_cause(),
+ Eq(expected_recovery_cause));
EXPECT_THAT(initialize_result.initialize_stats().index_restoration_cause(),
- Eq(InitializeStatsProto::VERSION_CHANGED));
+ Eq(expected_recovery_cause));
EXPECT_THAT(
initialize_result.initialize_stats().integer_index_restoration_cause(),
- Eq(InitializeStatsProto::VERSION_CHANGED));
+ Eq(expected_recovery_cause));
EXPECT_THAT(initialize_result.initialize_stats()
.qualified_id_join_index_restoration_cause(),
- Eq(InitializeStatsProto::VERSION_CHANGED));
+ Eq(expected_recovery_cause));
// Manually check version file
ICING_ASSERT_OK_AND_ASSIGN(
- version_util::VersionInfo version_info_after_init,
- version_util::ReadVersion(*filesystem(), GetVersionFilename(),
+ IcingSearchEngineVersionProto version_proto_after_init,
+ version_util::ReadVersion(*filesystem(), GetVersionFileDir(),
GetIndexDir()));
- EXPECT_THAT(version_info_after_init.version, Eq(version_util::kVersion));
- EXPECT_THAT(version_info_after_init.max_version,
- Eq(std::max(version_util::kVersion, GetParam().max_version)));
+ EXPECT_THAT(version_proto_after_init.version(), Eq(version_util::kVersion));
+ EXPECT_THAT(version_proto_after_init.max_version(),
+ Eq(std::max(version_util::kVersion,
+ GetParam().existing_version_info.max_version)));
SearchResultProto expected_search_result_proto;
expected_search_result_proto.mutable_status()->set_code(StatusProto::OK);
@@ -5836,9 +5888,11 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(
// Manually change existing data set's version to kVersion + 1. When
// initializing, it will detect "rollback".
- version_util::VersionInfo(
- /*version_in=*/version_util::kVersion + 1,
- /*max_version_in=*/version_util::kVersion + 1),
+ IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo(
+ /*version_in=*/version_util::kVersion + 1,
+ /*max_version_in=*/version_util::kVersion + 1),
+ /*existing_enabled_features_in=*/{}),
// Currently we don't have any "upgrade" that requires rebuild derived
// files, so skip this case until we have a case for it.
@@ -5846,27 +5900,45 @@ INSTANTIATE_TEST_SUITE_P(
// Manually change existing data set's version to kVersion - 1 and
// max_version to kVersion. When initializing, it will detect "roll
// forward".
- version_util::VersionInfo(
- /*version_in=*/version_util::kVersion - 1,
- /*max_version_in=*/version_util::kVersion),
+ IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo(
+ /*version_in=*/version_util::kVersion - 1,
+ /*max_version_in=*/version_util::kVersion),
+ /*existing_enabled_features_in=*/{}),
// Manually change existing data set's version to 0 and max_version to
// 0. When initializing, it will detect "version 0 upgrade".
//
// Note: in reality, version 0 won't be written into version file, but
// it is ok here since it is hack to simulate version 0 situation.
- version_util::VersionInfo(
- /*version_in=*/0,
- /*max_version_in=*/0),
+ IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo(
+ /*version_in=*/0,
+ /*max_version_in=*/0),
+ /*existing_enabled_features_in=*/{}),
// Manually change existing data set's version to 0 and max_version to
// kVersion. When initializing, it will detect "version 0 roll forward".
//
// Note: in reality, version 0 won't be written into version file, but
// it is ok here since it is hack to simulate version 0 situation.
- version_util::VersionInfo(
- /*version_in=*/0,
- /*max_version_in=*/version_util::kVersion)));
+ IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo(
+ /*version_in=*/0,
+ /*max_version_in=*/version_util::kVersion),
+ /*existing_enabled_features_in=*/{}),
+
+ // Manually write an unknown feature in the version proto while keeping
+ // version the same as kVersion.
+ //
+ // Result: this will rebuild all derived files with restoration cause
+ // FEATURE_FLAG_CHANGED
+ IcingSearchEngineInitializationVersionChangeTestParam(
+ version_util::VersionInfo(
+ /*version_in=*/version_util::kVersion,
+ /*max_version_in=*/version_util::kVersion),
+ /*existing_enabled_features_in=*/{
+ IcingSearchEngineFeatureInfoProto::UNKNOWN})));
class IcingSearchEngineInitializationChangePropertyExistenceHitsFlagTest
: public IcingSearchEngineInitializationTest,
@@ -5962,7 +6034,7 @@ TEST_P(IcingSearchEngineInitializationChangePropertyExistenceHitsFlagTest,
ASSERT_THAT(initialize_result.status(), ProtoIsOk());
// Ensure that the term index is rebuilt if the flag is changed.
EXPECT_THAT(initialize_result.initialize_stats().index_restoration_cause(),
- Eq(flag_changed ? InitializeStatsProto::IO_ERROR
+ Eq(flag_changed ? InitializeStatsProto::FEATURE_FLAG_CHANGED
: InitializeStatsProto::NONE));
EXPECT_THAT(
initialize_result.initialize_stats().integer_index_restoration_cause(),
diff --git a/icing/icing-search-engine_schema_test.cc b/icing/icing-search-engine_schema_test.cc
index 49c024e..34081e9 100644
--- a/icing/icing-search-engine_schema_test.cc
+++ b/icing/icing-search-engine_schema_test.cc
@@ -36,7 +36,6 @@
#include "icing/proto/persist.pb.h"
#include "icing/proto/reset.pb.h"
#include "icing/proto/schema.pb.h"
-#include "icing/proto/scoring.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/status.pb.h"
#include "icing/proto/storage.pb.h"
@@ -60,6 +59,7 @@ namespace {
using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::Eq;
using ::testing::HasSubstr;
+using ::testing::Not;
using ::testing::Return;
// For mocking purpose, we allow tests to provide a custom Filesystem.
@@ -3154,6 +3154,86 @@ TEST_F(IcingSearchEngineSchemaTest, IcingShouldReturnErrorForExtraSections) {
HasSubstr("Too many properties to be indexed"));
}
+TEST_F(IcingSearchEngineSchemaTest, UpdatedTypeDescriptionIsSaved) {
+ // Create a schema with more sections than allowed.
+ PropertyConfigProto old_property =
+ PropertyConfigBuilder()
+ .SetName("prop0")
+ .SetDescription("old property description")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+ SchemaTypeConfigProto old_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(old_property)
+ .Build();
+ SchemaProto old_schema =
+ SchemaBuilder().AddType(old_schema_type_config).Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk());
+
+ // Update the type description
+ SchemaTypeConfigProto new_schema_type_config =
+ SchemaTypeConfigBuilder(old_schema_type_config)
+ .SetDescription("new description")
+ .Build();
+ SchemaProto new_schema =
+ SchemaBuilder().AddType(new_schema_type_config).Build();
+ ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk());
+
+ GetSchemaResultProto get_result = icing.GetSchema();
+ ASSERT_THAT(get_result.status(), ProtoIsOk());
+ ASSERT_THAT(get_result.schema(), EqualsProto(new_schema));
+ ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema)));
+}
+
+TEST_F(IcingSearchEngineSchemaTest, UpdatedPropertyDescriptionIsSaved) {
+ // Create a schema with more sections than allowed.
+ PropertyConfigProto old_property =
+ PropertyConfigBuilder()
+ .SetName("prop0")
+ .SetDescription("old property description")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+ SchemaTypeConfigProto old_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(old_property)
+ .Build();
+ SchemaProto old_schema =
+ SchemaBuilder().AddType(old_schema_type_config).Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk());
+
+ // Update the property description
+ PropertyConfigProto new_property =
+ PropertyConfigBuilder(old_property)
+ .SetDescription("new property description")
+ .Build();
+ SchemaTypeConfigProto new_schema_type_config =
+ SchemaTypeConfigBuilder()
+ .SetType("type")
+ .SetDescription("old description")
+ .AddProperty(new_property)
+ .Build();
+ SchemaProto new_schema =
+ SchemaBuilder().AddType(new_schema_type_config).Build();
+ ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk());
+
+ GetSchemaResultProto get_result = icing.GetSchema();
+ ASSERT_THAT(get_result.status(), ProtoIsOk());
+ ASSERT_THAT(get_result.schema(), EqualsProto(new_schema));
+ ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema)));
+}
+
} // namespace
} // namespace lib
} // namespace icing
diff --git a/icing/icing-search-engine_search_test.cc b/icing/icing-search-engine_search_test.cc
index 21512c6..a58dbc8 100644
--- a/icing/icing-search-engine_search_test.cc
+++ b/icing/icing-search-engine_search_test.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include <cstdint>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
+#include <vector>
-#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -27,7 +29,7 @@
#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/jni/jni-cache.h"
#include "icing/join/join-processor.h"
-#include "icing/portable/endian.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/equals-proto.h"
#include "icing/portable/platform.h"
#include "icing/proto/debug.pb.h"
@@ -49,11 +51,13 @@
#include "icing/result/result-state-manager.h"
#include "icing/schema-builder.h"
#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/clock.h"
#include "icing/util/snippet-helpers.h"
namespace icing {
@@ -63,6 +67,7 @@ namespace {
using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::DoubleEq;
+using ::testing::DoubleNear;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Gt;
@@ -120,6 +125,8 @@ class IcingSearchEngineSearchTest
// Non-zero value so we don't override it to be the current time
constexpr int64_t kDefaultCreationTimestampMs = 1575492852000;
+constexpr double kEps = 0.000001;
+
IcingSearchEngineOptions GetDefaultIcingOptions() {
IcingSearchEngineOptions icing_options;
icing_options.set_base_dir(GetTestBaseDir());
@@ -3716,6 +3723,109 @@ TEST_P(IcingSearchEngineSearchTest, SearchWithPropertyFilters) {
EXPECT_THAT(results.results(0).document(), EqualsProto(document_one));
}
+TEST_P(IcingSearchEngineSearchTest, SearchWithPropertyFiltersPolymorphism) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Person")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("name")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("emailAddress")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Artist")
+ .AddParentType("Person")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("name")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("emailAddress")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("company")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+
+ // Add a person document and an artist document
+ DocumentProto document_person =
+ DocumentBuilder()
+ .SetKey("namespace", "uri1")
+ .SetCreationTimestampMs(1000)
+ .SetSchema("Person")
+ .AddStringProperty("name", "Meg Ryan")
+ .AddStringProperty("emailAddress", "shopgirl@aol.com")
+ .Build();
+ DocumentProto document_artist =
+ DocumentBuilder()
+ .SetKey("namespace", "uri2")
+ .SetCreationTimestampMs(1000)
+ .SetSchema("Artist")
+ .AddStringProperty("name", "Meg Artist")
+ .AddStringProperty("emailAddress", "artist@aol.com")
+ .AddStringProperty("company", "company")
+ .Build();
+ ASSERT_THAT(icing.Put(document_person).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document_artist).status(), ProtoIsOk());
+
+ // Set a query with property filters of "name" in Person and "emailAddress"
+ // in Artist. By polymorphism, "name" should also apply to Artist.
+ auto search_spec = std::make_unique<SearchSpecProto>();
+ search_spec->set_term_match_type(TermMatchType::PREFIX);
+ search_spec->set_search_type(GetParam());
+ TypePropertyMask* person_type_property_mask =
+ search_spec->add_type_property_filters();
+ person_type_property_mask->set_schema_type("Person");
+ person_type_property_mask->add_paths("name");
+ TypePropertyMask* artist_type_property_mask =
+ search_spec->add_type_property_filters();
+ artist_type_property_mask->set_schema_type("Artist");
+ artist_type_property_mask->add_paths("emailAddress");
+
+ auto result_spec = std::make_unique<ResultSpecProto>();
+ auto scoring_spec = std::make_unique<ScoringSpecProto>();
+ *scoring_spec = GetDefaultScoringSpec();
+
+ // Verify that the property filter for "name" in Person is also applied to
+ // Artist.
+ search_spec->set_query("Meg");
+ SearchResultProto results =
+ icing.Search(*search_spec, *scoring_spec, *result_spec);
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document_person));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_artist));
+
+ // Verify that the property filter for "emailAddress" in Artist is only
+ // applied to Artist.
+ search_spec->set_query("aol");
+ results = icing.Search(*search_spec, *scoring_spec, *result_spec);
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_artist));
+
+ // Verify that the "company" property is filtered out, since it is not
+ // specified in the property filter.
+ search_spec->set_query("company");
+ results = icing.Search(*search_spec, *scoring_spec, *result_spec);
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), IsEmpty());
+}
+
TEST_P(IcingSearchEngineSearchTest, EmptySearchWithPropertyFilter) {
IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
@@ -4496,6 +4606,11 @@ TEST_P(IcingSearchEngineSearchTest, QueryStatsProtoTest) {
exp_stats.set_document_retrieval_latency_ms(5);
exp_stats.set_lock_acquisition_latency_ms(5);
exp_stats.set_num_joined_results_returned_current_page(0);
+ // document4, document5's hits will remain in the lite index (# of hits: 4).
+ exp_stats.set_lite_index_hit_buffer_byte_size(4 *
+ sizeof(TermIdHitPair::Value));
+ exp_stats.set_lite_index_hit_buffer_unsorted_byte_size(
+ 4 * sizeof(TermIdHitPair::Value));
QueryStatsProto::SearchStats* exp_parent_search_stats =
exp_stats.mutable_parent_search_stats();
@@ -4511,6 +4626,14 @@ TEST_P(IcingSearchEngineSearchTest, QueryStatsProtoTest) {
exp_parent_search_stats->set_num_fetched_hits_lite_index(2);
exp_parent_search_stats->set_num_fetched_hits_main_index(3);
exp_parent_search_stats->set_num_fetched_hits_integer_index(0);
+ if (GetParam() ==
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms(
+ 5);
+ exp_parent_search_stats
+ ->set_query_processor_parser_consume_query_latency_ms(5);
+ exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5);
+ }
EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats));
@@ -4769,6 +4892,11 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) {
exp_stats.set_num_joined_results_returned_current_page(3);
exp_stats.set_join_latency_ms(5);
exp_stats.set_is_join_query(true);
+ // person3, email4's hits will remain in the lite index (# of hits: 5).
+ exp_stats.set_lite_index_hit_buffer_byte_size(5 *
+ sizeof(TermIdHitPair::Value));
+ exp_stats.set_lite_index_hit_buffer_unsorted_byte_size(
+ 5 * sizeof(TermIdHitPair::Value));
QueryStatsProto::SearchStats* exp_parent_search_stats =
exp_stats.mutable_parent_search_stats();
@@ -4784,6 +4912,14 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) {
exp_parent_search_stats->set_num_fetched_hits_lite_index(1);
exp_parent_search_stats->set_num_fetched_hits_main_index(2);
exp_parent_search_stats->set_num_fetched_hits_integer_index(0);
+ if (GetParam() ==
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms(
+ 5);
+ exp_parent_search_stats
+ ->set_query_processor_parser_consume_query_latency_ms(5);
+ exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5);
+ }
QueryStatsProto::SearchStats* exp_child_search_stats =
exp_stats.mutable_child_search_stats();
@@ -4799,6 +4935,14 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) {
exp_child_search_stats->set_num_fetched_hits_lite_index(1);
exp_child_search_stats->set_num_fetched_hits_main_index(3);
exp_child_search_stats->set_num_fetched_hits_integer_index(0);
+ if (GetParam() ==
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ exp_child_search_stats->set_query_processor_lexer_extract_token_latency_ms(
+ 5);
+ exp_child_search_stats->set_query_processor_parser_consume_query_latency_ms(
+ 5);
+ exp_child_search_stats->set_query_processor_query_visitor_latency_ms(5);
+ }
EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats));
@@ -6638,6 +6782,8 @@ TEST_F(IcingSearchEngineSearchTest, NumericFilterQueryStatsProtoTest) {
exp_stats.set_document_retrieval_latency_ms(5);
exp_stats.set_lock_acquisition_latency_ms(5);
exp_stats.set_num_joined_results_returned_current_page(0);
+ exp_stats.set_lite_index_hit_buffer_byte_size(0);
+ exp_stats.set_lite_index_hit_buffer_unsorted_byte_size(0);
QueryStatsProto::SearchStats* exp_parent_search_stats =
exp_stats.mutable_parent_search_stats();
@@ -6656,6 +6802,11 @@ TEST_F(IcingSearchEngineSearchTest, NumericFilterQueryStatsProtoTest) {
// Since we will inspect 1 bucket from "price" in integer index and it
// contains 3 hits, we will fetch 3 hits (but filter out one of them).
exp_parent_search_stats->set_num_fetched_hits_integer_index(3);
+ exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms(
+ 5);
+ exp_parent_search_stats->set_query_processor_parser_consume_query_latency_ms(
+ 5);
+ exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5);
EXPECT_THAT(results.query_stats(), EqualsProto(exp_stats));
}
@@ -7162,6 +7313,208 @@ TEST_P(IcingSearchEngineSearchTest, HasPropertyQueryNestedDocument) {
EXPECT_THAT(results.results(), IsEmpty());
}
+TEST_P(IcingSearchEngineSearchTest, EmbeddingSearch) {
+ if (GetParam() !=
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ GTEST_SKIP() << "Embedding search is only supported in advanced query.";
+ }
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Email")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("body")
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding1")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding2")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED)))
+ .Build();
+ DocumentProto document0 =
+ DocumentBuilder()
+ .SetKey("icing", "uri0")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddStringProperty("body", "foo")
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {0.1, 0.2, 0.3, 0.4, 0.5}))
+ .AddVectorProperty(
+ "embedding2",
+ CreateVector("my_model_v1", {-0.1, -0.2, -0.3, 0.4, 0.5}),
+ CreateVector("my_model_v2", {0.6, 0.7, 0.8}))
+ .Build();
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("icing", "uri1")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {-0.1, 0.2, -0.3, -0.4, 0.5}))
+ .AddVectorProperty("embedding2",
+ CreateVector("my_model_v2", {0.6, 0.7, -0.8}))
+ .Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document0).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_term_match_type(TermMatchType::EXACT_ONLY);
+ search_spec.set_embedding_query_metric_type(
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT);
+ search_spec.add_enabled_features(
+ std::string(kListFilterQueryLanguageFeature));
+ search_spec.add_enabled_features(std::string(kEmbeddingSearchFeature));
+ search_spec.set_search_type(GetParam());
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v1", {1, -1, -1, 1, -1});
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding2)
+ // - document 1: -2.1 (embedding2)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v2", {-1, -1, 1});
+ ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
+ scoring_spec.set_rank_by(
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
+
+ // Match documents that have embeddings with a similarity closer to 0 that is
+ // greater than -1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({}) = -0.2
+ // - document 1: sum({-0.9}) + sum({}) = -0.9
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ SearchResultProto results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query the same as above but with a section restriction, which
+ // still matches document 0 and document 1 but the semantic score 0.3 should
+ // be removed from document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ // - document 1: sum({-0.9}) = -0.9
+ search_spec.set_query(
+ "embedding1:semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query that only matches document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -1.5)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+
+ // Create a query that only matches document 1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-2.1, kEps));
+
+ // Create a complex query that matches all hits from all documents.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2), -0.5 (embedding2)
+ // - document 1: -0.9 (embedding1), -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({-0.5}) = -0.7
+ // - document 1: sum({-0.9}) + sum({-2.1}) = -3
+ search_spec.set_query(
+ "semanticSearch(getSearchSpecEmbedding(0)) OR "
+ "semanticSearch(getSearchSpecEmbedding(1))");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3 - 0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9 - 2.1, kEps));
+
+ // Create a hybrid query that matches document 0 because of term-based search
+ // and document 1 because of embedding-based search.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({}) = 0
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query(
+ "foo OR semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ // Document 0 has no matched embedding hit, so its score is 0.
+ EXPECT_THAT(results.results(0).score(), DoubleNear(0, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-2.1, kEps));
+}
+
INSTANTIATE_TEST_SUITE_P(
IcingSearchEngineSearchTest, IcingSearchEngineSearchTest,
testing::Values(
diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.cc b/icing/index/embed/doc-hit-info-iterator-embedding.cc
new file mode 100644
index 0000000..8f2bc7d
--- /dev/null
+++ b/icing/index/embed/doc-hit-info-iterator-embedding.cc
@@ -0,0 +1,164 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/doc-hit-info-iterator-embedding.h"
+
+#include <memory>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/embed/embedding-scorer.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/iterator/section-restrict-data.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>>
+DocHitInfoIteratorEmbedding::Create(
+ const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ double score_low, double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index) {
+ ICING_RETURN_ERROR_IF_NULL(query);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
+ ICING_RETURN_ERROR_IF_NULL(score_map);
+
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or = embedding_index->GetAccessorForVector(*query);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ // A not-found error should be fine, since that means there is no matching
+ // embedding hits in the index.
+ pl_accessor = nullptr;
+ } else {
+ // Otherwise, return the error as is.
+ return pl_accessor_or.status();
+ }
+
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer,
+ EmbeddingScorer::Create(metric_type));
+
+ return std::unique_ptr<DocHitInfoIteratorEmbedding>(
+ new DocHitInfoIteratorEmbedding(query, std::move(section_restrict_data),
+ metric_type, std::move(embedding_scorer),
+ score_low, score_high, score_map,
+ embedding_index, std::move(pl_accessor)));
+}
+
+libtextclassifier3::StatusOr<const EmbeddingHit*>
+DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() {
+ if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) {
+ ICING_ASSIGN_OR_RETURN(cached_embedding_hits_,
+ posting_list_accessor_->GetNextHitsBatch());
+ cached_embedding_hits_idx_ = 0;
+ if (cached_embedding_hits_.empty()) {
+ no_more_hit_ = true;
+ return nullptr;
+ }
+ }
+ const EmbeddingHit& embedding_hit =
+ cached_embedding_hits_[cached_embedding_hits_idx_];
+ if (doc_hit_info_.document_id() == kInvalidDocumentId) {
+ doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id());
+ if (section_restrict_data_ != nullptr) {
+ current_allowed_sections_mask_ =
+ section_restrict_data_->ComputeAllowedSectionsMask(
+ doc_hit_info_.document_id());
+ }
+ } else if (doc_hit_info_.document_id() !=
+ embedding_hit.basic_hit().document_id()) {
+ return nullptr;
+ }
+ ++cached_embedding_hits_idx_;
+ return &embedding_hit;
+}
+
+libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() {
+ if (no_more_hit_ || posting_list_accessor_ == nullptr) {
+ return absl_ports::ResourceExhaustedError(
+ "No more DocHitInfos in iterator");
+ }
+
+ doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone);
+ std::vector<double>* matched_scores = nullptr;
+ current_allowed_sections_mask_ = kSectionIdMaskAll;
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit,
+ AdvanceToNextEmbeddingHit());
+ if (embedding_hit == nullptr) {
+ // No more hits for the current document.
+ break;
+ }
+
+ // Filter out the embedding hit according to the section restriction.
+ if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) &
+ current_allowed_sections_mask_) == 0) {
+ continue;
+ }
+
+ // Calculate the semantic score.
+ int dimension = query_.values_size();
+ ICING_ASSIGN_OR_RETURN(
+ const float* vector,
+ embedding_index_.GetEmbeddingVector(*embedding_hit, dimension));
+ double semantic_score =
+ embedding_scorer_->Score(dimension,
+ /*v1=*/query_.values().data(),
+ /*v2=*/vector);
+
+ // If the semantic score is within the desired score range, update
+ // doc_hit_info_ and score_map_.
+ if (score_low_ <= semantic_score && semantic_score <= score_high_) {
+ doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id());
+ if (matched_scores == nullptr) {
+ matched_scores = &(score_map_[doc_hit_info_.document_id()]);
+ }
+ matched_scores->push_back(semantic_score);
+ }
+ }
+
+ if (doc_hit_info_.document_id() == kInvalidDocumentId) {
+ return absl_ports::ResourceExhaustedError(
+ "No more DocHitInfos in iterator");
+ }
+
+ ++num_advance_calls_;
+
+ // Skip the current document if it has no vector matched.
+ if (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone) {
+ return Advance();
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.h b/icing/index/embed/doc-hit-info-iterator-embedding.h
new file mode 100644
index 0000000..21180db
--- /dev/null
+++ b/icing/index/embed/doc-hit-info-iterator-embedding.h
@@ -0,0 +1,161 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/embed/embedding-scorer.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+
+namespace icing {
+namespace lib {
+
+class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator {
+ public:
+ // Create a DocHitInfoIterator for iterating through all docs which have an
+ // embedding matched with the provided query with a score in the range of
+ // [score_low, score_high], using the provided metric_type.
+ //
+ // The iterator will store the matched embedding scores in score_map to
+ // prepare for scoring.
+ //
+ // The iterator will handle the section restriction logic internally by the
+ // provided section_restrict_data.
+ //
+ // Returns:
+ // - a DocHitInfoIteratorEmbedding instance on success.
+ // - Any error from posting lists.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<DocHitInfoIteratorEmbedding>>
+ Create(const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ double score_low, double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index);
+
+ libtextclassifier3::Status Advance() override;
+
+ // The iterator will internally handle the section restriction logic by itself
+ // to have better control, so that it is able to filter out embedding hits
+ // from unwanted sections to avoid retrieving unnecessary vectors and
+ // calculate scores for them.
+ bool full_section_restriction_applied() const override { return true; }
+
+ libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
+ return absl_ports::InvalidArgumentError(
+ "Query suggestions for the semanticSearch function are not supported");
+ }
+
+ CallStats GetCallStats() const override {
+ return CallStats(
+ /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
+ /*num_leaf_advance_calls_main_index_in=*/0,
+ /*num_leaf_advance_calls_integer_index_in=*/0,
+ /*num_leaf_advance_calls_no_index_in=*/0,
+ /*num_blocks_inspected_in=*/0);
+ }
+
+ std::string ToString() const override { return "embedding_iterator"; }
+
+ // PopulateMatchedTermsStats is not applicable to embedding search.
+ void PopulateMatchedTermsStats(
+ std::vector<TermMatchInfo>* matched_terms_stats,
+ SectionIdMask filtering_section_mask) const override {}
+
+ private:
+ explicit DocHitInfoIteratorEmbedding(
+ const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
+ double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index,
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)
+ : query_(*query),
+ section_restrict_data_(std::move(section_restrict_data)),
+ metric_type_(metric_type),
+ embedding_scorer_(std::move(embedding_scorer)),
+ score_low_(score_low),
+ score_high_(score_high),
+ score_map_(*score_map),
+ embedding_index_(*embedding_index),
+ posting_list_accessor_(std::move(posting_list_accessor)),
+ cached_embedding_hits_idx_(0),
+ current_allowed_sections_mask_(kSectionIdMaskAll),
+ no_more_hit_(false),
+ num_advance_calls_(0) {}
+
+ // Advance to the next embedding hit of the current document. If the current
+ // document id is kInvalidDocumentId, the method will advance to the first
+ // embedding hit of the next document and update doc_hit_info_.
+ //
+ // This method also properly updates cached_embedding_hits_,
+ // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
+ // no_more_hit_ to reflect the current state.
+ //
+ // Returns:
+ // - a const pointer to the next embedding hit on success.
+ // - nullptr, if there is no more hit for the current document, or no more
+ // hit in general if the current document id is kInvalidDocumentId.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
+
+ // Query information
+ const PropertyProto::VectorProto& query_; // Does not own
+ std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable.
+
+ // Scoring arguments
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ std::unique_ptr<EmbeddingScorer> embedding_scorer_;
+ double score_low_;
+ double score_high_;
+
+ // Score map
+ EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own
+
+ // Access to embeddings index data
+ const EmbeddingIndex& embedding_index_;
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
+
+ // Cached data from the embeddings index
+ std::vector<EmbeddingHit> cached_embedding_hits_;
+ int cached_embedding_hits_idx_;
+ SectionIdMask current_allowed_sections_mask_;
+ bool no_more_hit_;
+
+ int num_advance_calls_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
diff --git a/icing/index/embed/embedding-hit.h b/icing/index/embed/embedding-hit.h
new file mode 100644
index 0000000..165a123
--- /dev/null
+++ b/icing/index/embed/embedding-hit.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_HIT_H_
+#define ICING_INDEX_EMBED_EMBEDDING_HIT_H_
+
+#include <cstdint>
+
+#include "icing/index/hit/hit.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingHit {
+ public:
+ // Order: basic_hit, location
+ // Value bits layout: 32 basic_hit + 32 location
+ using Value = uint64_t;
+
+ // WARNING: Changing this value will invalidate any pre-existing posting lists
+ // on user devices.
+ //
+ // kInvalidValue contains:
+ // - BasicHit of value 0, which is invalid.
+ // - location of 0 (valid), which is ok because BasicHit is already invalid.
+ static_assert(BasicHit::kInvalidValue == 0);
+ static constexpr Value kInvalidValue = 0;
+
+ explicit EmbeddingHit(BasicHit basic_hit, uint32_t location) {
+ value_ = (static_cast<uint64_t>(basic_hit.value()) << 32) | location;
+ }
+
+ explicit EmbeddingHit(Value value) : value_(value) {}
+
+ // BasicHit contains the document id and the section id.
+ BasicHit basic_hit() const { return BasicHit(value_ >> 32); }
+ // The location of the referred embedding vector in the vector storage.
+ uint32_t location() const { return value_ & 0xFFFFFFFF; };
+
+ bool is_valid() const { return basic_hit().is_valid(); }
+ Value value() const { return value_; }
+
+ bool operator<(const EmbeddingHit& h2) const { return value_ < h2.value_; }
+ bool operator==(const EmbeddingHit& h2) const { return value_ == h2.value_; }
+
+ private:
+ Value value_;
+};
+
+static_assert(sizeof(BasicHit) == 4);
+static_assert(sizeof(EmbeddingHit) == 8);
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_HIT_H_
diff --git a/icing/index/embed/embedding-hit_test.cc b/icing/index/embed/embedding-hit_test.cc
new file mode 100644
index 0000000..31a0b6c
--- /dev/null
+++ b/icing/index/embed/embedding-hit_test.cc
@@ -0,0 +1,80 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-hit.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/index/hit/hit.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::IsFalse;
+
+static constexpr DocumentId kSomeDocumentId = 24;
+static constexpr SectionId kSomeSectionid = 5;
+static constexpr uint32_t kSomeLocation = 123;
+
+TEST(EmbeddingHitTest, Accessors) {
+ BasicHit basic_hit(kSomeSectionid, kSomeDocumentId);
+ EmbeddingHit embedding_hit(basic_hit, kSomeLocation);
+ EXPECT_THAT(embedding_hit.basic_hit(), Eq(basic_hit));
+ EXPECT_THAT(embedding_hit.location(), Eq(kSomeLocation));
+}
+
+TEST(EmbeddingHitTest, Invalid) {
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EXPECT_THAT(invalid_hit.is_valid(), IsFalse());
+
+ // Also make sure the invalid EmbeddingHit contains an invalid document id.
+ EXPECT_THAT(invalid_hit.basic_hit().document_id(), Eq(kInvalidDocumentId));
+ EXPECT_THAT(invalid_hit.basic_hit().section_id(), Eq(kMinSectionId));
+ EXPECT_THAT(invalid_hit.location(), Eq(0));
+}
+
+TEST(EmbeddingHitTest, Comparison) {
+ // Create basic hits with basic_hit1 < basic_hit2 < basic_hit3.
+ BasicHit basic_hit1(/*section_id=*/1, /*document_id=*/2409);
+ BasicHit basic_hit2(/*section_id=*/1, /*document_id=*/243);
+ BasicHit basic_hit3(/*section_id=*/15, /*document_id=*/243);
+
+ // Embedding hits are sorted by BasicHit first, and then by location.
+ // So embedding_hit3 < embedding_hit4 < embedding_hit2 < embedding_hit1.
+ EmbeddingHit embedding_hit1(basic_hit3, /*location=*/10);
+ EmbeddingHit embedding_hit2(basic_hit3, /*location=*/0);
+ EmbeddingHit embedding_hit3(basic_hit1, /*location=*/100);
+ EmbeddingHit embedding_hit4(basic_hit2, /*location=*/0);
+
+ std::vector<EmbeddingHit> embedding_hits{embedding_hit1, embedding_hit2,
+ embedding_hit3, embedding_hit4};
+ std::sort(embedding_hits.begin(), embedding_hits.end());
+ EXPECT_THAT(embedding_hits, ElementsAre(embedding_hit3, embedding_hit4,
+ embedding_hit2, embedding_hit1));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-index.cc b/icing/index/embed/embedding-index.cc
new file mode 100644
index 0000000..2e70f0e
--- /dev/null
+++ b/icing/index/embed/embedding-index.cc
@@ -0,0 +1,440 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-index.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/file/destructible-directory.h"
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/store/document-id.h"
+#include "icing/store/dynamic-trie-key-mapper.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+#include "icing/util/encode-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+constexpr uint32_t kEmbeddingHitListMapperMaxSize =
+ 128 * 1024 * 1024; // 128 MiB;
+
+// The maximum length returned by encode_util::EncodeIntToCString is 5 for
+// uint32_t.
+constexpr uint32_t kEncodedDimensionLength = 5;
+
+std::string GetMetadataFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/metadata");
+}
+
+std::string GetFlashIndexStorageFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/flash_index_storage");
+}
+
+std::string GetEmbeddingHitListMapperPath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper");
+}
+
+std::string GetEmbeddingVectorsFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_vectors");
+}
+
+// An injective function that maps the ordered pair (dimension, model_signature)
+// to a string, which is used to form a key for embedding_posting_list_mapper_.
+std::string GetPostingListKey(uint32_t dimension,
+ std::string_view model_signature) {
+ std::string encoded_dimension_str =
+ encode_util::EncodeIntToCString(dimension);
+ // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes.
+ while (encoded_dimension_str.size() < kEncodedDimensionLength) {
+ // C string cannot contain 0 bytes, so we append it using 1, just like what
+ // we do in encode_util::EncodeIntToCString.
+ //
+ // The reason that this works is because DecodeIntToString decodes a byte
+ // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded
+ // dimension that is less than 5 bytes, it means that the dimension contains
+ // unencoded leading 0x00. So here we're explicitly encoding those bytes as
+ // 0x01.
+ encoded_dimension_str.push_back(1);
+ }
+ return absl_ports::StrCat(encoded_dimension_str, model_signature);
+}
+
+std::string GetPostingListKey(const PropertyProto::VectorProto& vector) {
+ return GetPostingListKey(vector.values_size(), vector.model_signature());
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>>
+EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) {
+ ICING_RETURN_ERROR_IF_NULL(filesystem);
+
+ std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>(
+ new EmbeddingIndex(*filesystem, std::move(working_path)));
+ ICING_RETURN_IF_ERROR(index->Initialize());
+ return index;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Initialize() {
+ bool is_new = false;
+ if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) {
+ // Create working directory.
+ if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) {
+ return absl_ports::InternalError(
+ absl_ports::StrCat("Failed to create directory: ", working_path_));
+ }
+ is_new = true;
+ }
+
+ ICING_ASSIGN_OR_RETURN(
+ MemoryMappedFile metadata_mmapped_file,
+ MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_),
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
+ /*max_file_size=*/kMetadataFileSize,
+ /*pre_mapping_file_offset=*/0,
+ /*pre_mapping_mmap_size=*/kMetadataFileSize));
+ metadata_mmapped_file_ =
+ std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file));
+
+ ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(
+ GetFlashIndexStorageFilePath(working_path_),
+ &filesystem_, posting_list_hit_serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_posting_list_mapper_,
+ DynamicTrieKeyMapper<PostingListIdentifier>::Create(
+ filesystem_, GetEmbeddingHitListMapperPath(working_path_),
+ kEmbeddingHitListMapperMaxSize));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_vectors_,
+ FileBackedVector<float>::Create(
+ filesystem_, GetEmbeddingVectorsFilePath(working_path_),
+ MemoryMappedFile::READ_WRITE_AUTO_SYNC));
+
+ if (is_new) {
+ ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary(
+ /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize));
+ info().magic = Info::kMagic;
+ info().last_added_document_id = kInvalidDocumentId;
+ ICING_RETURN_IF_ERROR(InitializeNewStorage());
+ } else {
+ if (metadata_mmapped_file_->available_size() != kMetadataFileSize) {
+ return absl_ports::FailedPreconditionError(
+ "Incorrect metadata file size");
+ }
+ if (info().magic != Info::kMagic) {
+ return absl_ports::FailedPreconditionError("Incorrect magic value");
+ }
+ ICING_RETURN_IF_ERROR(InitializeExistingStorage());
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Clear() {
+ pending_embedding_hits_.clear();
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+ if (filesystem_.DirectoryExists(working_path_.c_str())) {
+ ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_));
+ }
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+EmbeddingIndex::GetAccessor(uint32_t dimension,
+ std::string_view model_signature) const {
+ if (dimension == 0) {
+ return absl_ports::InvalidArgumentError("Dimension is 0");
+ }
+ std::string key = GetPostingListKey(dimension, model_signature);
+ ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id,
+ embedding_posting_list_mapper_->Get(key));
+ return PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id);
+}
+
+libtextclassifier3::Status EmbeddingIndex::BufferEmbedding(
+ const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) {
+ if (vector.values_size() == 0) {
+ return absl_ports::InvalidArgumentError("Vector dimension is 0");
+ }
+
+ uint32_t location = embedding_vectors_->num_elements();
+ uint32_t dimension = vector.values_size();
+ std::string key = GetPostingListKey(vector);
+
+ // Buffer the embedding hit.
+ pending_embedding_hits_.push_back(
+ {std::move(key), EmbeddingHit(basic_hit, location)});
+
+ // Put vector
+ ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr,
+ embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension);
+
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() {
+ std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end());
+ auto iter_curr_key = pending_embedding_hits_.rbegin();
+ while (iter_curr_key != pending_embedding_hits_.rend()) {
+ // In order to batch putting embedding hits with the same key (dimension,
+ // model_signature) to the same posting list, we find the range
+ // [iter_curr_key, iter_next_key) of embedding hits with the same key and
+ // put them into their corresponding posting list together.
+ auto iter_next_key = iter_curr_key;
+ while (iter_next_key != pending_embedding_hits_.rend() &&
+ iter_next_key->first == iter_curr_key->first) {
+ iter_next_key++;
+ }
+
+ const std::string& key = iter_curr_key->first;
+ libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or =
+ embedding_posting_list_mapper_->Get(key);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (posting_list_id_or.ok()) {
+ // Existing posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id_or.ValueOrDie()));
+ } else if (absl_ports::IsNotFound(posting_list_id_or.status())) {
+ // New posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get()));
+ } else {
+ // Errors
+ return std::move(posting_list_id_or).status();
+ }
+
+ // Adding the embedding hits.
+ for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) {
+ ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second));
+ }
+
+ // Finalize this posting list and add the posting list id in
+ // embedding_posting_list_mapper_.
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id));
+
+ // Advance to the next key.
+ iter_curr_key = iter_next_key;
+ }
+ pending_embedding_hits_.clear();
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::TransferIndex(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ EmbeddingIndex* new_index) const {
+ std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr =
+ embedding_posting_list_mapper_->GetIterator();
+ while (itr->Advance()) {
+ std::string_view key = itr->GetKey();
+ // This should never happen unless there is an inconsistency, or the index
+ // is corrupted.
+ if (key.size() < kEncodedDimensionLength) {
+ return absl_ports::InternalError(
+ "Got invalid key from embedding posting list mapper.");
+ }
+ uint32_t dimension = encode_util::DecodeIntFromCString(
+ std::string_view(key.begin(), kEncodedDimensionLength));
+
+ // Transfer hits
+ std::vector<EmbeddingHit> new_hits;
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ /*existing_posting_list_id=*/itr->GetValue()));
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ old_pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ break;
+ }
+ for (EmbeddingHit& old_hit : batch) {
+ // Safety checks to add robustness to the codebase, so to make sure
+ // that we never access invalid memory, in case that hit from the
+ // posting list is corrupted.
+ ICING_ASSIGN_OR_RETURN(const float* old_vector,
+ GetEmbeddingVector(old_hit, dimension));
+ if (old_hit.basic_hit().document_id() < 0 ||
+ old_hit.basic_hit().document_id() >=
+ document_id_old_to_new.size()) {
+ return absl_ports::InternalError(
+ "Embedding hit document id is out of bound. The provided map is "
+ "too small, or the index may have been corrupted.");
+ }
+
+ // Construct transferred hit
+ DocumentId new_document_id =
+ document_id_old_to_new[old_hit.basic_hit().document_id()];
+ if (new_document_id == kInvalidDocumentId) {
+ continue;
+ }
+ uint32_t new_location = new_index->embedding_vectors_->num_elements();
+ new_hits.push_back(EmbeddingHit(
+ BasicHit(old_hit.basic_hit().section_id(), new_document_id),
+ new_location));
+
+ // Copy the embedding vector of the hit to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ FileBackedVector<float>::MutableArrayView mutable_arr,
+ new_index->embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, old_vector, dimension);
+ }
+ }
+ // No hit needs to be added to the new index.
+ if (new_hits.empty()) {
+ return libtextclassifier3::Status::OK;
+ }
+ // Add transferred hits to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum,
+ PostingListEmbeddingHitAccessor::Create(
+ new_index->flash_index_storage_.get(),
+ new_index->posting_list_hit_serializer_.get()));
+ for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend();
+ ++new_hit_itr) {
+ ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr));
+ }
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*hit_accum).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(
+ new_index->embedding_posting_list_mapper_->Put(key, result.id));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Optimize(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ DocumentId new_last_added_document_id) {
+ // This is just for completeness, but this should never be necessary, since we
+ // should never have pending hits at the time when Optimize is run.
+ ICING_RETURN_IF_ERROR(CommitBufferToIndex());
+
+ std::string temporary_index_working_path = working_path_ + "_temp";
+ if (!filesystem_.DeleteDirectoryRecursively(
+ temporary_index_working_path.c_str())) {
+ ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path;
+ return absl_ports::InternalError(
+ "Unable to delete temp directory to prepare to build new index.");
+ }
+
+ DestructibleDirectory temporary_index_dir(
+ &filesystem_, std::move(temporary_index_working_path));
+ if (!temporary_index_dir.is_valid()) {
+ return absl_ports::InternalError(
+ "Unable to create temp directory to build new index.");
+ }
+
+ {
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<EmbeddingIndex> new_index,
+ EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir()));
+ ICING_RETURN_IF_ERROR(
+ TransferIndex(document_id_old_to_new, new_index.get()));
+ new_index->set_last_added_document_id(new_last_added_document_id);
+ ICING_RETURN_IF_ERROR(new_index->PersistToDisk());
+ }
+
+ // Destruct current storage instances to safely swap directories.
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+
+ if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(),
+ working_path_.c_str())) {
+ return absl_ports::InternalError(
+ "Unable to apply new index due to failed swap!");
+ }
+
+ // Reinitialize the index.
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) {
+ return metadata_mmapped_file_->PersistToDisk();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) {
+ if (!flash_index_storage_->PersistToDisk()) {
+ return absl_ports::InternalError("Fail to persist flash index to disk");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk());
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum(
+ bool force) {
+ return info().ComputeChecksum();
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum(
+ bool force) {
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc,
+ embedding_posting_list_mapper_->ComputeChecksum());
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc,
+ embedding_vectors_->ComputeChecksum());
+ return Crc32(embedding_posting_list_mapper_crc.Get() ^
+ embedding_vectors_crc.Get());
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-index.h b/icing/index/embed/embedding-index.h
new file mode 100644
index 0000000..7318871
--- /dev/null
+++ b/icing/index/embed/embedding-index.h
@@ -0,0 +1,274 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
+#define ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/persistent-storage.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/index/hit/hit.h"
+#include "icing/store/document-id.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingIndex : public PersistentStorage {
+ public:
+ struct Info {
+ static constexpr int32_t kMagic = 0xfbe13cbb;
+
+ int32_t magic;
+ DocumentId last_added_document_id;
+
+ Crc32 ComputeChecksum() const {
+ return Crc32(
+ std::string_view(reinterpret_cast<const char*>(this), sizeof(Info)));
+ }
+ } __attribute__((packed));
+ static_assert(sizeof(Info) == 8, "");
+
+ // Metadata file layout: <Crcs><Info>
+ static constexpr int32_t kCrcsMetadataBufferOffset = 0;
+ static constexpr int32_t kInfoMetadataBufferOffset =
+ static_cast<int32_t>(sizeof(Crcs));
+ static constexpr int32_t kMetadataFileSize = sizeof(Crcs) + sizeof(Info);
+ static_assert(kMetadataFileSize == 20, "");
+
+ static constexpr WorkingPathType kWorkingPathType =
+ WorkingPathType::kDirectory;
+
+ EmbeddingIndex(const EmbeddingIndex&) = delete;
+ EmbeddingIndex& operator=(const EmbeddingIndex&) = delete;
+
+ // Creates a new EmbeddingIndex instance to index embeddings.
+ //
+ // Returns:
+ // - FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored
+ // checksum.
+ // - INTERNAL_ERROR on I/O errors.
+ // - Any error from MemoryMappedFile, FlashIndexStorage,
+ // DynamicTrieKeyMapper, or FileBackedVector.
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> Create(
+ const Filesystem* filesystem, std::string working_path);
+
+ static libtextclassifier3::Status Discard(const Filesystem& filesystem,
+ const std::string& working_path) {
+ return PersistentStorage::Discard(filesystem, working_path,
+ kWorkingPathType);
+ }
+
+ libtextclassifier3::Status Clear();
+
+ // Buffer an embedding pending to be added to the index. This is required
+ // since EmbeddingHits added in posting lists must be decreasing, which means
+ // that section ids and location indexes for a single document must be added
+ // decreasingly.
+ //
+ // Returns:
+ // - OK on success
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status BufferEmbedding(
+ const BasicHit& basic_hit, const PropertyProto::VectorProto& vector);
+
+ // Commit the embedding hits in the buffer to the index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ // - Any error from posting lists
+ libtextclassifier3::Status CommitBufferToIndex();
+
+ // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
+ // with the provided dimension and signature.
+ //
+ // Returns:
+ // - a PostingListEmbeddingHitAccessor instance on success.
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - NOT_FOUND error if there is no matching embedding hit.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ GetAccessor(uint32_t dimension, std::string_view model_signature) const;
+
+ // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
+ // with the provided vector's dimension and signature.
+ //
+ // Returns:
+ // - a PostingListEmbeddingHitAccessor instance on success.
+ // - INVALID_ARGUMENT error if the dimension is 0.
+ // - NOT_FOUND error if there is no matching embedding hit.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ GetAccessorForVector(const PropertyProto::VectorProto& vector) const {
+ return GetAccessor(vector.values_size(), vector.model_signature());
+ }
+
+ // Reduces internal file sizes by reclaiming space of deleted documents.
+ // new_last_added_document_id will be used to update the last added document
+ // id in the lite index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on IO error, this indicates that the index may be in an
+ // invalid state and should be cleared.
+ libtextclassifier3::Status Optimize(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ DocumentId new_last_added_document_id);
+
+ libtextclassifier3::StatusOr<const float*> GetEmbeddingVector(
+ const EmbeddingHit& hit, uint32_t dimension) const {
+ if (static_cast<int64_t>(hit.location()) + dimension >
+ GetTotalVectorSize()) {
+ return absl_ports::InternalError(
+ "Got an embedding hit that refers to a vector out of range.");
+ }
+ return embedding_vectors_->array() + hit.location();
+ }
+
+ const float* GetRawEmbeddingData() const {
+ return embedding_vectors_->array();
+ }
+
+ int32_t GetTotalVectorSize() const {
+ return embedding_vectors_->num_elements();
+ }
+
+ DocumentId last_added_document_id() const {
+ return info().last_added_document_id;
+ }
+
+ void set_last_added_document_id(DocumentId document_id) {
+ Info& info_ref = info();
+ if (info_ref.last_added_document_id == kInvalidDocumentId ||
+ document_id > info_ref.last_added_document_id) {
+ info_ref.last_added_document_id = document_id;
+ }
+ }
+
+ private:
+ explicit EmbeddingIndex(const Filesystem& filesystem,
+ std::string working_path)
+ : PersistentStorage(filesystem, std::move(working_path),
+ kWorkingPathType) {}
+
+ libtextclassifier3::Status Initialize();
+
+ // Transfers embedding data and hits from the current index to new_index.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error. This could potentially leave the storages
+ // in an invalid state and the caller should handle it properly (e.g.
+ // discard and rebuild)
+ libtextclassifier3::Status TransferIndex(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ EmbeddingIndex* new_index) const;
+
+ // Flushes contents of metadata file.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status PersistMetadataToDisk(bool force) override;
+
+ // Flushes contents of all storages to underlying files.
+ //
+ // Returns:
+ // - OK on success
+ // - INTERNAL_ERROR on I/O error
+ libtextclassifier3::Status PersistStoragesToDisk(bool force) override;
+
+ // Computes and returns Info checksum.
+ //
+ // Returns:
+ // - Crc of the Info on success
+ libtextclassifier3::StatusOr<Crc32> ComputeInfoChecksum(bool force) override;
+
+ // Computes and returns all storages checksum.
+ //
+ // Returns:
+ // - Crc of all storages on success
+ // - INTERNAL_ERROR if any data inconsistency
+ libtextclassifier3::StatusOr<Crc32> ComputeStoragesChecksum(
+ bool force) override;
+
+ Crcs& crcs() override {
+ return *reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() +
+ kCrcsMetadataBufferOffset);
+ }
+
+ const Crcs& crcs() const override {
+ return *reinterpret_cast<const Crcs*>(metadata_mmapped_file_->region() +
+ kCrcsMetadataBufferOffset);
+ }
+
+ Info& info() {
+ return *reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() +
+ kInfoMetadataBufferOffset);
+ }
+
+ const Info& info() const {
+ return *reinterpret_cast<const Info*>(metadata_mmapped_file_->region() +
+ kInfoMetadataBufferOffset);
+ }
+
+ // In memory data:
+ // Pending embedding hits with their embedding keys used for
+ // embedding_posting_list_mapper_.
+ std::vector<std::pair<std::string, EmbeddingHit>> pending_embedding_hits_;
+
+ // Metadata
+ std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_;
+
+ // Posting list storage
+ std::unique_ptr<PostingListEmbeddingHitSerializer>
+ posting_list_hit_serializer_ =
+ std::make_unique<PostingListEmbeddingHitSerializer>();
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+
+ // The mapper from embedding keys to the corresponding posting list identifier
+ // that stores all embedding hits with the same key.
+ //
+ // The key for an embedding hit is a one-to-one encoded string of the ordered
+ // pair (dimension, model_signature) corresponding to the embedding.
+ std::unique_ptr<KeyMapper<PostingListIdentifier>>
+ embedding_posting_list_mapper_;
+
+ // A single FileBackedVector that holds all embedding vectors.
+ std::unique_ptr<FileBackedVector<float>> embedding_vectors_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
diff --git a/icing/index/embed/embedding-index_test.cc b/icing/index/embed/embedding-index_test.cc
new file mode 100644
index 0000000..5980e82
--- /dev/null
+++ b/icing/index/embed/embedding-index_test.cc
@@ -0,0 +1,582 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-index.h"
+
+#include <unistd.h>
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/filesystem.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/proto/document.pb.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::Test;
+
+class EmbeddingIndexTest : public Test {
+ protected:
+ void SetUp() override {
+ embedding_index_dir_ = GetTestTempDir() + "/embedding_index_test";
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
+ }
+
+ void TearDown() override {
+ embedding_index_.reset();
+ filesystem_.DeleteDirectoryRecursively(embedding_index_dir_.c_str());
+ }
+
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ uint32_t dimension, std::string_view model_signature) {
+ std::vector<EmbeddingHit> hits;
+
+ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or =
+ embedding_index_->GetAccessor(dimension, model_signature);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ return hits;
+ } else {
+ return std::move(pl_accessor_or).status();
+ }
+
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ return hits;
+ }
+ hits.insert(hits.end(), batch.begin(), batch.end());
+ }
+ }
+
+ std::vector<float> GetRawEmbeddingData() {
+ return std::vector<float>(embedding_index_->GetRawEmbeddingData(),
+ embedding_index_->GetRawEmbeddingData() +
+ embedding_index_->GetTotalVectorSize());
+ }
+
+ Filesystem filesystem_;
+ std::string embedding_index_dir_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
+};
+
+TEST_F(EmbeddingIndexTest, AddSingleEmbedding) {
+ PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, AddMultipleEmbeddingsInTheSameSection) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, HitsWithLowerSectionIdReturnedFirst) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/5, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/2, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/5, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, HitsWithHigherDocumentIdReturnedFirst) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+}
+
+TEST_F(EmbeddingIndexTest, AddEmbeddingsFromDifferentModels) {
+ PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model2", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/5, /*model_signature=*/"non-existent-model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest,
+ AddEmbeddingsWithSameSignatureButDifferentDimension) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(0);
+
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, ClearIndex) {
+ // Loop the same logic twice to make sure that clear works as expected, and
+ // the index is still valid after clearing.
+ for (int i = 0; i < 2; i++) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/2, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/1),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Check that clear works as expected.
+ ICING_ASSERT_OK(embedding_index_->Clear());
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), kInvalidDocumentId);
+ }
+}
+
+TEST_F(EmbeddingIndexTest, EmptyCommitIsOk) {
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexTest, MultipleCommits) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+}
+
+TEST_F(EmbeddingIndexTest,
+ InvalidCommit_SectionIdCanOnlyDecreaseForSingleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector2));
+ // Posting list with delta encoding can only allow decreasing values.
+ EXPECT_THAT(embedding_index_->CommitBufferToIndex(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(EmbeddingIndexTest, InvalidCommit_DocumentIdCanOnlyIncrease) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector2));
+ // Posting list with delta encoding can only allow decreasing values, which
+ // means document ids must be committed increasingly, since document ids are
+ // inverted in hit values.
+ EXPECT_THAT(embedding_index_->CommitBufferToIndex(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(EmbeddingIndexTest, EmptyOptimizeIsOk) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{},
+ /*new_last_added_document_id=*/kInvalidDocumentId));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeSingleEmbeddingSingleDocument) {
+ PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(2);
+
+ // Before optimize
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize without deleting any documents, and check that the index is
+ // not changed.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1, 2},
+ /*new_last_added_document_id=*/2));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize to map document id 2 to 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize to delete the document.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsSingleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/2), vector2));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(2);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize without deleting any documents, and check that the index is
+ // not changed.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1, 2},
+ /*new_last_added_document_id=*/2));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 2);
+
+ // Run optimize to map document id 2 to 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize to delete the document.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsMultipleDocument) {
+ PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector2 = CreateVector("model", {1, 2, 3});
+ PropertyProto::VectorProto vector3 =
+ CreateVector("model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), vector2));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector3));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/6),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/3))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 0.3, 1, 2, 3, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize without deleting any documents. It is expected to see that the
+ // raw embedding data is rearranged, since during index transfer, embedding
+ // vectors from higher document ids are added first.
+ //
+ // Also keep in mind that once the raw data is rearranged, calling another
+ // Optimize subsequently will not change the raw data again.
+ for (int i = 0; i < 2; i++) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(-0.1, -0.2, -0.3, 0.1, 0.2, 0.3, 1, 2, 3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+ }
+
+ // Run optimize to delete document 0, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{kInvalidDocumentId, 0},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(-0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+TEST_F(EmbeddingIndexTest, OptimizeEmbeddingsFromDifferentModels) {
+ PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2});
+ PropertyProto::VectorProto vector2 = CreateVector("model1", {1, 2});
+ PropertyProto::VectorProto vector3 =
+ CreateVector("model2", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), vector1));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/0, /*document_id=*/1), vector2));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(/*section_id=*/1, /*document_id=*/1), vector3));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+ embedding_index_->set_last_added_document_id(1);
+
+ // Before optimize
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/2),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1),
+ /*location=*/4))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.1, 0.2, 1, 2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+
+ // Run optimize without deleting any documents. It is expected to see that the
+ // raw embedding data is rearranged, since during index transfer:
+ // - Embedding vectors with lower keys, which are the string encoded ordered
+ // pairs (dimension, model_signature), are iterated first.
+ // - Embedding vectors from higher document ids are added first.
+ //
+ // Also keep in mind that once the raw data is rearranged, calling another
+ // Optimize subsequently will not change the raw data again.
+ for (int i = 0; i < 2; i++) {
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, 1},
+ /*new_last_added_document_id=*/1));
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/2))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1),
+ /*location=*/4))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(1, 2, 0.1, 0.2, -0.1, -0.2, -0.3));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 1);
+ }
+
+ // Run optimize to delete document 1, and check that the index is
+ // updated correctly.
+ ICING_ASSERT_OK(embedding_index_->Optimize(
+ /*document_id_old_to_new=*/{0, kInvalidDocumentId},
+ /*new_last_added_document_id=*/0));
+ EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0))));
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2));
+ EXPECT_EQ(embedding_index_->last_added_document_id(), 0);
+}
+
+} // namespace
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-query-results.h b/icing/index/embed/embedding-query-results.h
new file mode 100644
index 0000000..de85489
--- /dev/null
+++ b/icing/index/embed/embedding-query-results.h
@@ -0,0 +1,72 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+#define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "icing/proto/search.pb.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+// A class to store results generated from embedding queries.
+struct EmbeddingQueryResults {
+ // Maps from DocumentId to the list of matched embedding scores for that
+ // document, which will be used in the advanced scoring language to
+ // determine the results for the "this.matchedSemanticScores(...)" function.
+ using EmbeddingQueryScoreMap =
+ std::unordered_map<DocumentId, std::vector<double>>;
+
+ // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap.
+ std::unordered_map<
+ int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code,
+ EmbeddingQueryScoreMap>>
+ result_scores;
+
+ // Returns the matched scores for the given query_vector_index, metric_type,
+ // and doc_id. Returns nullptr if (query_vector_index, metric_type) does not
+ // exist in the result_scores map.
+ const std::vector<double>* GetMatchedScoresForDocument(
+ int query_vector_index,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ DocumentId doc_id) const {
+ // Check if a mapping exists for the query_vector_index
+ auto outer_it = result_scores.find(query_vector_index);
+ if (outer_it == result_scores.end()) {
+ return nullptr;
+ }
+ // Check if a mapping exists for the metric_type
+ auto inner_it = outer_it->second.find(metric_type);
+ if (inner_it == outer_it->second.end()) {
+ return nullptr;
+ }
+ const EmbeddingQueryScoreMap& score_map = inner_it->second;
+
+ // Check if the doc_id exists in the score_map
+ auto scores_it = score_map.find(doc_id);
+ if (scores_it == score_map.end()) {
+ return nullptr;
+ }
+ return &scores_it->second;
+ }
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
diff --git a/icing/index/embed/embedding-scorer.cc b/icing/index/embed/embedding-scorer.cc
new file mode 100644
index 0000000..0d84e01
--- /dev/null
+++ b/icing/index/embed/embedding-scorer.cc
@@ -0,0 +1,95 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-scorer.h"
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+float CalculateDotProduct(int dimension, const float* v1, const float* v2) {
+ float dot_product = 0.0;
+ for (int i = 0; i < dimension; ++i) {
+ dot_product += v1[i] * v2[i];
+ }
+ return dot_product;
+}
+
+float CalculateNorm2(int dimension, const float* v) {
+ return std::sqrt(CalculateDotProduct(dimension, v, v));
+}
+
+float CalculateCosine(int dimension, const float* v1, const float* v2) {
+ float divisor = CalculateNorm2(dimension, v1) * CalculateNorm2(dimension, v2);
+ if (divisor == 0.0) {
+ return 0.0;
+ }
+ return CalculateDotProduct(dimension, v1, v2) / divisor;
+}
+
+float CalculateEuclideanDistance(int dimension, const float* v1,
+ const float* v2) {
+ float result = 0.0;
+ for (int i = 0; i < dimension; ++i) {
+ float diff = v1[i] - v2[i];
+ result += diff * diff;
+ }
+ return std::sqrt(result);
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>>
+EmbeddingScorer::Create(
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) {
+ switch (metric_type) {
+ case SearchSpecProto::EmbeddingQueryMetricType::COSINE:
+ return std::make_unique<CosineEmbeddingScorer>();
+ case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT:
+ return std::make_unique<DotProductEmbeddingScorer>();
+ case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN:
+ return std::make_unique<EuclideanDistanceEmbeddingScorer>();
+ default:
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type)));
+ }
+}
+
+float CosineEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateCosine(dimension, v1, v2);
+}
+
+float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateDotProduct(dimension, v1, v2);
+}
+
+float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1,
+ const float* v2) const {
+ return CalculateEuclideanDistance(dimension, v1, v2);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/embedding-scorer.h b/icing/index/embed/embedding-scorer.h
new file mode 100644
index 0000000..8caf0bc
--- /dev/null
+++ b/icing/index/embed/embedding-scorer.h
@@ -0,0 +1,54 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
+#define ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingScorer {
+ public:
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> Create(
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type);
+ virtual float Score(int dimension, const float* v1,
+ const float* v2) const = 0;
+
+ virtual ~EmbeddingScorer() = default;
+};
+
+class CosineEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+class DotProductEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+class EuclideanDistanceEmbeddingScorer : public EmbeddingScorer {
+ public:
+ float Score(int dimension, const float* v1, const float* v2) const override;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_SCORER_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.cc b/icing/index/embed/posting-list-embedding-hit-accessor.cc
new file mode 100644
index 0000000..e154165
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor.cc
@@ -0,0 +1,132 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+PostingListEmbeddingHitAccessor::Create(
+ FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer) {
+ uint32_t max_posting_list_bytes = storage->max_posting_list_bytes();
+ ICING_ASSIGN_OR_RETURN(PostingListUsed in_memory_posting_list,
+ PostingListUsed::CreateFromUnitializedRegion(
+ serializer, max_posting_list_bytes));
+ return std::unique_ptr<PostingListEmbeddingHitAccessor>(
+ new PostingListEmbeddingHitAccessor(storage, serializer,
+ std::move(in_memory_posting_list)));
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+PostingListEmbeddingHitAccessor::CreateFromExisting(
+ FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer,
+ PostingListIdentifier existing_posting_list_id) {
+ // Our in_memory_posting_list_ will start as empty.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ Create(storage, serializer));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage->GetPostingList(existing_posting_list_id));
+ pl_accessor->preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ return pl_accessor;
+}
+
+// Returns the next batch of hits for the provided posting list.
+libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
+PostingListEmbeddingHitAccessor::GetNextHitsBatch() {
+ if (preexisting_posting_list_ == nullptr) {
+ if (has_reached_posting_list_chain_end_) {
+ return std::vector<EmbeddingHit>();
+ }
+ return absl_ports::FailedPreconditionError(
+ "Cannot retrieve hits from a PostingListEmbeddingHitAccessor that was "
+ "not created from a preexisting posting list.");
+ }
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<EmbeddingHit> batch,
+ serializer_->GetHits(&preexisting_posting_list_->posting_list));
+ uint32_t next_block_index = kInvalidBlockIndex;
+ // Posting lists will only be chained when they are max-sized, in which case
+ // next_block_index will point to the next block for the next posting list.
+ // Otherwise, next_block_index can be kInvalidBlockIndex or be used to point
+ // to the next free list block, which is not relevant here.
+ if (preexisting_posting_list_->posting_list.size_in_bytes() ==
+ storage_->max_posting_list_bytes()) {
+ next_block_index = preexisting_posting_list_->next_block_index;
+ }
+
+ if (next_block_index != kInvalidBlockIndex) {
+ // Since we only have to deal with next block for max-sized posting list
+ // block, max_num_posting_lists is 1 and posting_list_index_bits is
+ // BitsToStore(1).
+ PostingListIdentifier next_posting_list_id(
+ next_block_index, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/BitsToStore(1));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage_->GetPostingList(next_posting_list_id));
+ preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ } else {
+ has_reached_posting_list_chain_end_ = true;
+ preexisting_posting_list_.reset();
+ }
+ return batch;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitAccessor::PrependHit(
+ const EmbeddingHit &hit) {
+ PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr)
+ ? preexisting_posting_list_->posting_list
+ : in_memory_posting_list_;
+ libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit);
+ if (!absl_ports::IsResourceExhausted(status)) {
+ return status;
+ }
+ // There is no more room to add hits to this current posting list! Therefore,
+ // we need to either move those hits to a larger posting list or flush this
+ // posting list and create another max-sized posting list in the chain.
+ if (preexisting_posting_list_ != nullptr) {
+ ICING_RETURN_IF_ERROR(FlushPreexistingPostingList());
+ } else {
+ ICING_RETURN_IF_ERROR(FlushInMemoryPostingList());
+ }
+
+ // Re-add hit. Should always fit since we just cleared
+ // in_memory_posting_list_. It's fine to explicitly reference
+ // in_memory_posting_list_ here because there's no way of reaching this line
+ // while preexisting_posting_list_ is still in use.
+ return serializer_->PrependHit(&in_memory_posting_list_, hit);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.h b/icing/index/embed/posting-list-embedding-hit-accessor.h
new file mode 100644
index 0000000..4acb9a3
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor.h
@@ -0,0 +1,106 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
+#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+namespace icing {
+namespace lib {
+
+// This class is used to provide a simple abstraction for adding hits to posting
+// lists. PostingListEmbeddingHitAccessor handles 1) selection of properly-sized
+// posting lists for the accumulated hits during Finalize() and 2) chaining of
+// max-sized posting lists.
+class PostingListEmbeddingHitAccessor : public PostingListAccessor {
+ public:
+ // Creates an empty PostingListEmbeddingHitAccessor.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of
+ // PostingListEmbeddingHitAccessor
+ // - INVALID_ARGUMENT error if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ Create(FlashIndexStorage* storage,
+ PostingListEmbeddingHitSerializer* serializer);
+
+ // Create a PostingListEmbeddingHitAccessor with an existing posting list
+ // identified by existing_posting_list_id.
+ //
+ // The PostingListEmbeddingHitAccessor will add hits to this posting list
+ // until it is necessary either to 1) chain the posting list (if it is
+ // max-sized) or 2) move its hits to a larger posting list.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of
+ // PostingListEmbeddingHitAccessor
+ // - INVALID_ARGUMENT if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ CreateFromExisting(FlashIndexStorage* storage,
+ PostingListEmbeddingHitSerializer* serializer,
+ PostingListIdentifier existing_posting_list_id);
+
+ PostingListSerializer* GetSerializer() override { return serializer_; }
+
+ // Retrieve the next batch of hits for the posting list chain
+ //
+ // RETURNS:
+ // - On success, a vector of hits in the posting list chain
+ // - INTERNAL if called on an instance of PostingListEmbeddingHitAccessor
+ // that was created via PostingListEmbeddingHitAccessor::Create, if unable
+ // to read the next posting list in the chain or if the posting list has
+ // been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetNextHitsBatch();
+
+ // Prepend one hit. This may result in flushing the posting list to disk (if
+ // the PostingListEmbeddingHitAccessor holds a max-sized posting list that is
+ // full) or freeing a pre-existing posting list if it is too small to fit all
+ // hits necessary.
+ //
+ // RETURNS:
+ // - OK, on success
+ // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
+ // previously added hit.
+ // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new
+ // posting list.
+ libtextclassifier3::Status PrependHit(const EmbeddingHit& hit);
+
+ private:
+ explicit PostingListEmbeddingHitAccessor(
+ FlashIndexStorage* storage, PostingListEmbeddingHitSerializer* serializer,
+ PostingListUsed in_memory_posting_list)
+ : PostingListAccessor(storage, std::move(in_memory_posting_list)),
+ serializer_(serializer) {}
+
+ PostingListEmbeddingHitSerializer* serializer_; // Does not own.
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-accessor_test.cc b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc
new file mode 100644
index 0000000..b9ebe87
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc
@@ -0,0 +1,387 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+#include "icing/index/hit/hit.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/hit-test-utils.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::Lt;
+using ::testing::SizeIs;
+
+class PostingListEmbeddingHitAccessorTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ test_dir_ = GetTestTempDir() + "/test_dir";
+ file_name_ = test_dir_ + "/test_file.idx.index";
+
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str()));
+
+ serializer_ = std::make_unique<PostingListEmbeddingHitSerializer>();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+ }
+
+ void TearDown() override {
+ flash_index_storage_.reset();
+ serializer_.reset();
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ }
+
+ Filesystem filesystem_;
+ std::string test_dir_;
+ std::string file_name_;
+ std::unique_ptr<PostingListEmbeddingHitSerializer> serializer_;
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+};
+
+TEST_F(PostingListEmbeddingHitAccessorTest, HitsAddAndRetrieveProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit& hit : hits1) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
+ }
+ PostingListAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result.status);
+ EXPECT_THAT(result.id.block_index(), Eq(1));
+ EXPECT_THAT(result.id.posting_list_index(), Eq(0));
+
+ // Retrieve some hits.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(pl_holder.next_block_index, Eq(kInvalidBlockIndex));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLKeepOnSameBlock) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add a single hit. This will fit in a min-sized posting list.
+ EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/1);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ // Should have been allocated to the first block.
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more hit. The minimum size for a posting list must be able to fit
+ // at least two hits, so this should NOT cause the previous pl to be
+ // reallocated.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ EmbeddingHit hit2(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit2));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ // Should have been allocated to the same posting list as the first hit.
+ EXPECT_THAT(result2.id, Eq(result1.id));
+
+ // The posting list at result2.id should hold all of the hits that have been
+ // added.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result2.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAre(hit2, hit1)));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLReallocateToLargerPL) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/11, /*desired_byte_length=*/1);
+
+ // Add 8 hits with a small posting list of 24 bytes. The first 7 hits will
+ // be compressed to one byte each and will be able to fit in the 8 byte
+ // padded region. The last hit will fit in one of the special hits. The
+ // posting list will be ALMOST_FULL and can fit at most 2 more hits.
+ for (int i = 0; i < 8; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ // Should have been allocated to the first block.
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Now let's add some more hits!
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ // The current posting list can fit at most 2 more hits.
+ for (int i = 8; i < 10; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ // The 2 hits should still fit on the first block
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more hit
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result2.id));
+ // The current posting list should be FULL. Adding more hits should result in
+ // these hits being moved to a larger posting list.
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[10]));
+ PostingListAccessor::FinalizeResult result3 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result3.status);
+ // Should have been allocated to the second (new) block because the posting
+ // list should have grown beyond the size that the first block maintains.
+ EXPECT_THAT(result3.id.block_index(), Eq(2));
+ EXPECT_THAT(result3.id.posting_list_index(), Eq(0));
+
+ // The posting list at result3.id should hold all of the hits that have been
+ // added.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result3.id));
+ EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, MultiBlockChainsBlocksProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5000, /*desired_byte_length=*/1);
+ for (const EmbeddingHit& hit : hits1) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ PostingListIdentifier second_block_id = result1.id;
+ // Should have been allocated to the second block, which holds a max-sized
+ // posting list.
+ EXPECT_THAT(second_block_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // Now let's retrieve them!
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_block_id));
+ // This pl_holder will only hold a posting list with the hits that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits,
+ serializer_->GetHits(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_hits, SizeIs(Lt(hits1.size())));
+ auto first_block_hits_start = hits1.rbegin() + second_block_hits.size();
+ EXPECT_THAT(second_block_hits,
+ ElementsAreArray(hits1.rbegin(), first_block_hits_start));
+
+ // Now retrieve all of the hits that were on the first block.
+ uint32_t first_block_id = pl_holder.next_block_index;
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(
+ serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest,
+ PreexistingMultiBlockReusesBlocksProperly) {
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/5050, /*desired_byte_length=*/1);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some hits! Any hits!
+ for (int i = 0; i < 5000; ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result1.status);
+ PostingListIdentifier first_add_id = result1.id;
+ EXPECT_THAT(first_add_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // Now add a couple more hits. These should fit on the existing, not full
+ // second block.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), first_add_id));
+ for (int i = 5000; i < hits.size(); ++i) {
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i]));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result2.status);
+ PostingListIdentifier second_add_id = result2.id;
+ EXPECT_THAT(second_add_id, Eq(first_add_id));
+
+ // We should be able to retrieve all 5050 hits.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_add_id));
+ // This pl_holder will only hold a posting list with the hits that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits,
+ serializer_->GetHits(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_hits, SizeIs(Lt(hits.size())));
+ auto first_block_hits_start = hits.rbegin() + second_block_hits.size();
+ EXPECT_THAT(second_block_hits,
+ ElementsAreArray(hits.rbegin(), first_block_hits_start));
+
+ // Now retrieve all of the hits that were on the first block.
+ uint32_t first_block_id = pl_holder.next_block_index;
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(
+ serializer_->GetHits(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits.rend())));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, InvalidHitReturnsInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EXPECT_THAT(pl_accessor->PrependHit(invalid_hit),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest,
+ HitsNotDecreasingReturnsInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/5);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+
+ EmbeddingHit hit2(BasicHit(/*section_id=*/6, /*document_id=*/1),
+ /*location=*/5);
+ EXPECT_THAT(pl_accessor->PrependHit(hit2),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EmbeddingHit hit3(BasicHit(/*section_id=*/2, /*document_id=*/0),
+ /*location=*/5);
+ EXPECT_THAT(pl_accessor->PrependHit(hit3),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EmbeddingHit hit4(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/6);
+ EXPECT_THAT(pl_accessor->PrependHit(hit4),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, NewPostingListNoHitsAdded) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ EXPECT_THAT(result1.status,
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPostingListNoHitsAdded) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1),
+ /*location=*/5);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor2,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor2).Finalize();
+ ICING_ASSERT_OK(result2.status);
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.cc b/icing/index/embed/posting-list-embedding-hit-serializer.cc
new file mode 100644
index 0000000..9247bcb
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer.cc
@@ -0,0 +1,647 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+#include <cinttypes>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+uint32_t PostingListEmbeddingHitSerializer::GetBytesUsed(
+ const PostingListUsed* posting_list_used) const {
+ // The special hits will be included if they represent actual hits. If they
+ // represent the hit offset or the invalid hit sentinel, they are not
+ // included.
+ return posting_list_used->size_in_bytes() -
+ GetStartByteOffset(posting_list_used);
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) {
+ // If in either the FULL state or ALMOST_FULL state, this posting list *is*
+ // the minimum size posting list that can fit these hits. So just return the
+ // size of the posting list.
+ return posting_list_used->size_in_bytes();
+ }
+
+ // - In NOT_FULL status, BytesUsed contains no special hits. For a posting
+ // list in the NOT_FULL state with n hits, we would have n-1 compressed hits
+ // and 1 uncompressed hit.
+ // - The minimum sized posting list that would be guaranteed to fit these hits
+ // would be FULL, but calculating the size required for the FULL posting
+ // list would require deserializing the last two added hits, so instead we
+ // will calculate the size of an ALMOST_FULL posting list to fit.
+ // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the
+ // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in
+ // the compressed region.
+ // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed
+ // hits.
+ // - Therefore, fitting these hits into a posting list would require
+ // BytesUsed + one extra full hit.
+ return GetBytesUsed(posting_list_used) + sizeof(EmbeddingHit);
+}
+
+void PostingListEmbeddingHitSerializer::Clear(
+ PostingListUsed* posting_list_used) const {
+ // Safe to ignore return value because posting_list_used->size_in_bytes() is
+ // a valid argument.
+ SetStartByteOffset(posting_list_used,
+ /*offset=*/posting_list_used->size_in_bytes());
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::MoveFrom(
+ PostingListUsed* dst, PostingListUsed* src) const {
+ ICING_RETURN_ERROR_IF_NULL(dst);
+ ICING_RETURN_ERROR_IF_NULL(src);
+ if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "src MinPostingListSizeToFit %d must be larger than size %d.",
+ GetMinPostingListSizeToFit(src), dst->size_in_bytes()));
+ }
+
+ if (!IsPostingListValid(dst)) {
+ return absl_ports::FailedPreconditionError(
+ "Dst posting list is in an invalid state and can't be used!");
+ }
+ if (!IsPostingListValid(src)) {
+ return absl_ports::InvalidArgumentError(
+ "Cannot MoveFrom an invalid src posting list!");
+ }
+
+ // Pop just enough hits that all of src's compressed hits fit in
+ // dst posting_list's compressed area. Then we can memcpy that area.
+ std::vector<EmbeddingHit> hits;
+ while (IsFull(src) || IsAlmostFull(src) ||
+ (dst->size_in_bytes() - kSpecialHitsSize < GetBytesUsed(src))) {
+ if (!GetHitsInternal(src, /*limit=*/1, /*pop=*/true, &hits).ok()) {
+ return absl_ports::AbortedError(
+ "Unable to retrieve hits from src posting list.");
+ }
+ }
+
+ // memcpy the area and set up start byte offset.
+ Clear(dst);
+ memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src),
+ src->posting_list_buffer() + GetStartByteOffset(src),
+ GetBytesUsed(src));
+ // Because we popped all hits from src outside of the compressed area and we
+ // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() -
+ // kSpecialHitSize. This is guaranteed to be a valid byte offset for the
+ // NOT_FULL state, so ignoring the value is safe.
+ SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src));
+
+ // Put back remaining hits.
+ for (size_t i = 0; i < hits.size(); i++) {
+ const EmbeddingHit& hit = hits[hits.size() - i - 1];
+ // PrependHit can return either INVALID_ARGUMENT - if hit is invalid or not
+ // less than the previous hit - or RESOURCE_EXHAUSTED. RESOURCE_EXHAUSTED
+ // should be impossible because we've already assured that there is enough
+ // room above.
+ ICING_RETURN_IF_ERROR(PrependHit(dst, hit));
+ }
+
+ Clear(src);
+ return libtextclassifier3::Status::OK;
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetPadEnd(
+ const PostingListUsed* posting_list_used, uint32_t offset) const {
+ EmbeddingHit::Value pad;
+ uint32_t pad_end = offset;
+ while (pad_end < posting_list_used->size_in_bytes()) {
+ size_t pad_len = VarInt::Decode(
+ posting_list_used->posting_list_buffer() + pad_end, &pad);
+ if (pad != 0) {
+ // No longer a pad.
+ break;
+ }
+ pad_end += pad_len;
+ }
+ return pad_end;
+}
+
+bool PostingListEmbeddingHitSerializer::PadToEnd(
+ PostingListUsed* posting_list_used, uint32_t start, uint32_t end) const {
+ if (end > posting_list_used->size_in_bytes()) {
+ ICING_LOG(ERROR) << "Cannot pad a region that ends after size!";
+ return false;
+ }
+ // In VarInt a value of 0 encodes to 0.
+ memset(posting_list_used->posting_list_buffer() + start, 0, end - start);
+ return true;
+}
+
+libtextclassifier3::Status
+PostingListEmbeddingHitSerializer::PrependHitToAlmostFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ // Get delta between first hit and the new hit. Try to fit delta
+ // in the padded area and put new hit at the special position 1.
+ // Calling ValueOrDie is safe here because 1 < kNumSpecialData.
+ EmbeddingHit cur = GetSpecialHit(posting_list_used, /*index=*/1);
+ if (cur.value() <= hit.value()) {
+ return absl_ports::InvalidArgumentError(
+ "Hit being prepended must be strictly less than the most recent Hit");
+ }
+ uint64_t delta = cur.value() - hit.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+
+ uint32_t pad_end = GetPadEnd(posting_list_used,
+ /*offset=*/kSpecialHitsSize);
+
+ if (pad_end >= kSpecialHitsSize + delta_len) {
+ // Pad area has enough space for delta of existing hit (cur). Write delta at
+ // pad_end - delta_len.
+ uint8_t* delta_offset =
+ posting_list_used->posting_list_buffer() + pad_end - delta_len;
+ memcpy(delta_offset, delta_buf, delta_len);
+
+ // Now first hit is the new hit, at special position 1. Safe to ignore the
+ // return value because 1 < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+ // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
+ // argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // No space for delta. We put the new hit at special position 0
+ // and go to the full state. Safe to ignore the return value because 1 <
+ // kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0, hit);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+void PostingListEmbeddingHitSerializer::PrependHitToEmpty(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ // First hit to be added. Just add verbatim, no compression.
+ if (posting_list_used->size_in_bytes() == kSpecialHitsSize) {
+ // Safe to ignore the return value because 1 < kNumSpecialData
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+ // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
+ // argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // Since this is the first hit, size != kSpecialHitsSize and
+ // size % sizeof(EmbeddingHit) == 0, we know that there is room to fit 'hit'
+ // into the compressed region, so ValueOrDie is safe.
+ uint32_t offset =
+ PrependHitUncompressed(posting_list_used, hit,
+ /*offset=*/posting_list_used->size_in_bytes())
+ .ValueOrDie();
+ // Safe to ignore the return value because PrependHitUncompressed is
+ // guaranteed to return a valid offset.
+ SetStartByteOffset(posting_list_used, offset);
+ }
+}
+
+libtextclassifier3::Status
+PostingListEmbeddingHitSerializer::PrependHitToNotFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const {
+ // First hit in compressed area. It is uncompressed. See if delta
+ // between the first hit and new hit will still fit in the
+ // compressed area.
+ if (offset + sizeof(EmbeddingHit::Value) >
+ posting_list_used->size_in_bytes()) {
+ // The first hit in the compressed region *should* be uncompressed, but
+ // somehow there isn't enough room between offset and the end of the
+ // compressed area to fit an uncompressed hit. This should NEVER happen.
+ return absl_ports::FailedPreconditionError(
+ "Posting list is in an invalid state.");
+ }
+ EmbeddingHit::Value cur_value;
+ memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset,
+ sizeof(EmbeddingHit::Value));
+ if (cur_value <= hit.value()) {
+ return absl_ports::InvalidArgumentError(
+ IcingStringUtil::StringPrintf("EmbeddingHit %" PRId64
+ " being prepended must be "
+ "strictly less than the most recent "
+ "EmbeddingHit %" PRId64,
+ hit.value(), cur_value));
+ }
+ uint64_t delta = cur_value - hit.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+
+ // offset now points to one past the end of the first hit.
+ offset += sizeof(EmbeddingHit::Value);
+ if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) + delta_len <= offset) {
+ // Enough space for delta in compressed area.
+
+ // Prepend delta.
+ offset -= delta_len;
+ memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
+ delta_len);
+
+ // Prepend new hit. We know that there is room for 'hit' because of the if
+ // statement above, so calling ValueOrDie is safe.
+ offset =
+ PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie();
+ // offset is guaranteed to be valid here. So it's safe to ignore the return
+ // value. The if above will guarantee that offset >= kSpecialHitSize and <
+ // posting_list_used->size_in_bytes() because the if ensures that there is
+ // enough room between offset and kSpecialHitSize to fit the delta of the
+ // previous hit and the uncompressed hit.
+ SetStartByteOffset(posting_list_used, offset);
+ } else if (kSpecialHitsSize + delta_len <= offset) {
+ // Only have space for delta. The new hit must be put in special
+ // position 1.
+
+ // Prepend delta.
+ offset -= delta_len;
+ memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
+ delta_len);
+
+ // Prepend pad. Safe to ignore the return value of PadToEnd because offset
+ // must be less than posting_list_used->size_in_bytes(). Otherwise, this
+ // function already would have returned FAILED_PRECONDITION.
+ PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+
+ // Put new hit in special position 1. Safe to ignore return value because 1
+ // < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, hit);
+
+ // State almost_full. Safe to ignore the return value because
+ // sizeof(EmbeddingHit) is a valid argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
+ } else {
+ // Very rare case where delta is larger than sizeof(EmbeddingHit::Value)
+ // (i.e. varint delta encoding expanded required storage). We
+ // move first hit to special position 1 and put new hit in
+ // special position 0.
+ EmbeddingHit cur(cur_value);
+ // Safe to ignore the return value of PadToEnd because offset must be less
+ // than posting_list_used->size_in_bytes(). Otherwise, this function
+ // already would have returned FAILED_PRECONDITION.
+ PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+ // Safe to ignore the return value here because 0 and 1 < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/1, cur);
+ SetSpecialHit(posting_list_used, /*index=*/0, hit);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::PrependHit(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
+ static_assert(
+ sizeof(EmbeddingHit::Value) <= sizeof(uint64_t),
+ "EmbeddingHit::Value cannot be larger than 8 bytes because the delta "
+ "must be able to fit in 8 bytes.");
+ if (!hit.is_valid()) {
+ return absl_ports::InvalidArgumentError("Cannot prepend an invalid hit!");
+ }
+ if (!IsPostingListValid(posting_list_used)) {
+ return absl_ports::FailedPreconditionError(
+ "This PostingListUsed is in an invalid state and can't add any hits!");
+ }
+
+ if (IsFull(posting_list_used)) {
+ // State full: no space left.
+ return absl_ports::ResourceExhaustedError("No more room for hits");
+ } else if (IsAlmostFull(posting_list_used)) {
+ return PrependHitToAlmostFull(posting_list_used, hit);
+ } else if (IsEmpty(posting_list_used)) {
+ PrependHitToEmpty(posting_list_used, hit);
+ return libtextclassifier3::Status::OK;
+ } else {
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ return PrependHitToNotFull(posting_list_used, hit, offset);
+ }
+}
+
+libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
+PostingListEmbeddingHitSerializer::GetHits(
+ const PostingListUsed* posting_list_used) const {
+ std::vector<EmbeddingHit> hits_out;
+ ICING_RETURN_IF_ERROR(GetHits(posting_list_used, &hits_out));
+ return hits_out;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHits(
+ const PostingListUsed* posting_list_used,
+ std::vector<EmbeddingHit>* hits_out) const {
+ return GetHitsInternal(posting_list_used,
+ /*limit=*/std::numeric_limits<uint32_t>::max(),
+ /*pop=*/false, hits_out);
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::PopFrontHits(
+ PostingListUsed* posting_list_used, uint32_t num_hits) const {
+ if (num_hits == 1 && IsFull(posting_list_used)) {
+ // The PL is in full status which means that we save 2 uncompressed hits in
+ // the 2 special postions. But full status may be reached by 2 different
+ // statuses.
+ // (1) In "almost full" status
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // When we prepend another hit, we can only put it at the special
+ // position 0. And we get a full PL
+ // +-----------------+----------------+-------+-----------------+
+ // |new 1st hit |original 1st hit|(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // (2) In "not full" status
+ // +-----------------+----------------+------+-------+------------------+
+ // |hits-start-offset|Hit::kInvalidVal|(pad) |1st hit|(compressed) hits |
+ // +-----------------+----------------+------+-------+------------------+
+ // When we prepend another hit, we can reach any of the 3 following
+ // scenarios:
+ // (2.1) not full
+ // if the space of pad and original 1st hit can accommodate the new 1st hit
+ // and the encoded delta value.
+ // +-----------------+----------------+------+-----------+-----------------+
+ // |hits-start-offset|Hit::kInvalidVal|(pad) |new 1st hit|(compressed) hits|
+ // +-----------------+----------------+------+-----------+-----------------+
+ // (2.2) almost full
+ // If the space of pad and original 1st hit cannot accommodate the new 1st
+ // hit and the encoded delta value but can accommodate the encoded delta
+ // value only. We can put the new 1st hit at special position 1.
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |new 1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // (2.3) full
+ // In very rare case, it cannot even accommodate only the encoded delta
+ // value. we can move the original 1st hit into special position 1 and the
+ // new 1st hit into special position 0. This may happen because we use
+ // VarInt encoding method which may make the encoded value longer (about
+ // 4/3 times of original)
+ // +-----------------+----------------+-------+-----------------+
+ // |new 1st hit |original 1st hit|(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ // Suppose now the PL is full. But we don't know whether it arrived to
+ // this status from "not full" like (2.3) or from "almost full" like (1).
+ // We'll return to "almost full" status like (1) if we simply pop the new
+ // 1st hit but we want to make the prepending operation "reversible". So
+ // there should be some way to return to "not full" if possible. A simple
+ // way to do it is to pop 2 hits out of the PL to status "almost full" or
+ // "not full". And add the original 1st hit back. We can return to the
+ // correct original statuses of (2.1) or (1). This makes our prepending
+ // operation reversible.
+ std::vector<EmbeddingHit> out;
+
+ // Popping 2 hits should never fail because we've just ensured that the
+ // posting list is in the FULL state.
+ ICING_RETURN_IF_ERROR(
+ GetHitsInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out));
+
+ // PrependHit should never fail because out[1] is a valid hit less than
+ // previous hits in the posting list and because there's no way that the
+ // posting list could run out of room because it previously stored this hit
+ // AND another hit.
+ ICING_RETURN_IF_ERROR(PrependHit(posting_list_used, out[1]));
+ } else if (num_hits > 0) {
+ return GetHitsInternal(posting_list_used, /*limit=*/num_hits, /*pop=*/true,
+ nullptr);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHitsInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<EmbeddingHit>* out) const {
+ // Put current uncompressed val here.
+ EmbeddingHit::Value val = EmbeddingHit::kInvalidValue;
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ uint32_t count = 0;
+
+ // First traverse the first two special positions.
+ while (count < limit && offset < kSpecialHitsSize) {
+ // Calling ValueOrDie is safe here because offset / sizeof(EmbeddingHit) <
+ // kNumSpecialData because of the check above.
+ EmbeddingHit hit = GetSpecialHit(posting_list_used,
+ /*index=*/offset / sizeof(EmbeddingHit));
+ val = hit.value();
+ if (out != nullptr) {
+ out->push_back(hit);
+ }
+ offset += sizeof(EmbeddingHit);
+ count++;
+ }
+
+ // If special position 1 was set then we need to skip padding.
+ if (val != EmbeddingHit::kInvalidValue && offset == kSpecialHitsSize) {
+ offset = GetPadEnd(posting_list_used, offset);
+ }
+
+ while (count < limit && offset < posting_list_used->size_in_bytes()) {
+ if (val == EmbeddingHit::kInvalidValue) {
+ // First hit is in compressed area. Put that in val.
+ memcpy(&val, posting_list_used->posting_list_buffer() + offset,
+ sizeof(EmbeddingHit::Value));
+ offset += sizeof(EmbeddingHit::Value);
+ } else {
+ // Now we have delta encoded subsequent hits. Decode and push.
+ uint64_t delta;
+ offset += VarInt::Decode(
+ posting_list_used->posting_list_buffer() + offset, &delta);
+ val += delta;
+ }
+ EmbeddingHit hit(val);
+ if (out != nullptr) {
+ out->push_back(hit);
+ }
+ count++;
+ }
+
+ if (pop) {
+ PostingListUsed* mutable_posting_list_used =
+ const_cast<PostingListUsed*>(posting_list_used);
+ // Modify the posting list so that we pop all hits actually
+ // traversed.
+ if (offset >= kSpecialHitsSize &&
+ offset < posting_list_used->size_in_bytes()) {
+ // In the compressed area. Pop and reconstruct. offset/val is
+ // the last traversed hit, which we must discard. So move one
+ // more forward.
+ uint64_t delta;
+ offset += VarInt::Decode(
+ posting_list_used->posting_list_buffer() + offset, &delta);
+ val += delta;
+
+ // Now val is the first hit of the new posting list.
+ if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) <= offset) {
+ // val fits in compressed area. Simply copy.
+ offset -= sizeof(EmbeddingHit::Value);
+ memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val,
+ sizeof(EmbeddingHit::Value));
+ } else {
+ // val won't fit in compressed area.
+ EmbeddingHit hit(val);
+ // Okay to ignore the return value here because 1 < kNumSpecialData.
+ SetSpecialHit(mutable_posting_list_used, /*index=*/1, hit);
+
+ // Prepend pad. Safe to ignore the return value of PadToEnd because
+ // offset must be less than posting_list_used->size_in_bytes() thanks to
+ // the if above.
+ PadToEnd(mutable_posting_list_used,
+ /*start=*/kSpecialHitsSize,
+ /*end=*/offset);
+ offset = sizeof(EmbeddingHit);
+ }
+ }
+ // offset is guaranteed to be valid so ignoring the return value of
+ // set_start_byte_offset is safe. It falls into one of four scenarios:
+ // Scenario 1: the above if was false because offset is not <
+ // posting_list_used->size_in_bytes()
+ // In this case, offset must be == posting_list_used->size_in_bytes()
+ // because we reached offset by unwinding hits on the posting list.
+ // Scenario 2: offset is < kSpecialHitSize
+ // In this case, offset is guaranteed to be either 0 or
+ // sizeof(EmbeddingHit) because offset is incremented by
+ // sizeof(EmbeddingHit) within the first while loop.
+ // Scenario 3: offset is within the compressed region and the new first hit
+ // in the posting list (the value that 'val' holds) will fit as an
+ // uncompressed hit in the compressed region. The resulting offset from
+ // decompressing val must be >= kSpecialHitSize because otherwise we'd be
+ // in Scenario 4
+ // Scenario 4: offset is within the compressed region, but the new first hit
+ // in the posting list is too large to fit as an uncompressed hit in the
+ // in the compressed region. Therefore, it must be stored in a special hit
+ // and offset will be sizeof(EmbeddingHit).
+ SetStartByteOffset(mutable_posting_list_used, offset);
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+EmbeddingHit PostingListEmbeddingHitSerializer::GetSpecialHit(
+ const PostingListUsed* posting_list_used, uint32_t index) const {
+ static_assert(sizeof(EmbeddingHit::Value) >= sizeof(uint32_t), "HitTooSmall");
+ EmbeddingHit val(EmbeddingHit::kInvalidValue);
+ memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val),
+ sizeof(val));
+ return val;
+}
+
+void PostingListEmbeddingHitSerializer::SetSpecialHit(
+ PostingListUsed* posting_list_used, uint32_t index,
+ const EmbeddingHit& val) const {
+ memcpy(posting_list_used->posting_list_buffer() + index * sizeof(val), &val,
+ sizeof(val));
+}
+
+bool PostingListEmbeddingHitSerializer::IsPostingListValid(
+ const PostingListUsed* posting_list_used) const {
+ if (IsAlmostFull(posting_list_used)) {
+ // Special Hit 1 should hold a Hit. Calling ValueOrDie is safe because we
+ // know that 1 < kNumSpecialData.
+ if (!GetSpecialHit(posting_list_used, /*index=*/1).is_valid()) {
+ ICING_LOG(ERROR)
+ << "Both special hits cannot be invalid at the same time.";
+ return false;
+ }
+ } else if (!IsFull(posting_list_used)) {
+ // NOT_FULL. Special Hit 0 should hold a valid offset. Calling ValueOrDie is
+ // safe because we know that 0 < kNumSpecialData.
+ if (GetSpecialHit(posting_list_used, /*index=*/0).value() >
+ posting_list_used->size_in_bytes() ||
+ GetSpecialHit(posting_list_used, /*index=*/0).value() <
+ kSpecialHitsSize) {
+ ICING_LOG(ERROR) << "EmbeddingHit: "
+ << GetSpecialHit(posting_list_used, /*index=*/0).value()
+ << " size: " << posting_list_used->size_in_bytes()
+ << " sp size: " << kSpecialHitsSize;
+ return false;
+ }
+ }
+ return true;
+}
+
+uint32_t PostingListEmbeddingHitSerializer::GetStartByteOffset(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used)) {
+ return 0;
+ } else if (IsAlmostFull(posting_list_used)) {
+ return sizeof(EmbeddingHit);
+ } else {
+ // NOT_FULL, calling ValueOrDie is safe because we know that 0 <
+ // kNumSpecialData.
+ return GetSpecialHit(posting_list_used, /*index=*/0).value();
+ }
+}
+
+bool PostingListEmbeddingHitSerializer::SetStartByteOffset(
+ PostingListUsed* posting_list_used, uint32_t offset) const {
+ if (offset > posting_list_used->size_in_bytes()) {
+ ICING_LOG(ERROR) << "offset cannot be a value greater than size "
+ << posting_list_used->size_in_bytes() << ". offset is "
+ << offset << ".";
+ return false;
+ }
+ if (offset < kSpecialHitsSize && offset > sizeof(EmbeddingHit)) {
+ ICING_LOG(ERROR) << "offset cannot be a value between ("
+ << sizeof(EmbeddingHit) << ", " << kSpecialHitsSize
+ << "). offset is " << offset << ".";
+ return false;
+ }
+ if (offset < sizeof(EmbeddingHit) && offset != 0) {
+ ICING_LOG(ERROR) << "offset cannot be a value between (0, "
+ << sizeof(EmbeddingHit) << "). offset is " << offset
+ << ".";
+ return false;
+ }
+ if (offset >= kSpecialHitsSize) {
+ // not_full state. Safe to ignore the return value because 0 and 1 are both
+ // < kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0, EmbeddingHit(offset));
+ SetSpecialHit(posting_list_used, /*index=*/1,
+ EmbeddingHit(EmbeddingHit::kInvalidValue));
+ } else if (offset == sizeof(EmbeddingHit)) {
+ // almost_full state. Safe to ignore the return value because 1 is both <
+ // kNumSpecialData.
+ SetSpecialHit(posting_list_used, /*index=*/0,
+ EmbeddingHit(EmbeddingHit::kInvalidValue));
+ }
+ // Nothing to do for the FULL state - the offset isn't actually stored
+ // anywhere and both special hits hold valid hits.
+ return true;
+}
+
+libtextclassifier3::StatusOr<uint32_t>
+PostingListEmbeddingHitSerializer::PrependHitUncompressed(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const {
+ if (offset < kSpecialHitsSize + sizeof(EmbeddingHit::Value)) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Not enough room to prepend EmbeddingHit::Value at offset %d.",
+ offset));
+ }
+ offset -= sizeof(EmbeddingHit::Value);
+ EmbeddingHit::Value val = hit.value();
+ memcpy(posting_list_used->posting_list_buffer() + offset, &val,
+ sizeof(EmbeddingHit::Value));
+ return offset;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.h b/icing/index/embed/posting-list-embedding-hit-serializer.h
new file mode 100644
index 0000000..76198c2
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer.h
@@ -0,0 +1,284 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
+#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
+
+#include <cstdint>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+// A serializer class to serialize hits to PostingListUsed. Layout described in
+// comments in posting-list-embedding-hit-serializer.cc.
+class PostingListEmbeddingHitSerializer : public PostingListSerializer {
+ public:
+ static constexpr uint32_t kSpecialHitsSize =
+ kNumSpecialData * sizeof(EmbeddingHit);
+
+ uint32_t GetDataTypeBytes() const override { return sizeof(EmbeddingHit); }
+
+ uint32_t GetMinPostingListSize() const override {
+ static constexpr uint32_t kMinPostingListSize = kSpecialHitsSize;
+ static_assert(sizeof(PostingListIndex) <= kMinPostingListSize,
+ "PostingListIndex must be small enough to fit in a "
+ "minimum-sized Posting List.");
+
+ return kMinPostingListSize;
+ }
+
+ uint32_t GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const override;
+
+ uint32_t GetBytesUsed(
+ const PostingListUsed* posting_list_used) const override;
+
+ void Clear(PostingListUsed* posting_list_used) const override;
+
+ libtextclassifier3::Status MoveFrom(PostingListUsed* dst,
+ PostingListUsed* src) const override;
+
+ // Prepend a hit to the posting list.
+ //
+ // RETURNS:
+ // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
+ // previously added hit.
+ // - RESOURCE_EXHAUSTED if there is no more room to add hit to the posting
+ // list.
+ libtextclassifier3::Status PrependHit(PostingListUsed* posting_list_used,
+ const EmbeddingHit& hit) const;
+
+ // Prepend hits to the posting list. Hits should be sorted in descending order
+ // (as defined by the less than operator for Hit)
+ //
+ // Returns the number of hits that could be prepended to the posting list. If
+ // keep_prepended is true, whatever could be prepended is kept, otherwise the
+ // posting list is left in its original state.
+ template <class T, EmbeddingHit (*GetHit)(const T&)>
+ libtextclassifier3::StatusOr<uint32_t> PrependHitArray(
+ PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
+ bool keep_prepended) const;
+
+ // Retrieves the hits stored in the posting list.
+ //
+ // RETURNS:
+ // - On success, a vector of hits sorted by the reverse order of prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ const PostingListUsed* posting_list_used) const;
+
+ // Same as GetHits but appends hits to hits_out.
+ //
+ // RETURNS:
+ // - On success, a vector of hits sorted by the reverse order of prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetHits(const PostingListUsed* posting_list_used,
+ std::vector<EmbeddingHit>* hits_out) const;
+
+ // Undo the last num_hits hits prepended. If num_hits > number of
+ // hits we clear all hits.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status PopFrontHits(PostingListUsed* posting_list_used,
+ uint32_t num_hits) const;
+
+ private:
+ // Posting list layout formats:
+ //
+ // not_full
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |hits-start-offset|Hit::kInvalidVal|xxxxxxx|(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // almost_full
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // full()
+ //
+ // +-----------------+----------------+-------+-----------------+
+ // |1st hit |2nd hit |(pad) |(compressed) hits|
+ // +-----------------+----------------+-------+-----------------+
+ //
+ // The first two uncompressed hits also implicitly encode information about
+ // the size of the compressed hits region.
+ //
+ // 1. If the posting list is NOT_FULL, then
+ // posting_list_buffer_[0] contains the byte offset of the start of the
+ // compressed hits - and, thus, the size of the compressed hits region is
+ // size_in_bytes - posting_list_buffer_[0].
+ //
+ // 2. If posting list is ALMOST_FULL or FULL, then the compressed hits region
+ // starts somewhere between [kSpecialHitsSize, kSpecialHitsSize +
+ // sizeof(EmbeddingHit) - 1] and ends at size_in_bytes - 1.
+
+ // Helpers to determine what state the posting list is in.
+ bool IsFull(const PostingListUsed* posting_list_used) const {
+ return GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
+ GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ bool IsAlmostFull(const PostingListUsed* posting_list_used) const {
+ return !GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
+ GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ bool IsEmpty(const PostingListUsed* posting_list_used) const {
+ return GetSpecialHit(posting_list_used, /*index=*/0).value() ==
+ posting_list_used->size_in_bytes() &&
+ !GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
+ }
+
+ // Returns false if both special hits are invalid or if the offset value
+ // stored in the special hit is less than kSpecialHitsSize or greater than
+ // posting_list_used->size_in_bytes(). Returns true, otherwise.
+ bool IsPostingListValid(const PostingListUsed* posting_list_used) const;
+
+ // Prepend hit to a posting list that is in the ALMOST_FULL state.
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if hit is not less than the previously added hit.
+ libtextclassifier3::Status PrependHitToAlmostFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit) const;
+
+ // Prepend hit to a posting list that is in the EMPTY state. This will always
+ // succeed because there are no pre-existing hits and no validly constructed
+ // posting list could fail to fit one hit.
+ void PrependHitToEmpty(PostingListUsed* posting_list_used,
+ const EmbeddingHit& hit) const;
+
+ // Prepend hit to a posting list that is in the NOT_FULL state.
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if hit is not less than the previously added hit.
+ libtextclassifier3::Status PrependHitToNotFull(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const;
+
+ // Returns either 0 (full state), sizeof(EmbeddingHit) (almost_full state) or
+ // a byte offset between kSpecialHitsSize and
+ // posting_list_used->size_in_bytes() (inclusive) (not_full state).
+ uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const;
+
+ // Sets the special hits to properly reflect what offset is (see layout
+ // comment for further details).
+ //
+ // Returns false if offset > posting_list_used->size_in_bytes() or offset is
+ // (kSpecialHitsSize, sizeof(EmbeddingHit)) or offset is
+ // (sizeof(EmbeddingHit), 0). True, otherwise.
+ bool SetStartByteOffset(PostingListUsed* posting_list_used,
+ uint32_t offset) const;
+
+ // Manipulate padded areas. We never store the same hit value twice
+ // so a delta of 0 is a pad byte.
+
+ // Returns offset of first non-pad byte.
+ uint32_t GetPadEnd(const PostingListUsed* posting_list_used,
+ uint32_t offset) const;
+
+ // Fill padding between offset start and offset end with 0s.
+ // Returns false if end > posting_list_used->size_in_bytes(). True,
+ // otherwise.
+ bool PadToEnd(PostingListUsed* posting_list_used, uint32_t start,
+ uint32_t end) const;
+
+ // Helper for AppendHits/PopFrontHits. Adds limit number of hits to out or all
+ // hits in the posting list if the posting list contains less than limit
+ // number of hits. out can be NULL.
+ //
+ // NOTE: If called with limit=1, pop=true on a posting list that transitioned
+ // from NOT_FULL directly to FULL, GetHitsInternal will not return the posting
+ // list to NOT_FULL. Instead it will leave it in a valid state, but it will be
+ // ALMOST_FULL.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetHitsInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<EmbeddingHit>* out) const;
+
+ // Retrieves the value stored in the index-th special hit.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ //
+ // RETURNS:
+ // - A valid SpecialData<EmbeddingHit>.
+ EmbeddingHit GetSpecialHit(const PostingListUsed* posting_list_used,
+ uint32_t index) const;
+
+ // Sets the value stored in the index-th special hit to val.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ void SetSpecialHit(PostingListUsed* posting_list_used, uint32_t index,
+ const EmbeddingHit& val) const;
+
+ // Prepends hit to the memory region [offset - sizeof(EmbeddingHit), offset]
+ // and returns the new beginning of the padded region.
+ //
+ // RETURNS:
+ // - The new beginning of the padded region, if successful.
+ // - INVALID_ARGUMENT if hit will not fit (uncompressed) between offset and
+ // kSpecialHitsSize
+ libtextclassifier3::StatusOr<uint32_t> PrependHitUncompressed(
+ PostingListUsed* posting_list_used, const EmbeddingHit& hit,
+ uint32_t offset) const;
+};
+
+// Inlined functions. Implementation details below. Avert eyes!
+template <class T, EmbeddingHit (*GetHit)(const T&)>
+libtextclassifier3::StatusOr<uint32_t>
+PostingListEmbeddingHitSerializer::PrependHitArray(
+ PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
+ bool keep_prepended) const {
+ if (!IsPostingListValid(posting_list_used)) {
+ return 0;
+ }
+
+ // Prepend hits working backwards from array[num_hits - 1].
+ uint32_t i;
+ for (i = 0; i < num_hits; ++i) {
+ if (!PrependHit(posting_list_used, GetHit(array[num_hits - i - 1])).ok()) {
+ break;
+ }
+ }
+ if (i != num_hits && !keep_prepended) {
+ // Didn't fit. Undo everything and check that we have the same offset as
+ // before. PopFrontHits guarantees that it will remove all 'i' hits so long
+ // as there are at least 'i' hits in the posting list, which we know there
+ // are.
+ ICING_RETURN_IF_ERROR(PopFrontHits(posting_list_used, /*num_hits=*/i));
+ }
+ return i;
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer_test.cc b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc
new file mode 100644
index 0000000..f829634
--- /dev/null
+++ b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc
@@ -0,0 +1,864 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/posting-list-embedding-hit-serializer.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <deque>
+#include <iterator>
+#include <limits>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/hit/hit.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/hit-test-utils.h"
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+using testing::Eq;
+using testing::IsEmpty;
+using testing::Le;
+using testing::Lt;
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+struct HitElt {
+ HitElt() = default;
+ explicit HitElt(const EmbeddingHit &hit_in) : hit(hit_in) {}
+
+ static EmbeddingHit get_hit(const HitElt &hit_elt) { return hit_elt.hit; }
+
+ EmbeddingHit hit;
+};
+
+TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedPrependHitNotFull) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ static const int kNumHits = 2551;
+ static const size_t kHitsSize = kNumHits * sizeof(EmbeddingHit);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize));
+
+ // Make used.
+ EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0));
+ int expected_size = sizeof(EmbeddingHit::Value);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
+
+ EmbeddingHit hit1(BasicHit(/*section_id=*/0, /*document_id=*/1),
+ /*location=*/1);
+ uint64_t delta = hit0.value() - hit1.value();
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+
+ EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/2),
+ /*location=*/2);
+ delta = hit1.value() - hit2.value();
+ delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
+
+ EmbeddingHit hit3(BasicHit(/*section_id=*/0, /*document_id=*/3),
+ /*location=*/3);
+ delta = hit2.value() - hit3.value();
+ delta_len = VarInt::Encode(delta, delta_buf);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3));
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListUsedPrependHitAlmostFull) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 32
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ // Fill up the compressed region.
+ // Transitions:
+ // Adding hit0: EMPTY -> NOT_FULL
+ // Adding hit1: NOT_FULL -> NOT_FULL
+ // Adding hit2: NOT_FULL -> NOT_FULL
+ EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/1);
+ EmbeddingHit hit1 = CreateEmbeddingHit(hit0, /*desired_byte_length=*/3);
+ EmbeddingHit hit2 = CreateEmbeddingHit(hit1, /*desired_byte_length=*/3);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2));
+ // Size used will be 8 (hit2) + 3 (hit1-hit2) + 3 (hit0-hit1) = 14 bytes
+ int expected_size = sizeof(EmbeddingHit) + 3 + 3;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
+
+ // Add one more hit to transition NOT_FULL -> ALMOST_FULL
+ EmbeddingHit hit3 = CreateEmbeddingHit(hit2, /*desired_byte_length=*/3);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3));
+ // Storing them in the compressed region requires 8 (hit) + 3 (hit2-hit3) +
+ // 3 (hit1-hit2) + 3 (hit0-hit1) = 17 bytes, but there are only 16 bytes in
+ // the compressed region. So instead, the posting list will transition to
+ // ALMOST_FULL. The in-use compressed region will actually shrink from 14
+ // bytes to 9 bytes because the uncompressed version of hit2 will be
+ // overwritten with the compressed delta of hit2. hit3 will be written to one
+ // of the special hits. Because we're in ALMOST_FULL, the expected size is the
+ // size of the pl minus the one hit used to mark the posting list as
+ // ALMOST_FULL.
+ expected_size = pl_size - sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL
+ EmbeddingHit hit4 = CreateEmbeddingHit(hit3, /*desired_byte_length=*/6);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4));
+ // There are currently 9 bytes in use in the compressed region. Hit3 will
+ // have a 6-byte delta, which fits in the compressed region. Hit3 will be
+ // moved from the special hit to the compressed region (which will have 15
+ // bytes in use after adding hit3). Hit4 will be placed in one of the special
+ // hits and the posting list will remain in ALMOST_FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> FULL
+ EmbeddingHit hit5 = CreateEmbeddingHit(hit4, /*desired_byte_length=*/2);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5));
+ // There are currently 15 bytes in use in the compressed region. Hit4 will
+ // have a 2-byte delta which will not fit in the compressed region. So hit4
+ // will remain in one of the special hits and hit5 will occupy the other,
+ // making the posting list FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0)));
+
+ // The posting list is FULL. Adding another hit should fail.
+ EmbeddingHit hit6 = CreateEmbeddingHit(hit5, /*desired_byte_length=*/1);
+ EXPECT_THAT(serializer.PrependHit(&pl_used, hit6),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedMinSize) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Min size = 16
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ // PL State: EMPTY
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+
+ // Add a hit, PL should shift to ALMOST_FULL state
+ EmbeddingHit hit0(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/1);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ // Size = sizeof(uncompressed hit0)
+ int expected_size = sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
+
+ // Add the smallest hit possible with a delta of 0b1. PL should shift to FULL
+ // state.
+ EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0),
+ /*location=*/0);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0)
+ expected_size += sizeof(EmbeddingHit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+
+ // Try to add the smallest hit possible. Should fail
+ EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0);
+ EXPECT_THAT(serializer.PrependHit(&pl_used, hit2),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit1, hit0)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayMinSizePostingList) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Min Size = 16
+ int pl_size = serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ std::vector<HitElt> hits_in;
+ hits_in.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ std::reverse(hits_in.begin(), hits_in.end());
+
+ // Add five hits. The PL is in the empty state and an empty min size PL can
+ // only fit two hits. So PrependHitArray should fail.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_can_prepend,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_can_prepend, Eq(2));
+
+ int can_fit_hits = num_can_prepend;
+ // The PL has room for 2 hits. We should be able to add them without any
+ // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL
+ const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_can_prepend,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false)));
+ EXPECT_THAT(num_can_prepend, Eq(can_fit_hits));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ std::deque<EmbeddingHit> hits_pushed;
+ std::transform(hits_in.rbegin(),
+ hits_in.rend() - hits_in.size() + can_fit_hits,
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayPostingList) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 48
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ std::vector<HitElt> hits_in;
+ hits_in.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // The last hit is uncompressed and the four before it should only take one
+ // byte. Total use = 8 bytes.
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43-36 EmbeddingHit #4
+ // 35-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=36
+ // ----------------------
+ int byte_size = sizeof(EmbeddingHit::Value) + hits_in.size() - 1;
+
+ // Add five hits. The PL is in the empty state and should be able to fit all
+ // five hits without issue, transitioning the PL from EMPTY -> NOT_FULL.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ std::deque<EmbeddingHit> hits_pushed;
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ EmbeddingHit first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/1);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 EmbeddingHit #11
+ // 21-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=22
+ // ----------------------
+ byte_size += 14;
+
+ // Add these 7 hits. The PL is currently in the NOT_FULL state and should
+ // remain in the NOT_FULL state.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/8);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 delta(EmbeddingHit #11)
+ // 21-16 <unused>
+ // 15-8 EmbeddingHit #12
+ // 7-0 kSpecialHit
+ // ----------------------
+ byte_size = 40; // 48 - 8
+
+ // Add this 1 hit. The PL is currently in the NOT_FULL state and should
+ // transition to the ALMOST_FULL state - even though there is still some
+ // unused space.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ first_hit =
+ CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/5);
+ hits_in.clear();
+ hits_in.emplace_back(first_hit);
+ hits_in.emplace_back(
+ CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
+ std::reverse(hits_in.begin(), hits_in.end());
+ // ----------------------
+ // 47 delta(EmbeddingHit #0)
+ // 46 delta(EmbeddingHit #1)
+ // 45 delta(EmbeddingHit #2)
+ // 44 delta(EmbeddingHit #3)
+ // 43 delta(EmbeddingHit #4)
+ // 42-41 delta(EmbeddingHit #5)
+ // 40 delta(EmbeddingHit #6)
+ // 39-38 delta(EmbeddingHit #7)
+ // 37-35 delta(EmbeddingHit #8)
+ // 34-33 delta(EmbeddingHit #9)
+ // 32-30 delta(EmbeddingHit #10)
+ // 29-22 delta(EmbeddingHit #11)
+ // 21-17 delta(EmbeddingHit #12)
+ // 16 <unused>
+ // 15-8 EmbeddingHit #13
+ // 7-0 EmbeddingHit #14
+ // ----------------------
+
+ // Add these 2 hits.
+ // - The PL is currently in the ALMOST_FULL state. Adding the first hit should
+ // keep the PL in ALMOST_FULL because the delta between
+ // EmbeddingHit #12 and EmbeddingHit #13 (5 byte) can fit in the unused area
+ // (6 bytes).
+ // - Adding the second hit should transition to the FULL state because the
+ // delta between EmbeddingHit #13 and EmbeddingHit #14 (3 bytes) is larger
+ // than the remaining unused area (1 byte).
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
+ EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
+ // All hits from hits_in were added.
+ std::transform(hits_in.rbegin(), hits_in.rend(),
+ std::front_inserter(hits_pushed), HitElt::get_hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListPrependHitArrayTooManyHits) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ static constexpr int kNumHits = 130;
+ static constexpr int kDeltaSize = 1;
+ static constexpr size_t kHitsSize =
+ ((kNumHits - 2) * kDeltaSize + (2 * sizeof(EmbeddingHit)));
+
+ // Create an array with one too many hits
+ std::vector<HitElt> hit_elts_in_too_many;
+ hit_elts_in_too_many.emplace_back(EmbeddingHit(
+ BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0));
+ for (int i = 0; i < kNumHits; ++i) {
+ hit_elts_in_too_many.emplace_back(CreateEmbeddingHit(
+ hit_elts_in_too_many.back().hit, /*desired_byte_length=*/1));
+ }
+ // Reverse so that hits are inserted in descending order
+ std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ // PrependHitArray should fail because hit_elts_in_too_many is far too large
+ // for the minimum size pl.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(2));
+ ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size()));
+ ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize));
+ // PrependHitArray should fail because hit_elts_in_too_many is one hit too
+ // large for this pl.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1));
+ ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ PostingListStatusJumpFromNotFullToFullAndBack) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 24
+ const uint32_t pl_size = 3 * sizeof(EmbeddingHit);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ EmbeddingHit max_valued_hit(
+ BasicHit(/*section_id=*/kMaxSectionId, /*document_id=*/kMinDocumentId),
+ /*location=*/std::numeric_limits<uint32_t>::max());
+ ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit));
+ uint32_t bytes_used = serializer.GetBytesUsed(&pl);
+ ASSERT_THAT(bytes_used, sizeof(EmbeddingHit));
+ // Status not full.
+ ASSERT_THAT(
+ bytes_used,
+ Le(pl_size - PostingListEmbeddingHitSerializer::kSpecialHitsSize));
+
+ EmbeddingHit min_valued_hit(
+ BasicHit(/*section_id=*/kMinSectionId, /*document_id=*/kMaxDocumentId),
+ /*location=*/0);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl, min_valued_hit));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAre(min_valued_hit, max_valued_hit)));
+ // Status should jump to full directly.
+ ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size));
+ ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAre(max_valued_hit)));
+ // Status should return to not full as before.
+ ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(bytes_used));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, DeltaOverflow) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ const uint32_t pl_size = 4 * sizeof(EmbeddingHit);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ static const EmbeddingHit::Value kMaxHitValue =
+ std::numeric_limits<EmbeddingHit::Value>::max();
+ static const EmbeddingHit::Value kOverflow[4] = {
+ kMaxHitValue >> 2,
+ (kMaxHitValue >> 2) * 2,
+ (kMaxHitValue >> 2) * 3,
+ kMaxHitValue - 1,
+ };
+
+ // Fit at least 4 ordinary values.
+ std::deque<EmbeddingHit> hits_pushed;
+ for (EmbeddingHit::Value v = 0; v < 4; v++) {
+ hits_pushed.push_front(
+ EmbeddingHit(BasicHit(kMaxSectionId, kMaxDocumentId), 4 - v));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front()));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+ }
+
+ // Cannot fit 4 overflow values.
+ hits_pushed.clear();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ for (int i = 3; i >= 1; i--) {
+ hits_pushed.push_front(EmbeddingHit(/*value=*/kOverflow[i]));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front()));
+ EXPECT_THAT(serializer.GetHits(&pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+ }
+ EXPECT_THAT(serializer.PrependHit(&pl, EmbeddingHit(/*value=*/kOverflow[0])),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ GetMinPostingListToFitForNotFullPL) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ // Size = 64
+ int pl_size = 4 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used NOT_FULL
+ std::vector<EmbeddingHit> hits_in =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 63-62 delta(EmbeddingHit #0)
+ // 61-60 delta(EmbeddingHit #1)
+ // 59-58 delta(EmbeddingHit #2)
+ // 57-56 delta(EmbeddingHit #3)
+ // 55-48 EmbeddingHit #5
+ // 47-16 <unused>
+ // 15-8 kSpecialHit
+ // 7-0 Offset=48
+ // ----------------------
+ int bytes_used = 16;
+
+ // Check that all hits have been inserted
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used
+ // into a posting list with this min size should make it ALMOST_FULL, which we
+ // can see should have size = 24.
+ // ----------------------
+ // 23-22 delta(EmbeddingHit #0)
+ // 21-20 delta(EmbeddingHit #1)
+ // 19-18 delta(EmbeddingHit #2)
+ // 17-16 delta(EmbeddingHit #3)
+ // 15-8 EmbeddingHit #4
+ // 7-0 kSpecialHit
+ // ----------------------
+ int expected_min_size = 24;
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(expected_min_size));
+
+ // Also check that this min size to fit posting list actually does fit all the
+ // hits and can only hit one more hit in the ALMOST_FULL state.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, min_size_to_fit));
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ }
+
+ // Adding another hit to the min-size-to-fit posting list should succeed
+ EmbeddingHit hit =
+ CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ // Adding any other hits should fail with RESOURCE_EXHAUSTED error.
+ EXPECT_THAT(
+ serializer.PrependHit(&min_size_to_fit_pl,
+ CreateEmbeddingHit(hit, /*desired_byte_length=*/1)),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // Check that all hits have been inserted and the min-fit posting list is now
+ // FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl),
+ Eq(min_size_to_fit));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 24;
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used ALMOST_FULL
+ std::vector<EmbeddingHit> hits_in =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 23-22 delta(EmbeddingHit #0)
+ // 21-20 delta(EmbeddingHit #1)
+ // 19-18 delta(EmbeddingHit #2)
+ // 17-16 delta(EmbeddingHit #3)
+ // 15-8 EmbeddingHit #4
+ // 7-0 kSpecialHit
+ // ----------------------
+ int bytes_used = 16;
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should return the same size as pl_used.
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
+
+ // Add another hit to make the posting list FULL
+ EmbeddingHit hit =
+ CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should still be the same size as pl_used.
+ min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, MoveFrom) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ ICING_ASSERT_OK(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveFromNullArgumentReturnsInvalidArgument) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveFromInvalidPostingListReturnsInvalidArgument) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ // Write invalid hits to the beginning of pl_used1 to make it invalid.
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EmbeddingHit *first_hit =
+ reinterpret_cast<EmbeddingHit *>(pl_used1.posting_list_buffer());
+ *first_hit = invalid_hit;
+ ++first_hit;
+ *first_hit = invalid_hit;
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest,
+ MoveToInvalidPostingListReturnsFailedPrecondition) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ // Write invalid hits to the beginning of pl_used2 to make it invalid.
+ EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
+ EmbeddingHit *first_hit =
+ reinterpret_cast<EmbeddingHit *>(pl_used2.posting_list_buffer());
+ *first_hit = invalid_hit;
+ ++first_hit;
+ *first_hit = invalid_hit;
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+}
+
+TEST(PostingListEmbeddingHitSerializerTest, MoveToPostingListTooSmall) {
+ PostingListEmbeddingHitSerializer serializer;
+
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ std::vector<EmbeddingHit> hits1 =
+ CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ for (const EmbeddingHit &hit : hits1) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, serializer.GetMinPostingListSize()));
+ std::vector<EmbeddingHit> hits2 =
+ CreateEmbeddingHits(/*num_hits=*/1, /*desired_byte_length=*/2);
+ for (const EmbeddingHit &hit : hits2) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit));
+ }
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(serializer.GetHits(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
+ EXPECT_THAT(serializer.GetHits(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend())));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embedding-indexing-handler.cc b/icing/index/embedding-indexing-handler.cc
new file mode 100644
index 0000000..049e307
--- /dev/null
+++ b/icing/index/embedding-indexing-handler.cc
@@ -0,0 +1,85 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embedding-indexing-handler.h"
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/hit/hit.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/status-macros.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>>
+EmbeddingIndexingHandler::Create(const Clock* clock,
+ EmbeddingIndex* embedding_index) {
+ ICING_RETURN_ERROR_IF_NULL(clock);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
+
+ return std::unique_ptr<EmbeddingIndexingHandler>(
+ new EmbeddingIndexingHandler(clock, embedding_index));
+}
+
+libtextclassifier3::Status EmbeddingIndexingHandler::Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ bool recovery_mode, PutDocumentStatsProto* put_document_stats) {
+ std::unique_ptr<Timer> index_timer = clock_.GetNewTimer();
+
+ if (!IsDocumentIdValid(document_id)) {
+ return absl_ports::InvalidArgumentError(
+ IcingStringUtil::StringPrintf("Invalid DocumentId %d", document_id));
+ }
+
+ if (embedding_index_.last_added_document_id() != kInvalidDocumentId &&
+ document_id <= embedding_index_.last_added_document_id()) {
+ if (recovery_mode) {
+ // Skip the document if document_id <= last_added_document_id in
+ // recovery mode without returning an error.
+ return libtextclassifier3::Status::OK;
+ }
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "DocumentId %d must be greater than last added document_id %d",
+ document_id, embedding_index_.last_added_document_id()));
+ }
+ embedding_index_.set_last_added_document_id(document_id);
+
+ for (const Section<PropertyProto::VectorProto>& vector_section :
+ tokenized_document.vector_sections()) {
+ BasicHit hit(/*section_id=*/vector_section.metadata.id, document_id);
+ for (const PropertyProto::VectorProto& vector : vector_section.content) {
+ ICING_RETURN_IF_ERROR(embedding_index_.BufferEmbedding(hit, vector));
+ }
+ }
+ ICING_RETURN_IF_ERROR(embedding_index_.CommitBufferToIndex());
+
+ if (put_document_stats != nullptr) {
+ put_document_stats->set_embedding_index_latency_ms(
+ index_timer->GetElapsedMilliseconds());
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/embedding-indexing-handler.h b/icing/index/embedding-indexing-handler.h
new file mode 100644
index 0000000..f3adf6a
--- /dev/null
+++ b/icing/index/embedding-indexing-handler.h
@@ -0,0 +1,70 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
+#define ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
+
+#include <memory>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/data-indexing-handler.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+class EmbeddingIndexingHandler : public DataIndexingHandler {
+ public:
+ ~EmbeddingIndexingHandler() override = default;
+
+ // Creates an EmbeddingIndexingHandler instance which does not take
+ // ownership of any input components. All pointers must refer to valid objects
+ // that outlive the created EmbeddingIndexingHandler instance.
+ //
+ // Returns:
+ // - An EmbeddingIndexingHandler instance on success
+ // - FAILED_PRECONDITION_ERROR if any of the input pointer is null
+ static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>>
+ Create(const Clock* clock, EmbeddingIndex* embedding_index);
+
+ // Handles the embedding indexing process: add hits into the embedding index
+ // for all contents in tokenized_document.vector_sections.
+ //
+ // Returns:
+ // - OK on success.
+ // - INVALID_ARGUMENT_ERROR if document_id is invalid OR document_id is less
+ // than or equal to the document_id of a previously indexed document in
+ // non recovery mode.
+ // - INTERNAL_ERROR if any other errors occur.
+ // - Any embedding index errors.
+ libtextclassifier3::Status Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ bool recovery_mode, PutDocumentStatsProto* put_document_stats) override;
+
+ private:
+ explicit EmbeddingIndexingHandler(const Clock* clock,
+ EmbeddingIndex* embedding_index)
+ : DataIndexingHandler(clock), embedding_index_(*embedding_index) {}
+
+ EmbeddingIndex& embedding_index_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_
diff --git a/icing/index/embedding-indexing-handler_test.cc b/icing/index/embedding-indexing-handler_test.cc
new file mode 100644
index 0000000..ed711c1
--- /dev/null
+++ b/icing/index/embedding-indexing-handler_test.cc
@@ -0,0 +1,620 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embedding-indexing-handler.h"
+
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/portable/platform.h"
+#include "icing/proto/document_wrapper.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/icu-data-file-helper.h"
+#include "icing/testing/test-data.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
+#include "icing/util/status-macros.h"
+#include "icing/util/tokenized-document.h"
+#include "unicode/uloc.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::IsEmpty;
+using ::testing::IsTrue;
+
+// Indexable properties (section) and section id. Section id is determined by
+// the lexicographical order of indexable property paths.
+// Schema type with indexable properties: FakeType
+// Section id = 0: "body"
+// Section id = 1: "bodyEmbedding"
+// Section id = 2: "title"
+// Section id = 3: "titleEmbedding"
+static constexpr std::string_view kFakeType = "FakeType";
+static constexpr std::string_view kPropertyBody = "body";
+static constexpr std::string_view kPropertyBodyEmbedding = "bodyEmbedding";
+static constexpr std::string_view kPropertyTitle = "title";
+static constexpr std::string_view kPropertyTitleEmbedding = "titleEmbedding";
+static constexpr std::string_view kPropertyNonIndexableEmbedding =
+ "nonIndexableEmbedding";
+
+static constexpr SectionId kSectionIdBodyEmbedding = 1;
+static constexpr SectionId kSectionIdTitleEmbedding = 3;
+
+// Schema type with nested indexable properties: FakeCollectionType
+// Section id = 0: "collection.body"
+// Section id = 1: "collection.bodyEmbedding"
+// Section id = 2: "collection.title"
+// Section id = 3: "collection.titleEmbedding"
+// Section id = 4: "fullDocEmbedding"
+static constexpr std::string_view kFakeCollectionType = "FakeCollectionType";
+static constexpr std::string_view kPropertyCollection = "collection";
+static constexpr std::string_view kPropertyFullDocEmbedding =
+ "fullDocEmbedding";
+
+static constexpr SectionId kSectionIdNestedBodyEmbedding = 1;
+static constexpr SectionId kSectionIdNestedTitleEmbedding = 3;
+static constexpr SectionId kSectionIdFullDocEmbedding = 4;
+
+class EmbeddingIndexingHandlerTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ if (!IsCfStringTokenization() && !IsReverseJniTokenization()) {
+ ICING_ASSERT_OK(
+ // File generated via icu_data_file rule in //icing/BUILD.
+ icu_data_file_helper::SetUpICUDataFile(
+ GetTestFilePath("icing/icu.dat")));
+ }
+
+ base_dir_ = GetTestTempDir() + "/icing_test";
+ ASSERT_THAT(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()),
+ IsTrue());
+
+ embedding_index_working_path_ = base_dir_ + "/embedding_index";
+ schema_store_dir_ = base_dir_ + "/schema_store";
+ document_store_dir_ = base_dir_ + "/document_store";
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_working_path_));
+
+ language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ lang_segmenter_,
+ language_segmenter_factory::Create(std::move(segmenter_options)));
+
+ ASSERT_THAT(
+ filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()),
+ IsTrue());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(
+ SchemaTypeConfigBuilder()
+ .SetType(kFakeType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyTitle)
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyBody)
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyTitleEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::
+ LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyBodyEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::
+ LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyNonIndexableEmbedding)
+ .SetDataType(TYPE_VECTOR)
+ .SetCardinality(CARDINALITY_REPEATED)))
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kFakeCollectionType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyCollection)
+ .SetDataTypeDocument(
+ kFakeType,
+ /*index_nested_properties=*/true)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kPropertyFullDocEmbedding)
+ .SetDataTypeVector(
+ EmbeddingIndexingConfig::
+ EmbeddingIndexingType::LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ICING_ASSERT_OK(schema_store_->SetSchema(
+ schema, /*ignore_errors_and_delete_documents=*/false,
+ /*allow_circular_schema_definitions=*/false));
+
+ ASSERT_TRUE(
+ filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str()));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult doc_store_create_result,
+ DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_,
+ schema_store_.get(),
+ /*force_recovery_and_revalidate_documents=*/false,
+ /*namespace_id_fingerprint=*/false,
+ /*pre_mapping_fbv=*/false,
+ /*use_persistent_hash_map=*/false,
+ PortableFileBackedProtoLog<
+ DocumentWrapper>::kDeflateCompressionLevel,
+ /*initialize_stats=*/nullptr));
+ document_store_ = std::move(doc_store_create_result.document_store);
+ }
+
+ void TearDown() override {
+ document_store_.reset();
+ schema_store_.reset();
+ lang_segmenter_.reset();
+ embedding_index_.reset();
+
+ filesystem_.DeleteDirectoryRecursively(base_dir_.c_str());
+ }
+
+ libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
+ uint32_t dimension, std::string_view model_signature) {
+ std::vector<EmbeddingHit> hits;
+
+ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListEmbeddingHitAccessor>>
+ pl_accessor_or =
+ embedding_index_->GetAccessor(dimension, model_signature);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (pl_accessor_or.ok()) {
+ pl_accessor = std::move(pl_accessor_or).ValueOrDie();
+ } else if (absl_ports::IsNotFound(pl_accessor_or.status())) {
+ return hits;
+ } else {
+ return std::move(pl_accessor_or).status();
+ }
+
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ return hits;
+ }
+ hits.insert(hits.end(), batch.begin(), batch.end());
+ }
+ }
+
+ std::vector<float> GetRawEmbeddingData() {
+ return std::vector<float>(embedding_index_->GetRawEmbeddingData(),
+ embedding_index_->GetRawEmbeddingData() +
+ embedding_index_->GetTotalVectorSize());
+ }
+
+ Filesystem filesystem_;
+ FakeClock fake_clock_;
+ std::string base_dir_;
+ std::string embedding_index_working_path_;
+ std::string schema_store_dir_;
+ std::string document_store_dir_;
+
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
+ std::unique_ptr<LanguageSegmenter> lang_segmenter_;
+ std::unique_ptr<SchemaStore> schema_store_;
+ std::unique_ptr<DocumentStore> document_store_;
+};
+
+} // namespace
+
+TEST_F(EmbeddingIndexingHandlerTest, CreationWithNullPointerShouldFail) {
+ EXPECT_THAT(EmbeddingIndexingHandler::Create(/*clock=*/nullptr,
+ embedding_index_.get()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+
+ EXPECT_THAT(EmbeddingIndexingHandler::Create(&fake_clock_,
+ /*embedding_index=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest, HandleEmbeddingSection) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kInvalidDocumentId));
+ // Handle document.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest, HandleNestedEmbeddingSection) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_collection_type/1")
+ .SetSchema(std::string(kFakeCollectionType))
+ .AddDocumentProperty(
+ std::string(kPropertyCollection),
+ DocumentBuilder()
+ .SetKey("icing", "nested_fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(
+ std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build())
+ .AddVectorProperty(std::string(kPropertyFullDocEmbedding),
+ CreateVector("model", {2.1, 2.2, 2.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kInvalidDocumentId));
+ // Handle document.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(
+ BasicHit(kSectionIdNestedTitleEmbedding, /*document_id=*/0),
+ /*location=*/6),
+ EmbeddingHit(BasicHit(kSectionIdFullDocEmbedding, /*document_id=*/0),
+ /*location=*/9))));
+ EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
+ 0.1, 0.2, 0.3, 2.1, 2.2, 2.3));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleInvalidDocumentIdShouldReturnInvalidArgumentError) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK(document_store_->Put(tokenized_document.document()));
+
+ static constexpr DocumentId kCurrentDocumentId = 3;
+ embedding_index_->set_last_added_document_id(kCurrentDocumentId);
+ ASSERT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handling document with kInvalidDocumentId should cause a failure, and both
+ // index data and last_added_document_id should remain unchanged.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, kInvalidDocumentId,
+ /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+
+ // Recovery mode should get the same result.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, kInvalidDocumentId,
+ /*recovery_mode=*/true, /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(),
+ Eq(kCurrentDocumentId));
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(tokenized_document.document()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handling document with document_id == last_added_document_id should cause a
+ // failure, and both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
+
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+
+ // Handling document with document_id < last_added_document_id should cause a
+ // failure, and both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id + 1);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false,
+ /*put_document_stats=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
+
+ // Check that the embedding index should be empty
+ EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(IsEmpty()));
+ EXPECT_THAT(GetRawEmbeddingData(), IsEmpty());
+}
+
+TEST_F(EmbeddingIndexingHandlerTest,
+ HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId) {
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title one")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {0.1, 0.2, 0.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body one")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {0.4, 0.5, 0.6}),
+ CreateVector("model", {0.7, 0.8, 0.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {1.1, 1.2, 1.3}))
+ .Build();
+ DocumentProto document2 =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/2")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kPropertyTitle), "title two")
+ .AddVectorProperty(std::string(kPropertyTitleEmbedding),
+ CreateVector("model", {10.1, 10.2, 10.3}))
+ .AddStringProperty(std::string(kPropertyBody), "body two")
+ .AddVectorProperty(std::string(kPropertyBodyEmbedding),
+ CreateVector("model", {10.4, 10.5, 10.6}),
+ CreateVector("model", {10.7, 10.8, 10.9}))
+ .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
+ CreateVector("model", {11.1, 11.2, 11.3}))
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document1,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document1)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document2,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ std::move(document2)));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store_->Put(tokenized_document1.document()));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id2,
+ document_store_->Put(tokenized_document2.document()));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<EmbeddingIndexingHandler> handler,
+ EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get()));
+
+ // Handle document with document_id > last_added_document_id in recovery mode.
+ // The handler should index this document and update last_added_document_id.
+ EXPECT_THAT(
+ handler->Handle(tokenized_document1, document_id1, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id1));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+
+ // Handle document with document_id == last_added_document_id in recovery
+ // mode. We should not get any error, but the handler should ignore the
+ // document, so both index data and last_added_document_id should remain
+ // unchanged.
+ embedding_index_->set_last_added_document_id(document_id2);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+
+ // Handle document with document_id < last_added_document_id in recovery mode.
+ // We should not get any error, but the handler should ignore the document, so
+ // both index data and last_added_document_id should remain unchanged.
+ embedding_index_->set_last_added_document_id(document_id2 + 1);
+ ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
+ EXPECT_THAT(
+ handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true,
+ /*put_document_stats=*/nullptr),
+ IsOk());
+ EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
+
+ // Check index
+ EXPECT_THAT(
+ GetHits(/*dimension=*/3, /*model_signature=*/"model"),
+ IsOkAndHolds(ElementsAre(
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/0),
+ EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
+ /*location=*/3),
+ EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
+ /*location=*/6))));
+ EXPECT_THAT(GetRawEmbeddingData(),
+ ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/hit/hit.cc b/icing/index/hit/hit.cc
index 493e62b..7b7f2f6 100644
--- a/icing/index/hit/hit.cc
+++ b/icing/index/hit/hit.cc
@@ -14,42 +14,19 @@
#include "icing/index/hit/hit.h"
+#include <cstring>
+#include <limits>
+
+#include "icing/schema/section.h"
#include "icing/store/document-id.h"
#include "icing/util/bit-util.h"
+#include "icing/util/logging.h"
namespace icing {
namespace lib {
namespace {
-enum FlagOffset {
- // This hit, whether exact or not, came from a prefixed section and will
- // need to be backfilled into branching posting lists if/when those are
- // created.
- kInPrefixSection = 0,
- // This hit represents a prefix of a longer term. If exact matches are
- // required, then this hit should be ignored.
- kPrefixHit = 1,
- // Whether or not the hit has a term_frequency other than
- // kDefaultTermFrequency.
- kHasTermFrequency = 2,
- kNumFlags = 3,
-};
-
-static_assert(kDocumentIdBits + kSectionIdBits + kNumFlags <
- sizeof(Hit::Value) * 8,
- "Hit::kInvalidValue contains risky value and we should have at "
- "least one unused bit to avoid potential bugs. Please follow the "
- "process mentioned in hit.h to correct the value of "
- "Hit::kInvalidValue and remove this static_assert afterwards.");
-
-static_assert(kDocumentIdBits + kSectionIdBits + kNumFlags <=
- sizeof(Hit::Value) * 8,
- "HitOverflow");
-static_assert(kDocumentIdBits == 22, "");
-static_assert(kSectionIdBits == 6, "");
-static_assert(kNumFlags == 3, "");
-
inline DocumentId InvertDocumentId(DocumentId document_id) {
static_assert(kMaxDocumentId <= (std::numeric_limits<DocumentId>::max() - 1),
"(kMaxDocumentId + 1) must not overflow.");
@@ -88,10 +65,28 @@ SectionId BasicHit::section_id() const {
/*len=*/kSectionIdBits);
}
+Hit::Hit(Value value, Flags flags, TermFrequency term_frequency)
+ : flags_(flags), term_frequency_(term_frequency) {
+ memcpy(value_.data(), &value, sizeof(value));
+ if (!CheckFlagsAreConsistent()) {
+ ICING_VLOG(1)
+ << "Creating Hit that has inconsistent flag values across its fields: "
+ << "Hit(value=" << value << ", flags=" << flags
+ << "term_frequency=" << term_frequency << ")";
+ }
+}
+
Hit::Hit(SectionId section_id, DocumentId document_id,
Hit::TermFrequency term_frequency, bool is_in_prefix_section,
bool is_prefix_hit)
: term_frequency_(term_frequency) {
+ // We compute flags first as the value's has_flags bit depends on the flags_
+ // field.
+ Flags temp_flags = 0;
+ bit_util::BitfieldSet(term_frequency != kDefaultTermFrequency,
+ kHasTermFrequency, /*len=*/1, &temp_flags);
+ flags_ = temp_flags;
+
// Values are stored so that when sorted, they appear in document_id
// descending, section_id ascending, order. Also, all else being
// equal, non-prefix hits sort before prefix hits. So inverted
@@ -99,30 +94,29 @@ Hit::Hit(SectionId section_id, DocumentId document_id,
// (uninverted) section_id.
Value temp_value = 0;
bit_util::BitfieldSet(InvertDocumentId(document_id),
- kSectionIdBits + kNumFlags, kDocumentIdBits,
+ kSectionIdBits + kNumFlagsInValueField, kDocumentIdBits,
+ &temp_value);
+ bit_util::BitfieldSet(section_id, kNumFlagsInValueField, kSectionIdBits,
&temp_value);
- bit_util::BitfieldSet(section_id, kNumFlags, kSectionIdBits, &temp_value);
- bit_util::BitfieldSet(term_frequency != kDefaultTermFrequency,
- kHasTermFrequency, /*len=*/1, &temp_value);
bit_util::BitfieldSet(is_prefix_hit, kPrefixHit, /*len=*/1, &temp_value);
bit_util::BitfieldSet(is_in_prefix_section, kInPrefixSection,
/*len=*/1, &temp_value);
- value_ = temp_value;
+
+ bool has_flags = flags_ != kNoEnabledFlags;
+ bit_util::BitfieldSet(has_flags, kHasFlags, /*len=*/1, &temp_value);
+
+ memcpy(value_.data(), &temp_value, sizeof(temp_value));
}
DocumentId Hit::document_id() const {
DocumentId inverted_document_id = bit_util::BitfieldGet(
- value(), kSectionIdBits + kNumFlags, kDocumentIdBits);
+ value(), kSectionIdBits + kNumFlagsInValueField, kDocumentIdBits);
// Undo the document_id inversion.
return InvertDocumentId(inverted_document_id);
}
SectionId Hit::section_id() const {
- return bit_util::BitfieldGet(value(), kNumFlags, kSectionIdBits);
-}
-
-bool Hit::has_term_frequency() const {
- return bit_util::BitfieldGet(value(), kHasTermFrequency, 1);
+ return bit_util::BitfieldGet(value(), kNumFlagsInValueField, kSectionIdBits);
}
bool Hit::is_prefix_hit() const {
@@ -133,6 +127,27 @@ bool Hit::is_in_prefix_section() const {
return bit_util::BitfieldGet(value(), kInPrefixSection, 1);
}
+bool Hit::has_flags() const {
+ return bit_util::BitfieldGet(value(), kHasFlags, 1);
+}
+
+bool Hit::has_term_frequency() const {
+ return bit_util::BitfieldGet(flags(), kHasTermFrequency, 1);
+}
+
+bool Hit::CheckFlagsAreConsistent() const {
+ bool has_flags = flags_ != kNoEnabledFlags;
+ bool has_flags_enabled_in_value =
+ bit_util::BitfieldGet(value(), kHasFlags, /*len=*/1);
+
+ bool has_term_frequency = term_frequency_ != kDefaultTermFrequency;
+ bool has_term_frequency_enabled_in_flags =
+ bit_util::BitfieldGet(flags(), kHasTermFrequency, /*len=*/1);
+
+ return has_flags == has_flags_enabled_in_value &&
+ has_term_frequency == has_term_frequency_enabled_in_flags;
+}
+
Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) {
return Hit(old_hit.section_id(), new_document_id, old_hit.term_frequency(),
old_hit.is_in_prefix_section(), old_hit.is_prefix_hit());
@@ -140,7 +155,8 @@ Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) {
bool Hit::EqualsDocumentIdAndSectionId::operator()(const Hit& hit1,
const Hit& hit2) const {
- return (hit1.value() >> kNumFlags) == (hit2.value() >> kNumFlags);
+ return (hit1.value() >> kNumFlagsInValueField) ==
+ (hit2.value() >> kNumFlagsInValueField);
}
} // namespace lib
diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h
index 111b320..e971016 100644
--- a/icing/index/hit/hit.h
+++ b/icing/index/hit/hit.h
@@ -17,6 +17,7 @@
#include <array>
#include <cstdint>
+#include <cstring>
#include <limits>
#include "icing/legacy/core/icing-packed-pod.h"
@@ -59,6 +60,7 @@ class BasicHit {
explicit BasicHit(SectionId section_id, DocumentId document_id);
explicit BasicHit() : value_(kInvalidValue) {}
+ explicit BasicHit(Value value) : value_(value) {}
bool is_valid() const { return value_ != kInvalidValue; }
Value value() const { return value_; }
@@ -70,6 +72,9 @@ class BasicHit {
private:
// Value bits layout: 4 unused + 22 document_id + 6 section id.
+ //
+ // The Value is guaranteed to be an unsigned integer, but its size and the
+ // information it stores may change if the Hit's encoding format is changed.
Value value_;
} __attribute__((packed));
static_assert(sizeof(BasicHit) == 4, "");
@@ -78,51 +83,83 @@ static_assert(sizeof(BasicHit) == 4, "");
// consists of:
// - a DocumentId
// - a SectionId
-// referring to the document and section that the hit corresponds to, as well as
-// metadata about the hit:
-// - whether the Hit has a TermFrequency other than the default value
-// - whether the Hit does not appear exactly in the document, but instead
-// represents a term that is a prefix of a term in the document
-// - whether the Hit came from a section that has prefix expansion enabled
-// and a term frequency for the hit.
+// - referring to the document and section that the hit corresponds to
+// - Metadata about the hit:
+// - whether the Hit does not appear exactly in the document, but instead
+// represents a term that is a prefix of a term in the document
+// (is_prefix_hit)
+// - whether the Hit came from a section that has prefix expansion enabled
+// (is_in_prefix_section)
+// - whether the Hit has set any bitmask flags (has_flags)
+// - bitmasks in flags fields:
+// - whether the Hit has a TermFrequency other than the default value
+// (has_term_frequency)
+// - a term frequency for the hit
//
// The hit is the most basic unit of the index and, when grouped together by
// term, can be used to encode what terms appear in what documents.
class Hit {
public:
- // The datatype used to encode Hit information: the document_id, section_id
- // and the has_term_frequency, prefix hit and in prefix section flags.
+ // The datatype used to encode Hit information: the document_id, section_id,
+ // and 3 flags: is_prefix_hit, is_hit_in_prefix_section and has_flags flag.
+ //
+ // The Value is guaranteed to be an unsigned integer, but its size and the
+ // information it stores may change if the Hit's encoding format is changed.
using Value = uint32_t;
// WARNING: Changing this value will invalidate any pre-existing posting lists
// on user devices.
//
- // WARNING:
- // - Hit::kInvalidValue should contain inverted kInvalidDocumentId, which is
- // b'00...0. However, currently we set it as UINT32_MAX and actually it
- // contains b'11...1, which is the inverted document_id 0.
- // - It means Hit::kInvalidValue contains valid (document_id, section_id,
- // flags), so we potentially cannot distinguish if a Hit is invalid or not.
- // The invalidity is an essential feature for posting list since we use it
- // to determine the state of the posting list.
- // - The reason why it won't break the current posting list is because the
- // unused bit(s) are set as 1 for Hit::kInvalidValue and 0 for all valid
- // Hits. In other words, the unused bit(s) are actually serving as "invalid
- // flag".
- // - If we want to exhaust all unused bits in the future, then we have to
- // change Hit::kInvalidValue to set the inverted document_id section
- // correctly (b'00...0, refer to BasicHit::kInvalidValue as an example).
- // - Also this problem is guarded by static_assert in hit.cc. If exhausting
- // all unused bits, then the static_assert will detect and fail. We can
- // safely remove the static_assert check after following the above process
- // to resolve the incorrect Hit::kInvalidValue issue.
- static constexpr Value kInvalidValue = std::numeric_limits<Value>::max();
+ // kInvalidValue contains:
+ // - 0 for unused bits. Note that unused bits are always 0 for both valid and
+ // invalid Hit values.
+ // - Inverted kInvalidDocumentId
+ // - SectionId 0 (valid), which is ok because inverted kInvalidDocumentId has
+ // already invalidated the value. In fact, we currently use all 2^6 section
+ // ids and there is no "invalid section id", so it doesn't matter what
+ // SectionId we set for kInvalidValue.
+ static constexpr Value kInvalidValue = 0;
// Docs are sorted in reverse, and 0 is never used as the inverted
// DocumentId (because it is the inverse of kInvalidValue), so it is always
// the max in a descending sort.
static constexpr Value kMaxDocumentIdSortValue = 0;
+ enum FlagOffsetsInFlagsField {
+ // Whether or not the hit has a term_frequency other than
+ // kDefaultTermFrequency.
+ kHasTermFrequency = 0,
+ kNumFlagsInFlagsField = 1,
+ };
+
+ enum FlagOffsetsInValueField {
+ // Whether or not the hit has a flags value other than kNoEnabledFlags (i.e.
+ // it has flags enabled in the flags field)
+ kHasFlags = 0,
+ // This hit, whether exact or not, came from a prefixed section and will
+ // need to be backfilled into branching posting lists if/when those are
+ // created.
+ kInPrefixSection = 1,
+ // This hit represents a prefix of a longer term. If exact matches are
+ // required, then this hit should be ignored.
+ kPrefixHit = 2,
+ kNumFlagsInValueField = 3,
+ };
+ static_assert(kDocumentIdBits + kSectionIdBits + kNumFlagsInValueField <=
+ sizeof(Hit::Value) * 8,
+ "HitOverflow");
+ static_assert(kDocumentIdBits == 22, "");
+ static_assert(kSectionIdBits == 6, "");
+ static_assert(kNumFlagsInValueField == 3, "");
+
+ // The datatype used to encode additional bit-flags in the Hit.
+ // This is guaranteed to be an unsigned integer, but its size may change if
+ // more flags are introduced in the future and require more bits to encode.
+ using Flags = uint8_t;
+ static constexpr Flags kNoEnabledFlags = 0;
+
// The Term Frequency of a Hit.
+ // This is guaranteed to be an unsigned integer, but its size may change if we
+ // need to expand the max term-frequency.
using TermFrequency = uint8_t;
using TermFrequencyArray = std::array<Hit::TermFrequency, kTotalNumSections>;
// Max TermFrequency is 255.
@@ -131,40 +168,67 @@ class Hit {
static constexpr TermFrequency kDefaultTermFrequency = 1;
static constexpr TermFrequency kNoTermFrequency = 0;
- explicit Hit(Value value = kInvalidValue,
- TermFrequency term_frequency = kDefaultTermFrequency)
- : value_(value), term_frequency_(term_frequency) {}
- Hit(SectionId section_id, DocumentId document_id,
- TermFrequency term_frequency, bool is_in_prefix_section = false,
- bool is_prefix_hit = false);
+ explicit Hit(Value value)
+ : Hit(value, kNoEnabledFlags, kDefaultTermFrequency) {}
+ explicit Hit(Value value, Flags flags, TermFrequency term_frequency);
+ explicit Hit(SectionId section_id, DocumentId document_id,
+ TermFrequency term_frequency, bool is_in_prefix_section,
+ bool is_prefix_hit);
bool is_valid() const { return value() != kInvalidValue; }
- Value value() const { return value_; }
+
+ Value value() const {
+ Value value;
+ memcpy(&value, value_.data(), sizeof(value));
+ return value;
+ }
+
DocumentId document_id() const;
SectionId section_id() const;
+ bool is_prefix_hit() const;
+ bool is_in_prefix_section() const;
+ // Whether or not the hit has any flags set to true.
+ bool has_flags() const;
+
+ Flags flags() const { return flags_; }
// Whether or not the hit contains a valid term frequency.
bool has_term_frequency() const;
+
TermFrequency term_frequency() const { return term_frequency_; }
- bool is_prefix_hit() const;
- bool is_in_prefix_section() const;
+
+ // Returns true if the flags values across the Hit's value_, term_frequency_
+ // and flags_ fields are consistent.
+ bool CheckFlagsAreConsistent() const;
// Creates a new hit based on old_hit but with new_document_id set.
static Hit TranslateHit(Hit old_hit, DocumentId new_document_id);
- bool operator<(const Hit& h2) const { return value() < h2.value(); }
- bool operator==(const Hit& h2) const { return value() == h2.value(); }
+ bool operator<(const Hit& h2) const {
+ if (value() != h2.value()) {
+ return value() < h2.value();
+ }
+ return flags() < h2.flags();
+ }
+ bool operator==(const Hit& h2) const {
+ return value() == h2.value() && flags() == h2.flags();
+ }
struct EqualsDocumentIdAndSectionId {
bool operator()(const Hit& hit1, const Hit& hit2) const;
};
private:
- // Value and TermFrequency must be in this order.
- // Value bits layout: 1 unused + 22 document_id + 6 section id + 3 flags.
- Value value_;
+ // Value, Flags and TermFrequency must be in this order.
+ // Value bits layout: 1 unused + 22 document_id + 6 section_id + 1
+ // is_prefix_hit + 1 is_in_prefix_section + 1 has_flags.
+ std::array<char, sizeof(Value)> value_;
+ // Flags bits layout: 1 reserved + 6 unused + 1 has_term_frequency.
+ // The left-most bit is reserved for chaining additional fields in case of
+ // future hit expansions.
+ Flags flags_;
TermFrequency term_frequency_;
-} __attribute__((packed));
-static_assert(sizeof(Hit) == 5, "");
+};
+static_assert(sizeof(Hit) == 6, "");
// TODO(b/138991332) decide how to remove/replace all is_packed_pod assertions.
static_assert(icing_is_packed_pod<Hit>::value, "go/icing-ubsan");
diff --git a/icing/index/hit/hit_test.cc b/icing/index/hit/hit_test.cc
index 0086d91..1233e00 100644
--- a/icing/index/hit/hit_test.cc
+++ b/icing/index/hit/hit_test.cc
@@ -14,6 +14,9 @@
#include "icing/index/hit/hit.h"
+#include <algorithm>
+#include <vector>
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/schema/section.h"
@@ -94,77 +97,127 @@ TEST(BasicHitTest, Comparison) {
}
TEST(HitTest, HasTermFrequencyFlag) {
- Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency);
+ Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(h1.has_term_frequency(), IsFalse());
EXPECT_THAT(h1.term_frequency(), Eq(Hit::kDefaultTermFrequency));
- Hit h2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency);
+ Hit h2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(h2.has_term_frequency(), IsTrue());
EXPECT_THAT(h2.term_frequency(), Eq(kSomeTermFrequency));
}
TEST(HitTest, IsPrefixHitFlag) {
- Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency);
+ Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(h1.is_prefix_hit(), IsFalse());
Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
EXPECT_THAT(h2.is_prefix_hit(), IsFalse());
Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
/*is_in_prefix_section=*/false, /*is_prefix_hit=*/true);
EXPECT_THAT(h3.is_prefix_hit(), IsTrue());
+
+ Hit h4(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
+ EXPECT_THAT(h4.is_prefix_hit(), IsFalse());
}
TEST(HitTest, IsInPrefixSectionFlag) {
- Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency);
+ Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(h1.is_in_prefix_section(), IsFalse());
Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(h2.is_in_prefix_section(), IsFalse());
Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
EXPECT_THAT(h3.is_in_prefix_section(), IsTrue());
}
+TEST(HitTest, HasFlags) {
+ Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ EXPECT_THAT(h1.has_flags(), IsFalse());
+
+ Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true);
+ EXPECT_THAT(h2.has_flags(), IsFalse());
+
+ Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
+ EXPECT_THAT(h3.has_flags(), IsFalse());
+
+ Hit h4(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ EXPECT_THAT(h4.has_flags(), IsTrue());
+
+ Hit h5(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_prefix_hit=*/true, /*is_in_prefix_section=*/true);
+ EXPECT_THAT(h5.has_flags(), IsTrue());
+
+ Hit h6(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_prefix_hit=*/false, /*is_in_prefix_section=*/true);
+ EXPECT_THAT(h6.has_flags(), IsTrue());
+
+ Hit h7(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_prefix_hit=*/true, /*is_in_prefix_section=*/false);
+ EXPECT_THAT(h7.has_flags(), IsTrue());
+}
+
TEST(HitTest, Accessors) {
- Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency);
+ Hit h1(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true);
EXPECT_THAT(h1.document_id(), Eq(kSomeDocumentId));
EXPECT_THAT(h1.section_id(), Eq(kSomeSectionid));
+ EXPECT_THAT(h1.term_frequency(), Eq(kSomeTermFrequency));
+ EXPECT_THAT(h1.is_in_prefix_section(), IsFalse());
+ EXPECT_THAT(h1.is_prefix_hit(), IsTrue());
}
TEST(HitTest, Valid) {
- Hit def;
+ Hit def(Hit::kInvalidValue);
EXPECT_THAT(def.is_valid(), IsFalse());
-
- Hit explicit_invalid(Hit::kInvalidValue);
+ Hit explicit_invalid(Hit::kInvalidValue, Hit::kNoEnabledFlags,
+ Hit::kDefaultTermFrequency);
EXPECT_THAT(explicit_invalid.is_valid(), IsFalse());
static constexpr Hit::Value kSomeValue = 65372;
- Hit explicit_valid(kSomeValue);
+ Hit explicit_valid(kSomeValue, Hit::kNoEnabledFlags,
+ Hit::kDefaultTermFrequency);
EXPECT_THAT(explicit_valid.is_valid(), IsTrue());
- Hit maximum_document_id_hit(kSomeSectionid, kMaxDocumentId,
- kSomeTermFrequency);
+ Hit maximum_document_id_hit(
+ kSomeSectionid, kMaxDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(maximum_document_id_hit.is_valid(), IsTrue());
- Hit maximum_section_id_hit(kMaxSectionId, kSomeDocumentId,
- kSomeTermFrequency);
+ Hit maximum_section_id_hit(kMaxSectionId, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
EXPECT_THAT(maximum_section_id_hit.is_valid(), IsTrue());
- Hit minimum_document_id_hit(kSomeSectionid, 0, kSomeTermFrequency);
+ Hit minimum_document_id_hit(kSomeSectionid, 0, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
EXPECT_THAT(minimum_document_id_hit.is_valid(), IsTrue());
- Hit minimum_section_id_hit(0, kSomeDocumentId, kSomeTermFrequency);
+ Hit minimum_section_id_hit(0, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
EXPECT_THAT(minimum_section_id_hit.is_valid(), IsTrue());
- // We use Hit with value Hit::kMaxDocumentIdSortValue for std::lower_bound in
- // the lite index. Verify that the value of the smallest valid Hit (which
+ // We use Hit with value Hit::kMaxDocumentIdSortValue for std::lower_bound
+ // in the lite index. Verify that the value of the smallest valid Hit (which
// contains kMinSectionId, kMaxDocumentId and 3 flags = false) is >=
// Hit::kMaxDocumentIdSortValue.
- Hit smallest_hit(kMinSectionId, kMaxDocumentId, Hit::kDefaultTermFrequency);
+ Hit smallest_hit(kMinSectionId, kMaxDocumentId, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ASSERT_THAT(smallest_hit.is_valid(), IsTrue());
ASSERT_THAT(smallest_hit.has_term_frequency(), IsFalse());
ASSERT_THAT(smallest_hit.is_prefix_hit(), IsFalse());
@@ -173,39 +226,78 @@ TEST(HitTest, Valid) {
}
TEST(HitTest, Comparison) {
- Hit hit(1, 243, Hit::kDefaultTermFrequency);
+ Hit hit(/*section_id=*/1, /*document_id=*/243, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// DocumentIds are sorted in ascending order. So a hit with a lower
// document_id should be considered greater than one with a higher
// document_id.
- Hit higher_document_id_hit(1, 2409, Hit::kDefaultTermFrequency);
- Hit higher_section_id_hit(15, 243, Hit::kDefaultTermFrequency);
+ Hit higher_document_id_hit(
+ /*section_id=*/1, /*document_id=*/2409, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit higher_section_id_hit(/*section_id=*/15, /*document_id=*/243,
+ Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
// Whether or not a term frequency was set is considered, but the term
// frequency itself is not.
- Hit term_frequency_hit(1, 243, 12);
- Hit prefix_hit(1, 243, Hit::kDefaultTermFrequency,
+ Hit term_frequency_hit(/*section_id=*/1, 243, /*term_frequency=*/12,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
+ Hit prefix_hit(/*section_id=*/1, 243, Hit::kDefaultTermFrequency,
/*is_in_prefix_section=*/false,
/*is_prefix_hit=*/true);
- Hit hit_in_prefix_section(1, 243, Hit::kDefaultTermFrequency,
+ Hit hit_in_prefix_section(/*section_id=*/1, 243, Hit::kDefaultTermFrequency,
/*is_in_prefix_section=*/true,
/*is_prefix_hit=*/false);
+ Hit hit_with_all_flags_enabled(/*section_id=*/1, 243, 56,
+ /*is_in_prefix_section=*/true,
+ /*is_prefix_hit=*/true);
std::vector<Hit> hits{hit,
higher_document_id_hit,
higher_section_id_hit,
term_frequency_hit,
prefix_hit,
- hit_in_prefix_section};
+ hit_in_prefix_section,
+ hit_with_all_flags_enabled};
std::sort(hits.begin(), hits.end());
- EXPECT_THAT(
- hits, ElementsAre(higher_document_id_hit, hit, hit_in_prefix_section,
- prefix_hit, term_frequency_hit, higher_section_id_hit));
+ EXPECT_THAT(hits,
+ ElementsAre(higher_document_id_hit, hit, term_frequency_hit,
+ hit_in_prefix_section, prefix_hit,
+ hit_with_all_flags_enabled, higher_section_id_hit));
- Hit higher_term_frequency_hit(1, 243, 108);
+ Hit higher_term_frequency_hit(/*section_id=*/1, 243, /*term_frequency=*/108,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
// The term frequency value is not considered when comparing hits.
EXPECT_THAT(term_frequency_hit, Not(Lt(higher_term_frequency_hit)));
EXPECT_THAT(higher_term_frequency_hit, Not(Lt(term_frequency_hit)));
}
+TEST(HitTest, CheckFlagsAreConsistent) {
+ Hit::Value value_without_flags = 1 << 30;
+ Hit::Value value_with_flags = value_without_flags + 1;
+ Hit::Flags flags_with_term_freq = 1;
+
+ Hit consistent_hit_no_flags(value_without_flags, Hit::kNoEnabledFlags,
+ Hit::kDefaultTermFrequency);
+ Hit consistent_hit_with_term_frequency(value_with_flags, flags_with_term_freq,
+ kSomeTermFrequency);
+ EXPECT_THAT(consistent_hit_no_flags.CheckFlagsAreConsistent(), IsTrue());
+ EXPECT_THAT(consistent_hit_with_term_frequency.CheckFlagsAreConsistent(),
+ IsTrue());
+
+ Hit inconsistent_hit_1(value_with_flags, Hit::kNoEnabledFlags,
+ Hit::kDefaultTermFrequency);
+ Hit inconsistent_hit_2(value_with_flags, Hit::kNoEnabledFlags,
+ kSomeTermFrequency);
+ Hit inconsistent_hit_3(value_with_flags, flags_with_term_freq,
+ Hit::kDefaultTermFrequency);
+ EXPECT_THAT(inconsistent_hit_1.CheckFlagsAreConsistent(), IsFalse());
+ EXPECT_THAT(inconsistent_hit_2.CheckFlagsAreConsistent(), IsFalse());
+ EXPECT_THAT(inconsistent_hit_3.CheckFlagsAreConsistent(), IsFalse());
+}
+
} // namespace
} // namespace lib
diff --git a/icing/index/index.cc b/icing/index/index.cc
index 98058be..f917aa6 100644
--- a/icing/index/index.cc
+++ b/icing/index/index.cc
@@ -67,8 +67,7 @@ libtextclassifier3::StatusOr<LiteIndex::Options> CreateLiteIndexOptions(
}
return LiteIndex::Options(
options.base_dir + "/idx/lite.", options.index_merge_size,
- options.lite_index_sort_at_indexing, options.lite_index_sort_size,
- options.include_property_existence_metadata_hits);
+ options.lite_index_sort_at_indexing, options.lite_index_sort_size);
}
std::string MakeMainIndexFilepath(const std::string& base_dir) {
@@ -345,7 +344,8 @@ libtextclassifier3::Status Index::Editor::BufferTerm(const char* term) {
libtextclassifier3::Status Index::Editor::IndexAllBufferedTerms() {
for (auto itr = seen_tokens_.begin(); itr != seen_tokens_.end(); itr++) {
Hit hit(section_id_, document_id_, /*term_frequency=*/itr->second,
- term_match_type_ == TermMatchType::PREFIX);
+ /*is_in_prefix_section=*/term_match_type_ == TermMatchType::PREFIX,
+ /*is_prefix_hit=*/false);
ICING_ASSIGN_OR_RETURN(
uint32_t term_id, term_id_codec_->EncodeTvi(itr->first, TviType::LITE));
ICING_RETURN_IF_ERROR(lite_index_->AddHit(term_id, hit));
diff --git a/icing/index/index.h b/icing/index/index.h
index a5d75c4..a09e28f 100644
--- a/icing/index/index.h
+++ b/icing/index/index.h
@@ -35,6 +35,7 @@
#include "icing/index/term-metadata.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/proto/debug.pb.h"
+#include "icing/proto/logging.pb.h"
#include "icing/proto/scoring.pb.h"
#include "icing/proto/storage.pb.h"
#include "icing/proto/term.pb.h"
@@ -72,20 +73,16 @@ class Index {
struct Options {
explicit Options(const std::string& base_dir, uint32_t index_merge_size,
bool lite_index_sort_at_indexing,
- uint32_t lite_index_sort_size,
- bool include_property_existence_metadata_hits = false)
+ uint32_t lite_index_sort_size)
: base_dir(base_dir),
index_merge_size(index_merge_size),
lite_index_sort_at_indexing(lite_index_sort_at_indexing),
- lite_index_sort_size(lite_index_sort_size),
- include_property_existence_metadata_hits(
- include_property_existence_metadata_hits) {}
+ lite_index_sort_size(lite_index_sort_size) {}
std::string base_dir;
int32_t index_merge_size;
bool lite_index_sort_at_indexing;
int32_t lite_index_sort_size;
- bool include_property_existence_metadata_hits;
};
// Creates an instance of Index in the directory pointed by file_dir.
@@ -178,6 +175,13 @@ class Index {
return debug_info;
}
+ void PublishQueryStats(QueryStatsProto* query_stats) const {
+ query_stats->set_lite_index_hit_buffer_byte_size(
+ lite_index_->GetHitBufferByteSize());
+ query_stats->set_lite_index_hit_buffer_unsorted_byte_size(
+ lite_index_->GetHitBufferUnsortedByteSize());
+ }
+
// Returns the byte size of the all the elements held in the index. This
// excludes the size of any internal metadata of the index, e.g. the index's
// header.
@@ -301,9 +305,7 @@ class Index {
}
// Sorts the LiteIndex HitBuffer.
- void SortLiteIndex() {
- lite_index_->SortHits();
- }
+ void SortLiteIndex() { lite_index_->SortHits(); }
// Reduces internal file sizes by reclaiming space of deleted documents.
// new_last_added_document_id will be used to update the last added document
diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc
index 04a6bb7..50b65ad 100644
--- a/icing/index/index_test.cc
+++ b/icing/index/index_test.cc
@@ -33,9 +33,11 @@
#include "icing/file/filesystem.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/legacy/index/icing-mock-filesystem.h"
#include "icing/proto/debug.pb.h"
+#include "icing/proto/logging.pb.h"
#include "icing/proto/storage.pb.h"
#include "icing/proto/term.pb.h"
#include "icing/schema/section.h"
@@ -2734,11 +2736,47 @@ TEST_F(IndexTest, IndexStorageInfoProto) {
EXPECT_THAT(storage_info.main_index_lexicon_size(), Ge(0));
EXPECT_THAT(storage_info.main_index_storage_size(), Ge(0));
EXPECT_THAT(storage_info.main_index_block_size(), Ge(0));
- // There should be 1 block for the header and 1 block for two posting lists.
+ // There should be 1 block for the header and 1 block for three posting lists
+ // ("fo", "foo", "foul").
EXPECT_THAT(storage_info.num_blocks(), Eq(2));
EXPECT_THAT(storage_info.min_free_fraction(), Ge(0));
}
+TEST_F(IndexTest, PublishQueryStats) {
+ // Add two documents to the lite index without merging.
+ Index::Editor edit = index_->Edit(kDocumentId0, kSectionId2,
+ TermMatchType::PREFIX, /*namespace_id=*/0);
+ ASSERT_THAT(edit.BufferTerm("foo"), IsOk());
+ EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk());
+ edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::PREFIX,
+ /*namespace_id=*/0);
+ ASSERT_THAT(edit.BufferTerm("foul"), IsOk());
+ EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk());
+
+ // Verify query stats.
+ QueryStatsProto query_stats1;
+ index_->PublishQueryStats(&query_stats1);
+ EXPECT_THAT(query_stats1.lite_index_hit_buffer_byte_size(),
+ Eq(2 * sizeof(TermIdHitPair::Value)));
+ EXPECT_THAT(query_stats1.lite_index_hit_buffer_unsorted_byte_size(),
+ Ge(2 * sizeof(TermIdHitPair::Value)));
+
+ // Sort lite index.
+ index_->SortLiteIndex();
+ QueryStatsProto query_stats2;
+ index_->PublishQueryStats(&query_stats2);
+ EXPECT_THAT(query_stats2.lite_index_hit_buffer_byte_size(),
+ Eq(2 * sizeof(TermIdHitPair::Value)));
+ EXPECT_THAT(query_stats2.lite_index_hit_buffer_unsorted_byte_size(), Eq(0));
+
+ // Merge lite index to main index.
+ ICING_ASSERT_OK(index_->Merge());
+ QueryStatsProto query_stats3;
+ index_->PublishQueryStats(&query_stats3);
+ EXPECT_THAT(query_stats3.lite_index_hit_buffer_byte_size(), Eq(0));
+ EXPECT_THAT(query_stats3.lite_index_hit_buffer_unsorted_byte_size(), Eq(0));
+}
+
} // namespace
} // namespace lib
diff --git a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
index c16a1c4..d766712 100644
--- a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
+++ b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h
@@ -17,13 +17,17 @@
#include <cstdint>
#include <memory>
+#include <set>
#include <string>
-#include <string_view>
#include <utility>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
namespace icing {
diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
index 35dc0b9..735adaa 100644
--- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc
@@ -113,12 +113,12 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions(
const DocumentStore* document_store, const SchemaStore* schema_store,
const SearchSpecProto& search_spec, int64_t current_time_ms) {
std::unordered_map<std::string, std::set<std::string>> type_property_filters;
- // TODO(b/294274922): Add support for polymorphism in type property filters.
- for (const TypePropertyMask& type_property_mask :
- search_spec.type_property_filters()) {
- type_property_filters[type_property_mask.schema_type()] =
- std::set<std::string>(type_property_mask.paths().begin(),
- type_property_mask.paths().end());
+ for (const SchemaStore::ExpandedTypePropertyMask& type_property_mask :
+ schema_store->ExpandTypePropertyMasks(
+ search_spec.type_property_filters())) {
+ type_property_filters[type_property_mask.schema_type] =
+ std::set<std::string>(type_property_mask.paths.begin(),
+ type_property_mask.paths.end());
}
auto data = std::make_unique<SectionRestrictData>(
document_store, schema_store, current_time_ms, type_property_filters);
@@ -134,7 +134,9 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions(
ChildrenMapper mapper;
mapper = [&data, &mapper](std::unique_ptr<DocHitInfoIterator> iterator)
-> std::unique_ptr<DocHitInfoIterator> {
- if (iterator->is_leaf()) {
+ if (iterator->full_section_restriction_applied()) {
+ return iterator;
+ } else if (iterator->is_leaf()) {
return std::make_unique<DocHitInfoIteratorSectionRestrict>(
std::move(iterator), data);
} else {
@@ -149,24 +151,8 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() {
doc_hit_info_ = DocHitInfo(kInvalidDocumentId);
while (delegate_->Advance().ok()) {
DocumentId document_id = delegate_->doc_hit_info().document_id();
-
- auto data_optional = data_->document_store().GetAliveDocumentFilterData(
- document_id, data_->current_time_ms());
- if (!data_optional) {
- // Ran into some error retrieving information on this hit, skip
- continue;
- }
-
- // Guaranteed that the DocumentFilterData exists at this point
- SchemaTypeId schema_type_id = data_optional.value().schema_type_id();
- auto schema_type_or = data_->schema_store().GetSchemaType(schema_type_id);
- if (!schema_type_or.ok()) {
- // Ran into error retrieving schema type, skip
- continue;
- }
- const std::string* schema_type = std::move(schema_type_or).ValueOrDie();
SectionIdMask allowed_sections_mask =
- data_->ComputeAllowedSectionsMask(*schema_type);
+ data_->ComputeAllowedSectionsMask(document_id);
// A hit can be in multiple sections at once, need to check which of the
// section ids match the sections allowed by type_property_masks_. This can
diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h
index 728f957..921e4d4 100644
--- a/icing/index/iterator/doc-hit-info-iterator.h
+++ b/icing/index/iterator/doc-hit-info-iterator.h
@@ -214,6 +214,12 @@ class DocHitInfoIterator {
virtual bool is_leaf() { return false; }
+ // Whether the iterator has already been applied all the required section
+ // restrictions.
+ // If true, calling DocHitInfoIteratorSectionRestrict::ApplyRestrictions on
+ // this iterator will have no effect.
+ virtual bool full_section_restriction_applied() const { return false; }
+
virtual ~DocHitInfoIterator() = default;
// Returns:
diff --git a/icing/index/iterator/section-restrict-data.cc b/icing/index/iterator/section-restrict-data.cc
index 085437d..581ac8e 100644
--- a/icing/index/iterator/section-restrict-data.cc
+++ b/icing/index/iterator/section-restrict-data.cc
@@ -22,6 +22,8 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
namespace icing {
namespace lib {
@@ -78,5 +80,24 @@ SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask(
return new_section_id_mask;
}
+SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask(
+ DocumentId document_id) {
+ auto data_optional =
+ document_store_.GetAliveDocumentFilterData(document_id, current_time_ms_);
+ if (!data_optional) {
+ // Ran into some error retrieving information on this document, skip
+ return kSectionIdMaskNone;
+ }
+ // Guaranteed that the DocumentFilterData exists at this point
+ SchemaTypeId schema_type_id = data_optional.value().schema_type_id();
+ auto schema_type_or = schema_store_.GetSchemaType(schema_type_id);
+ if (!schema_type_or.ok()) {
+ // Ran into error retrieving schema type, skip
+ return kSectionIdMaskNone;
+ }
+ const std::string* schema_type = std::move(schema_type_or).ValueOrDie();
+ return ComputeAllowedSectionsMask(*schema_type);
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/index/iterator/section-restrict-data.h b/icing/index/iterator/section-restrict-data.h
index 26ca597..64d5087 100644
--- a/icing/index/iterator/section-restrict-data.h
+++ b/icing/index/iterator/section-restrict-data.h
@@ -23,6 +23,7 @@
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
namespace icing {
@@ -48,10 +49,21 @@ class SectionRestrictData {
// Returns:
// - If type_property_filters_ has an entry for the given schema type or
// wildcard(*), return a bitwise or of section IDs in the schema type
- // that that are also present in the relevant filter list.
+ // that are also present in the relevant filter list.
// - Otherwise, return kSectionIdMaskAll.
SectionIdMask ComputeAllowedSectionsMask(const std::string& schema_type);
+ // Calculates the section mask of allowed sections(determined by the
+ // property filters map) for the given document id, by retrieving its schema
+ // type name and calling the above method.
+ //
+ // Returns:
+ // - If type_property_filters_ has an entry for the given document's schema
+ // type or wildcard(*), return a bitwise or of section IDs in the schema
+ // type that are also present in the relevant filter list.
+ // - Otherwise, return kSectionIdMaskAll.
+ SectionIdMask ComputeAllowedSectionsMask(DocumentId document_id);
+
const DocumentStore& document_store() const { return document_store_; }
const SchemaStore& schema_store() const { return schema_store_; }
diff --git a/icing/index/lite/lite-index-header.h b/icing/index/lite/lite-index-header.h
index 75de8fa..741d173 100644
--- a/icing/index/lite/lite-index-header.h
+++ b/icing/index/lite/lite-index-header.h
@@ -53,14 +53,7 @@ class LiteIndex_Header {
class LiteIndex_HeaderImpl : public LiteIndex_Header {
public:
struct HeaderData {
- static uint32_t GetCurrentMagic(
- bool include_property_existence_metadata_hits) {
- if (!include_property_existence_metadata_hits) {
- return 0x01c61418;
- } else {
- return 0x56e07d5b;
- }
- }
+ static const uint32_t kMagic = 0xC2EAD682;
uint32_t lite_index_crc;
uint32_t magic;
@@ -76,15 +69,10 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header {
uint32_t searchable_end;
};
- explicit LiteIndex_HeaderImpl(HeaderData *hdr,
- bool include_property_existence_metadata_hits)
- : hdr_(hdr),
- include_property_existence_metadata_hits_(
- include_property_existence_metadata_hits) {}
+ explicit LiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {}
bool check_magic() const override {
- return hdr_->magic == HeaderData::GetCurrentMagic(
- include_property_existence_metadata_hits_);
+ return hdr_->magic == HeaderData::kMagic;
}
uint32_t lite_index_crc() const override { return hdr_->lite_index_crc; }
@@ -111,8 +99,7 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header {
void Reset() override {
hdr_->lite_index_crc = 0;
- hdr_->magic =
- HeaderData::GetCurrentMagic(include_property_existence_metadata_hits_);
+ hdr_->magic = HeaderData::kMagic;
hdr_->last_added_docid = kInvalidDocumentId;
hdr_->cur_size = 0;
hdr_->searchable_end = 0;
@@ -120,7 +107,6 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header {
private:
HeaderData *hdr_;
- bool include_property_existence_metadata_hits_;
};
static_assert(24 == sizeof(LiteIndex_HeaderImpl::HeaderData),
"sizeof(HeaderData) != 24");
diff --git a/icing/index/lite/lite-index-options.cc b/icing/index/lite/lite-index-options.cc
index 7e6c076..b4810ea 100644
--- a/icing/index/lite/lite-index-options.cc
+++ b/icing/index/lite/lite-index-options.cc
@@ -69,16 +69,14 @@ IcingDynamicTrie::Options CalculateTrieOptions(uint32_t hit_buffer_size) {
} // namespace
-LiteIndexOptions::LiteIndexOptions(
- const std::string& filename_base, uint32_t hit_buffer_want_merge_bytes,
- bool hit_buffer_sort_at_indexing, uint32_t hit_buffer_sort_threshold_bytes,
- bool include_property_existence_metadata_hits)
+LiteIndexOptions::LiteIndexOptions(const std::string& filename_base,
+ uint32_t hit_buffer_want_merge_bytes,
+ bool hit_buffer_sort_at_indexing,
+ uint32_t hit_buffer_sort_threshold_bytes)
: filename_base(filename_base),
hit_buffer_want_merge_bytes(hit_buffer_want_merge_bytes),
hit_buffer_sort_at_indexing(hit_buffer_sort_at_indexing),
- hit_buffer_sort_threshold_bytes(hit_buffer_sort_threshold_bytes),
- include_property_existence_metadata_hits(
- include_property_existence_metadata_hits) {
+ hit_buffer_sort_threshold_bytes(hit_buffer_sort_threshold_bytes) {
hit_buffer_size = CalculateHitBufferSize(hit_buffer_want_merge_bytes);
lexicon_options = CalculateTrieOptions(hit_buffer_size);
display_mappings_options = CalculateTrieOptions(hit_buffer_size);
diff --git a/icing/index/lite/lite-index-options.h b/icing/index/lite/lite-index-options.h
index 8b03449..5eae5c7 100644
--- a/icing/index/lite/lite-index-options.h
+++ b/icing/index/lite/lite-index-options.h
@@ -32,8 +32,7 @@ struct LiteIndexOptions {
LiteIndexOptions(const std::string& filename_base,
uint32_t hit_buffer_want_merge_bytes,
bool hit_buffer_sort_at_indexing,
- uint32_t hit_buffer_sort_threshold_bytes,
- bool include_property_existence_metadata_hits = false);
+ uint32_t hit_buffer_sort_threshold_bytes);
IcingDynamicTrie::Options lexicon_options;
IcingDynamicTrie::Options display_mappings_options;
@@ -43,7 +42,6 @@ struct LiteIndexOptions {
uint32_t hit_buffer_size = 0;
bool hit_buffer_sort_at_indexing = false;
uint32_t hit_buffer_sort_threshold_bytes = 0;
- bool include_property_existence_metadata_hits = false;
};
} // namespace lib
diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc
index 3f9cc93..3aed7e4 100644
--- a/icing/index/lite/lite-index.cc
+++ b/icing/index/lite/lite-index.cc
@@ -74,7 +74,7 @@ size_t header_size() { return sizeof(LiteIndex_HeaderImpl::HeaderData); }
} // namespace
const TermIdHitPair::Value TermIdHitPair::kInvalidValue =
- TermIdHitPair(0, Hit()).value();
+ TermIdHitPair(0, Hit(Hit::kInvalidValue)).value();
libtextclassifier3::StatusOr<std::unique_ptr<LiteIndex>> LiteIndex::Create(
const LiteIndex::Options& options, const IcingFilesystem* filesystem) {
@@ -168,8 +168,7 @@ libtextclassifier3::Status LiteIndex::Initialize() {
header_mmap_.Remap(hit_buffer_fd_.get(), kHeaderFileOffset, header_size());
header_ = std::make_unique<LiteIndex_HeaderImpl>(
reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>(
- header_mmap_.address()),
- options_.include_property_existence_metadata_hits);
+ header_mmap_.address()));
header_->Reset();
if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true,
@@ -184,8 +183,7 @@ libtextclassifier3::Status LiteIndex::Initialize() {
header_mmap_.Remap(hit_buffer_fd_.get(), kHeaderFileOffset, header_size());
header_ = std::make_unique<LiteIndex_HeaderImpl>(
reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>(
- header_mmap_.address()),
- options_.include_property_existence_metadata_hits);
+ header_mmap_.address()));
if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true,
sizeof(TermIdHitPair::Value), header_->cur_size(),
@@ -499,7 +497,7 @@ int LiteIndex::FetchHits(
// When disabled, the entire HitBuffer should be sorted already and only
// binary search is needed.
if (options_.hit_buffer_sort_at_indexing) {
- uint32_t unsorted_length = header_->cur_size() - header_->searchable_end();
+ uint32_t unsorted_length = GetHitBufferUnsortedSizeImpl();
for (uint32_t i = 1; i <= unsorted_length; ++i) {
TermIdHitPair term_id_hit_pair = array[header_->cur_size() - i];
if (term_id_hit_pair.term_id() == term_id) {
@@ -519,7 +517,8 @@ int LiteIndex::FetchHits(
// Do binary search over the sorted section and repeat the above steps.
TermIdHitPair target_term_id_hit_pair(
- term_id, Hit(Hit::kMaxDocumentIdSortValue, Hit::kDefaultTermFrequency));
+ term_id, Hit(Hit::kMaxDocumentIdSortValue, Hit::kNoEnabledFlags,
+ Hit::kDefaultTermFrequency));
for (const TermIdHitPair* ptr = std::lower_bound(
array, array + header_->searchable_end(), target_term_id_hit_pair);
ptr < array + header_->searchable_end(); ++ptr) {
@@ -607,13 +606,13 @@ IndexStorageInfoProto LiteIndex::GetStorageInfo(
void LiteIndex::SortHitsImpl() {
// Make searchable by sorting by hit buffer.
- uint32_t sort_len = header_->cur_size() - header_->searchable_end();
- if (sort_len <= 0) {
+ uint32_t need_sort_len = GetHitBufferUnsortedSizeImpl();
+ if (need_sort_len <= 0) {
return;
}
IcingTimer timer;
- auto* array_start =
+ TermIdHitPair::Value* array_start =
hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size());
TermIdHitPair::Value* sort_start = array_start + header_->searchable_end();
std::sort(sort_start, array_start + header_->cur_size());
@@ -625,7 +624,7 @@ void LiteIndex::SortHitsImpl() {
std::inplace_merge(array_start, array_start + header_->searchable_end(),
array_start + header_->cur_size());
}
- ICING_VLOG(2) << "Lite index sort and merge " << sort_len << " into "
+ ICING_VLOG(2) << "Lite index sort and merge " << need_sort_len << " into "
<< header_->searchable_end() << " in " << timer.Elapsed() * 1000
<< "ms";
diff --git a/icing/index/lite/lite-index.h b/icing/index/lite/lite-index.h
index 288602a..45dc280 100644
--- a/icing/index/lite/lite-index.h
+++ b/icing/index/lite/lite-index.h
@@ -308,6 +308,24 @@ class LiteIndex {
IndexStorageInfoProto GetStorageInfo(IndexStorageInfoProto storage_info) const
ICING_LOCKS_EXCLUDED(mutex_);
+ // Returns the size of unsorted part of hit_buffer_.
+ uint32_t GetHitBufferByteSize() const ICING_LOCKS_EXCLUDED(mutex_) {
+ absl_ports::shared_lock l(&mutex_);
+ return size_impl() * sizeof(TermIdHitPair::Value);
+ }
+
+ // Returns the size of unsorted part of hit_buffer_.
+ uint32_t GetHitBufferUnsortedSize() const ICING_LOCKS_EXCLUDED(mutex_) {
+ absl_ports::shared_lock l(&mutex_);
+ return GetHitBufferUnsortedSizeImpl();
+ }
+
+ // Returns the byte size of unsorted part of hit_buffer_.
+ uint64_t GetHitBufferUnsortedByteSize() const ICING_LOCKS_EXCLUDED(mutex_) {
+ absl_ports::shared_lock l(&mutex_);
+ return GetHitBufferUnsortedSizeImpl() * sizeof(TermIdHitPair::Value);
+ }
+
// Reduces internal file sizes by reclaiming space of deleted documents.
//
// This method also sets the last_added_docid of the index to
@@ -377,13 +395,13 @@ class LiteIndex {
bool NeedSortAtQuerying() const ICING_SHARED_LOCKS_REQUIRED(mutex_) {
return HasUnsortedHitsExceedingSortThresholdImpl() ||
(!options_.hit_buffer_sort_at_indexing &&
- header_->cur_size() - header_->searchable_end() > 0);
+ GetHitBufferUnsortedSizeImpl() > 0);
}
// Non-locking implementation for HasUnsortedHitsExceedingSortThresholdImpl().
bool HasUnsortedHitsExceedingSortThresholdImpl() const
ICING_SHARED_LOCKS_REQUIRED(mutex_) {
- return header_->cur_size() - header_->searchable_end() >=
+ return GetHitBufferUnsortedSizeImpl() >=
(options_.hit_buffer_sort_threshold_bytes /
sizeof(TermIdHitPair::Value));
}
@@ -408,6 +426,12 @@ class LiteIndex {
std::vector<Hit::TermFrequencyArray>* term_frequency_out) const
ICING_SHARED_LOCKS_REQUIRED(mutex_);
+ // Returns the size of unsorted part of hit_buffer_.
+ uint32_t GetHitBufferUnsortedSizeImpl() const
+ ICING_SHARED_LOCKS_REQUIRED(mutex_) {
+ return header_->cur_size() - header_->searchable_end();
+ }
+
// File descriptor that points to where the header and hit buffer are written
// to.
ScopedFd hit_buffer_fd_ ICING_GUARDED_BY(mutex_);
diff --git a/icing/index/lite/lite-index_test.cc b/icing/index/lite/lite-index_test.cc
index 9811fa2..f8ea94a 100644
--- a/icing/index/lite/lite-index_test.cc
+++ b/icing/index/lite/lite-index_test.cc
@@ -27,7 +27,7 @@
#include "icing/index/hit/hit.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/lite/doc-hit-info-iterator-term-lite.h"
-#include "icing/index/lite/lite-index-header.h"
+#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/index/term-id-codec.h"
#include "icing/legacy/index/icing-dynamic-trie.h"
#include "icing/legacy/index/icing-filesystem.h"
@@ -49,6 +49,7 @@ using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::IsFalse;
using ::testing::IsTrue;
+using ::testing::Pointee;
using ::testing::SizeIs;
class LiteIndexTest : public testing::Test {
@@ -71,15 +72,25 @@ class LiteIndexTest : public testing::Test {
constexpr NamespaceId kNamespace0 = 0;
+TEST_F(LiteIndexTest, TermIdHitPairInvalidValue) {
+ TermIdHitPair invalidTermHitPair(TermIdHitPair::kInvalidValue);
+
+ EXPECT_THAT(invalidTermHitPair.term_id(), Eq(0));
+ EXPECT_THAT(invalidTermHitPair.hit().value(), Eq(Hit::kInvalidValue));
+ EXPECT_THAT(invalidTermHitPair.hit().term_frequency(),
+ Eq(Hit::kDefaultTermFrequency));
+}
+
TEST_F(LiteIndexTest,
LiteIndexFetchHits_sortAtQuerying_unsortedHitsBelowSortThreshold) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/false,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ // Unsorted tail can contain a max of 8 TermIdHitPairs.
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/false,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -95,9 +106,9 @@ TEST_F(LiteIndexTest,
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit foo_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit foo_hit1(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit0));
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit1));
@@ -107,24 +118,18 @@ TEST_F(LiteIndexTest,
ICING_ASSERT_OK_AND_ASSIGN(uint32_t bar_term_id,
term_id_codec_->EncodeTvi(bar_tvi, TviType::LITE));
Hit bar_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit bar_hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit0));
ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit1));
+ // Check the total size and unsorted size of the hit buffer.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(4)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4));
// Check that unsorted hits does not exceed the sort threshold.
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse());
- // Check that hits are unsorted. Persist the data and pread from
- // LiteIndexHeader.
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- LiteIndex_HeaderImpl::HeaderData header_data;
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4));
-
// Query the LiteIndex
std::vector<DocHitInfo> hits1;
lite_index->FetchHits(
@@ -148,28 +153,26 @@ TEST_F(LiteIndexTest,
// checker.
EXPECT_THAT(hits2, IsEmpty());
- // Check that hits are sorted after querying LiteIndex. Persist the data and
- // pread from LiteIndexHeader.
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0));
+ // Check the total size and unsorted size of the hit buffer. Hits should be
+ // sorted after querying LiteIndex.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(4)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0));
}
TEST_F(LiteIndexTest,
LiteIndexFetchHits_sortAtIndexing_unsortedHitsBelowSortThreshold) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
// However note that in these tests we're unable to sort hits after
// indexing, as sorting performed by the string-section-indexing-handler
// after indexing all hits in an entire document, rather than after each
// AddHits() operation.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/true,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/true,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -185,9 +188,9 @@ TEST_F(LiteIndexTest,
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit foo_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit foo_hit1(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit0));
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit1));
@@ -197,24 +200,18 @@ TEST_F(LiteIndexTest,
ICING_ASSERT_OK_AND_ASSIGN(uint32_t bar_term_id,
term_id_codec_->EncodeTvi(bar_tvi, TviType::LITE));
Hit bar_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit bar_hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit0));
ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit1));
+ // Check the total size and unsorted size of the hit buffer.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(4)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4));
// Check that unsorted hits does not exceed the sort threshold.
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse());
- // Check that hits are unsorted. Persist the data and pread from
- // LiteIndexHeader.
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- LiteIndex_HeaderImpl::HeaderData header_data;
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4));
-
// Query the LiteIndex
std::vector<DocHitInfo> hits1;
lite_index->FetchHits(
@@ -238,15 +235,11 @@ TEST_F(LiteIndexTest,
// checker.
EXPECT_THAT(hits2, IsEmpty());
- // Check that hits are still unsorted after querying LiteIndex because the
- // HitBuffer unsorted size is still below the sort threshold, and we've
- // enabled sort_at_indexing.
- // Persist the data and performing a pread on LiteIndexHeader.
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4));
+ // Check the total size and unsorted size of the hit buffer. Hits should be
+ // still unsorted after querying LiteIndex because the HitBuffer unsorted size
+ // is still below the sort threshold, and we've enabled sort_at_indexing.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(4)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4));
}
TEST_F(
@@ -254,15 +247,16 @@ TEST_F(
LiteIndexFetchHits_sortAtQuerying_unsortedHitsExceedingSortAtIndexThreshold) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
// However note that in these tests we're unable to sort hits after
// indexing, as sorting performed by the string-section-indexing-handler
// after indexing all hits in an entire document, rather than after each
// AddHits() operation.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/false,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/false,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -274,36 +268,36 @@ TEST_F(
// Create 4 hits for docs 0-2, and 2 hits for doc 3 -- 14 in total
// Doc 0
Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit1(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit2(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 1
Hit doc1_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit2(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit3(/*section_id=*/2, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 2
Hit doc2_hit0(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit1(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit2(/*section_id=*/1, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit3(/*section_id=*/2, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 3
Hit doc3_hit0(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc3_hit1(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Create terms
// Foo
@@ -348,7 +342,10 @@ TEST_F(
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc2_hit3));
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc3_hit0));
ICING_ASSERT_OK(lite_index->AddHit(baz_term_id, doc3_hit1));
- // Verify that the HitBuffer has not been sorted.
+ // Check the total size and unsorted size of the hit buffer. The HitBuffer has
+ // not been sorted.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(14)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(14));
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue());
// We now have the following in the hit buffer:
@@ -400,7 +397,10 @@ TEST_F(
EXPECT_THAT(hits3[1].document_id(), Eq(0));
EXPECT_THAT(hits3[1].hit_section_ids_mask(), Eq(0b1));
- // Check that the HitBuffer is sorted after the query call.
+ // Check the total size and unsorted size of the hit buffer. The HitBuffer
+ // should be sorted after the query call.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(14)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0));
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse());
}
@@ -409,11 +409,12 @@ TEST_F(
LiteIndexFetchHits_sortAtIndexing_unsortedHitsExceedingSortAtIndexThreshold) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/true,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/true,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -425,49 +426,49 @@ TEST_F(
// Create 4 hits for docs 0-2, and 2 hits for doc 3 -- 14 in total
// Doc 0
Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit1(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit2(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 1
Hit doc1_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit2(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit3(/*section_id=*/2, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 2
Hit doc2_hit0(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit1(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit2(/*section_id=*/1, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc2_hit3(/*section_id=*/2, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 3
Hit doc3_hit0(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc3_hit1(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc3_hit2(/*section_id=*/1, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc3_hit3(/*section_id=*/2, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Doc 4
Hit doc4_hit0(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc4_hit1(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc4_hit2(/*section_id=*/1, /*document_id=*/4, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc4_hit3(/*section_id=*/2, /*document_id=*/4, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
// Create terms
// Foo
@@ -511,13 +512,11 @@ TEST_F(
// AddHit() itself, we need to invoke SortHits() manually.
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue());
lite_index->SortHits();
- // Check that the HitBuffer is sorted.
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- LiteIndex_HeaderImpl::HeaderData header_data;
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0));
+ // Check the total size and unsorted size of the hit buffer. The HitBuffer
+ // should be sorted after calling SortHits().
+ EXPECT_THAT(lite_index, Pointee(SizeIs(8)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0));
+ EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse());
// Add 12 more hits so that sort threshold is exceeded again.
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc2_hit0));
@@ -536,6 +535,8 @@ TEST_F(
// Adding these hits exceeds the sort threshold. However when sort_at_indexing
// is enabled, sorting is done in the string-section-indexing-handler rather
// than AddHit() itself.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(20)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(12));
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue());
// We now have the following in the hit buffer:
@@ -589,25 +590,24 @@ TEST_F(
EXPECT_THAT(hits3[1].document_id(), Eq(0));
EXPECT_THAT(hits3[1].hit_section_ids_mask(), Eq(0b1));
- // Check that the HitBuffer is sorted after the query call. FetchHits should
- // sort before performing binary search if the HitBuffer unsorted size exceeds
- // the sort threshold. Regardless of the sort_at_indexing config.
+ // Check the total size and unsorted size of the hit buffer. FetchHits should
+ // sort before performing search if the HitBuffer unsorted size exceeds the
+ // sort threshold, regardless of the sort_at_indexing config (to avoid
+ // sequential search on an extremely long unsorted tails).
+ EXPECT_THAT(lite_index, Pointee(SizeIs(20)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0));
EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse());
- ASSERT_THAT(lite_index->PersistToDisk(), IsOk());
- ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(),
- &header_data, sizeof(header_data),
- LiteIndex::kHeaderFileOffset));
- EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0));
}
TEST_F(LiteIndexTest, LiteIndexIterator) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/true,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/true,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -623,9 +623,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
SectionIdMask doc0_section_id_mask = 0b11;
std::unordered_map<SectionId, Hit::TermFrequency>
expected_section_ids_tf_map0 = {{0, 3}, {1, 5}};
@@ -633,9 +633,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator) {
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc0_hit1));
Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
SectionIdMask doc1_section_id_mask = 0b110;
std::unordered_map<SectionId, Hit::TermFrequency>
expected_section_ids_tf_map1 = {{1, 7}, {2, 11}};
@@ -671,11 +671,12 @@ TEST_F(LiteIndexTest, LiteIndexIterator) {
TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) {
// Set up LiteIndex and TermIdCodec
std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
- // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs.
- LiteIndex::Options options(lite_index_file_name,
- /*hit_buffer_want_merge_bytes=*/1024 * 1024,
- /*hit_buffer_sort_at_indexing=*/false,
- /*hit_buffer_sort_threshold_bytes=*/64);
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/false,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
LiteIndex::Create(options, &icing_filesystem_));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -691,9 +692,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
SectionIdMask doc0_section_id_mask = 0b11;
std::unordered_map<SectionId, Hit::TermFrequency>
expected_section_ids_tf_map0 = {{0, 3}, {1, 5}};
@@ -701,9 +702,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) {
ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc0_hit1));
Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
SectionIdMask doc1_section_id_mask = 0b110;
std::unordered_map<SectionId, Hit::TermFrequency>
expected_section_ids_tf_map1 = {{1, 7}, {2, 11}};
@@ -736,6 +737,54 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) {
term, expected_section_ids_tf_map0)));
}
+TEST_F(LiteIndexTest, LiteIndexHitBufferSize) {
+ // Set up LiteIndex and TermIdCodec
+ std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index";
+ // The unsorted tail can contain a max of 8 TermIdHitPairs.
+ LiteIndex::Options options(
+ lite_index_file_name,
+ /*hit_buffer_want_merge_bytes=*/1024 * 1024,
+ /*hit_buffer_sort_at_indexing=*/true,
+ /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index,
+ LiteIndex::Create(options, &icing_filesystem_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ term_id_codec_,
+ TermIdCodec::Create(
+ IcingDynamicTrie::max_value_index(IcingDynamicTrie::Options()),
+ IcingDynamicTrie::max_value_index(options.lexicon_options)));
+
+ const std::string term = "foo";
+ ICING_ASSERT_OK_AND_ASSIGN(
+ uint32_t tvi,
+ lite_index->InsertTerm(term, TermMatchType::PREFIX, kNamespace0));
+ ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
+ term_id_codec_->EncodeTvi(tvi, TviType::LITE));
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, hit0));
+ ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, hit1));
+
+ // Check the total size and byte size of the hit buffer.
+ EXPECT_THAT(lite_index, Pointee(SizeIs(2)));
+ EXPECT_THAT(lite_index->GetHitBufferByteSize(),
+ Eq(2 * sizeof(TermIdHitPair::Value)));
+ // Check the unsorted size and byte size of the hit buffer.
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(2));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedByteSize(),
+ Eq(2 * sizeof(TermIdHitPair::Value)));
+
+ // Sort the hit buffer and check again.
+ lite_index->SortHits();
+ EXPECT_THAT(lite_index, Pointee(SizeIs(2)));
+ EXPECT_THAT(lite_index->GetHitBufferByteSize(),
+ Eq(2 * sizeof(TermIdHitPair::Value)));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0));
+ EXPECT_THAT(lite_index->GetHitBufferUnsortedByteSize(), Eq(0));
+}
+
} // namespace
} // namespace lib
} // namespace icing
diff --git a/icing/index/lite/lite-index_thread-safety_test.cc b/icing/index/lite/lite-index_thread-safety_test.cc
index 53aa6cd..a73ca28 100644
--- a/icing/index/lite/lite-index_thread-safety_test.cc
+++ b/icing/index/lite/lite-index_thread-safety_test.cc
@@ -13,15 +13,25 @@
// limitations under the License.
#include <array>
+#include <cstdint>
+#include <memory>
#include <string>
+#include <string_view>
#include <thread>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/lite/lite-index.h"
#include "icing/index/term-id-codec.h"
+#include "icing/legacy/index/icing-dynamic-trie.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/store/namespace-id.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
@@ -109,9 +119,11 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousFetchHits_singleTerm) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
Hit doc_hit1(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId1,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0));
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit1));
@@ -155,9 +167,11 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousFetchHits_multipleTerms) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t term_id,
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
Hit doc_hit1(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId1,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit0));
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit1));
}
@@ -208,7 +222,8 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousAddHitAndFetchHits_singleTerm) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0));
// Create kNumThreads threads. Every even-numbered thread calls FetchHits and
@@ -228,7 +243,8 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousAddHitAndFetchHits_singleTerm) {
} else {
// Odd-numbered thread calls AddHit.
Hit doc_hit(/*section_id=*/thread_id / 2, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit));
}
};
@@ -273,7 +289,8 @@ TEST_F(LiteIndexThreadSafetyTest,
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0));
// Create kNumThreads threads. Every even-numbered thread calls FetchHits and
@@ -302,7 +319,8 @@ TEST_F(LiteIndexThreadSafetyTest,
// Odd-numbered thread calls AddHit.
// AddHit to section 0 of a new doc.
Hit doc_hit(/*section_id=*/kSectionId0, /*document_id=*/thread_id / 2,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit));
}
};
@@ -335,9 +353,11 @@ TEST_F(LiteIndexThreadSafetyTest, ManyAddHitAndOneFetchHits_multipleTerms) {
ICING_ASSERT_OK_AND_ASSIGN(uint32_t term_id,
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
Hit doc_hit1(/*section_id=*/kSectionId1, /*document_id=*/kDocumentId0,
- Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit0));
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit1));
}
@@ -370,7 +390,7 @@ TEST_F(LiteIndexThreadSafetyTest, ManyAddHitAndOneFetchHits_multipleTerms) {
// AddHit to section (thread_id % 5 + 1) of doc 0.
Hit doc_hit(/*section_id=*/thread_id % 5 + 1,
/*document_id=*/kDocumentId0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit));
}
};
diff --git a/icing/index/lite/term-id-hit-pair.h b/icing/index/lite/term-id-hit-pair.h
index 82bd010..760aa76 100644
--- a/icing/index/lite/term-id-hit-pair.h
+++ b/icing/index/lite/term-id-hit-pair.h
@@ -15,25 +15,22 @@
#ifndef ICING_INDEX_TERM_ID_HIT_PAIR_H_
#define ICING_INDEX_TERM_ID_HIT_PAIR_H_
+#include <array>
#include <cstdint>
-#include <limits>
-#include <memory>
-#include <string>
-#include <vector>
#include "icing/index/hit/hit.h"
-#include "icing/util/bit-util.h"
namespace icing {
namespace lib {
class TermIdHitPair {
public:
- // Layout bits: 24 termid + 32 hit value + 8 hit term frequency.
- using Value = uint64_t;
+ // Layout bits: 24 termid + 32 hit value + 8 hit flags + 8 hit term frequency.
+ using Value = std::array<uint8_t, 9>;
static constexpr int kTermIdBits = 24;
static constexpr int kHitValueBits = sizeof(Hit::Value) * 8;
+ static constexpr int kHitFlagsBits = sizeof(Hit::Flags) * 8;
static constexpr int kHitTermFrequencyBits = sizeof(Hit::TermFrequency) * 8;
static const Value kInvalidValue;
@@ -41,33 +38,48 @@ class TermIdHitPair {
explicit TermIdHitPair(Value v = kInvalidValue) : value_(v) {}
TermIdHitPair(uint32_t term_id, const Hit& hit) {
- static_assert(kTermIdBits + kHitValueBits + kHitTermFrequencyBits <=
- sizeof(Value) * 8,
- "TermIdHitPairTooBig");
-
- value_ = 0;
- // Term id goes into the most significant bits because it takes
- // precedent in sorts.
- bit_util::BitfieldSet(term_id, kHitValueBits + kHitTermFrequencyBits,
- kTermIdBits, &value_);
- bit_util::BitfieldSet(hit.value(), kHitTermFrequencyBits, kHitValueBits,
- &value_);
- bit_util::BitfieldSet(hit.term_frequency(), 0, kHitTermFrequencyBits,
- &value_);
+ static_assert(
+ kTermIdBits + kHitValueBits + kHitFlagsBits + kHitTermFrequencyBits <=
+ sizeof(Value) * 8,
+ "TermIdHitPairTooBig");
+
+ // Set termId. Term id takes 3 bytes and goes into value_[0:2] (most
+ // significant bits) because it takes precedent in sorts.
+ value_[0] = static_cast<uint8_t>((term_id >> 16) & 0xff);
+ value_[1] = static_cast<uint8_t>((term_id >> 8) & 0xff);
+ value_[2] = static_cast<uint8_t>((term_id >> 0) & 0xff);
+
+ // Set hit value. Hit value takes 4 bytes and goes into value_[3:6]
+ value_[3] = static_cast<uint8_t>((hit.value() >> 24) & 0xff);
+ value_[4] = static_cast<uint8_t>((hit.value() >> 16) & 0xff);
+ value_[5] = static_cast<uint8_t>((hit.value() >> 8) & 0xff);
+ value_[6] = static_cast<uint8_t>((hit.value() >> 0) & 0xff);
+
+ // Set flags in value_[7].
+ value_[7] = hit.flags();
+
+ // Set term-frequency in value_[8]
+ value_[8] = hit.term_frequency();
}
uint32_t term_id() const {
- return bit_util::BitfieldGet(value_, kHitValueBits + kHitTermFrequencyBits,
- kTermIdBits);
+ return (static_cast<uint32_t>(value_[0]) << 16) |
+ (static_cast<uint32_t>(value_[1]) << 8) |
+ (static_cast<uint32_t>(value_[2]) << 0);
}
Hit hit() const {
- return Hit(
- bit_util::BitfieldGet(value_, kHitTermFrequencyBits, kHitValueBits),
- bit_util::BitfieldGet(value_, 0, kHitTermFrequencyBits));
+ Hit::Value hit_value = (static_cast<uint32_t>(value_[3]) << 24) |
+ (static_cast<uint32_t>(value_[4]) << 16) |
+ (static_cast<uint32_t>(value_[5]) << 8) |
+ (static_cast<uint32_t>(value_[6]) << 0);
+ Hit::Flags hit_flags = value_[7];
+ Hit::TermFrequency term_frequency = value_[8];
+
+ return Hit(hit_value, hit_flags, term_frequency);
}
- Value value() const { return value_; }
+ const Value& value() const { return value_; }
bool operator==(const TermIdHitPair& rhs) const {
return value_ == rhs.value_;
diff --git a/icing/index/lite/term-id-hit-pair_test.cc b/icing/index/lite/term-id-hit-pair_test.cc
new file mode 100644
index 0000000..28855b4
--- /dev/null
+++ b/icing/index/lite/term-id-hit-pair_test.cc
@@ -0,0 +1,95 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/lite/term-id-hit-pair.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/index/hit/hit.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+
+static constexpr DocumentId kSomeDocumentId = 24;
+static constexpr SectionId kSomeSectionid = 5;
+static constexpr Hit::TermFrequency kSomeTermFrequency = 57;
+static constexpr uint32_t kSomeTermId = 129;
+static constexpr uint32_t kSomeSmallerTermId = 1;
+static constexpr uint32_t kSomeLargerTermId = 0b101010101111111100000001;
+
+TEST(TermIdHitPairTest, Accessors) {
+ Hit hit1(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit hit2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true);
+ Hit hit3(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit invalid_hit(Hit::kInvalidValue);
+
+ TermIdHitPair term_id_hit_pair_1(kSomeTermId, hit1);
+ EXPECT_THAT(term_id_hit_pair_1.term_id(), Eq(kSomeTermId));
+ EXPECT_THAT(term_id_hit_pair_1.hit(), Eq(hit1));
+
+ TermIdHitPair term_id_hit_pair_2(kSomeLargerTermId, hit2);
+ EXPECT_THAT(term_id_hit_pair_2.term_id(), Eq(kSomeLargerTermId));
+ EXPECT_THAT(term_id_hit_pair_2.hit(), Eq(hit2));
+
+ TermIdHitPair term_id_hit_pair_3(kSomeTermId, invalid_hit);
+ EXPECT_THAT(term_id_hit_pair_3.term_id(), Eq(kSomeTermId));
+ EXPECT_THAT(term_id_hit_pair_3.hit(), Eq(invalid_hit));
+}
+
+TEST(TermIdHitPairTest, Comparison) {
+ Hit hit(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit smaller_hit(/*section_id=*/1, /*document_id=*/100, /*term_frequency=*/1,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+
+ TermIdHitPair term_id_hit_pair(kSomeTermId, hit);
+ TermIdHitPair term_id_hit_pair_equal(kSomeTermId, hit);
+ TermIdHitPair term_id_hit_pair_smaller_hit(kSomeTermId, smaller_hit);
+ TermIdHitPair term_id_hit_pair_smaller_term_id(kSomeSmallerTermId, hit);
+ TermIdHitPair term_id_hit_pair_larger_term_id(kSomeLargerTermId, hit);
+ TermIdHitPair term_id_hit_pair_smaller_term_id_and_hit(kSomeSmallerTermId,
+ smaller_hit);
+
+ std::vector<TermIdHitPair> term_id_hit_pairs{
+ term_id_hit_pair,
+ term_id_hit_pair_equal,
+ term_id_hit_pair_smaller_hit,
+ term_id_hit_pair_smaller_term_id,
+ term_id_hit_pair_larger_term_id,
+ term_id_hit_pair_smaller_term_id_and_hit};
+ std::sort(term_id_hit_pairs.begin(), term_id_hit_pairs.end());
+ EXPECT_THAT(term_id_hit_pairs,
+ ElementsAre(term_id_hit_pair_smaller_term_id_and_hit,
+ term_id_hit_pair_smaller_term_id,
+ term_id_hit_pair_smaller_hit, term_id_hit_pair_equal,
+ term_id_hit_pair, term_id_hit_pair_larger_term_id));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/main/main-index-merger.cc b/icing/index/main/main-index-merger.cc
index c26a6d7..cc130c2 100644
--- a/icing/index/main/main-index-merger.cc
+++ b/icing/index/main/main-index-merger.cc
@@ -14,14 +14,20 @@
#include "icing/index/main/main-index-merger.h"
+#include <algorithm>
#include <cstdint>
#include <cstring>
-#include <memory>
#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
-#include "icing/file/posting_list/index-block.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/lite/lite-index.h"
#include "icing/index/lite/term-id-hit-pair.h"
+#include "icing/index/main/main-index.h"
#include "icing/index/term-id-codec.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/util/logging.h"
diff --git a/icing/index/main/main-index-merger_test.cc b/icing/index/main/main-index-merger_test.cc
index 37e14fc..333e338 100644
--- a/icing/index/main/main-index-merger_test.cc
+++ b/icing/index/main/main-index-merger_test.cc
@@ -13,19 +13,23 @@
// limitations under the License.
#include "icing/index/main/main-index-merger.h"
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "icing/absl_ports/canonical_errors.h"
#include "icing/file/filesystem.h"
-#include "icing/index/iterator/doc-hit-info-iterator.h"
-#include "icing/index/main/doc-hit-info-iterator-term-main.h"
-#include "icing/index/main/main-index-merger.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/lite/lite-index.h"
+#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/index/main/main-index.h"
#include "icing/index/term-id-codec.h"
-#include "icing/index/term-property-id.h"
#include "icing/legacy/index/icing-dynamic-trie.h"
#include "icing/legacy/index/icing-filesystem.h"
-#include "icing/schema/section.h"
#include "icing/store/namespace-id.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
@@ -89,10 +93,10 @@ TEST_F(MainIndexMergerTest, TranslateTermNotAdded) {
term_id_codec_->EncodeTvi(fool_tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit));
// 2. Build up a fake LexiconMergeOutputs
@@ -128,10 +132,10 @@ TEST_F(MainIndexMergerTest, PrefixExpansion) {
term_id_codec_->EncodeTvi(fool_tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit));
// 2. Build up a fake LexiconMergeOutputs
@@ -191,11 +195,11 @@ TEST_F(MainIndexMergerTest, DedupePrefixAndExactWithDifferentTermFrequencies) {
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit));
Hit foo_doc0_hit(/*section_id=*/0, /*document_id=*/0,
- Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/true,
+ /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, foo_doc0_hit));
// 2. Build up a fake LexiconMergeOutputs
@@ -255,10 +259,10 @@ TEST_F(MainIndexMergerTest, DedupeWithExactSameTermFrequencies) {
term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE));
Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit));
Hit foo_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, foo_doc0_hit));
// The prefix hit should take the sum as term_frequency - 114.
Hit prefix_foo_doc0_hit(/*section_id=*/0, /*document_id=*/0,
@@ -320,11 +324,11 @@ TEST_F(MainIndexMergerTest, DedupePrefixExpansion) {
Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0,
/*term_frequency=*/Hit::kMaxTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit));
Hit fool_doc0_hit(/*section_id=*/0, /*document_id=*/0,
Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, fool_doc0_hit));
// 2. Build up a fake LexiconMergeOutputs
diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc
index aae60c6..85ee4dc 100644
--- a/icing/index/main/main-index.cc
+++ b/icing/index/main/main-index.cc
@@ -651,7 +651,7 @@ libtextclassifier3::Status MainIndex::AddPrefixBackfillHits(
ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor->GetNextHitsBatch());
}
- Hit last_added_hit;
+ Hit last_added_hit(Hit::kInvalidValue);
// The hits in backfill_hits are in the reverse order of how they were added.
// Iterate in reverse to add them to this new posting list in the correct
// order.
diff --git a/icing/index/main/main-index_test.cc b/icing/index/main/main-index_test.cc
index fa96e6c..db9dbe2 100644
--- a/icing/index/main/main-index_test.cc
+++ b/icing/index/main/main-index_test.cc
@@ -14,23 +14,33 @@
#include "icing/index/main/main-index.h"
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "icing/absl_ports/canonical_errors.h"
#include "icing/file/filesystem.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/lite/lite-index.h"
#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/index/main/doc-hit-info-iterator-term-main.h"
#include "icing/index/main/main-index-merger.h"
#include "icing/index/term-id-codec.h"
-#include "icing/index/term-property-id.h"
#include "icing/legacy/index/icing-dynamic-trie.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/legacy/index/icing-mock-filesystem.h"
#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/namespace-id.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -152,7 +162,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForPrefixReturnsValidAccessor) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -178,7 +188,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForPrefixReturnsNotFound) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -217,7 +227,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForExactReturnsValidAccessor) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -254,18 +264,18 @@ TEST_F(MainIndexTest, MergeIndexToEmpty) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit));
ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit));
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit));
Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc2_hit));
ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc2_hit));
@@ -332,18 +342,18 @@ TEST_F(MainIndexTest, MergeIndexToPreexisting) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit));
ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit));
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit));
Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc2_hit));
ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc2_hit));
@@ -387,14 +397,14 @@ TEST_F(MainIndexTest, MergeIndexToPreexisting) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc3_hit(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc3_hit));
ICING_ASSERT_OK(lite_index_->AddHit(four_term_id, doc3_hit));
ICING_ASSERT_OK(lite_index_->AddHit(foul_term_id, doc3_hit));
ICING_ASSERT_OK(lite_index_->AddHit(fall_term_id, doc3_hit));
Hit doc4_hit(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(four_term_id, doc4_hit));
ICING_ASSERT_OK(lite_index_->AddHit(foul_term_id, doc4_hit));
@@ -449,15 +459,15 @@ TEST_F(MainIndexTest, ExactRetrievedInPrefixSearch) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit));
Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc2_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -500,15 +510,15 @@ TEST_F(MainIndexTest, PrefixNotRetrievedInExactSearch) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit));
Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc2_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -554,17 +564,17 @@ TEST_F(MainIndexTest,
document_id % Hit::kMaxTermFrequency + 1);
Hit doc_hit0(
/*section_id=*/0, /*document_id=*/document_id, term_frequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit0));
Hit doc_hit1(
/*section_id=*/1, /*document_id=*/document_id, term_frequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit1));
Hit doc_hit2(
/*section_id=*/2, /*document_id=*/document_id, term_frequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit2));
}
@@ -619,7 +629,7 @@ TEST_F(MainIndexTest, MergeIndexBackfilling) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/true);
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit));
// 2. Create the main index. It should have no entries in its lexicon.
@@ -648,7 +658,7 @@ TEST_F(MainIndexTest, MergeIndexBackfilling) {
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit));
// 5. Merge the index. The main index should now contain "fool", "foot"
@@ -682,7 +692,7 @@ TEST_F(MainIndexTest, OneHitInTheFirstPageForTwoPagesMainIndex) {
uint32_t num_docs = 2038;
for (DocumentId document_id = 0; document_id < num_docs; ++document_id) {
Hit doc_hit(section_id, document_id, Hit::kDefaultTermFrequency,
- /*is_in_prefix_section=*/false);
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit));
}
diff --git a/icing/index/main/posting-list-hit-accessor_test.cc b/icing/index/main/posting-list-hit-accessor_test.cc
index 1127814..c2460ff 100644
--- a/icing/index/main/posting-list-hit-accessor_test.cc
+++ b/icing/index/main/posting-list-hit-accessor_test.cc
@@ -15,14 +15,19 @@
#include "icing/index/main/posting-list-hit-accessor.h"
#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/file/filesystem.h"
#include "icing/file/posting_list/flash-index-storage.h"
-#include "icing/file/posting_list/index-block.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-common.h"
#include "icing/file/posting_list/posting-list-identifier.h"
-#include "icing/file/posting_list/posting-list-used.h"
#include "icing/index/hit/hit.h"
#include "icing/index/main/posting-list-hit-serializer.h"
#include "icing/testing/common-matchers.h"
@@ -102,7 +107,8 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLKeepOnSameBlock) {
PostingListHitAccessor::Create(flash_index_storage_.get(),
serializer_.get()));
// Add a single hit. This will fit in a min-sized posting list.
- Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency);
+ Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
PostingListAccessor::FinalizeResult result1 =
std::move(*pl_accessor).Finalize();
@@ -139,12 +145,12 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) {
std::unique_ptr<PostingListHitAccessor> pl_accessor,
PostingListHitAccessor::Create(flash_index_storage_.get(),
serializer_.get()));
- // The smallest posting list size is 15 bytes. The first four hits will be
- // compressed to one byte each and will be able to fit in the 5 byte padded
- // region. The last hit will fit in one of the special hits. The posting list
- // will be ALMOST_FULL and can fit at most 2 more hits.
+ // Use a small posting list of 30 bytes. The first 17 hits will be compressed
+ // to one byte each and will be able to fit in the 18 byte padded region. The
+ // last hit will fit in one of the special hits. The posting list will be
+ // ALMOST_FULL and can fit at most 2 more hits.
std::vector<Hit> hits1 =
- CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
+ CreateHits(/*num_hits=*/18, /*desired_byte_length=*/1);
for (const Hit& hit : hits1) {
ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
@@ -160,10 +166,9 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) {
pl_accessor,
PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), serializer_.get(), result1.id));
- // The current posting list can fit at most 2 more hits. Adding 12 more hits
- // should result in these hits being moved to a larger posting list.
+ // The current posting list can fit at most 2 more hits.
std::vector<Hit> hits2 = CreateHits(
- /*start_docid=*/hits1.back().document_id() + 1, /*num_hits=*/12,
+ /*last_hit=*/hits1.back(), /*num_hits=*/2,
/*desired_byte_length=*/1);
for (const Hit& hit : hits2) {
@@ -172,18 +177,36 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) {
PostingListAccessor::FinalizeResult result2 =
std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result2.status);
+ // The 2 hits should still fit on the first block
+ EXPECT_THAT(result1.id.block_index(), Eq(1));
+ EXPECT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more hit
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result2.id));
+ // The current posting list should be FULL. Adding more hits should result in
+ // these hits being moved to a larger posting list.
+ Hit single_hit =
+ CreateHit(/*last_hit=*/hits2.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(pl_accessor->PrependHit(single_hit));
+ PostingListAccessor::FinalizeResult result3 =
+ std::move(*pl_accessor).Finalize();
+ ICING_EXPECT_OK(result3.status);
// Should have been allocated to the second (new) block because the posting
// list should have grown beyond the size that the first block maintains.
- EXPECT_THAT(result2.id.block_index(), Eq(2));
- EXPECT_THAT(result2.id.posting_list_index(), Eq(0));
+ EXPECT_THAT(result3.id.block_index(), Eq(2));
+ EXPECT_THAT(result3.id.posting_list_index(), Eq(0));
- // The posting list at result2.id should hold all of the hits that have been
+ // The posting list at result3.id should hold all of the hits that have been
// added.
for (const Hit& hit : hits2) {
hits1.push_back(hit);
}
+ hits1.push_back(single_hit);
ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
- flash_index_storage_->GetPostingList(result2.id));
+ flash_index_storage_->GetPostingList(result3.id));
EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list),
IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
}
@@ -307,7 +330,7 @@ TEST_F(PostingListHitAccessorTest, InvalidHitReturnsInvalidArgument) {
std::unique_ptr<PostingListHitAccessor> pl_accessor,
PostingListHitAccessor::Create(flash_index_storage_.get(),
serializer_.get()));
- Hit invalid_hit;
+ Hit invalid_hit(Hit::kInvalidValue);
EXPECT_THAT(pl_accessor->PrependHit(invalid_hit),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -317,14 +340,17 @@ TEST_F(PostingListHitAccessorTest, HitsNotDecreasingReturnsInvalidArgument) {
std::unique_ptr<PostingListHitAccessor> pl_accessor,
PostingListHitAccessor::Create(flash_index_storage_.get(),
serializer_.get()));
- Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency);
+ Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
- Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency);
+ Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(pl_accessor->PrependHit(hit2),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency);
+ Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
EXPECT_THAT(pl_accessor->PrependHit(hit3),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -345,7 +371,8 @@ TEST_F(PostingListHitAccessorTest, PreexistingPostingListNoHitsAdded) {
std::unique_ptr<PostingListHitAccessor> pl_accessor,
PostingListHitAccessor::Create(flash_index_storage_.get(),
serializer_.get()));
- Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency);
+ Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
PostingListAccessor::FinalizeResult result1 =
std::move(*pl_accessor).Finalize();
diff --git a/icing/index/main/posting-list-hit-serializer.cc b/icing/index/main/posting-list-hit-serializer.cc
index e14a0c0..88c0754 100644
--- a/icing/index/main/posting-list-hit-serializer.cc
+++ b/icing/index/main/posting-list-hit-serializer.cc
@@ -19,8 +19,11 @@
#include <limits>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/hit/hit.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/legacy/index/icing-bit-util.h"
#include "icing/util/logging.h"
@@ -35,6 +38,10 @@ uint32_t GetTermFrequencyByteSize(const Hit& hit) {
return hit.has_term_frequency() ? sizeof(Hit::TermFrequency) : 0;
}
+uint32_t GetFlagsByteSize(const Hit& hit) {
+ return hit.has_flags() ? sizeof(Hit::Flags) : 0;
+}
+
} // namespace
uint32_t PostingListHitSerializer::GetBytesUsed(
@@ -55,14 +62,45 @@ uint32_t PostingListHitSerializer::GetMinPostingListSizeToFit(
return posting_list_used->size_in_bytes();
}
- // In NOT_FULL status BytesUsed contains no special hits. The minimum sized
- // posting list that would be guaranteed to fit these hits would be
- // ALMOST_FULL, with kInvalidHit in special_hit(0), the uncompressed Hit in
- // special_hit(1) and the n compressed hits in the compressed region.
- // BytesUsed contains one uncompressed Hit and n compressed hits. Therefore,
- // fitting these hits into a posting list would require BytesUsed plus one
- // extra hit.
- return GetBytesUsed(posting_list_used) + sizeof(Hit);
+ // Edge case: if number of hits <= 2, and posting_list_used is NOT_FULL, we
+ // should return kMinPostingListSize as we know that two hits is able to fit
+ // in the 2 special hits position of a min-sized posting list.
+ // - Checking this scenario requires deserializing posting_list_used, which is
+ // an expensive operation. However this is only possible when
+ // GetBytesUsed(posting_list_used) <=
+ // sizeof(uncompressed hit) + max(delta encoded value size) +
+ // (sizeof(Hit) + (sizeof(Hit) - sizeof(Hit::Value)), so we don't need to do
+ // this when the size exceeds this number.
+ if (GetBytesUsed(posting_list_used) <=
+ 2 * sizeof(Hit) + 5 - sizeof(Hit::Value)) {
+ // Check if we're able to get more than 2 hits from posting_list_used
+ std::vector<Hit> hits;
+ libtextclassifier3::Status status =
+ GetHitsInternal(posting_list_used, /*limit=*/3, /*pop=*/false, &hits);
+ if (status.ok() && hits.size() <= 2) {
+ return GetMinPostingListSize();
+ }
+ }
+
+ // - In NOT_FULL status, BytesUsed contains no special hits. For a posting
+ // list in the NOT_FULL state with n hits, we would have n-1 compressed hits
+ // and 1 uncompressed hit.
+ // - The minimum sized posting list that would be guaranteed to fit these hits
+ // would be FULL, but calculating the size required for the FULL posting
+ // list would require deserializing the last two added hits, so instead we
+ // will calculate the size of an ALMOST_FULL posting list to fit.
+ // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the
+ // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in
+ // the compressed region.
+ // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed
+ // hits. However, it's possible that the uncompressed Hit is not a full hit,
+ // but rather only the Hit::Value (this is the case if
+ // !hit.has_term_frequency()).
+ // - Therefore, fitting these hits into a posting list would require
+ // BytesUsed + one extra full hit + byte difference between a full hit and
+ // Hit::Value. i.e:
+ // ByteUsed + sizeof(Hit) + (sizeof(Hit) - sizeof(Hit::Value)).
+ return GetBytesUsed(posting_list_used) + 2 * sizeof(Hit) - sizeof(Hit::Value);
}
void PostingListHitSerializer::Clear(PostingListUsed* posting_list_used) const {
@@ -160,27 +198,37 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToAlmostFull(
// in the padded area and put new hit at the special position 1.
// Calling ValueOrDie is safe here because 1 < kNumSpecialData.
Hit cur = GetSpecialHit(posting_list_used, /*index=*/1).ValueOrDie();
- if (cur.value() <= hit.value()) {
+ if (cur < hit || cur == hit) {
return absl_ports::InvalidArgumentError(
"Hit being prepended must be strictly less than the most recent Hit");
}
- uint64_t delta = cur.value() - hit.value();
uint8_t delta_buf[VarInt::kMaxEncodedLen64];
- size_t delta_len = VarInt::Encode(delta, delta_buf);
+ size_t delta_len =
+ EncodeNextHitValue(/*next_hit_value=*/hit.value(),
+ /*curr_hit_value=*/cur.value(), delta_buf);
+ uint32_t cur_flags_bytes = GetFlagsByteSize(cur);
uint32_t cur_term_frequency_bytes = GetTermFrequencyByteSize(cur);
uint32_t pad_end = GetPadEnd(posting_list_used,
/*offset=*/kSpecialHitsSize);
- if (pad_end >= kSpecialHitsSize + delta_len + cur_term_frequency_bytes) {
- // Pad area has enough space for delta and term_frequency of existing hit
- // (cur). Write delta at pad_end - delta_len - cur_term_frequency_bytes.
+ if (pad_end >= kSpecialHitsSize + delta_len + cur_flags_bytes +
+ cur_term_frequency_bytes) {
+ // Pad area has enough space for delta, flags and term_frequency of existing
+ // hit (cur). Write delta at pad_end - delta_len - cur_flags_bytes -
+ // cur_term_frequency_bytes.
uint8_t* delta_offset = posting_list_used->posting_list_buffer() + pad_end -
- delta_len - cur_term_frequency_bytes;
+ delta_len - cur_flags_bytes -
+ cur_term_frequency_bytes;
memcpy(delta_offset, delta_buf, delta_len);
- // Now copy term_frequency.
+
+ // Now copy flags.
+ Hit::Flags flags = cur.flags();
+ uint8_t* flags_offset = delta_offset + delta_len;
+ memcpy(flags_offset, &flags, cur_flags_bytes);
+ // Copy term_frequency.
Hit::TermFrequency term_frequency = cur.term_frequency();
- uint8_t* term_frequency_offset = delta_offset + delta_len;
+ uint8_t* term_frequency_offset = flags_offset + cur_flags_bytes;
memcpy(term_frequency_offset, &term_frequency, cur_term_frequency_bytes);
// Now first hit is the new hit, at special position 1. Safe to ignore the
@@ -231,23 +279,36 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull(
return absl_ports::FailedPreconditionError(
"Posting list is in an invalid state.");
}
+
+ // Retrieve the last added (cur) hit's value and flags and compare to the hit
+ // we're adding.
Hit::Value cur_value;
- memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset,
- sizeof(Hit::Value));
- if (cur_value <= hit.value()) {
+ uint8_t* cur_value_offset = posting_list_used->posting_list_buffer() + offset;
+ memcpy(&cur_value, cur_value_offset, sizeof(Hit::Value));
+ Hit::Flags cur_flags = Hit::kNoEnabledFlags;
+ if (GetFlagsByteSize(Hit(cur_value)) > 0) {
+ uint8_t* cur_flags_offset = cur_value_offset + sizeof(Hit::Value);
+ memcpy(&cur_flags, cur_flags_offset, sizeof(Hit::Flags));
+ }
+ // Term-frequency is not used for hit comparison so it's ok to pass in the
+ // default term-frequency here.
+ Hit cur(cur_value, cur_flags, Hit::kDefaultTermFrequency);
+ if (cur < hit || cur == hit) {
return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
- "Hit %d being prepended must be strictly less than the most recent "
- "Hit %d",
- hit.value(), cur_value));
+ "Hit (value=%d, flags=%d) being prepended must be strictly less than "
+ "the most recent Hit (value=%d, flags=%d)",
+ hit.value(), hit.flags(), cur_value, cur_flags));
}
- uint64_t delta = cur_value - hit.value();
uint8_t delta_buf[VarInt::kMaxEncodedLen64];
- size_t delta_len = VarInt::Encode(delta, delta_buf);
+ size_t delta_len =
+ EncodeNextHitValue(/*next_hit_value=*/hit.value(),
+ /*curr_hit_value=*/cur.value(), delta_buf);
+ uint32_t hit_flags_bytes = GetFlagsByteSize(hit);
uint32_t hit_term_frequency_bytes = GetTermFrequencyByteSize(hit);
// offset now points to one past the end of the first hit.
offset += sizeof(Hit::Value);
- if (kSpecialHitsSize + sizeof(Hit::Value) + delta_len +
+ if (kSpecialHitsSize + sizeof(Hit::Value) + delta_len + hit_flags_bytes +
hit_term_frequency_bytes <=
offset) {
// Enough space for delta in compressed area.
@@ -257,9 +318,9 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull(
memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
delta_len);
- // Prepend new hit with (possibly) its term_frequency. We know that there is
- // room for 'hit' because of the if statement above, so calling ValueOrDie
- // is safe.
+ // Prepend new hit with (possibly) its flags and term_frequency. We know
+ // that there is room for 'hit' because of the if statement above, so
+ // calling ValueOrDie is safe.
offset =
PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie();
// offset is guaranteed to be valid here. So it's safe to ignore the return
@@ -295,13 +356,13 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull(
// (i.e. varint delta encoding expanded required storage). We
// move first hit to special position 1 and put new hit in
// special position 0.
- Hit cur(cur_value);
- // offset is < kSpecialHitsSize + delta_len. delta_len is at most 5 bytes.
+
+ // Offset is < kSpecialHitsSize + delta_len. delta_len is at most 5 bytes.
// Therefore, offset must be less than kSpecialHitSize + 5. Since posting
- // list size must be divisible by sizeof(Hit) (5), it is guaranteed that
+ // list size must be divisible by sizeof(Hit) (6), it is guaranteed that
// offset < size_in_bytes, so it is safe to ignore the return value here.
- ICING_RETURN_IF_ERROR(
- ConsumeTermFrequencyIfPresent(posting_list_used, &cur, &offset));
+ ICING_RETURN_IF_ERROR(ConsumeFlagsAndTermFrequencyIfPresent(
+ posting_list_used, &cur, &offset));
// Safe to ignore the return value of PadToEnd because offset must be less
// than posting_list_used->size_in_bytes(). Otherwise, this function
// already would have returned FAILED_PRECONDITION.
@@ -463,19 +524,19 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal(
offset += sizeof(Hit::Value);
} else {
// Now we have delta encoded subsequent hits. Decode and push.
- uint64_t delta;
- offset += VarInt::Decode(
- posting_list_used->posting_list_buffer() + offset, &delta);
- val += delta;
+ DecodedHitInfo decoded_hit_info = DecodeNextHitValue(
+ posting_list_used->posting_list_buffer() + offset, val);
+ offset += decoded_hit_info.encoded_size;
+ val = decoded_hit_info.hit_value;
}
Hit hit(val);
libtextclassifier3::Status status =
- ConsumeTermFrequencyIfPresent(posting_list_used, &hit, &offset);
+ ConsumeFlagsAndTermFrequencyIfPresent(posting_list_used, &hit, &offset);
if (!status.ok()) {
// This posting list has been corrupted somehow. The first hit of the
- // posting list claims to have a term frequency, but there's no more room
- // in the posting list for that term frequency to exist. Return an empty
- // vector and zero to indicate no hits retrieved.
+ // posting list claims to have a term frequency or flag, but there's no
+ // more room in the posting list for that term frequency or flag to exist.
+ // Return an empty vector and zero to indicate no hits retrieved.
if (out != nullptr) {
out->clear();
}
@@ -497,10 +558,10 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal(
// In the compressed area. Pop and reconstruct. offset/val is
// the last traversed hit, which we must discard. So move one
// more forward.
- uint64_t delta;
- offset += VarInt::Decode(
- posting_list_used->posting_list_buffer() + offset, &delta);
- val += delta;
+ DecodedHitInfo decoded_hit_info = DecodeNextHitValue(
+ posting_list_used->posting_list_buffer() + offset, val);
+ offset += decoded_hit_info.encoded_size;
+ val = decoded_hit_info.hit_value;
// Now val is the first hit of the new posting list.
if (kSpecialHitsSize + sizeof(Hit::Value) <= offset) {
@@ -509,17 +570,18 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal(
memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val,
sizeof(Hit::Value));
} else {
- // val won't fit in compressed area. Also see if there is a
+ // val won't fit in compressed area. Also see if there is a flag or
// term_frequency.
Hit hit(val);
libtextclassifier3::Status status =
- ConsumeTermFrequencyIfPresent(posting_list_used, &hit, &offset);
+ ConsumeFlagsAndTermFrequencyIfPresent(posting_list_used, &hit,
+ &offset);
if (!status.ok()) {
// This posting list has been corrupted somehow. The first hit of
- // the posting list claims to have a term frequency, but there's no
- // more room in the posting list for that term frequency to exist.
- // Return an empty vector and zero to indicate no hits retrieved. Do
- // not pop anything.
+ // the posting list claims to have a term frequency or flag, but
+ // there's no more room in the posting list for that term frequency or
+ // flag to exist. Return an empty vector and zero to indicate no hits
+ // retrieved. Do not pop anything.
if (out != nullptr) {
out->clear();
}
@@ -569,7 +631,7 @@ libtextclassifier3::StatusOr<Hit> PostingListHitSerializer::GetSpecialHit(
return absl_ports::InvalidArgumentError(
"Special hits only exist at indices 0 and 1");
}
- Hit val;
+ Hit val(Hit::kInvalidValue);
memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val),
sizeof(val));
return val;
@@ -653,11 +715,11 @@ bool PostingListHitSerializer::SetStartByteOffset(
// not_full state. Safe to ignore the return value because 0 and 1 are both
// < kNumSpecialData.
SetSpecialHit(posting_list_used, /*index=*/0, Hit(offset));
- SetSpecialHit(posting_list_used, /*index=*/1, Hit());
+ SetSpecialHit(posting_list_used, /*index=*/1, Hit(Hit::kInvalidValue));
} else if (offset == sizeof(Hit)) {
// almost_full state. Safe to ignore the return value because 1 is both <
// kNumSpecialData.
- SetSpecialHit(posting_list_used, /*index=*/0, Hit());
+ SetSpecialHit(posting_list_used, /*index=*/0, Hit(Hit::kInvalidValue));
}
// Nothing to do for the FULL state - the offset isn't actually stored
// anywhere and both special hits hold valid hits.
@@ -667,46 +729,72 @@ bool PostingListHitSerializer::SetStartByteOffset(
libtextclassifier3::StatusOr<uint32_t>
PostingListHitSerializer::PrependHitUncompressed(
PostingListUsed* posting_list_used, const Hit& hit, uint32_t offset) const {
- if (hit.has_term_frequency()) {
- if (offset < kSpecialHitsSize + sizeof(Hit)) {
- return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
- "Not enough room to prepend Hit at offset %d.", offset));
- }
- offset -= sizeof(Hit);
- memcpy(posting_list_used->posting_list_buffer() + offset, &hit,
- sizeof(Hit));
- } else {
- if (offset < kSpecialHitsSize + sizeof(Hit::Value)) {
- return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
- "Not enough room to prepend Hit::Value at offset %d.", offset));
- }
- offset -= sizeof(Hit::Value);
- Hit::Value val = hit.value();
- memcpy(posting_list_used->posting_list_buffer() + offset, &val,
- sizeof(Hit::Value));
+ uint32_t hit_bytes_to_prepend = sizeof(Hit::Value) + GetFlagsByteSize(hit) +
+ GetTermFrequencyByteSize(hit);
+
+ if (offset < kSpecialHitsSize + hit_bytes_to_prepend) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Not enough room to prepend Hit at offset %d.", offset));
}
+
+ if (hit.has_term_frequency()) {
+ offset -= sizeof(Hit::TermFrequency);
+ Hit::TermFrequency term_frequency = hit.term_frequency();
+ memcpy(posting_list_used->posting_list_buffer() + offset, &term_frequency,
+ sizeof(Hit::TermFrequency));
+ }
+ if (hit.has_flags()) {
+ offset -= sizeof(Hit::Flags);
+ Hit::Flags flags = hit.flags();
+ memcpy(posting_list_used->posting_list_buffer() + offset, &flags,
+ sizeof(Hit::Flags));
+ }
+ offset -= sizeof(Hit::Value);
+ Hit::Value val = hit.value();
+ memcpy(posting_list_used->posting_list_buffer() + offset, &val,
+ sizeof(Hit::Value));
return offset;
}
libtextclassifier3::Status
-PostingListHitSerializer::ConsumeTermFrequencyIfPresent(
+PostingListHitSerializer::ConsumeFlagsAndTermFrequencyIfPresent(
const PostingListUsed* posting_list_used, Hit* hit,
uint32_t* offset) const {
- if (!hit->has_term_frequency()) {
- // No term frequency to consume. Everything is fine.
+ if (!hit->has_flags()) {
+ // No flags (and by extension, no term-frequency) to consume. Everything is
+ // fine.
return libtextclassifier3::Status::OK;
}
- if (*offset + sizeof(Hit::TermFrequency) >
- posting_list_used->size_in_bytes()) {
+
+ // Consume flags
+ Hit::Flags flags;
+ if (*offset + sizeof(Hit::Flags) > posting_list_used->size_in_bytes()) {
return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
- "offset %d must not point past the end of the posting list of size %d.",
+ "offset %d must not point past the end of the posting list of size "
+ "%d.",
*offset, posting_list_used->size_in_bytes()));
}
- Hit::TermFrequency term_frequency;
- memcpy(&term_frequency, posting_list_used->posting_list_buffer() + *offset,
- sizeof(Hit::TermFrequency));
- *hit = Hit(hit->value(), term_frequency);
- *offset += sizeof(Hit::TermFrequency);
+ memcpy(&flags, posting_list_used->posting_list_buffer() + *offset,
+ sizeof(Hit::Flags));
+ *hit = Hit(hit->value(), flags, Hit::kDefaultTermFrequency);
+ *offset += sizeof(Hit::Flags);
+
+ if (hit->has_term_frequency()) {
+ // Consume term frequency
+ if (*offset + sizeof(Hit::TermFrequency) >
+ posting_list_used->size_in_bytes()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "offset %d must not point past the end of the posting list of size "
+ "%d.",
+ *offset, posting_list_used->size_in_bytes()));
+ }
+ Hit::TermFrequency term_frequency;
+ memcpy(&term_frequency, posting_list_used->posting_list_buffer() + *offset,
+ sizeof(Hit::TermFrequency));
+ *hit = Hit(hit->value(), flags, term_frequency);
+ *offset += sizeof(Hit::TermFrequency);
+ }
+
return libtextclassifier3::Status::OK;
}
diff --git a/icing/index/main/posting-list-hit-serializer.h b/icing/index/main/posting-list-hit-serializer.h
index 2986d9c..08c792c 100644
--- a/icing/index/main/posting-list-hit-serializer.h
+++ b/icing/index/main/posting-list-hit-serializer.h
@@ -15,6 +15,7 @@
#ifndef ICING_INDEX_MAIN_POSTING_LIST_HIT_SERIALIZER_H_
#define ICING_INDEX_MAIN_POSTING_LIST_HIT_SERIALIZER_H_
+#include <cstddef>
#include <cstdint>
#include <vector>
@@ -23,6 +24,7 @@
#include "icing/file/posting_list/posting-list-common.h"
#include "icing/file/posting_list/posting-list-used.h"
#include "icing/index/hit/hit.h"
+#include "icing/legacy/index/icing-bit-util.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -34,6 +36,52 @@ class PostingListHitSerializer : public PostingListSerializer {
public:
static constexpr uint32_t kSpecialHitsSize = kNumSpecialData * sizeof(Hit);
+ struct DecodedHitInfo {
+ // The decoded hit value.
+ Hit::Value hit_value;
+
+ // Size of the encoded hit in bytes.
+ size_t encoded_size;
+ };
+
+ // Given the current hit value, encodes the next hit value for serialization
+ // in the posting list.
+ //
+ // The encoded value is the varint-encoded delta between next_hit_value and
+ // curr_hit_value.
+ // - We add 1 to this delta so as to avoid getting a delta value of 0.
+ // - This allows us to add duplicate hits with the same value, which is a
+ // valid case if we need to store hits with different flags that belong in
+ // the same section-id/doc-id combo.
+ // - We cannot have an encoded hit delta with a value of 0 as 0 is currently
+ // used for padding the unused region in the posting list.
+ //
+ // REQUIRES: next_hit_value <= curr_hit_value AND
+ // curr_hit_value - next_hit_value <
+ // std::numeric_limits<Hit::Value>::max()
+ //
+ // RETURNS: next_hit_value's encoded length in bytes and writes the encoded
+ // value directly into encoded_buf_out.
+ static size_t EncodeNextHitValue(Hit::Value next_hit_value,
+ Hit::Value curr_hit_value,
+ uint8_t* encoded_buf_out) {
+ uint64_t delta = curr_hit_value - next_hit_value + 1;
+ return VarInt::Encode(delta, encoded_buf_out);
+ }
+
+ // Given the current hit value, decodes the next hit value from an encoded
+ // byte array buffer.
+ //
+ // RETURNS: DecodedHitInfo containing the decoded hit value and the value's
+ // encoded size in bytes.
+ static DecodedHitInfo DecodeNextHitValue(const uint8_t* encoded_buf_in,
+ Hit::Value curr_hit_value) {
+ uint64_t delta;
+ size_t delta_size = VarInt::Decode(encoded_buf_in, &delta);
+ Hit::Value hit_value = curr_hit_value + delta - 1;
+ return {hit_value, delta_size};
+ }
+
uint32_t GetDataTypeBytes() const override { return sizeof(Hit); }
uint32_t GetMinPostingListSize() const override {
@@ -299,15 +347,19 @@ class PostingListHitSerializer : public PostingListSerializer {
PostingListUsed* posting_list_used, const Hit& hit,
uint32_t offset) const;
- // If hit has a term frequency, consumes the term frequency at offset, updates
- // hit to include the term frequency and updates offset to reflect that the
- // term frequency has been consumed.
+ // If hit has the flags and/or term frequency field, consumes the flags and/or
+ // term frequency at offset, updates hit to include the flag and/or term
+ // frequency and updates offset to reflect that the flag and/or term frequency
+ // fields have been consumed.
//
// RETURNS:
// - OK, if successful
- // - INVALID_ARGUMENT if hit has a term frequency and offset +
- // sizeof(Hit::TermFrequency) >= posting_list_used->size_in_bytes()
- libtextclassifier3::Status ConsumeTermFrequencyIfPresent(
+ // - INVALID_ARGUMENT if hit has a flags and/or term frequency field and
+ // offset + sizeof(Hit's flag) + sizeof(Hit's tf) >=
+ // posting_list_used->size_in_bytes()
+ // i.e. the posting list is not large enough to consume the hit's flags
+ // and term frequency fields
+ libtextclassifier3::Status ConsumeFlagsAndTermFrequencyIfPresent(
const PostingListUsed* posting_list_used, Hit* hit,
uint32_t* offset) const;
};
diff --git a/icing/index/main/posting-list-hit-serializer_test.cc b/icing/index/main/posting-list-hit-serializer_test.cc
index 7f0b945..ea135ef 100644
--- a/icing/index/main/posting-list-hit-serializer_test.cc
+++ b/icing/index/main/posting-list-hit-serializer_test.cc
@@ -14,22 +14,32 @@
#include "icing/index/main/posting-list-hit-serializer.h"
+#include <algorithm>
+#include <cstddef>
#include <cstdint>
#include <deque>
-#include <memory>
+#include <iterator>
+#include <limits>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/hit/hit.h"
+#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/hit-test-utils.h"
using testing::ElementsAre;
using testing::ElementsAreArray;
using testing::Eq;
+using testing::Gt;
using testing::IsEmpty;
+using testing::IsFalse;
+using testing::IsTrue;
using testing::Le;
using testing::Lt;
@@ -58,103 +68,217 @@ TEST(PostingListHitSerializerTest, PostingListUsedPrependHitNotFull) {
PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize));
// Make used.
- Hit hit0(/*section_id=*/0, 0, /*term_frequency=*/56);
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/56,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0));
- // Size = sizeof(uncompressed hit0)
- int expected_size = sizeof(Hit);
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // Size = sizeof(uncompressed hit0::Value)
+ // + sizeof(hit0::Flags)
+ // + sizeof(hit0::TermFrequency)
+ int expected_size =
+ sizeof(Hit::Value) + sizeof(Hit::Flags) + sizeof(Hit::TermFrequency);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
- Hit hit1(/*section_id=*/0, 1, Hit::kDefaultTermFrequency);
+ Hit hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = PostingListHitSerializer::EncodeNextHitValue(
+ /*next_hit_value=*/hit1.value(),
+ /*curr_hit_value=*/hit0.value(), delta_buf);
ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1));
- // Size = sizeof(uncompressed hit1)
- // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency)
- expected_size += 2 + sizeof(Hit::TermFrequency);
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // Size = sizeof(uncompressed hit1::Value)
+ // + sizeof(hit0-hit1)
+ // + sizeof(hit0::Flags)
+ // + sizeof(hit0::TermFrequency)
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit1, hit0)));
- Hit hit2(/*section_id=*/0, 2, /*term_frequency=*/56);
+ Hit hit2(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/56,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ delta_len = PostingListHitSerializer::EncodeNextHitValue(
+ /*next_hit_value=*/hit2.value(),
+ /*curr_hit_value=*/hit1.value(), delta_buf);
ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2));
- // Size = sizeof(uncompressed hit2)
+ // Size = sizeof(uncompressed hit2::Value) + sizeof(hit2::Flags)
+ // + sizeof(hit2::TermFrequency)
// + sizeof(hit1-hit2)
- // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency)
- expected_size += 2;
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // + sizeof(hit0-hit1)
+ // + sizeof(hit0::flags)
+ // + sizeof(hit0::term_frequency)
+ expected_size += delta_len + sizeof(Hit::Flags) + sizeof(Hit::TermFrequency);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
- Hit hit3(/*section_id=*/0, 3, Hit::kDefaultTermFrequency);
+ Hit hit3(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ delta_len = PostingListHitSerializer::EncodeNextHitValue(
+ /*next_hit_value=*/hit3.value(),
+ /*curr_hit_value=*/hit2.value(), delta_buf);
ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3));
- // Size = sizeof(uncompressed hit3)
- // + sizeof(hit2-hit3) + sizeof(hit2::term_frequency)
+ // Size = sizeof(uncompressed hit3::Value)
+ // + sizeof(hit2-hit3) + sizeof(hit2::Flags)
+ // + sizeof(hit2::TermFrequency)
// + sizeof(hit1-hit2)
- // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency)
- expected_size += 2 + sizeof(Hit::TermFrequency);
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // + sizeof(hit0-hit1)
+ // + sizeof(hit0::flags)
+ // + sizeof(hit0::term_frequency)
+ expected_size += delta_len;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+}
+
+TEST(PostingListHitSerializerTest,
+ PostingListUsedPrependHitAlmostFull_withFlags) {
+ PostingListHitSerializer serializer;
+
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ // Fill up the compressed region.
+ // Transitions:
+ // Adding hit0: EMPTY -> NOT_FULL
+ // Adding hit1: NOT_FULL -> NOT_FULL
+ // Adding hit2: NOT_FULL -> NOT_FULL
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3);
+ Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2, /*term_frequency=*/57,
+ /*is_in_prefix_section=*/true,
+ /*is_prefix_hit=*/true);
+ EXPECT_THAT(hit2.has_flags(), IsTrue());
+ EXPECT_THAT(hit2.has_term_frequency(), IsTrue());
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2));
+ // Size used will be 4 (hit2::Value) + 1 (hit2::Flags) + 1
+ // (hit2::TermFrequency) + 2 (hit1-hit2) + 3 (hit0-hit1) = 11 bytes
+ int expected_size = sizeof(Hit::Value) + sizeof(Hit::Flags) +
+ sizeof(Hit::TermFrequency) + 2 + 3;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
+
+ // Add one more hit to transition NOT_FULL -> ALMOST_FULL
+ Hit hit3 =
+ CreateHit(hit2, /*desired_byte_length=*/3, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ EXPECT_THAT(hit3.has_flags(), IsFalse());
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3));
+ // Storing them in the compressed region requires 4 (hit3::Value) + 3
+ // (hit2-hit3) + 1 (hit2::Flags) + 1 (hit2::TermFrequency) + 2 (hit1-hit2) + 3
+ // (hit0-hit1) = 14 bytes, but there are only 12 bytes in the compressed
+ // region. So instead, the posting list will transition to ALMOST_FULL.
+ // The in-use compressed region will actually shrink from 11 bytes to 10 bytes
+ // because the uncompressed version of hit2 will be overwritten with the
+ // compressed delta of hit2. hit3 will be written to one of the special hits.
+ // Because we're in ALMOST_FULL, the expected size is the size of the pl minus
+ // the one hit used to mark the posting list as ALMOST_FULL.
+ expected_size = pl_size - sizeof(Hit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL
+ Hit hit4 = CreateHit(hit3, /*desired_byte_length=*/2);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4));
+ // There are currently 10 bytes in use in the compressed region. hit3 will
+ // have a 2-byte delta, which fits in the compressed region. Hit3 will be
+ // moved from the special hit to the compressed region (which will have 12
+ // bytes in use after adding hit3). Hit4 will be placed in one of the special
+ // hits and the posting list will remain in ALMOST_FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0)));
+
+ // Add one more hit to transition ALMOST_FULL -> FULL
+ Hit hit5 = CreateHit(hit4, /*desired_byte_length=*/2);
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5));
+ // There are currently 12 bytes in use in the compressed region. hit4 will
+ // have a 2-byte delta which will not fit in the compressed region. So hit4
+ // will remain in one of the special hits and hit5 will occupy the other,
+ // making the posting list FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0)));
+
+ // The posting list is FULL. Adding another hit should fail.
+ Hit hit6 = CreateHit(hit5, /*desired_byte_length=*/1);
+ EXPECT_THAT(serializer.PrependHit(&pl_used, hit6),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
}
TEST(PostingListHitSerializerTest, PostingListUsedPrependHitAlmostFull) {
PostingListHitSerializer serializer;
- int size = 2 * serializer.GetMinPostingListSize();
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
// Fill up the compressed region.
// Transitions:
// Adding hit0: EMPTY -> NOT_FULL
// Adding hit1: NOT_FULL -> NOT_FULL
// Adding hit2: NOT_FULL -> NOT_FULL
- Hit hit0(/*section_id=*/0, 0, Hit::kDefaultTermFrequency);
- Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/2);
- Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2);
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3);
+ Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/3);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2));
- // Size used will be 2+2+4=8 bytes
- int expected_size = sizeof(Hit::Value) + 2 + 2;
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // Size used will be 4 (hit2::Value) + 3 (hit1-hit2) + 3 (hit0-hit1)
+ // = 10 bytes
+ int expected_size = sizeof(Hit::Value) + 3 + 3;
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit2, hit1, hit0)));
// Add one more hit to transition NOT_FULL -> ALMOST_FULL
Hit hit3 = CreateHit(hit2, /*desired_byte_length=*/3);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3));
- // Compressed region would be 2+2+3+4=11 bytes, but the compressed region is
- // only 10 bytes. So instead, the posting list will transition to ALMOST_FULL.
- // The in-use compressed region will actually shrink from 8 bytes to 7 bytes
+ // Storing them in the compressed region requires 4 (hit3::Value) + 3
+ // (hit2-hit3) + 3 (hit1-hit2) + 3 (hit0-hit1) = 13 bytes, but there are only
+ // 12 bytes in the compressed region. So instead, the posting list will
+ // transition to ALMOST_FULL.
+ // The in-use compressed region will actually shrink from 10 bytes to 9 bytes
// because the uncompressed version of hit2 will be overwritten with the
// compressed delta of hit2. hit3 will be written to one of the special hits.
// Because we're in ALMOST_FULL, the expected size is the size of the pl minus
// the one hit used to mark the posting list as ALMOST_FULL.
- expected_size = size - sizeof(Hit);
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ expected_size = pl_size - sizeof(Hit);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0)));
// Add one more hit to transition ALMOST_FULL -> ALMOST_FULL
Hit hit4 = CreateHit(hit3, /*desired_byte_length=*/2);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4));
- // There are currently 7 bytes in use in the compressed region. hit3 will have
- // a 2-byte delta. That delta will fit in the compressed region (which will
- // now have 9 bytes in use), hit4 will be placed in one of the special hits
- // and the posting list will remain in ALMOST_FULL.
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
+ // There are currently 9 bytes in use in the compressed region. Hit3 will
+ // have a 2-byte delta, which fits in the compressed region. Hit3 will be
+ // moved from the special hit to the compressed region (which will have 11
+ // bytes in use after adding hit3). Hit4 will be placed in one of the special
+ // hits and the posting list will remain in ALMOST_FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0)));
// Add one more hit to transition ALMOST_FULL -> FULL
Hit hit5 = CreateHit(hit4, /*desired_byte_length=*/2);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5));
- // There are currently 9 bytes in use in the compressed region. hit4 will have
- // a 2-byte delta which will not fit in the compressed region. So hit4 will
- // remain in one of the special hits and hit5 will occupy the other, making
- // the posting list FULL.
- EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(size));
+ // There are currently 11 bytes in use in the compressed region. Hit4 will
+ // have a 2-byte delta which will not fit in the compressed region. So hit4
+ // will remain in one of the special hits and hit5 will occupy the other,
+ // making the posting list FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
EXPECT_THAT(serializer.GetHits(&pl_used),
IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0)));
@@ -164,9 +288,59 @@ TEST(PostingListHitSerializerTest, PostingListUsedPrependHitAlmostFull) {
StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
}
+TEST(PostingListHitSerializerTest, PrependHitsWithSameValue) {
+ PostingListHitSerializer serializer;
+
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ // Fill up the compressed region.
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3);
+ Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2, /*term_frequency=*/57,
+ /*is_in_prefix_section=*/true,
+ /*is_prefix_hit=*/true);
+ // Create hit3 with the same value but different flags as hit2 (hit3_flags
+ // is set to have all currently-defined flags enabled)
+ Hit::Flags hit3_flags = 0;
+ for (int i = 0; i < Hit::kNumFlagsInFlagsField; ++i) {
+ hit3_flags |= (1 << i);
+ }
+ Hit hit3(hit2.value(), /*term_frequency=*/hit2.term_frequency(),
+ /*flags=*/hit3_flags);
+
+ // hit3 is larger than hit2 (its flag value is larger), and so needs to be
+ // prepended first
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3));
+ ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2));
+ // Posting list is now ALMOST_FULL
+ // ----------------------
+ // 23-21 delta(Hit #0)
+ // 20-19 delta(Hit #1)
+ // 18 term-frequency(Hit #2)
+ // 17 flags(Hit #2)
+ // 16 delta(Hit #2) = 0
+ // 15-12 <unused padding = 0>
+ // 11-6 Hit #3
+ // 5-0 kSpecialHit
+ // ----------------------
+ int bytes_used = pl_size - sizeof(Hit);
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAre(hit2, hit3, hit1, hit0)));
+}
+
TEST(PostingListHitSerializerTest, PostingListUsedMinSize) {
PostingListHitSerializer serializer;
+ // Min size = 12
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
PostingListUsed::CreateFromUnitializedRegion(
@@ -176,7 +350,7 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) {
EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
// Add a hit, PL should shift to ALMOST_FULL state
- Hit hit0(/*section_id=*/0, 0, /*term_frequency=*/0,
+ Hit hit0(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/0,
/*is_in_prefix_section=*/false,
/*is_prefix_hit=*/true);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0));
@@ -185,10 +359,10 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) {
EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size));
EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0)));
- // Add the smallest hit possible - no term_frequency and a delta of 1. PL
- // should shift to FULL state.
- Hit hit1(/*section_id=*/0, 0, /*term_frequency=*/0,
- /*is_in_prefix_section=*/true,
+ // Add the smallest hit possible - no term_frequency, non-prefix hit and a
+ // delta of 0b10. PL should shift to FULL state.
+ Hit hit1(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/0,
+ /*is_in_prefix_section=*/false,
/*is_prefix_hit=*/false);
ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1));
// Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0)
@@ -198,7 +372,7 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) {
IsOkAndHolds(ElementsAre(hit1, hit0)));
// Try to add the smallest hit possible. Should fail
- Hit hit2(/*section_id=*/0, 0, /*term_frequency=*/0,
+ Hit hit2(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/0,
/*is_in_prefix_section=*/false,
/*is_prefix_hit=*/false);
EXPECT_THAT(serializer.PrependHit(&pl_used, hit2),
@@ -212,14 +386,17 @@ TEST(PostingListHitSerializerTest,
PostingListPrependHitArrayMinSizePostingList) {
PostingListHitSerializer serializer;
- // Min Size = 10
- int size = serializer.GetMinPostingListSize();
+ // Min Size = 12
+ int pl_size = serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<HitElt> hits_in;
- hits_in.emplace_back(Hit(1, 0, Hit::kDefaultTermFrequency));
+ hits_in.emplace_back(Hit(/*section_id=*/1, /*document_id=*/0,
+ Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false));
hits_in.emplace_back(
CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
hits_in.emplace_back(
@@ -235,7 +412,7 @@ TEST(PostingListHitSerializerTest,
ICING_ASSERT_OK_AND_ASSIGN(
uint32_t num_can_prepend,
(serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hits_in[0], hits_in.size(), false)));
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
EXPECT_THAT(num_can_prepend, Eq(2));
int can_fit_hits = num_can_prepend;
@@ -243,10 +420,11 @@ TEST(PostingListHitSerializerTest,
// problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL
const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2);
ICING_ASSERT_OK_AND_ASSIGN(
- num_can_prepend, (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, hits_in_ptr, can_fit_hits, false)));
+ num_can_prepend,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false)));
EXPECT_THAT(num_can_prepend, Eq(can_fit_hits));
- EXPECT_THAT(size, Eq(serializer.GetBytesUsed(&pl_used)));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
std::deque<Hit> hits_pushed;
std::transform(hits_in.rbegin(),
hits_in.rend() - hits_in.size() + can_fit_hits,
@@ -258,14 +436,17 @@ TEST(PostingListHitSerializerTest,
TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
PostingListHitSerializer serializer;
- // Size = 30
- int size = 3 * serializer.GetMinPostingListSize();
+ // Size = 36
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<HitElt> hits_in;
- hits_in.emplace_back(Hit(1, 0, Hit::kDefaultTermFrequency));
+ hits_in.emplace_back(Hit(/*section_id=*/1, /*document_id=*/0,
+ Hit::kDefaultTermFrequency,
+ /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false));
hits_in.emplace_back(
CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1));
hits_in.emplace_back(
@@ -278,14 +459,14 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
// The last hit is uncompressed and the four before it should only take one
// byte. Total use = 8 bytes.
// ----------------------
- // 29 delta(Hit #1)
- // 28 delta(Hit #2)
- // 27 delta(Hit #3)
- // 26 delta(Hit #4)
- // 25-22 Hit #5
- // 21-10 <unused>
- // 9-5 kSpecialHit
- // 4-0 Offset=22
+ // 35 delta(Hit #0)
+ // 34 delta(Hit #1)
+ // 33 delta(Hit #2)
+ // 32 delta(Hit #3)
+ // 31-28 Hit #4
+ // 27-12 <unused>
+ // 11-6 kSpecialHit
+ // 5-0 Offset=28
// ----------------------
int byte_size = sizeof(Hit::Value) + hits_in.size() - 1;
@@ -294,7 +475,7 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
ICING_ASSERT_OK_AND_ASSIGN(
uint32_t num_could_fit,
(serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hits_in[0], hits_in.size(), false)));
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
std::deque<Hit> hits_pushed;
@@ -316,31 +497,35 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
hits_in.emplace_back(
CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ hits_in.emplace_back(
+ CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
std::reverse(hits_in.begin(), hits_in.end());
- // Size increased by the deltas of these hits (1+2+1+2+3+2) = 11 bytes
+ // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes
// ----------------------
- // 29 delta(Hit #1)
- // 28 delta(Hit #2)
- // 27 delta(Hit #3)
- // 26 delta(Hit #4)
- // 25 delta(Hit #5)
- // 24-23 delta(Hit #6)
- // 22 delta(Hit #7)
- // 21-20 delta(Hit #8)
- // 19-17 delta(Hit #9)
- // 16-15 delta(Hit #10)
- // 14-11 Hit #11
- // 10 <unused>
- // 9-5 kSpecialHit
- // 4-0 Offset=11
+ // 35 delta(Hit #0)
+ // 34 delta(Hit #1)
+ // 33 delta(Hit #2)
+ // 32 delta(Hit #3)
+ // 31 delta(Hit #4)
+ // 30-29 delta(Hit #5)
+ // 28 delta(Hit #6)
+ // 27-26 delta(Hit #7)
+ // 25-23 delta(Hit #8)
+ // 22-21 delta(Hit #9)
+ // 20-18 delta(Hit #10)
+ // 17-14 Hit #11
+ // 13-12 <unused>
+ // 11-6 kSpecialHit
+ // 5-0 Offset=14
// ----------------------
- byte_size += 11;
+ byte_size += 14;
- // Add these 6 hits. The PL is currently in the NOT_FULL state and should
+ // Add these 7 hits. The PL is currently in the NOT_FULL state and should
// remain in the NOT_FULL state.
ICING_ASSERT_OK_AND_ASSIGN(
- num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hits_in[0], hits_in.size(), false)));
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
// All hits from hits_in were added.
@@ -353,29 +538,32 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
hits_in.clear();
hits_in.emplace_back(first_hit);
// ----------------------
- // 29 delta(Hit #1)
- // 28 delta(Hit #2)
- // 27 delta(Hit #3)
- // 26 delta(Hit #4)
- // 25 delta(Hit #5)
- // 24-23 delta(Hit #6)
- // 22 delta(Hit #7)
- // 21-20 delta(Hit #8)
- // 19-17 delta(Hit #9)
- // 16-15 delta(Hit #10)
- // 14-12 delta(Hit #11)
- // 11-10 <unused>
- // 9-5 Hit #12
- // 4-0 kSpecialHit
+ // 35 delta(Hit #0)
+ // 34 delta(Hit #1)
+ // 33 delta(Hit #2)
+ // 32 delta(Hit #3)
+ // 31 delta(Hit #4)
+ // 30-29 delta(Hit #5)
+ // 28 delta(Hit #6)
+ // 27-26 delta(Hit #7)
+ // 25-23 delta(Hit #8)
+ // 22-21 delta(Hit #9)
+ // 20-18 delta(Hit #10)
+ // 17-15 delta(Hit #11)
+ // 14-12 <unused>
+ // 11-6 Hit #12
+ // 5-0 kSpecialHit
// ----------------------
- byte_size = 25;
+ byte_size = 30; // 36 - 6
// Add this 1 hit. The PL is currently in the NOT_FULL state and should
// transition to the ALMOST_FULL state - even though there is still some
- // unused space.
+ // unused space. This is because the unused space (3 bytes) is less than
+ // the size of a uncompressed Hit.
ICING_ASSERT_OK_AND_ASSIGN(
- num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hits_in[0], hits_in.size(), false)));
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used)));
// All hits from hits_in were added.
@@ -388,37 +576,40 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) {
hits_in.clear();
hits_in.emplace_back(first_hit);
hits_in.emplace_back(
- CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2));
+ CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3));
std::reverse(hits_in.begin(), hits_in.end());
// ----------------------
- // 29 delta(Hit #1)
- // 28 delta(Hit #2)
- // 27 delta(Hit #3)
- // 26 delta(Hit #4)
- // 25 delta(Hit #5)
- // 24-23 delta(Hit #6)
- // 22 delta(Hit #7)
- // 21-20 delta(Hit #8)
- // 19-17 delta(Hit #9)
- // 16-15 delta(Hit #10)
- // 14-12 delta(Hit #11)
- // 11 delta(Hit #12)
- // 10 <unused>
- // 9-5 Hit #13
- // 4-0 Hit #14
+ // 35 delta(Hit #0)
+ // 34 delta(Hit #1)
+ // 33 delta(Hit #2)
+ // 32 delta(Hit #3)
+ // 31 delta(Hit #4)
+ // 30-29 delta(Hit #5)
+ // 28 delta(Hit #6)
+ // 27-26 delta(Hit #7)
+ // 25-23 delta(Hit #8)
+ // 22-21 delta(Hit #9)
+ // 20-18 delta(Hit #10)
+ // 17-15 delta(Hit #11)
+ // 14 delta(Hit #12)
+ // 13-12 <unused>
+ // 11-6 Hit #13
+ // 5-0 Hit #14
// ----------------------
- // Add these 2 hits. The PL is currently in the ALMOST_FULL state. Adding the
- // first hit should keep the PL in ALMOST_FULL because the delta between Hit
- // #12 and Hit #13 (1 byte) can fit in the unused area (2 bytes). Adding the
- // second hit should tranisition to the FULL state because the delta between
- // Hit #13 and Hit #14 (2 bytes) is larger than the remaining unused area
- // (1 byte).
+ // Add these 2 hits.
+ // - The PL is currently in the ALMOST_FULL state. Adding the first hit should
+ // keep the PL in ALMOST_FULL because the delta between
+ // Hit #13 and Hit #14 (1 byte) can fit in the unused area (3 bytes).
+ // - Adding the second hit should transition to the FULL state because the
+ // delta between Hit #14 and Hit #15 (3 bytes) is larger than the remaining
+ // unused area (2 byte).
ICING_ASSERT_OK_AND_ASSIGN(
- num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hits_in[0], hits_in.size(), false)));
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false)));
EXPECT_THAT(num_could_fit, Eq(hits_in.size()));
- EXPECT_THAT(size, Eq(serializer.GetBytesUsed(&pl_used)));
+ EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used)));
// All hits from hits_in were added.
std::transform(hits_in.rbegin(), hits_in.rend(),
std::front_inserter(hits_pushed), HitElt::get_hit);
@@ -431,9 +622,9 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) {
static constexpr int kNumHits = 128;
static constexpr int kDeltaSize = 1;
- static constexpr int kTermFrequencySize = 1;
static constexpr size_t kHitsSize =
- ((kNumHits * (kDeltaSize + kTermFrequencySize)) / 5) * 5;
+ ((kNumHits - 2) * kDeltaSize + (2 * sizeof(Hit))) / sizeof(Hit) *
+ sizeof(Hit);
// Create an array with one too many hits
std::vector<Hit> hits_in_too_many =
@@ -442,18 +633,21 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) {
for (const Hit &hit : hits_in_too_many) {
hit_elts_in_too_many.emplace_back(hit);
}
+ // Reverse so that hits are inserted in descending order
+ std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end());
+
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
PostingListUsed::CreateFromUnitializedRegion(
&serializer, serializer.GetMinPostingListSize()));
-
// PrependHitArray should fail because hit_elts_in_too_many is far too large
// for the minimum size pl.
ICING_ASSERT_OK_AND_ASSIGN(
uint32_t num_could_fit,
(serializer.PrependHitArray<HitElt, HitElt::get_hit>(
&pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
- false)));
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(2));
ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size()));
ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
@@ -464,10 +658,11 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) {
// PrependHitArray should fail because hit_elts_in_too_many is one hit too
// large for this pl.
ICING_ASSERT_OK_AND_ASSIGN(
- num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
- &pl_used, &hit_elts_in_too_many[0],
- hit_elts_in_too_many.size(), false)));
- ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size()));
+ num_could_fit,
+ (serializer.PrependHitArray<HitElt, HitElt::get_hit>(
+ &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(),
+ /*keep_prepended=*/false)));
+ ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1));
ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty()));
}
@@ -476,16 +671,37 @@ TEST(PostingListHitSerializerTest,
PostingListStatusJumpFromNotFullToFullAndBack) {
PostingListHitSerializer serializer;
+ // Size = 18
const uint32_t pl_size = 3 * sizeof(Hit);
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl,
PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
- ICING_ASSERT_OK(serializer.PrependHit(&pl, Hit(Hit::kInvalidValue - 1, 0)));
+
+ Hit max_valued_hit(kMaxSectionId, kMinDocumentId, Hit::kMaxTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit));
uint32_t bytes_used = serializer.GetBytesUsed(&pl);
+ ASSERT_THAT(bytes_used, sizeof(Hit::Value) + sizeof(Hit::Flags) +
+ sizeof(Hit::TermFrequency));
// Status not full.
ASSERT_THAT(bytes_used,
Le(pl_size - PostingListHitSerializer::kSpecialHitsSize));
- ICING_ASSERT_OK(serializer.PrependHit(&pl, Hit(Hit::kInvalidValue >> 2, 0)));
+
+ Hit min_valued_hit(kMinSectionId, kMaxDocumentId, Hit::kMaxTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true);
+ uint8_t delta_buf[VarInt::kMaxEncodedLen64];
+ size_t delta_len = PostingListHitSerializer::EncodeNextHitValue(
+ /*next_hit_value=*/min_valued_hit.value(),
+ /*curr_hit_value=*/max_valued_hit.value(), delta_buf);
+ // The compressed region available is pl_size - 2 * sizeof(specialHits) = 6
+ // We need to also fit max_valued_hit's flags and term-frequency fields, which
+ // each take 1 byte
+ // So we'll jump directly to FULL if the varint-encoded delta of the 2 hits >
+ // 6 - sizeof(Hit::Flags) - sizeof(Hit::TermFrequency) = 4
+ ASSERT_THAT(delta_len, Gt(4));
+ ICING_ASSERT_OK(serializer.PrependHit(
+ &pl, Hit(kMinSectionId, kMaxDocumentId, Hit::kMaxTermFrequency,
+ /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true)));
// Status should jump to full directly.
ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size));
ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1));
@@ -501,11 +717,12 @@ TEST(PostingListHitSerializerTest, DeltaOverflow) {
PostingListUsed pl,
PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ static const Hit::Value kMaxHitValue = std::numeric_limits<Hit::Value>::max();
static const Hit::Value kOverflow[4] = {
- Hit::kInvalidValue >> 2,
- (Hit::kInvalidValue >> 2) * 2,
- (Hit::kInvalidValue >> 2) * 3,
- Hit::kInvalidValue - 1,
+ kMaxHitValue >> 2,
+ (kMaxHitValue >> 2) * 2,
+ (kMaxHitValue >> 2) * 3,
+ kMaxHitValue - 1,
};
// Fit at least 4 ordinary values.
@@ -516,22 +733,245 @@ TEST(PostingListHitSerializerTest, DeltaOverflow) {
// Cannot fit 4 overflow values.
ICING_ASSERT_OK_AND_ASSIGN(
pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
- ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[3])));
- ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[2])));
+ Hit::Flags has_term_frequency_flags = 0b1;
+ ICING_EXPECT_OK(serializer.PrependHit(
+ &pl, Hit(/*value=*/kOverflow[3], has_term_frequency_flags,
+ /*term_frequency=*/8)));
+ ICING_EXPECT_OK(serializer.PrependHit(
+ &pl, Hit(/*value=*/kOverflow[2], has_term_frequency_flags,
+ /*term_frequency=*/8)));
// Can fit only one more.
- ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[1])));
- EXPECT_THAT(serializer.PrependHit(&pl, Hit(kOverflow[0])),
+ ICING_EXPECT_OK(serializer.PrependHit(
+ &pl, Hit(/*value=*/kOverflow[1], has_term_frequency_flags,
+ /*term_frequency=*/8)));
+ EXPECT_THAT(serializer.PrependHit(
+ &pl, Hit(/*value=*/kOverflow[0], has_term_frequency_flags,
+ /*term_frequency=*/8)),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListHitSerializerTest, GetMinPostingListToFitForNotFullPL) {
+ PostingListHitSerializer serializer;
+
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used NOT_FULL
+ std::vector<Hit> hits_in =
+ CreateHits(/*num_hits=*/7, /*desired_byte_length=*/1);
+ for (const Hit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 23 delta(Hit #0)
+ // 22 delta(Hit #1)
+ // 21 delta(Hit #2)
+ // 20 delta(Hit #3)
+ // 19 delta(Hit #4)
+ // 18 delta(Hit #5)
+ // 17-14 Hit #6
+ // 13-12 <unused>
+ // 11-6 kSpecialHit
+ // 5-0 Offset=14
+ // ----------------------
+ int bytes_used = 10;
+
+ // Check that all hits have been inserted
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used
+ // into a posting list with this min size should make it ALMOST_FULL, which we
+ // can see should have size = 18.
+ // ----------------------
+ // 17 delta(Hit #0)
+ // 16 delta(Hit #1)
+ // 15 delta(Hit #2)
+ // 14 delta(Hit #3)
+ // 13 delta(Hit #4)
+ // 12 delta(Hit #5)
+ // 11-6 Hit #6
+ // 5-0 kSpecialHit
+ // ----------------------
+ int expected_min_size = 18;
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(expected_min_size));
+
+ // Also check that this min size to fit posting list actually does fit all the
+ // hits and can only hit one more hit in the ALMOST_FULL state.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, min_size_to_fit));
+ for (const Hit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ }
+
+ // Adding another hit to the min-size-to-fit posting list should succeed
+ Hit hit = CreateHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit));
+ // Adding any other hits should fail with RESOURCE_EXHAUSTED error.
+ EXPECT_THAT(serializer.PrependHit(&min_size_to_fit_pl,
+ CreateHit(hit, /*desired_byte_length=*/1)),
StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // Check that all hits have been inserted and the min-fit posting list is now
+ // FULL.
+ EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl),
+ Eq(min_size_to_fit));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+}
+
+TEST(PostingListHitSerializerTest, GetMinPostingListToFitForTwoHits) {
+ PostingListHitSerializer serializer;
+
+ // Size = 36
+ int pl_size = 3 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+
+ // Create and add 2 hits
+ Hit first_hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+ std::vector<Hit> hits_in =
+ CreateHits(first_hit, /*num_hits=*/2, /*desired_byte_length=*/4);
+ for (const Hit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 35 term-frequency(Hit #0)
+ // 34 flags(Hit #0)
+ // 33-30 delta(Hit #0)
+ // 29 term-frequency(Hit #1)
+ // 28 flags(Hit #1)
+ // 27-24 Hit #1
+ // 23-12 <unused>
+ // 11-6 kSpecialHit
+ // 5-0 Offset=24
+ // ----------------------
+ int bytes_used = 12;
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should return min posting list size.
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used),
+ Eq(serializer.GetMinPostingListSize()));
+}
+
+TEST(PostingListHitSerializerTest, GetMinPostingListToFitForThreeSmallHits) {
+ PostingListHitSerializer serializer;
+
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add 3 small hits that fit in the size range where we should be
+ // checking for whether the PL has only 2 hits
+ std::vector<Hit> hits_in =
+ CreateHits(/*num_hits=*/3, /*desired_byte_length=*/1);
+ for (const Hit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 23 delta(Hit #0)
+ // 22 delta(Hit #1)
+ // 21-18 Hit #2
+ // 17-12 <unused>
+ // 11-6 kSpecialHit
+ // 5-0 Offset=18
+ // ----------------------
+ int bytes_used = 6;
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used
+ // into a posting list with this min size should make it ALMOST_FULL, which we
+ // can see should have size = 14. This should not return the min posting list
+ // size.
+ // ----------------------
+ // 13 delta(Hit #0)
+ // 12 delta(Hit #1)
+ // 11-6 Hit #2
+ // 5-0 kSpecialHit
+ // ----------------------
+ int expected_min_size = 14;
+
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used),
+ Gt(serializer.GetMinPostingListSize()));
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used),
+ Eq(expected_min_size));
+}
+
+TEST(PostingListHitSerializerTest,
+ GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) {
+ PostingListHitSerializer serializer;
+
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
+ // Create and add some hits to make pl_used ALMOST_FULL
+ std::vector<Hit> hits_in =
+ CreateHits(/*num_hits=*/7, /*desired_byte_length=*/2);
+ for (const Hit &hit : hits_in) {
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ }
+ // ----------------------
+ // 23-22 delta(Hit #0)
+ // 21-20 delta(Hit #1)
+ // 19-18 delta(Hit #2)
+ // 17-16 delta(Hit #3)
+ // 15-14 delta(Hit #4)
+ // 13-12 delta(Hit #5)
+ // 11-6 Hit #6
+ // 5-0 kSpecialHit
+ // ----------------------
+ int bytes_used = 18;
+
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
+ std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend());
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should return the same size as pl_used.
+ uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
+
+ // Add another hit to make the posting list FULL
+ Hit hit = CreateHit(hits_in.back(), /*desired_byte_length=*/1);
+ ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size));
+ hits_pushed.emplace_front(hit);
+ EXPECT_THAT(serializer.GetHits(&pl_used),
+ IsOkAndHolds(ElementsAreArray(hits_pushed)));
+
+ // GetMinPostingListSizeToFit should still be the same size as pl_used.
+ min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used);
+ EXPECT_THAT(min_size_to_fit, Eq(pl_size));
}
TEST(PostingListHitSerializerTest, MoveFrom) {
PostingListHitSerializer serializer;
- int size = 3 * serializer.GetMinPostingListSize();
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used1,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit &hit : hits1) {
@@ -540,7 +980,7 @@ TEST(PostingListHitSerializerTest, MoveFrom) {
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used2,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits2 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2);
for (const Hit &hit : hits2) {
@@ -556,10 +996,10 @@ TEST(PostingListHitSerializerTest, MoveFrom) {
TEST(PostingListHitSerializerTest, MoveFromNullArgumentReturnsInvalidArgument) {
PostingListHitSerializer serializer;
- int size = 3 * serializer.GetMinPostingListSize();
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used1,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit &hit : hits) {
ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
@@ -575,10 +1015,10 @@ TEST(PostingListHitSerializerTest,
MoveFromInvalidPostingListReturnsInvalidArgument) {
PostingListHitSerializer serializer;
- int size = 3 * serializer.GetMinPostingListSize();
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used1,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit &hit : hits1) {
@@ -587,7 +1027,7 @@ TEST(PostingListHitSerializerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used2,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits2 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2);
for (const Hit &hit : hits2) {
@@ -595,7 +1035,7 @@ TEST(PostingListHitSerializerTest,
}
// Write invalid hits to the beginning of pl_used1 to make it invalid.
- Hit invalid_hit;
+ Hit invalid_hit(Hit::kInvalidValue);
Hit *first_hit = reinterpret_cast<Hit *>(pl_used1.posting_list_buffer());
*first_hit = invalid_hit;
++first_hit;
@@ -610,10 +1050,10 @@ TEST(PostingListHitSerializerTest,
MoveToInvalidPostingListReturnsFailedPrecondition) {
PostingListHitSerializer serializer;
- int size = 3 * serializer.GetMinPostingListSize();
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used1,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit &hit : hits1) {
@@ -622,7 +1062,7 @@ TEST(PostingListHitSerializerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used2,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits2 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2);
for (const Hit &hit : hits2) {
@@ -630,7 +1070,7 @@ TEST(PostingListHitSerializerTest,
}
// Write invalid hits to the beginning of pl_used2 to make it invalid.
- Hit invalid_hit;
+ Hit invalid_hit(Hit::kInvalidValue);
Hit *first_hit = reinterpret_cast<Hit *>(pl_used2.posting_list_buffer());
*first_hit = invalid_hit;
++first_hit;
@@ -644,10 +1084,10 @@ TEST(PostingListHitSerializerTest,
TEST(PostingListHitSerializerTest, MoveToPostingListTooSmall) {
PostingListHitSerializer serializer;
- int size = 3 * serializer.GetMinPostingListSize();
+ int pl_size = 3 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used1,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit &hit : hits1) {
@@ -672,30 +1112,36 @@ TEST(PostingListHitSerializerTest, MoveToPostingListTooSmall) {
IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend())));
}
-TEST(PostingListHitSerializerTest, PopHitsWithScores) {
+TEST(PostingListHitSerializerTest, PopHitsWithTermFrequenciesAndFlags) {
PostingListHitSerializer serializer;
- int size = 2 * serializer.GetMinPostingListSize();
+ // Size = 24
+ int pl_size = 2 * serializer.GetMinPostingListSize();
ICING_ASSERT_OK_AND_ASSIGN(
PostingListUsed pl_used,
- PostingListUsed::CreateFromUnitializedRegion(&serializer, size));
+ PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size));
- // This posting list is 20-bytes. Create four hits that will have deltas of
- // two bytes each and all of whom will have a non-default score. This posting
- // list will be almost_full.
+ // This posting list is 24-bytes. Create four hits that will have deltas of
+ // two bytes each and all of whom will have a non-default term-frequency. This
+ // posting list will be almost_full.
//
// ----------------------
- // 19 score(Hit #0)
- // 18-17 delta(Hit #0)
- // 16 score(Hit #1)
- // 15-14 delta(Hit #1)
- // 13 score(Hit #2)
- // 12-11 delta(Hit #2)
- // 10 <unused>
- // 9-5 Hit #3
- // 4-0 kInvalidHitVal
+ // 23 term-frequency(Hit #0)
+ // 22 flags(Hit #0)
+ // 21-20 delta(Hit #0)
+ // 19 term-frequency(Hit #1)
+ // 18 flags(Hit #1)
+ // 17-16 delta(Hit #1)
+ // 15 term-frequency(Hit #2)
+ // 14 flags(Hit #2)
+ // 13-12 delta(Hit #2)
+ // 11-6 Hit #3
+ // 5-0 kInvalidHit
// ----------------------
- Hit hit0(/*section_id=*/0, /*document_id=*/0, /*score=*/5);
+ int bytes_used = 18;
+
+ Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/5,
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/2);
Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2);
Hit hit3 = CreateHit(hit2, /*desired_byte_length=*/2);
@@ -707,22 +1153,26 @@ TEST(PostingListHitSerializerTest, PopHitsWithScores) {
ICING_ASSERT_OK_AND_ASSIGN(std::vector<Hit> hits_out,
serializer.GetHits(&pl_used));
EXPECT_THAT(hits_out, ElementsAre(hit3, hit2, hit1, hit0));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
// Now, pop the last hit. The posting list should contain the first three
// hits.
//
// ----------------------
- // 19 score(Hit #0)
- // 18-17 delta(Hit #0)
- // 16 score(Hit #1)
- // 15-14 delta(Hit #1)
- // 13-10 <unused>
- // 9-5 Hit #2
- // 4-0 kInvalidHitVal
+ // 23 term-frequency(Hit #0)
+ // 22 flags(Hit #0)
+ // 21-20 delta(Hit #0)
+ // 19 term-frequency(Hit #1)
+ // 18 flags(Hit #1)
+ // 17-16 delta(Hit #1)
+ // 15-12 <unused>
+ // 11-6 Hit #2
+ // 5-0 kInvalidHit
// ----------------------
ICING_ASSERT_OK(serializer.PopFrontHits(&pl_used, 1));
ICING_ASSERT_OK_AND_ASSIGN(hits_out, serializer.GetHits(&pl_used));
EXPECT_THAT(hits_out, ElementsAre(hit2, hit1, hit0));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used));
}
} // namespace
diff --git a/icing/query/advanced_query_parser/function.cc b/icing/query/advanced_query_parser/function.cc
index e7938db..0ff160b 100644
--- a/icing/query/advanced_query_parser/function.cc
+++ b/icing/query/advanced_query_parser/function.cc
@@ -13,8 +13,15 @@
// limitations under the License.
#include "icing/query/advanced_query_parser/function.h"
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/query/advanced_query_parser/param.h"
+#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -73,5 +80,20 @@ libtextclassifier3::StatusOr<PendingValue> Function::Eval(
return eval_(std::move(args));
}
+libtextclassifier3::StatusOr<DataType> Function::get_param_type(int i) const {
+ if (i < 0 || params_.empty()) {
+ return absl_ports::OutOfRangeError("Invalid argument index.");
+ }
+ const Param* parm;
+ if (i < params_.size()) {
+ parm = &params_.at(i);
+ } else if (params_.back().cardinality == Cardinality::kVariable) {
+ parm = &params_.back();
+ } else {
+ return absl_ports::OutOfRangeError("Invalid argument index.");
+ }
+ return parm->data_type;
+}
+
} // namespace lib
-} // namespace icing \ No newline at end of file
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/function.h b/icing/query/advanced_query_parser/function.h
index 3514878..08cc7e8 100644
--- a/icing/query/advanced_query_parser/function.h
+++ b/icing/query/advanced_query_parser/function.h
@@ -46,6 +46,8 @@ class Function {
libtextclassifier3::StatusOr<PendingValue> Eval(
std::vector<PendingValue>&& args) const;
+ libtextclassifier3::StatusOr<DataType> get_param_type(int i) const;
+
private:
Function(DataType return_type, std::string name, std::vector<Param> params,
EvalFunction eval)
diff --git a/icing/query/advanced_query_parser/param.h b/icing/query/advanced_query_parser/param.h
index 69c46be..9ea1915 100644
--- a/icing/query/advanced_query_parser/param.h
+++ b/icing/query/advanced_query_parser/param.h
@@ -35,13 +35,19 @@ struct Param {
libtextclassifier3::Status Matches(PendingValue& arg) const {
bool matches = arg.data_type() == data_type;
- // Values of type kText could also potentially be valid kLong values. If
- // we're expecting a kLong and we have a kText, try to parse it as a kLong.
+ // Values of type kText could also potentially be valid kLong or kDouble
+ // values. If we're expecting a kLong or kDouble and we have a kText, try to
+ // parse it as what we expect.
if (!matches && data_type == DataType::kLong &&
arg.data_type() == DataType::kText) {
ICING_RETURN_IF_ERROR(arg.ParseInt());
matches = true;
}
+ if (!matches && data_type == DataType::kDouble &&
+ arg.data_type() == DataType::kText) {
+ ICING_RETURN_IF_ERROR(arg.ParseDouble());
+ matches = true;
+ }
return matches ? libtextclassifier3::Status::OK
: absl_ports::InvalidArgumentError(
"Provided arg doesn't match required param type.");
diff --git a/icing/query/advanced_query_parser/pending-value.cc b/icing/query/advanced_query_parser/pending-value.cc
index 67bdc3a..a3f95d9 100644
--- a/icing/query/advanced_query_parser/pending-value.cc
+++ b/icing/query/advanced_query_parser/pending-value.cc
@@ -13,7 +13,11 @@
// limitations under the License.
#include "icing/query/advanced_query_parser/pending-value.h"
+#include <cstdlib>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
namespace icing {
namespace lib {
@@ -40,5 +44,27 @@ libtextclassifier3::Status PendingValue::ParseInt() {
return libtextclassifier3::Status::OK;
}
+libtextclassifier3::Status PendingValue::ParseDouble() {
+ if (data_type_ == DataType::kDouble) {
+ return libtextclassifier3::Status::OK;
+ } else if (data_type_ != DataType::kText) {
+ return absl_ports::InvalidArgumentError("Cannot parse value as double");
+ }
+ if (query_term_.is_prefix_val) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Cannot use prefix operator '*' with numeric value: ",
+ query_term_.term));
+ }
+ char* value_end;
+ double_val_ = std::strtod(query_term_.term.c_str(), &value_end);
+ if (value_end != query_term_.term.c_str() + query_term_.term.length()) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Unable to parse \"", query_term_.term, "\" as double."));
+ }
+ data_type_ = DataType::kDouble;
+ query_term_ = {/*term=*/"", /*raw_term=*/"", /*is_prefix_val=*/false};
+ return libtextclassifier3::Status::OK;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/query/advanced_query_parser/pending-value.h b/icing/query/advanced_query_parser/pending-value.h
index 1a6717e..34912f3 100644
--- a/icing/query/advanced_query_parser/pending-value.h
+++ b/icing/query/advanced_query_parser/pending-value.h
@@ -14,12 +14,16 @@
#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_
#define ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_
+#include <cstdint>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/util/status-macros.h"
@@ -30,10 +34,14 @@ namespace lib {
enum class DataType {
kNone,
kLong,
+ kDouble,
kText,
kString,
kStringList,
kDocumentIterator,
+ // TODO(b/326656531): Instead of creating a vector index type, consider
+ // changing it to vector type so that the data is the vector directly.
+ kVectorIndex,
};
struct QueryTerm {
@@ -52,6 +60,10 @@ struct PendingValue {
return PendingValue(std::move(text), DataType::kText);
}
+ static PendingValue CreateVectorIndexPendingValue(int64_t vector_index) {
+ return PendingValue(vector_index, DataType::kVectorIndex);
+ }
+
PendingValue() : data_type_(DataType::kNone) {}
explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator)
@@ -111,6 +123,16 @@ struct PendingValue {
return long_val_;
}
+ libtextclassifier3::StatusOr<double> double_val() {
+ ICING_RETURN_IF_ERROR(ParseDouble());
+ return double_val_;
+ }
+
+ libtextclassifier3::StatusOr<int64_t> vector_index_val() const {
+ ICING_RETURN_IF_ERROR(CheckDataType(DataType::kVectorIndex));
+ return long_val_;
+ }
+
// Attempts to interpret the value as an int. A pending value can be parsed as
// an int under two circumstances:
// 1. It holds a kText value which can be parsed to an int
@@ -122,12 +144,26 @@ struct PendingValue {
// - INVALID_ARGUMENT if the value could not be parsed as a long
libtextclassifier3::Status ParseInt();
+ // Attempts to interpret the value as a double. A pending value can be parsed
+ // as a double under two circumstances:
+ // 1. It holds a kText value which can be parsed to a double
+ // 2. It holds a kDouble value
+ // If #1 is true, then the parsed value will be stored in double_val_ and
+ // data_type will be updated to kDouble.
+ // RETURNS:
+ // - OK, if able to successfully parse the value into a double
+ // - INVALID_ARGUMENT if the value could not be parsed as a double
+ libtextclassifier3::Status ParseDouble();
+
DataType data_type() const { return data_type_; }
private:
explicit PendingValue(QueryTerm query_term, DataType data_type)
: query_term_(std::move(query_term)), data_type_(data_type) {}
+ explicit PendingValue(int64_t long_val, DataType data_type)
+ : long_val_(long_val), data_type_(data_type) {}
+
libtextclassifier3::Status CheckDataType(DataType required_data_type) const {
if (data_type_ == required_data_type) {
return libtextclassifier3::Status::OK;
@@ -151,6 +187,9 @@ struct PendingValue {
// long_val_ will be populated when data_type_ is kLong - after a successful
// call to ParseInt.
int64_t long_val_;
+ // double_val_ will be populated when data_type_ is kDouble - after a
+ // successful call to ParseDouble.
+ double double_val_;
DataType data_type_;
};
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc
index 31da959..1ac52c5 100644
--- a/icing/query/advanced_query_parser/query-visitor.cc
+++ b/icing/query/advanced_query_parser/query-visitor.cc
@@ -16,20 +16,26 @@
#include <algorithm>
#include <cstdint>
-#include <cstdlib>
#include <iterator>
#include <limits>
#include <memory>
#include <set>
#include <string>
+#include <string_view>
+#include <unordered_map>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/absl_ports/str_join.h"
+#include "icing/index/embed/doc-hit-info-iterator-embedding.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h"
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
+#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator-none.h"
#include "icing/index/iterator/doc-hit-info-iterator-not.h"
#include "icing/index/iterator/doc-hit-info-iterator-or.h"
@@ -37,17 +43,23 @@
#include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h"
#include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
#include "icing/index/property-existence-indexing-handler.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/function.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/param.h"
#include "icing/query/advanced_query_parser/parser.h"
#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/query/advanced_query_parser/util/string-util.h"
#include "icing/query/query-features.h"
+#include "icing/query/query-results.h"
#include "icing/schema/property-util.h"
+#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer.h"
+#include "icing/util/embedding-util.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() {
.ValueOrDie();
registered_functions_.insert(
{has_property_function.name(), std::move(has_property_function)});
+
+ // vector_index getSearchSpecEmbedding(long);
+ auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) {
+ return PendingValue::CreateVectorIndexPendingValue(
+ args.at(0).long_val().ValueOrDie());
+ };
+ Function get_search_spec_embedding_function =
+ Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding",
+ {Param(DataType::kLong)},
+ std::move(get_search_spec_embedding))
+ .ValueOrDie();
+ registered_functions_.insert({get_search_spec_embedding_function.name(),
+ std::move(get_search_spec_embedding_function)});
+
+ // DocHitInfoIterator semanticSearch(vector_index, double, double, string);
+ auto semantic_search = [this](std::vector<PendingValue>&& args) {
+ return this->SemanticSearchFunction(std::move(args));
+ };
+ Function semantic_search_function =
+ Function::Create(DataType::kDocumentIterator, "semanticSearch",
+ {Param(DataType::kVectorIndex),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kString, Cardinality::kOptional)},
+ std::move(semantic_search))
+ .ValueOrDie();
+ registered_functions_.insert(
+ {semantic_search_function.name(), std::move(semantic_search_function)});
}
libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
@@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
document_store_.last_added_document_id());
} else {
QueryVisitor query_visitor(
- &index_, &numeric_index_, &document_store_, &schema_store_,
- &normalizer_, &tokenizer_, query->raw_term, filter_options_,
- match_type_, needs_term_frequency_info_, pending_property_restricts_,
- processing_not_, current_time_ms_);
+ &index_, &numeric_index_, &embedding_index_, &document_store_,
+ &schema_store_, &normalizer_, &tokenizer_, query->raw_term,
+ embedding_query_vectors_, filter_options_, match_type_,
+ embedding_query_metric_type_, needs_term_frequency_info_,
+ pending_property_restricts_, processing_not_, current_time_ms_);
tree_root->Accept(&query_visitor);
ICING_ASSIGN_OR_RETURN(query_result,
std::move(query_visitor).ConsumeResults());
@@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction(
return PendingValue(std::move(property_in_document_iterator));
}
+libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction(
+ std::vector<PendingValue>&& args) {
+ features_.insert(kEmbeddingSearchFeature);
+
+ int64_t vector_index = args.at(0).vector_index_val().ValueOrDie();
+ if (embedding_query_vectors_ == nullptr || vector_index < 0 ||
+ vector_index >= embedding_query_vectors_->size()) {
+ return absl_ports::InvalidArgumentError("Got invalid vector search index!");
+ }
+
+ // Handle default values for the optional arguments.
+ double low = -std::numeric_limits<double>::infinity();
+ double high = std::numeric_limits<double>::infinity();
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ embedding_query_metric_type_;
+ if (args.size() >= 2) {
+ low = args.at(1).double_val().ValueOrDie();
+ }
+ if (args.size() >= 3) {
+ high = args.at(2).double_val().ValueOrDie();
+ }
+ if (args.size() >= 4) {
+ const std::string& metric = args.at(3).string_val().ValueOrDie()->term;
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+
+ // Create SectionRestrictData for section restriction.
+ std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr;
+ if (pending_property_restricts_.has_active_property_restricts()) {
+ std::unordered_map<std::string, std::set<std::string>>
+ type_property_filters;
+ type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] =
+ pending_property_restricts_.active_property_restricts();
+ section_restrict_data = std::make_unique<SectionRestrictData>(
+ &document_store_, &schema_store_, current_time_ms_,
+ type_property_filters);
+ }
+
+ // Create and return iterator.
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map =
+ &embedding_query_results_.result_scores[vector_index][metric_type];
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator,
+ DocHitInfoIteratorEmbedding::Create(
+ &embedding_query_vectors_->at(vector_index),
+ std::move(section_restrict_data), metric_type, low,
+ high, score_map, &embedding_index_));
+ return PendingValue(std::move(iterator));
+}
+
libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() {
if (pending_values_.empty()) {
return absl_ports::InvalidArgumentError("Unable to retrieve int value.");
@@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() {
// raw_text, then all of raw_text must correspond to this token.
raw_token = raw_text;
} else {
- ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken(
- raw_text, token.text));
+ ICING_ASSIGN_OR_RETURN(
+ raw_token, string_util::FindEscapedToken(raw_text, token.text));
}
normalized_term = normalizer_.NormalizeTerm(token.text);
QueryTerm term_value{std::move(normalized_term), raw_token,
@@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator(
"Visit unary operator child didn't correctly add pending values.");
}
- // 3. We want to preserve the original text of the integer value, append our
- // minus and *then* parse as an int.
- ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue());
- int_text_val.term = absl_ports::StrCat("-", int_text_val.term);
+ // 3. We want to preserve the original text of the numeric value, append our
+ // minus to the text. It will be parsed as either an int or a double later.
+ ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue());
+ numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term);
PendingValue pending_value =
- PendingValue::CreateTextPendingValue(std::move(int_text_val));
- ICING_RETURN_IF_ERROR(pending_value.long_val());
+ PendingValue::CreateTextPendingValue(std::move(numeric_text_val));
- // We've parsed our integer value successfully. Pop our placeholder, push it
+ // We've parsed our numeric value successfully. Pop our placeholder, push it
// on to the stack and return successfully.
if (!pending_values_.top().is_placeholder()) {
return absl_ports::InvalidArgumentError(
@@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) {
end = text_val.raw_term.data() + text_val.raw_term.length();
} else {
start = std::min(start, text_val.raw_term.data());
- end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length());
+ end = std::max(end,
+ text_val.raw_term.data() + text_val.raw_term.length());
}
members.push_back(std::move(text_val.term));
}
@@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
"Function ", node->function_name()->value(), " is not supported."));
return;
}
+ const Function& function = itr->second;
// 2. Put in a placeholder PendingValue
pending_values_.push(PendingValue());
// 3. Visit the children.
- for (const std::unique_ptr<Node>& arg : node->args()) {
+ expecting_numeric_arg_ = true;
+ for (int i = 0; i < node->args().size(); ++i) {
+ const std::unique_ptr<Node>& arg = node->args()[i];
+ libtextclassifier3::StatusOr<DataType> arg_type_or =
+ function.get_param_type(i);
+ bool current_level_expecting_numeric_arg = expecting_numeric_arg_;
+ // If arg_type_or has an error, we should ignore it for now, since
+ // function.Eval should do the type check and return better error messages.
+ if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong ||
+ arg_type_or.ValueOrDie() == DataType::kDouble)) {
+ expecting_numeric_arg_ = true;
+ }
arg->Accept(this);
+ expecting_numeric_arg_ = current_level_expecting_numeric_arg;
if (has_pending_error()) {
return;
}
@@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
pending_values_.pop();
}
std::reverse(args.begin(), args.end());
- const Function& function = itr->second;
auto eval_result = function.Eval(std::move(args));
if (!eval_result.ok()) {
pending_error_ = std::move(eval_result).status();
@@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && {
results.root_iterator = std::move(iterator_or).ValueOrDie();
results.query_term_iterators = std::move(query_term_iterators_);
results.query_terms = std::move(property_query_terms_map_);
+ results.embedding_query_results = std::move(embedding_query_results_);
results.features_in_use = std::move(features_);
return results;
}
diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h
index d090b3c..17149f5 100644
--- a/icing/query/advanced_query_parser/query-visitor.h
+++ b/icing/query/advanced_query_parser/query-visitor.h
@@ -17,13 +17,19 @@
#include <cstdint>
#include <memory>
+#include <set>
#include <stack>
#include <string>
+#include <string_view>
+#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -33,10 +39,12 @@
#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/query/query-features.h"
#include "icing/query/query-results.h"
+#include "icing/query/query-terms.h"
#include "icing/schema/schema-store.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/tokenizer.h"
#include "icing/transform/normalizer.h"
+#include <google/protobuf/repeated_field.h>
namespace icing {
namespace lib {
@@ -45,19 +53,23 @@ namespace lib {
// the parser.
class QueryVisitor : public AbstractSyntaxTreeVisitor {
public:
- explicit QueryVisitor(Index* index,
- const NumericIndex<int64_t>* numeric_index,
- const DocumentStore* document_store,
- const SchemaStore* schema_store,
- const Normalizer* normalizer,
- const Tokenizer* tokenizer,
- std::string_view raw_query_text,
- DocHitInfoIteratorFilter::Options filter_options,
- TermMatchType::Code match_type,
- bool needs_term_frequency_info, int64_t current_time_ms)
- : QueryVisitor(index, numeric_index, document_store, schema_store,
- normalizer, tokenizer, raw_query_text, filter_options,
- match_type, needs_term_frequency_info,
+ explicit QueryVisitor(
+ Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ const Normalizer* normalizer, const Tokenizer* tokenizer,
+ std::string_view raw_query_text,
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options filter_options,
+ TermMatchType::Code match_type,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ embedding_query_metric_type,
+ bool needs_term_frequency_info, int64_t current_time_ms)
+ : QueryVisitor(index, numeric_index, embedding_index, document_store,
+ schema_store, normalizer, tokenizer, raw_query_text,
+ embedding_query_vectors, filter_options, match_type,
+ embedding_query_metric_type, needs_term_frequency_info,
PendingPropertyRestricts(),
/*processing_not=*/false, current_time_ms) {}
@@ -106,22 +118,31 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
explicit QueryVisitor(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const DocumentStore* document_store, const SchemaStore* schema_store,
const Normalizer* normalizer, const Tokenizer* tokenizer,
std::string_view raw_query_text,
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors,
DocHitInfoIteratorFilter::Options filter_options,
- TermMatchType::Code match_type, bool needs_term_frequency_info,
+ TermMatchType::Code match_type,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ embedding_query_metric_type,
+ bool needs_term_frequency_info,
PendingPropertyRestricts pending_property_restricts, bool processing_not,
int64_t current_time_ms)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
document_store_(*document_store),
schema_store_(*schema_store),
normalizer_(*normalizer),
tokenizer_(*tokenizer),
raw_query_text_(raw_query_text),
+ embedding_query_vectors_(embedding_query_vectors),
filter_options_(std::move(filter_options)),
match_type_(match_type),
+ embedding_query_metric_type_(embedding_query_metric_type),
needs_term_frequency_info_(needs_term_frequency_info),
pending_property_restricts_(std::move(pending_property_restricts)),
processing_not_(processing_not),
@@ -264,6 +285,22 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
libtextclassifier3::StatusOr<PendingValue> HasPropertyFunction(
std::vector<PendingValue>&& args);
+ // Implementation of the semanticSearch(vector, low, high, metric) custom
+ // function. This function is used for supporting vector search with a
+ // syntax like `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`.
+ //
+ // low, high, metric are optional parameters:
+ // - low is default to negative infinity
+ // - high is default to positive infinity
+ // - metric is default to the metric specified in SearchSpec
+ //
+ // Returns:
+ // - a Pending Value of type DocHitIterator that returns all documents with
+ // an embedding vector that has a score within [low, high].
+ // - any errors returned by Lexer::ExtractTokens
+ libtextclassifier3::StatusOr<PendingValue> SemanticSearchFunction(
+ std::vector<PendingValue>&& args);
+
// Handles a NaryOperatorNode where the operator is HAS (':') and pushes an
// iterator with the proper section filter applied. If the current property
// restriction represented by pending_property_restricts and the first child
@@ -292,19 +329,26 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
SectionRestrictQueryTermsMap property_query_terms_map_;
QueryTermIteratorsMap query_term_iterators_;
+
+ EmbeddingQueryResults embedding_query_results_;
+
// Set of features invoked in the query.
std::unordered_set<Feature> features_;
Index& index_; // Does not own!
const NumericIndex<int64_t>& numeric_index_; // Does not own!
+ const EmbeddingIndex& embedding_index_; // Does not own!
const DocumentStore& document_store_; // Does not own!
const SchemaStore& schema_store_; // Does not own!
const Normalizer& normalizer_; // Does not own!
const Tokenizer& tokenizer_; // Does not own!
std::string_view raw_query_text_;
+ const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>*
+ embedding_query_vectors_; // Nullable, does not own!
DocHitInfoIteratorFilter::Options filter_options_;
TermMatchType::Code match_type_;
+ SearchSpecProto::EmbeddingQueryMetricType::Code embedding_query_metric_type_;
// Whether or not term_frequency information is needed. This affects:
// - how DocHitInfoIteratorTerms are constructed
// - whether the QueryTermIteratorsMap is populated in the QueryResults.
diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc
index 9455baa..c5ba866 100644
--- a/icing/query/advanced_query_parser/query-visitor_test.cc
+++ b/icing/query/advanced_query_parser/query-visitor_test.cc
@@ -15,6 +15,7 @@
#include "icing/query/advanced_query_parser/query-visitor.h"
#include <cstdint>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <string>
@@ -31,6 +32,7 @@
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/hit/hit.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
@@ -42,6 +44,7 @@
#include "icing/jni/jni-cache.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/platform.h"
+#include "icing/proto/search.pb.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
@@ -54,6 +57,7 @@
#include "icing/store/document-store.h"
#include "icing/store/namespace-id.h"
#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/test-data.h"
@@ -67,16 +71,22 @@
#include "icing/util/clock.h"
#include "icing/util/status-macros.h"
#include "unicode/uloc.h"
+#include <google/protobuf/repeated_field.h>
namespace icing {
namespace lib {
namespace {
+using ::testing::DoubleNear;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
+using ::testing::IsNull;
+using ::testing::Pointee;
using ::testing::UnorderedElementsAre;
+constexpr float kEps = 0.000001;
+
constexpr DocumentId kDocumentId0 = 0;
constexpr DocumentId kDocumentId1 = 1;
constexpr DocumentId kDocumentId2 = 2;
@@ -85,6 +95,18 @@ constexpr SectionId kSectionId0 = 0;
constexpr SectionId kSectionId1 = 1;
constexpr SectionId kSectionId2 = 2;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_UNKNOWN =
+ SearchSpecProto::EmbeddingQueryMetricType::UNKNOWN;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_COSINE = SearchSpecProto::EmbeddingQueryMetricType::COSINE;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_DOT_PRODUCT =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ EMBEDDING_METRIC_EUCLIDEAN =
+ SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN;
+
template <typename T, typename U>
std::vector<T> ExtractKeys(const std::unordered_map<T, U>& map) {
std::vector<T> keys;
@@ -106,6 +128,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
test_dir_ = GetTestTempDir() + "/icing";
index_dir_ = test_dir_ + "/index";
numeric_index_dir_ = test_dir_ + "/numeric_index";
+ embedding_index_dir_ = test_dir_ + "/embedding_index";
store_dir_ = test_dir_ + "/store";
schema_store_dir_ = test_dir_ + "/schema_store";
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -154,6 +177,10 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
+
ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create(
/*max_term_byte_size=*/1000));
@@ -219,6 +246,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
std::string test_dir_;
std::string index_dir_;
std::string numeric_index_dir_;
+ std::string embedding_index_dir_;
std::string schema_store_dir_;
std::string store_dir_;
Clock clock_;
@@ -226,6 +254,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> {
std::unique_ptr<DocumentStore> document_store_;
std::unique_ptr<Index> index_;
std::unique_ptr<DummyNumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<Normalizer> normalizer_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Tokenizer> tokenizer_;
@@ -252,9 +281,11 @@ TEST_P(QueryVisitorTest, SimpleLessThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -295,9 +326,11 @@ TEST_P(QueryVisitorTest, SimpleLessThanEq) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -338,9 +371,11 @@ TEST_P(QueryVisitorTest, SimpleEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -381,9 +416,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThanEq) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -424,9 +461,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -468,9 +507,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -512,9 +553,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanEqual) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -557,9 +600,11 @@ TEST_P(QueryVisitorTest, NestedPropertyLessThan) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -585,9 +630,11 @@ TEST_P(QueryVisitorTest, IntParsingError) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -599,9 +646,11 @@ TEST_P(QueryVisitorTest, NotEqualsUnsupported) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -647,9 +696,11 @@ TEST_P(QueryVisitorTest, LessThanTooManyOperandsInvalid) {
args.push_back(std::move(extra_value_node));
auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -674,9 +725,11 @@ TEST_P(QueryVisitorTest, LessThanTooFewOperandsInvalid) {
args.push_back(std::move(member_node));
auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -705,9 +758,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -727,9 +782,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
TEST_P(QueryVisitorTest, NeverVisitedReturnsInvalid) {
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), "",
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), "", /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
@@ -756,9 +813,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -786,9 +845,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -801,9 +862,11 @@ TEST_P(QueryVisitorTest, NumericComparisonPropertyStringIsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -865,9 +928,11 @@ TEST_P(QueryVisitorTest, NumericComparatorDoesntAffectLaterTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -908,9 +973,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyEnabled) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -960,9 +1027,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyDisabled) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/false, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1012,9 +1081,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1028,9 +1099,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) {
query = CreateQuery("fo*");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1048,9 +1121,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1062,9 +1137,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterNumericValueReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1076,9 +1153,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyRestrictReturnsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -1114,9 +1193,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1137,9 +1218,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) {
query = CreateQuery("ba?fo*");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1174,9 +1257,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1221,9 +1306,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTermPrefix) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1274,9 +1361,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingQuote) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1326,9 +1415,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingEscape) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1380,9 +1471,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1407,9 +1500,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
query = CreateQuery(R"(("foobar\\y"))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1462,9 +1557,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1488,9 +1585,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) {
query = CreateQuery(R"(("foobar\\n"))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1537,9 +1636,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingComplex) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1596,9 +1697,11 @@ TEST_P(QueryVisitorTest, SingleMinusTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1650,9 +1753,11 @@ TEST_P(QueryVisitorTest, SingleNotTerm) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1705,9 +1810,11 @@ TEST_P(QueryVisitorTest, NestedNotTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1774,9 +1881,11 @@ TEST_P(QueryVisitorTest, DeeplyNestedNotTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1813,9 +1922,11 @@ TEST_P(QueryVisitorTest, ImplicitAndTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1856,9 +1967,11 @@ TEST_P(QueryVisitorTest, ExplicitAndTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1899,9 +2012,11 @@ TEST_P(QueryVisitorTest, OrTerms) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1944,9 +2059,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -1969,9 +2086,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
query = CreateQuery("bar OR baz foo");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -1993,9 +2112,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) {
query = CreateQuery("(bar OR baz) foo");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2055,9 +2176,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2075,9 +2198,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) {
query = CreateQuery("foo NOT (bar OR baz)");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2140,9 +2265,11 @@ TEST_P(QueryVisitorTest, PropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2216,9 +2343,11 @@ TEST_F(QueryVisitorTest, MultiPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2259,9 +2388,11 @@ TEST_P(QueryVisitorTest, PropertyFilterStringIsInvalid) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -2315,9 +2446,11 @@ TEST_P(QueryVisitorTest, PropertyFilterNonNormalized) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2386,9 +2519,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithGrouping) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2453,9 +2588,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2474,9 +2611,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) {
/*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2540,9 +2679,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2561,9 +2702,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) {
/*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2627,9 +2770,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2648,9 +2793,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) {
"NOT ", CreateQuery("(foo OR bar)", /*property_restrict=*/"prop1"));
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2728,9 +2875,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2755,9 +2904,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) {
query = CreateQuery("(NOT foo OR bar)", /*property_restrict=*/"prop1");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -2837,9 +2988,11 @@ TEST_P(QueryVisitorTest, SegmentationTest) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -2957,9 +3110,11 @@ TEST_P(QueryVisitorTest, PropertyRestrictsPopCorrectly) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3074,9 +3229,11 @@ TEST_P(QueryVisitorTest, UnsatisfiablePropertyRestrictsPopCorrectly) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3098,9 +3255,11 @@ TEST_F(QueryVisitorTest, UnsupportedFunctionReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3112,9 +3271,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooFewArgumentsReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3126,9 +3287,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooManyArgumentsReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3142,9 +3305,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3154,9 +3319,11 @@ TEST_F(QueryVisitorTest,
query = R"(search(createList("subject")))";
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(),
@@ -3170,9 +3337,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3182,9 +3351,11 @@ TEST_F(QueryVisitorTest,
query = R"(search("foo", 7))";
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(),
@@ -3197,9 +3368,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3260,9 +3433,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_two_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3284,9 +3459,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
R"(", createList("prop1")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3308,9 +3485,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) {
R"(", createList("prop1")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_four_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_four_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_four_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3430,9 +3609,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_one_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3462,9 +3643,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
R"(", createList("prop6", "prop0", "prop4", "prop2")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3488,9 +3671,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) {
R"(", createList("prop0", "prop6")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3610,9 +3795,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(level_one_query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3634,9 +3821,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
R"(", createList("prop6", "prop0", "prop4", "prop2")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query));
QueryVisitor query_visitor_two(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_two);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3659,9 +3848,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) {
R"( "prop0", "prop6", "prop4", "prop7")))");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query));
QueryVisitor query_visitor_three(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(),
- level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor_three);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -3685,9 +3876,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3701,9 +3894,11 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3717,9 +3912,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3732,9 +3929,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3786,9 +3985,11 @@ TEST_P(QueryVisitorTest, PropertyDefinedFunctionReturnsMatchingDocuments) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3839,9 +4040,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3890,9 +4093,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -3910,9 +4115,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3925,9 +4132,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3941,9 +4150,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -3956,9 +4167,11 @@ TEST_F(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
@@ -4015,9 +4228,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) {
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor1(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor1);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -4033,9 +4248,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) {
query = CreateQuery("bar OR NOT hasProperty(\"price\")");
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
QueryVisitor query_visitor2(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor2);
ICING_ASSERT_OK_AND_ASSIGN(query_results,
@@ -4088,9 +4305,11 @@ TEST_P(QueryVisitorTest,
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
ParseQueryHelper(query));
QueryVisitor query_visitor(
- index_.get(), numeric_index_.get(), document_store_.get(),
- schema_store_.get(), normalizer_.get(), tokenizer_.get(), query,
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr,
DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
/*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
root_node->Accept(&query_visitor);
ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
@@ -4102,6 +4321,890 @@ TEST_P(QueryVisitorTest,
EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
}
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithNoArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query = "semanticSearch()";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithIncorrectArgumentTypeReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query = "semanticSearch(0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithExtraArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, \"COSINE\", 0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ GetSearchSpecEmbeddingFunctionWithExtraArgumentReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query index is invalid, since there are only 2 queries.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0, 1))";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithInvalidIndexReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query index is invalid, since there are only 2 queries.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(10))";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionWithInvalidMetricReturnsInvalidArgument) {
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3});
+ *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4});
+
+ // The embedding query metric is invalid.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10, \"UNKNOWN\")";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ // Passing an unknown default metric type without overriding it in the query
+ // expression is also considered invalid.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -10, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_UNKNOWN,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ EXPECT_THAT(std::move(query_visitor2).ConsumeResults(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleLowerBound) {
+ // Index two embedding vectors.
+ PropertyProto::VectorProto vector0 =
+ CreateVector("my_model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector1 =
+ CreateVector("my_model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector0));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has a semantic score of 1 with vector0 and
+ // -1 with vector1.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // The query should match vector0 only.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+
+ // The query should match both vector0 and vector1.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -1.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match nothing, since there is no vector with a
+ // score >= 1.01.
+ query = "semanticSearch(getSearchSpecEmbedding(0), 1.01)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleUpperBound) {
+ // Index two embedding vectors.
+ PropertyProto::VectorProto vector0 =
+ CreateVector("my_model", {0.1, 0.2, 0.3});
+ PropertyProto::VectorProto vector1 =
+ CreateVector("my_model", {-0.1, -0.2, -0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector0));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1), vector1));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has a semantic score of 1 with vector0 and
+ // -1 with vector1.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // The query should match vector1 only.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), -100, 0.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match both vector0 and vector1.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -100, 1.5)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1),
+ Pointee(UnorderedElementsAre(DoubleNear(-1, kEps))));
+
+ // The query should match nothing, since there is no vector with a
+ // score <= -1.01.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -100, -1.01)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty());
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionMetricOverride) {
+ // Index a embedding vector.
+ PropertyProto::VectorProto vector = CreateVector("my_model", {0.1, 0.2, 0.3});
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0), vector));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query that has:
+ // - a cosine semantic score of 1
+ // - a dot product semantic score of 0.14
+ // - a euclidean semantic score of 0
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3});
+
+ // Create a query that overrides the metric to COSINE.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.95, 1.05, \"COSINE\")";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(1, kEps))));
+
+ // Create a query that overrides the metric to DOT_PRODUCT.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), 0.1, 0.2, \"DOT_PRODUCT\")";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_COSINE,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(0.14, kEps))));
+
+ // Create a query that overrides the metric to EUCLIDEAN.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), -0.05, 0.05, \"EUCLIDEAN\")";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor3(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ // The default metric to be overridden
+ EMBEDDING_METRIC_UNKNOWN,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor3);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor3).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_EUCLIDEAN, kDocumentId0),
+ Pointee(UnorderedElementsAre(DoubleNear(0, kEps))));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionMultipleQueries) {
+ // Index 3 embedding vectors for document 0.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId2, kDocumentId0),
+ CreateVector("my_model2", {-1, 2, 3, -4})));
+ // Index 2 embedding vectors for document 1.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId1),
+ CreateVector("my_model2", {1, -2, 3, -4})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ // Semantic scores for this query:
+ // - document 0: -2 (section 0), 0 (section 1)
+ // - document 1: 6 (section 0)
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+ // Semantic scores for this query:
+ // - document 0: 4 (section 2)
+ // - document 1: -2 (section 1)
+ embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model2", {-1, 1, -1, -1});
+
+ // The query can only match document 0:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part should match
+ // semantic scores {-2, 0}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0)" part should match
+ // semantic scores {4}.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -5) AND "
+ "semanticSearch(getSearchSpecEmbedding(1), 0)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0,
+ std::vector<SectionId>{kSectionId0, kSectionId1,
+ kSectionId2}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0)));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(4)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // The query can match both document 0 and document 1:
+ // For document 0:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return
+ // semantic scores {}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return
+ // semantic scores {4}.
+ // For document 1:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return
+ // semantic scores {6}.
+ // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return
+ // semantic scores {}.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), 1) OR "
+ "semanticSearch(getSearchSpecEmbedding(1), 0.1)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ IsNull());
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId2}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ IsNull());
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(4)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest,
+ SemanticSearchFunctionMultipleQueriesScoresMergedRepeat) {
+ // Index 3 embedding vectors for document 0.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, -3})));
+ // Index 2 embedding vectors for document 1.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create two embedding queries.
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ // Semantic scores for this query:
+ // - document 0: -2 (section 0), 0 (section 1)
+ // - document 1: 6 (section 0)
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // The query should match both document 0 and document 1, since the overall
+ // range is [-10, 10]. The scores in the results should be merged.
+ std::string query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 0) OR "
+ "semanticSearch(getSearchSpecEmbedding(0), 0.0001, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // The same query appears twice, in which case all the scores in the results
+ // should repeat twice.
+ query =
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10) OR "
+ "semanticSearch(getSearchSpecEmbedding(0), -10, 10)";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6, 6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2, 0, -2, 0)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionHybridQueries) {
+ // Index terms
+ Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1,
+ TERM_MATCH_PREFIX, /*namespace_id=*/0);
+ ICING_ASSERT_OK(editor.BufferTerm("foo"));
+ ICING_ASSERT_OK(editor.IndexAllBufferedTerms());
+ editor = index_->Edit(kDocumentId1, kSectionId1, TERM_MATCH_PREFIX,
+ /*namespace_id=*/0);
+ ICING_ASSERT_OK(editor.BufferTerm("bar"));
+ ICING_ASSERT_OK(editor.IndexAllBufferedTerms());
+
+ // Index embedding vectors
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query with semantic scores:
+ // - document 0: -2
+ // - document 1: 6
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // Perform a hybrid search:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), 0)" part only matches
+ // document 1.
+ // - The "foo" part only matches document 0.
+ std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0) OR foo";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators),
+ UnorderedElementsAre("foo"));
+ EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre(""));
+ EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo"));
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results for document 1.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(6)));
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ IsNull());
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+
+ // Perform another hybrid search:
+ // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part matches both
+ // document 0 and 1.
+ // - The "foo" part only matches document 0.
+ // As a result, only document 0 will be returned.
+ query = "semanticSearch(getSearchSpecEmbedding(0), -5) AND foo";
+ ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
+ QueryVisitor query_visitor2(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor2);
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor2).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators),
+ UnorderedElementsAre("foo"));
+ EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre(""));
+ EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo"));
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ itr = query_results.root_iterator.get();
+ // Check results for document 0.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{
+ kSectionId0, kSectionId1}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST_F(QueryVisitorTest, SemanticSearchFunctionSectionRestriction) {
+ ICING_ASSERT_OK(schema_store_->SetSchema(
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("type")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("prop1")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("prop2")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build(),
+ /*ignore_errors_and_delete_documents=*/false,
+ /*allow_circular_schema_definitions=*/false));
+
+ // Create two documents.
+ ICING_ASSERT_OK(document_store_->Put(
+ DocumentBuilder().SetKey("ns", "uri0").SetSchema("type").Build()));
+ ICING_ASSERT_OK(document_store_->Put(
+ DocumentBuilder().SetKey("ns", "uri1").SetSchema("type").Build()));
+ // Add embedding vectors into different sections for the two documents.
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId0, kDocumentId0),
+ CreateVector("my_model1", {1, -2, -3})));
+ ICING_ASSERT_OK(embedding_index_->BufferEmbedding(
+ BasicHit(kSectionId1, kDocumentId0),
+ CreateVector("my_model1", {-1, -2, 3})));
+ ICING_ASSERT_OK(
+ embedding_index_->BufferEmbedding(BasicHit(kSectionId0, kDocumentId1),
+ CreateVector("my_model1", {-1, 2, 3})));
+ ICING_ASSERT_OK(
+ embedding_index_->BufferEmbedding(BasicHit(kSectionId1, kDocumentId1),
+ CreateVector("my_model1", {1, 2, -3})));
+ ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex());
+
+ // Create an embedding query with semantic scores:
+ // - document 0: -2 (section 0), 6 (section 1)
+ // - document 1: 2 (section 0), -6 (section 1)
+ google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors;
+ PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add();
+ *embedding_query = CreateVector("my_model1", {-1, -1, 1});
+
+ // An embedding query with section restriction. The scores returned should
+ // only be limited to the section restricted.
+ std::string query = "prop1:semanticSearch(getSearchSpecEmbedding(0), -100)";
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ ParseQueryHelper(query));
+ QueryVisitor query_visitor(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ document_store_.get(), schema_store_.get(), normalizer_.get(),
+ tokenizer_.get(), query, &embedding_query_vectors,
+ DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX,
+ EMBEDDING_METRIC_DOT_PRODUCT,
+ /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds());
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty());
+ EXPECT_THAT(query_results.query_terms, IsEmpty());
+ EXPECT_THAT(query_results.features_in_use,
+ UnorderedElementsAre(kListFilterQueryLanguageFeature,
+ kEmbeddingSearchFeature));
+ DocHitInfoIterator* itr = query_results.root_iterator.get();
+ // Check results.
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1),
+ Pointee(UnorderedElementsAre(2)));
+ ICING_ASSERT_OK(itr->Advance());
+ EXPECT_THAT(
+ itr->doc_hit_info(),
+ EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId0}));
+ EXPECT_THAT(
+ query_results.embedding_query_results.GetMatchedScoresForDocument(
+ /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0),
+ Pointee(UnorderedElementsAre(-2)));
+ EXPECT_THAT(itr->Advance(),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
INSTANTIATE_TEST_SUITE_P(QueryVisitorTest, QueryVisitorTest,
testing::Values(QueryType::kPlain,
QueryType::kSearch));
diff --git a/icing/query/query-features.h b/icing/query/query-features.h
index d829cd7..bc3602f 100644
--- a/icing/query/query-features.h
+++ b/icing/query/query-features.h
@@ -52,9 +52,15 @@ constexpr Feature kListFilterQueryLanguageFeature =
constexpr Feature kHasPropertyFunctionFeature =
"HAS_PROPERTY_FUNCTION"; // Features#HAS_PROPERTY_FUNCTION
+// This feature relates to the use of embedding searches in the advanced query
+// language. Ex. `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`.
+constexpr Feature kEmbeddingSearchFeature =
+ "EMBEDDING_SEARCH"; // Features#EMBEDDING_SEARCH
+
inline std::unordered_set<Feature> GetQueryFeaturesSet() {
return {kNumericSearchFeature, kVerbatimSearchFeature,
- kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature};
+ kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature,
+ kEmbeddingSearchFeature};
}
} // namespace lib
diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc
index bbfbf3c..6e13001 100644
--- a/icing/query/query-processor.cc
+++ b/icing/query/query-processor.cc
@@ -14,6 +14,7 @@
#include "icing/query/query-processor.h"
+#include <cstdint>
#include <deque>
#include <memory>
#include <stack>
@@ -26,6 +27,7 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h"
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
@@ -34,13 +36,14 @@
#include "icing/index/iterator/doc-hit-info-iterator-or.h"
#include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/proto/logging.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
#include "icing/query/advanced_query_parser/query-visitor.h"
#include "icing/query/query-features.h"
-#include "icing/query/query-processor.h"
#include "icing/query/query-results.h"
#include "icing/query/query-terms.h"
#include "icing/query/query-utils.h"
@@ -49,11 +52,11 @@
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter.h"
-#include "icing/tokenization/raw-query-tokenizer.h"
#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer-factory.h"
#include "icing/tokenization/tokenizer.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -109,39 +112,46 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame(
libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>>
QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
- const SchemaStore* schema_store) {
+ const SchemaStore* schema_store, const Clock* clock) {
ICING_RETURN_ERROR_IF_NULL(index);
ICING_RETURN_ERROR_IF_NULL(numeric_index);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
ICING_RETURN_ERROR_IF_NULL(language_segmenter);
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(clock);
- return std::unique_ptr<QueryProcessor>(
- new QueryProcessor(index, numeric_index, language_segmenter, normalizer,
- document_store, schema_store));
+ return std::unique_ptr<QueryProcessor>(new QueryProcessor(
+ index, numeric_index, embedding_index, language_segmenter, normalizer,
+ document_store, schema_store, clock));
}
QueryProcessor::QueryProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
- const SchemaStore* schema_store)
+ const SchemaStore* schema_store,
+ const Clock* clock)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
language_segmenter_(*language_segmenter),
normalizer_(*normalizer),
document_store_(*document_store),
- schema_store_(*schema_store) {}
+ schema_store_(*schema_store),
+ clock_(*clock) {}
libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
const SearchSpecProto& search_spec,
ScoringSpecProto::RankingStrategy::Code ranking_strategy,
- int64_t current_time_ms) {
+ int64_t current_time_ms, QueryStatsProto::SearchStats* search_stats) {
if (search_spec.search_type() == SearchSpecProto::SearchType::UNDEFINED) {
return absl_ports::InvalidArgumentError(absl_ports::StrCat(
"Search type ",
@@ -152,9 +162,9 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
if (search_spec.search_type() ==
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
ICING_VLOG(1) << "Using EXPERIMENTAL_ICING_ADVANCED_QUERY parser!";
- ICING_ASSIGN_OR_RETURN(
- results,
- ParseAdvancedQuery(search_spec, ranking_strategy, current_time_ms));
+ ICING_ASSIGN_OR_RETURN(results,
+ ParseAdvancedQuery(search_spec, ranking_strategy,
+ current_time_ms, search_stats));
} else {
ICING_ASSIGN_OR_RETURN(
results, ParseRawQuery(search_spec, ranking_strategy, current_time_ms));
@@ -167,8 +177,10 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
search_spec.enabled_features().end());
for (const Feature feature : results.features_in_use) {
if (enabled_features.find(feature) == enabled_features.end()) {
- return absl_ports::InvalidArgumentError(
- absl_ports::StrCat("Attempted use of unenabled feature ", feature));
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "Attempted use of unenabled feature ", feature,
+ ". Please make sure that you have explicitly set all advanced query "
+ "features used in this query as enabled in the SearchSpec."));
}
}
@@ -188,17 +200,27 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery(
const SearchSpecProto& search_spec,
ScoringSpecProto::RankingStrategy::Code ranking_strategy,
- int64_t current_time_ms) const {
- QueryResults results;
+ int64_t current_time_ms, QueryStatsProto::SearchStats* search_stats) const {
+ std::unique_ptr<Timer> lexer_timer = clock_.GetNewTimer();
Lexer lexer(search_spec.query(), Lexer::Language::QUERY);
ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens,
lexer.ExtractTokens());
+ if (search_stats != nullptr) {
+ search_stats->set_query_processor_lexer_extract_token_latency_ms(
+ lexer_timer->GetElapsedMilliseconds());
+ }
+ std::unique_ptr<Timer> parser_timer = clock_.GetNewTimer();
Parser parser = Parser::Create(std::move(lexer_tokens));
ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root,
parser.ConsumeQuery());
+ if (search_stats != nullptr) {
+ search_stats->set_query_processor_parser_consume_query_latency_ms(
+ parser_timer->GetElapsedMilliseconds());
+ }
if (tree_root == nullptr) {
+ QueryResults results;
results.root_iterator = std::make_unique<DocHitInfoIteratorAllDocumentId>(
document_store_.last_added_document_id());
return results;
@@ -210,13 +232,23 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery(
DocHitInfoIteratorFilter::Options options = GetFilterOptions(search_spec);
bool needs_term_frequency_info =
ranking_strategy == ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE;
- QueryVisitor query_visitor(&index_, &numeric_index_, &document_store_,
- &schema_store_, &normalizer_,
- plain_tokenizer.get(), search_spec.query(),
- std::move(options), search_spec.term_match_type(),
- needs_term_frequency_info, current_time_ms);
+
+ std::unique_ptr<Timer> query_visitor_timer = clock_.GetNewTimer();
+ QueryVisitor query_visitor(
+ &index_, &numeric_index_, &embedding_index_, &document_store_,
+ &schema_store_, &normalizer_, plain_tokenizer.get(), search_spec.query(),
+ &search_spec.embedding_query_vectors(), std::move(options),
+ search_spec.term_match_type(), search_spec.embedding_query_metric_type(),
+ needs_term_frequency_info, current_time_ms);
tree_root->Accept(&query_visitor);
- return std::move(query_visitor).ConsumeResults();
+ ICING_ASSIGN_OR_RETURN(QueryResults results,
+ std::move(query_visitor).ConsumeResults());
+ if (search_stats != nullptr) {
+ search_stats->set_query_processor_query_visitor_latency_ms(
+ query_visitor_timer->GetElapsedMilliseconds());
+ }
+
+ return results;
}
// TODO(cassiewang): Collect query stats to populate the SearchResultsProto
diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h
index d4c22dd..d90b5f6 100644
--- a/icing/query/query-processor.h
+++ b/icing/query/query-processor.h
@@ -19,17 +19,17 @@
#include <memory>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
-#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
-#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/numeric/numeric-index.h"
+#include "icing/proto/logging.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/query/query-results.h"
-#include "icing/query/query-terms.h"
#include "icing/schema/schema-store.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
namespace icing {
namespace lib {
@@ -48,8 +48,10 @@ class QueryProcessor {
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
- const DocumentStore* document_store, const SchemaStore* schema_store);
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ const Clock* clock);
// Parse the search configurations (including the query, any additional
// filters, etc.) in the SearchSpecProto into one DocHitInfoIterator.
@@ -68,15 +70,17 @@ class QueryProcessor {
libtextclassifier3::StatusOr<QueryResults> ParseSearch(
const SearchSpecProto& search_spec,
ScoringSpecProto::RankingStrategy::Code ranking_strategy,
- int64_t current_time_ms);
+ int64_t current_time_ms,
+ QueryStatsProto::SearchStats* search_stats = nullptr);
private:
explicit QueryProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
- const SchemaStore* schema_store);
+ const SchemaStore* schema_store, const Clock* clock);
// Parse the query into a one DocHitInfoIterator that represents the root of a
// query tree in our new Advanced Query Language.
@@ -88,7 +92,8 @@ class QueryProcessor {
libtextclassifier3::StatusOr<QueryResults> ParseAdvancedQuery(
const SearchSpecProto& search_spec,
ScoringSpecProto::RankingStrategy::Code ranking_strategy,
- int64_t current_time_ms) const;
+ int64_t current_time_ms,
+ QueryStatsProto::SearchStats* search_stats) const;
// Parse the query into a one DocHitInfoIterator that represents the root of a
// query tree.
@@ -106,12 +111,14 @@ class QueryProcessor {
// Not const because we could modify/sort the hit buffer in the lite index at
// query time.
- Index& index_;
- const NumericIndex<int64_t>& numeric_index_;
- const LanguageSegmenter& language_segmenter_;
- const Normalizer& normalizer_;
- const DocumentStore& document_store_;
- const SchemaStore& schema_store_;
+ Index& index_; // Does not own.
+ const NumericIndex<int64_t>& numeric_index_; // Does not own.
+ const EmbeddingIndex& embedding_index_; // Does not own.
+ const LanguageSegmenter& language_segmenter_; // Does not own.
+ const Normalizer& normalizer_; // Does not own.
+ const DocumentStore& document_store_; // Does not own.
+ const SchemaStore& schema_store_; // Does not own.
+ const Clock& clock_; // Does not own.
};
} // namespace lib
diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc
index 025e8e6..5b3bf99 100644
--- a/icing/query/query-processor_benchmark.cc
+++ b/icing/query/query-processor_benchmark.cc
@@ -12,26 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "testing/base/public/benchmark.h"
-#include "gmock/gmock.h"
#include "third_party/absl/flags/flag.h"
+#include "icing/absl_ports/str_cat.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/dummy-numeric-index.h"
-#include "icing/index/numeric/numeric-index.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/term.pb.h"
#include "icing/query/query-processor.h"
+#include "icing/query/query-results.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
+#include "icing/transform/normalizer.h"
#include "icing/util/clock.h"
#include "icing/util/logging.h"
#include "unicode/uloc.h"
@@ -118,6 +132,7 @@ void BM_QueryOneTerm(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -134,6 +149,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -172,8 +190,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get()));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -247,6 +266,7 @@ void BM_QueryFiveTerms(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -263,6 +283,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -315,8 +338,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get()));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
const std::string query_string = absl_ports::StrCat(
input_string_a, " ", input_string_b, " ", input_string_c, " ",
@@ -394,6 +418,7 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -410,6 +435,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -451,8 +479,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get()));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -526,6 +555,7 @@ void BM_QueryHiragana(benchmark::State& state) {
const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark";
const std::string index_dir = base_dir + "/index";
const std::string numeric_index_dir = base_dir + "/numeric_index";
+ const std::string embedding_index_dir = base_dir + "/embedding_index";
const std::string schema_dir = base_dir + "/schema";
const std::string doc_store_dir = base_dir + "/store";
@@ -542,6 +572,9 @@ void BM_QueryHiragana(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
auto numeric_index,
DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ auto embedding_index,
+ EmbeddingIndex::Create(&filesystem, embedding_index_dir));
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
@@ -583,8 +616,9 @@ void BM_QueryHiragana(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
QueryProcessor::Create(index.get(), numeric_index.get(),
- language_segmenter.get(), normalizer.get(),
- document_store.get(), schema_store.get()));
+ embedding_index.get(), language_segmenter.get(),
+ normalizer.get(), document_store.get(),
+ schema_store.get(), &clock));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc
index 53e3035..288c7fb 100644
--- a/icing/query/query-processor_test.cc
+++ b/icing/query/query-processor_test.cc
@@ -14,17 +14,24 @@
#include "icing/query/query-processor.h"
+#include <array>
#include <cstdint>
#include <memory>
#include <string>
+#include <unordered_map>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -33,6 +40,7 @@
#include "icing/jni/jni-cache.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/platform.h"
+#include "icing/proto/logging.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/term.pb.h"
@@ -53,6 +61,8 @@
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
+#include "icing/util/status-macros.h"
#include "unicode/uloc.h"
namespace icing {
@@ -61,6 +71,7 @@ namespace lib {
namespace {
using ::testing::ElementsAre;
+using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAre;
@@ -85,7 +96,8 @@ class QueryProcessorTest
store_dir_(test_dir_ + "/store"),
schema_store_dir_(test_dir_ + "/schema_store"),
index_dir_(test_dir_ + "/index"),
- numeric_index_dir_(test_dir_ + "/numeric_index") {}
+ numeric_index_dir_(test_dir_ + "/numeric_index"),
+ embedding_index_dir_(test_dir_ + "/embedding_index") {}
void SetUp() override {
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -124,6 +136,9 @@ class QueryProcessorTest
ICING_ASSERT_OK_AND_ASSIGN(
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
language_segmenter_factory::SegmenterOptions segmenter_options(
ULOC_US, jni_cache_.get());
@@ -136,9 +151,10 @@ class QueryProcessorTest
ICING_ASSERT_OK_AND_ASSIGN(
query_processor_,
- QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get()));
+ QueryProcessor::Create(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
}
libtextclassifier3::Status AddTokenToIndex(
@@ -174,10 +190,12 @@ class QueryProcessorTest
IcingFilesystem icing_filesystem_;
const std::string index_dir_;
const std::string numeric_index_dir_;
+ const std::string embedding_index_dir_;
protected:
std::unique_ptr<Index> index_;
std::unique_ptr<NumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
FakeClock fake_clock_;
@@ -190,34 +208,51 @@ class QueryProcessorTest
TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) {
EXPECT_THAT(
QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get()),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr,
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get()),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ /*embedding_index=*/nullptr,
+ language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(QueryProcessor::Create(
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ /*language_segmenter=*/nullptr, normalizer_.get(),
+ document_store_.get(), schema_store_.get(), &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- /*language_segmenter=*/nullptr, normalizer_.get(),
- document_store_.get(), schema_store_.get()),
+ embedding_index_.get(), language_segmenter_.get(),
+ /*normalizer=*/nullptr, document_store_.get(),
+ schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
QueryProcessor::Create(
- index_.get(), numeric_index_.get(), language_segmenter_.get(),
- /*normalizer=*/nullptr, document_store_.get(), schema_store_.get()),
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ /*document_store=*/nullptr, schema_store_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- /*document_store=*/nullptr, schema_store_.get()),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ /*schema_store=*/nullptr, &fake_clock_),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), /*clock=*/nullptr),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- /*schema_store=*/nullptr),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) {
@@ -2947,8 +2982,9 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> local_query_processor,
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get()));
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
SearchSpecProto search_spec;
search_spec.set_query("hello");
@@ -3009,8 +3045,9 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> local_query_processor,
QueryProcessor::Create(index_.get(), numeric_index_.get(),
- language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get()));
+ embedding_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
SearchSpecProto search_spec;
search_spec.set_query("hello");
@@ -3320,6 +3357,65 @@ TEST_P(QueryProcessorTest, GroupingInSectionRestriction) {
std::vector<SectionId>{prop1_section_id})));
}
+TEST_P(QueryProcessorTest, ParseAdvancedQueryShouldSetSearchStats) {
+ if (GetParam() !=
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ GTEST_SKIP();
+ }
+
+ // Create the schema and document store
+ SchemaProto schema = SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder().SetType("email"))
+ .Build();
+ ASSERT_THAT(schema_store_->SetSchema(
+ schema, /*ignore_errors_and_delete_documents=*/false,
+ /*allow_circular_schema_definitions=*/false),
+ IsOk());
+
+ // These documents don't actually match to the tokens in the index. We're
+ // inserting the documents to get the appropriate number of documents and
+ // namespaces populated.
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id,
+ document_store_->Put(DocumentBuilder()
+ .SetKey("namespace1", "1")
+ .SetSchema("email")
+ .Build()));
+
+ // Populate the index
+ SectionId section_id = 0;
+ TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
+
+ EXPECT_THAT(
+ AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
+ IsOk());
+ EXPECT_THAT(
+ AddTokenToIndex(document_id, section_id, term_match_type, "world"),
+ IsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_query("hello world");
+ search_spec.set_term_match_type(term_match_type);
+ search_spec.set_search_type(GetParam());
+
+ static constexpr int64_t kSearchStatsLatencyMs = 10;
+ fake_clock_.SetTimerElapsedMilliseconds(kSearchStatsLatencyMs);
+
+ QueryStatsProto::SearchStats search_stats;
+ ICING_ASSERT_OK_AND_ASSIGN(
+ QueryResults results,
+ query_processor_->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE,
+ fake_clock_.GetSystemTimeMilliseconds(), &search_stats));
+
+ ASSERT_THAT(results.root_iterator->Advance(), IsOk());
+ EXPECT_THAT(search_stats.query_processor_lexer_extract_token_latency_ms(),
+ Eq(kSearchStatsLatencyMs));
+ EXPECT_THAT(search_stats.query_processor_parser_consume_query_latency_ms(),
+ Eq(kSearchStatsLatencyMs));
+ EXPECT_THAT(search_stats.query_processor_query_visitor_latency_ms(),
+ Eq(kSearchStatsLatencyMs));
+}
+
INSTANTIATE_TEST_SUITE_P(
QueryProcessorTest, QueryProcessorTest,
testing::Values(
diff --git a/icing/query/query-results.h b/icing/query/query-results.h
index 52cdd71..983ab35 100644
--- a/icing/query/query-results.h
+++ b/icing/query/query-results.h
@@ -18,9 +18,10 @@
#include <memory>
#include <unordered_set>
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
-#include "icing/query/query-terms.h"
#include "icing/query/query-features.h"
+#include "icing/query/query-terms.h"
namespace icing {
namespace lib {
@@ -35,6 +36,10 @@ struct QueryResults {
// beginning with root_iterator.
// This will only be populated when ranking_strategy == RELEVANCE_SCORE.
QueryTermIteratorsMap query_term_iterators;
+ // Contains similarity scores from embedding based queries, which will be used
+ // in the advanced scoring language to determine the results for the
+ // "this.matchedSemanticScores(...)" function.
+ EmbeddingQueryResults embedding_query_results;
// Features that are invoked during query execution.
// The list of possible features is defined in query_features.h.
std::unordered_set<Feature> features_in_use;
diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc
index eb86e3b..dfebb98 100644
--- a/icing/query/suggestion-processor.cc
+++ b/icing/query/suggestion-processor.cc
@@ -14,14 +14,37 @@
#include "icing/query/suggestion-processor.h"
-#include "icing/proto/schema.pb.h"
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/index.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/index/term-metadata.h"
#include "icing/proto/search.pb.h"
#include "icing/query/query-processor.h"
+#include "icing/query/query-results.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-filter-data.h"
#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/store/namespace-id.h"
#include "icing/store/suggestion-result-checker-impl.h"
-#include "icing/tokenization/tokenizer-factory.h"
-#include "icing/tokenization/tokenizer.h"
+#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -29,20 +52,24 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
SuggestionProcessor::Create(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
- const SchemaStore* schema_store) {
+ const SchemaStore* schema_store,
+ const Clock* clock) {
ICING_RETURN_ERROR_IF_NULL(index);
ICING_RETURN_ERROR_IF_NULL(numeric_index);
+ ICING_RETURN_ERROR_IF_NULL(embedding_index);
ICING_RETURN_ERROR_IF_NULL(language_segmenter);
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(clock);
- return std::unique_ptr<SuggestionProcessor>(
- new SuggestionProcessor(index, numeric_index, language_segmenter,
- normalizer, document_store, schema_store));
+ return std::unique_ptr<SuggestionProcessor>(new SuggestionProcessor(
+ index, numeric_index, embedding_index, language_segmenter, normalizer,
+ document_store, schema_store, clock));
}
libtextclassifier3::StatusOr<
@@ -224,8 +251,9 @@ SuggestionProcessor::QuerySuggestions(
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(&index_, &numeric_index_, &language_segmenter_,
- &normalizer_, &document_store_, &schema_store_));
+ QueryProcessor::Create(&index_, &numeric_index_, &embedding_index_,
+ &language_segmenter_, &normalizer_,
+ &document_store_, &schema_store_, &clock_));
SearchSpecProto search_spec;
search_spec.set_query(suggestion_spec.prefix());
@@ -298,14 +326,18 @@ SuggestionProcessor::QuerySuggestions(
SuggestionProcessor::SuggestionProcessor(
Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
- const DocumentStore* document_store, const SchemaStore* schema_store)
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ const Clock* clock)
: index_(*index),
numeric_index_(*numeric_index),
+ embedding_index_(*embedding_index),
language_segmenter_(*language_segmenter),
normalizer_(*normalizer),
document_store_(*document_store),
- schema_store_(*schema_store) {}
+ schema_store_(*schema_store),
+ clock_(*clock) {}
} // namespace lib
} // namespace icing
diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h
index e100031..cf393b4 100644
--- a/icing/query/suggestion-processor.h
+++ b/icing/query/suggestion-processor.h
@@ -15,7 +15,12 @@
#ifndef ICING_QUERY_SUGGESTION_PROCESSOR_H_
#define ICING_QUERY_SUGGESTION_PROCESSOR_H_
+#include <cstdint>
+#include <memory>
+#include <vector>
+
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-index.h"
#include "icing/index/index.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/proto/search.pb.h"
@@ -23,6 +28,7 @@
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
namespace icing {
namespace lib {
@@ -41,9 +47,10 @@ class SuggestionProcessor {
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
Create(Index* index, const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer, const DocumentStore* document_store,
- const SchemaStore* schema_store);
+ const SchemaStore* schema_store, const Clock* clock);
// Query suggestions based on the given SuggestionSpecProto.
//
@@ -57,19 +64,23 @@ class SuggestionProcessor {
private:
explicit SuggestionProcessor(Index* index,
const NumericIndex<int64_t>* numeric_index,
+ const EmbeddingIndex* embedding_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
- const SchemaStore* schema_store);
+ const SchemaStore* schema_store,
+ const Clock* clock);
// Not const because we could modify/sort the TermMetaData buffer in the lite
// index.
- Index& index_;
- const NumericIndex<int64_t>& numeric_index_;
- const LanguageSegmenter& language_segmenter_;
- const Normalizer& normalizer_;
- const DocumentStore& document_store_;
- const SchemaStore& schema_store_;
+ Index& index_; // Does not own.
+ const NumericIndex<int64_t>& numeric_index_; // Does not own.
+ const EmbeddingIndex& embedding_index_; // Does not own.
+ const LanguageSegmenter& language_segmenter_; // Does not own.
+ const Normalizer& normalizer_; // Does not own.
+ const DocumentStore& document_store_; // Does not own.
+ const SchemaStore& schema_store_; // Does not own.
+ const Clock& clock_; // Does not own.
};
} // namespace lib
diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc
index 9f9094d..c9bdd8a 100644
--- a/icing/query/suggestion-processor_test.cc
+++ b/icing/query/suggestion-processor_test.cc
@@ -14,14 +14,30 @@
#include "icing/query/suggestion-processor.h"
+#include <cstdint>
+#include <memory>
#include <string>
+#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
+#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/index.h"
#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/index/term-metadata.h"
+#include "icing/jni/jni-cache.h"
+#include "icing/legacy/index/icing-filesystem.h"
+#include "icing/portable/platform.h"
#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
@@ -30,7 +46,9 @@
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer-factory.h"
+#include "icing/transform/normalizer.h"
#include "unicode/uloc.h"
namespace icing {
@@ -59,7 +77,8 @@ class SuggestionProcessorTest : public Test {
store_dir_(test_dir_ + "/store"),
schema_store_dir_(test_dir_ + "/schema_store"),
index_dir_(test_dir_ + "/index"),
- numeric_index_dir_(test_dir_ + "/numeric_index") {}
+ numeric_index_dir_(test_dir_ + "/numeric_index"),
+ embedding_index_dir_(test_dir_ + "/embedding_index") {}
void SetUp() override {
filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
@@ -105,6 +124,9 @@ class SuggestionProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
numeric_index_,
DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ embedding_index_,
+ EmbeddingIndex::Create(&filesystem_, embedding_index_dir_));
language_segmenter_factory::SegmenterOptions segmenter_options(
ULOC_US, jni_cache_.get());
@@ -118,8 +140,9 @@ class SuggestionProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
suggestion_processor_,
SuggestionProcessor::Create(
- index_.get(), numeric_index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(), schema_store_.get()));
+ index_.get(), numeric_index_.get(), embedding_index_.get(),
+ language_segmenter_.get(), normalizer_.get(), document_store_.get(),
+ schema_store_.get(), &fake_clock_));
}
libtextclassifier3::Status AddTokenToIndex(
@@ -146,10 +169,12 @@ class SuggestionProcessorTest : public Test {
IcingFilesystem icing_filesystem_;
const std::string index_dir_;
const std::string numeric_index_dir_;
+ const std::string embedding_index_dir_;
protected:
std::unique_ptr<Index> index_;
std::unique_ptr<NumericIndex<int64_t>> numeric_index_;
+ std::unique_ptr<EmbeddingIndex> embedding_index_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
FakeClock fake_clock_;
@@ -673,6 +698,16 @@ TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) {
EXPECT_THAT(terms_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
+
+ if (SearchSpecProto::default_instance().search_type() ==
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ suggestion_spec.set_prefix(
+ "bar OR semanticSearch(getSearchSpecEmbedding(0), 0.5, 1)");
+ terms_or = suggestion_processor_->QuerySuggestions(
+ suggestion_spec, fake_clock_.GetSystemTimeMilliseconds());
+ EXPECT_THAT(terms_or,
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ }
}
TEST_F(SuggestionProcessorTest, InvalidPrefixTest) {
diff --git a/icing/schema-builder.h b/icing/schema-builder.h
index c74505e..c702afe 100644
--- a/icing/schema-builder.h
+++ b/icing/schema-builder.h
@@ -56,6 +56,13 @@ constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_UNKNOWN =
constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_RANGE =
IntegerIndexingConfig::NumericMatchType::RANGE;
+constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ EMBEDDING_INDEXING_UNKNOWN =
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN;
+constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ EMBEDDING_INDEXING_LINEAR_SEARCH =
+ EmbeddingIndexingConfig::EmbeddingIndexingType::LINEAR_SEARCH;
+
constexpr PropertyConfigProto::DataType::Code TYPE_UNKNOWN =
PropertyConfigProto::DataType::UNKNOWN;
constexpr PropertyConfigProto::DataType::Code TYPE_STRING =
@@ -70,6 +77,8 @@ constexpr PropertyConfigProto::DataType::Code TYPE_BYTES =
PropertyConfigProto::DataType::BYTES;
constexpr PropertyConfigProto::DataType::Code TYPE_DOCUMENT =
PropertyConfigProto::DataType::DOCUMENT;
+constexpr PropertyConfigProto::DataType::Code TYPE_VECTOR =
+ PropertyConfigProto::DataType::VECTOR;
constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_NONE =
JoinableConfig::ValueType::NONE;
@@ -146,6 +155,15 @@ class PropertyConfigBuilder {
return *this;
}
+ PropertyConfigBuilder& SetDataTypeVector(
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ embedding_indexing_type) {
+ property_.set_data_type(PropertyConfigProto::DataType::VECTOR);
+ property_.mutable_embedding_indexing_config()->set_embedding_indexing_type(
+ embedding_indexing_type);
+ return *this;
+ }
+
PropertyConfigBuilder& SetJoinable(
JoinableConfig::ValueType::Code join_value_type, bool propagate_delete) {
property_.mutable_joinable_config()->set_value_type(join_value_type);
@@ -159,6 +177,11 @@ class PropertyConfigBuilder {
return *this;
}
+ PropertyConfigBuilder& SetDescription(std::string description) {
+ property_.set_description(std::move(description));
+ return *this;
+ }
+
PropertyConfigProto Build() const { return std::move(property_); }
private:
@@ -186,6 +209,11 @@ class SchemaTypeConfigBuilder {
return *this;
}
+ SchemaTypeConfigBuilder& SetDescription(std::string description) {
+ type_config_.set_description(std::move(description));
+ return *this;
+ }
+
SchemaTypeConfigBuilder& AddProperty(PropertyConfigProto property) {
*type_config_.add_properties() = std::move(property);
return *this;
diff --git a/icing/schema/property-util.cc b/icing/schema/property-util.cc
index 67ff748..7c86122 100644
--- a/icing/schema/property-util.cc
+++ b/icing/schema/property-util.cc
@@ -14,6 +14,8 @@
#include "icing/schema/property-util.h"
+#include <cstddef>
+#include <cstdint>
#include <string>
#include <string_view>
#include <vector>
@@ -131,6 +133,14 @@ ExtractPropertyValues<int64_t>(const PropertyProto& property) {
property.int64_values().end());
}
+template <>
+libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>>
+ExtractPropertyValues<PropertyProto::VectorProto>(
+ const PropertyProto& property) {
+ return std::vector<PropertyProto::VectorProto>(
+ property.vector_values().begin(), property.vector_values().end());
+}
+
} // namespace property_util
} // namespace lib
diff --git a/icing/schema/property-util.h b/icing/schema/property-util.h
index 7557879..c409a9d 100644
--- a/icing/schema/property-util.h
+++ b/icing/schema/property-util.h
@@ -15,8 +15,11 @@
#ifndef ICING_SCHEMA_PROPERTY_UTIL_H_
#define ICING_SCHEMA_PROPERTY_UTIL_H_
+#include <cstddef>
+#include <cstdint>
#include <string>
#include <string_view>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
@@ -163,6 +166,11 @@ template <>
libtextclassifier3::StatusOr<std::vector<int64_t>>
ExtractPropertyValues<int64_t>(const PropertyProto& property);
+template <>
+libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>>
+ExtractPropertyValues<PropertyProto::VectorProto>(
+ const PropertyProto& property);
+
template <typename T>
libtextclassifier3::StatusOr<std::vector<T>> ExtractPropertyValuesFromDocument(
const DocumentProto& document, std::string_view property_path) {
diff --git a/icing/schema/property-util_test.cc b/icing/schema/property-util_test.cc
index eddcc84..5e8a430 100644
--- a/icing/schema/property-util_test.cc
+++ b/icing/schema/property-util_test.cc
@@ -22,6 +22,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/portable/equals-proto.h"
#include "icing/proto/document.pb.h"
#include "icing/testing/common-matchers.h"
@@ -30,6 +31,7 @@ namespace lib {
namespace {
+using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
@@ -38,6 +40,8 @@ static constexpr std::string_view kPropertySingleString = "singleString";
static constexpr std::string_view kPropertyRepeatedString = "repeatedString";
static constexpr std::string_view kPropertySingleInteger = "singleInteger";
static constexpr std::string_view kPropertyRepeatedInteger = "repeatedInteger";
+static constexpr std::string_view kPropertySingleVector = "singleVector";
+static constexpr std::string_view kPropertyRepeatedVector = "repeatedVector";
static constexpr std::string_view kTypeNestedTest = "NestedTest";
static constexpr std::string_view kPropertyStr = "str";
@@ -83,6 +87,28 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeInteger) {
IsOkAndHolds(ElementsAre(123, -456, 0)));
}
+TEST(PropertyUtilTest, ExtractPropertyValuesTypeVector) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
+ PropertyProto property;
+ *property.mutable_vector_values()->Add() = vector1;
+ *property.mutable_vector_values()->Add() = vector2;
+
+ EXPECT_THAT(
+ property_util::ExtractPropertyValues<PropertyProto::VectorProto>(
+ property),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
+}
+
TEST(PropertyUtilTest, ExtractPropertyValuesMismatchedType) {
PropertyProto property;
property.mutable_int64_values()->Add(123);
@@ -110,6 +136,16 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeUnimplemented) {
}
TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
DocumentProto document =
DocumentBuilder()
.SetKey("icing", "test/1")
@@ -119,6 +155,9 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
"repeated2", "repeated3")
.AddInt64Property(std::string(kPropertySingleInteger), 123)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3)
+ .AddVectorProperty(std::string(kPropertySingleVector), vector1)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector), vector1,
+ vector2)
.Build();
// Single string
@@ -139,9 +178,33 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) {
EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<int64_t>(
document, /*property_path=*/kPropertyRepeatedInteger),
IsOkAndHolds(ElementsAre(1, 2, 3)));
+ // Single vector
+ EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ document, /*property_path=*/kPropertySingleVector),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1))));
+ // Repeated vector
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ document, /*property_path=*/kPropertyRepeatedVector),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
}
TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+ PropertyProto::VectorProto vector3;
+ vector3.set_model_signature("my_model3");
+ vector3.add_values(1.0f);
+
DocumentProto nested_document =
DocumentBuilder()
.SetKey("icing", "nested/1")
@@ -158,6 +221,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
.AddInt64Property(std::string(kPropertySingleInteger), 123)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2,
3)
+ .AddVectorProperty(std::string(kPropertySingleVector),
+ vector1)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector),
+ vector1, vector2)
.Build(),
DocumentBuilder()
.SetSchema(std::string(kTypeTest))
@@ -168,6 +235,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
.AddInt64Property(std::string(kPropertySingleInteger), 456)
.AddInt64Property(std::string(kPropertyRepeatedInteger), 4, 5,
6)
+ .AddVectorProperty(std::string(kPropertySingleVector),
+ vector2)
+ .AddVectorProperty(std::string(kPropertyRepeatedVector),
+ vector2, vector3)
.Build())
.Build();
@@ -189,6 +260,17 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) {
property_util::ExtractPropertyValuesFromDocument<int64_t>(
nested_document, /*property_path=*/"nestedDocument.repeatedInteger"),
IsOkAndHolds(ElementsAre(1, 2, 3, 4, 5, 6)));
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ nested_document, /*property_path=*/"nestedDocument.singleVector"),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2))));
+ EXPECT_THAT(
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(
+ nested_document, /*property_path=*/"nestedDocument.repeatedVector"),
+ IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2),
+ EqualsProto(vector2), EqualsProto(vector3))));
// Test the property at first level
EXPECT_THAT(
diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc
index e17e388..6830787 100644
--- a/icing/schema/schema-store.cc
+++ b/icing/schema/schema-store.cc
@@ -486,7 +486,7 @@ libtextclassifier3::Status SchemaStore::RegenerateDerivedFiles(
ICING_RETURN_IF_ERROR(schema_file_->Write(std::move(base_schema_ptr)));
// LINT.IfChange(min_overlay_version_compatibility)
- // Although the current version is 3, the schema is compatible with
+ // Although the current version is 4, the schema is compatible with
// version 1, so min_overlay_version_compatibility should be 1.
int32_t min_overlay_version_compatibility = version_util::kVersionOne;
// LINT.ThenChange(//depot/google3/icing/file/version-util.h:kVersion)
diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc
index 8cc7008..ca5cdd3 100644
--- a/icing/schema/schema-store_test.cc
+++ b/icing/schema/schema-store_test.cc
@@ -14,8 +14,10 @@
#include "icing/schema/schema-store.h"
+#include <cstdint>
#include <memory>
#include <string>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
@@ -23,6 +25,7 @@
#include "gtest/gtest.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/document-builder.h"
+#include "icing/file/file-backed-proto.h"
#include "icing/file/filesystem.h"
#include "icing/file/mock-filesystem.h"
#include "icing/file/version-util.h"
@@ -32,10 +35,7 @@
#include "icing/proto/logging.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/storage.pb.h"
-#include "icing/proto/term.pb.h"
#include "icing/schema-builder.h"
-#include "icing/schema/schema-util.h"
-#include "icing/schema/section-manager.h"
#include "icing/schema/section.h"
#include "icing/store/document-filter-data.h"
#include "icing/testing/common-matchers.h"
@@ -142,7 +142,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveConstructible) {
IsOkAndHolds(Eq(expected_checksum)));
SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN,
TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN,
- "prop1");
+ EMBEDDING_INDEXING_UNKNOWN, "prop1");
EXPECT_THAT(move_constructed_schema_store.GetSectionMetadata("type_a"),
IsOkAndHolds(Pointee(ElementsAre(expected_metadata))));
}
@@ -193,7 +193,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveAssignment) {
IsOkAndHolds(Eq(expected_checksum)));
SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN,
TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN,
- "prop1");
+ EMBEDDING_INDEXING_UNKNOWN, "prop1");
EXPECT_THAT(move_assigned_schema_store->GetSectionMetadata("type_a"),
IsOkAndHolds(Pointee(ElementsAre(expected_metadata))));
}
@@ -1527,10 +1527,10 @@ TEST_F(SchemaStoreTest, SetSchemaRegenerateDerivedFilesFailure) {
.Build();
SectionMetadata expected_int_prop1_metadata(
/*id_in=*/0, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, "intProp1");
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, "intProp1");
SectionMetadata expected_string_prop1_metadata(
/*id_in=*/1, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT,
- NUMERIC_MATCH_UNKNOWN, "stringProp1");
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, "stringProp1");
ICING_ASSERT_OK_AND_ASSIGN(SectionGroup section_group,
schema_store->ExtractSections(document));
ASSERT_THAT(section_group.string_sections, SizeIs(1));
diff --git a/icing/schema/schema-util.cc b/icing/schema/schema-util.cc
index 72287a8..976d1b7 100644
--- a/icing/schema/schema-util.cc
+++ b/icing/schema/schema-util.cc
@@ -115,6 +115,12 @@ bool IsIntegerNumericMatchTypeCompatible(
return old_indexed.numeric_match_type() == new_indexed.numeric_match_type();
}
+bool IsEmbeddingIndexingCompatible(const EmbeddingIndexingConfig& old_indexed,
+ const EmbeddingIndexingConfig& new_indexed) {
+ return old_indexed.embedding_indexing_type() ==
+ new_indexed.embedding_indexing_type();
+}
+
bool IsDocumentIndexingCompatible(const DocumentIndexingConfig& old_indexed,
const DocumentIndexingConfig& new_indexed) {
// TODO(b/265304217): This could mark the new schema as incompatible and
@@ -824,6 +830,10 @@ libtextclassifier3::Status SchemaUtil::ValidateDocumentIndexingConfig(
!property_config.document_indexing_config()
.indexable_nested_properties_list()
.empty();
+ case PropertyConfigProto::DataType::VECTOR:
+ return property_config.embedding_indexing_config()
+ .embedding_indexing_type() !=
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN;
case PropertyConfigProto::DataType::UNKNOWN:
case PropertyConfigProto::DataType::DOUBLE:
case PropertyConfigProto::DataType::BOOLEAN:
@@ -1087,7 +1097,10 @@ const SchemaUtil::SchemaDelta SchemaUtil::ComputeCompatibilityDelta(
new_property_config->integer_indexing_config()) ||
!IsDocumentIndexingCompatible(
old_property_config.document_indexing_config(),
- new_property_config->document_indexing_config())) {
+ new_property_config->document_indexing_config()) ||
+ !IsEmbeddingIndexingCompatible(
+ old_property_config.embedding_indexing_config(),
+ new_property_config->embedding_indexing_config())) {
is_index_incompatible = true;
}
diff --git a/icing/schema/schema-util_test.cc b/icing/schema/schema-util_test.cc
index 82683ba..e77ffab 100644
--- a/icing/schema/schema-util_test.cc
+++ b/icing/schema/schema-util_test.cc
@@ -2900,6 +2900,118 @@ TEST_P(SchemaUtilTest,
IsEmpty());
}
+TEST_P(SchemaUtilTest, ChangingIndexedVectorPropertiesMakesIndexIncompatible) {
+ SchemaProto schema_with_indexed_property =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaProto schema_with_unindexed_property =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::SchemaDelta schema_delta;
+ schema_delta.schema_types_index_incompatible.insert(kPersonType);
+
+ // New schema gained a new indexed vector property.
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(
+ schema_with_unindexed_property, schema_with_indexed_property,
+ no_dependents_map),
+ Eq(schema_delta));
+
+ // New schema lost an indexed vector property.
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(
+ schema_with_indexed_property, schema_with_unindexed_property,
+ no_dependents_map),
+ Eq(schema_delta));
+}
+
+TEST_P(SchemaUtilTest, AddingNewIndexedVectorPropertyMakesIndexIncompatible) {
+ // Configure old schema
+ SchemaProto old_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ // Configure new schema
+ SchemaProto new_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("NewEmbeddingProperty")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::SchemaDelta schema_delta;
+ schema_delta.schema_types_index_incompatible.insert(kPersonType);
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema,
+ no_dependents_map),
+ Eq(schema_delta));
+}
+
+TEST_P(SchemaUtilTest,
+ AddingNewNonIndexedVectorPropertyShouldRemainIndexCompatible) {
+ // Configure old schema
+ SchemaProto old_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ // Configure new schema
+ SchemaProto new_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType(kPersonType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("Property")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("NewProperty")
+ .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ SchemaUtil::DependentMap no_dependents_map;
+ EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema,
+ no_dependents_map)
+ .schema_types_index_incompatible,
+ IsEmpty());
+}
+
TEST_P(SchemaUtilTest,
AddingNewIndexedDocumentPropertyMakesIndexAndJoinIncompatible) {
SchemaTypeConfigProto nested_schema =
diff --git a/icing/schema/section-manager.cc b/icing/schema/section-manager.cc
index 3d540d6..8689bf2 100644
--- a/icing/schema/section-manager.cc
+++ b/icing/schema/section-manager.cc
@@ -61,6 +61,7 @@ libtextclassifier3::Status AppendNewSectionMetadata(
property_config.string_indexing_config().tokenizer_type(),
property_config.string_indexing_config().term_match_type(),
property_config.integer_indexing_config().numeric_match_type(),
+ property_config.embedding_indexing_config().embedding_indexing_type(),
std::move(concatenated_path)));
return libtextclassifier3::Status::OK;
}
@@ -162,6 +163,19 @@ libtextclassifier3::StatusOr<SectionGroup> SectionManager::ExtractSections(
section_group.integer_sections);
break;
}
+ case PropertyConfigProto::DataType::VECTOR: {
+ if (section_metadata.embedding_indexing_type ==
+ EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN) {
+ // Skip if embedding indexing type is UNKNOWN.
+ break;
+ }
+ AppendSection(
+ section_metadata,
+ property_util::ExtractPropertyValuesFromDocument<
+ PropertyProto::VectorProto>(document, section_metadata.path),
+ section_group.vector_sections);
+ break;
+ }
default: {
// Skip other data types.
break;
diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc
index eee78e9..b735fb1 100644
--- a/icing/schema/section-manager_test.cc
+++ b/icing/schema/section-manager_test.cc
@@ -14,10 +14,12 @@
#include "icing/schema/section-manager.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -27,6 +29,8 @@
#include "icing/schema-builder.h"
#include "icing/schema/schema-type-manager.h"
#include "icing/schema/schema-util.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-filter-data.h"
#include "icing/store/dynamic-trie-key-mapper.h"
#include "icing/store/key-mapper.h"
#include "icing/testing/common-matchers.h"
@@ -37,6 +41,7 @@ namespace lib {
namespace {
+using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
using ::testing::Pointee;
@@ -48,6 +53,8 @@ static constexpr std::string_view kTypeEmail = "Email";
static constexpr std::string_view kPropertyRecipientIds = "recipientIds";
static constexpr std::string_view kPropertyRecipients = "recipients";
static constexpr std::string_view kPropertySubject = "subject";
+static constexpr std::string_view kPropertySubjectEmbedding =
+ "subjectEmbedding";
static constexpr std::string_view kPropertyTimestamp = "timestamp";
// non-indexable
static constexpr std::string_view kPropertyAttachment = "attachment";
@@ -109,6 +116,14 @@ PropertyConfigProto CreateSubjectPropertyConfig() {
.Build();
}
+PropertyConfigProto CreateSubjectEmbeddingPropertyConfig() {
+ return PropertyConfigBuilder()
+ .SetName(kPropertySubjectEmbedding)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL)
+ .Build();
+}
+
PropertyConfigProto CreateTimestampPropertyConfig() {
return PropertyConfigBuilder()
.SetName(kPropertyTimestamp)
@@ -145,6 +160,7 @@ SchemaTypeConfigProto CreateEmailTypeConfig() {
return SchemaTypeConfigBuilder()
.SetType(kTypeEmail)
.AddProperty(CreateSubjectPropertyConfig())
+ .AddProperty(CreateSubjectEmbeddingPropertyConfig())
.AddProperty(PropertyConfigBuilder()
.SetName(kPropertyText)
.SetDataTypeString(TERM_MATCH_UNKNOWN, TOKENIZER_NONE)
@@ -220,11 +236,19 @@ class SectionManagerTest : public ::testing::Test {
ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeConversation, 1));
ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeGroup, 2));
+ email_subject_embedding_ = PropertyProto::VectorProto();
+ email_subject_embedding_.add_values(1.0);
+ email_subject_embedding_.add_values(2.0);
+ email_subject_embedding_.add_values(3.0);
+ email_subject_embedding_.set_model_signature("my_model");
+
email_document_ =
DocumentBuilder()
.SetKey("icing", "email/1")
.SetSchema(std::string(kTypeEmail))
.AddStringProperty(std::string(kPropertySubject), "the subject")
+ .AddVectorProperty(std::string(kPropertySubjectEmbedding),
+ email_subject_embedding_)
.AddStringProperty(std::string(kPropertyText), "the text")
.AddBytesProperty(std::string(kPropertyAttachment),
"attachment bytes")
@@ -265,6 +289,7 @@ class SectionManagerTest : public ::testing::Test {
SchemaUtil::TypeConfigMap type_config_map_;
std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper_;
+ PropertyProto::VectorProto email_subject_embedding_;
DocumentProto email_document_;
DocumentProto conversation_document_;
DocumentProto group_document_;
@@ -308,11 +333,21 @@ TEST_F(SectionManagerTest, ExtractSections) {
EXPECT_THAT(section_group.integer_sections[0].content, ElementsAre(1, 2, 3));
EXPECT_THAT(section_group.integer_sections[1].metadata,
- EqualsSectionMetadata(/*expected_id=*/3,
+ EqualsSectionMetadata(/*expected_id=*/4,
/*expected_property_path=*/"timestamp",
CreateTimestampPropertyConfig()));
EXPECT_THAT(section_group.integer_sections[1].content,
ElementsAre(kDefaultTimestamp));
+
+ // Vector sections
+ EXPECT_THAT(section_group.vector_sections, SizeIs(1));
+ EXPECT_THAT(
+ section_group.vector_sections[0].metadata,
+ EqualsSectionMetadata(/*expected_id=*/3,
+ /*expected_property_path=*/"subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()));
+ EXPECT_THAT(section_group.vector_sections[0].content,
+ ElementsAre(EqualsProto(email_subject_embedding_)));
}
TEST_F(SectionManagerTest, ExtractSectionsNested) {
@@ -359,11 +394,22 @@ TEST_F(SectionManagerTest, ExtractSectionsNested) {
EXPECT_THAT(
section_group.integer_sections[1].metadata,
- EqualsSectionMetadata(/*expected_id=*/3,
+ EqualsSectionMetadata(/*expected_id=*/4,
/*expected_property_path=*/"emails.timestamp",
CreateTimestampPropertyConfig()));
EXPECT_THAT(section_group.integer_sections[1].content,
ElementsAre(kDefaultTimestamp, kDefaultTimestamp));
+
+ // Vector sections
+ EXPECT_THAT(section_group.vector_sections, SizeIs(1));
+ EXPECT_THAT(section_group.vector_sections[0].metadata,
+ EqualsSectionMetadata(
+ /*expected_id=*/3,
+ /*expected_property_path=*/"emails.subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()));
+ EXPECT_THAT(section_group.vector_sections[0].content,
+ ElementsAre(EqualsProto(email_subject_embedding_),
+ EqualsProto(email_subject_embedding_)));
}
TEST_F(SectionManagerTest, ExtractSectionsIndexableNestedPropertiesList) {
@@ -461,7 +507,8 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
// 0 -> recipientIds
// 1 -> recipients
// 2 -> subject
- // 3 -> timestamp
+ // 3 -> subjectEmbedding
+ // 4 -> timestamp
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/0, /*section_id=*/0),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
@@ -472,13 +519,30 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
IsOkAndHolds(Pointee(EqualsSectionMetadata(
/*expected_id=*/1, /*expected_property_path=*/"recipients",
CreateRecipientsPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/2),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/2, /*expected_property_path=*/"subject",
+ CreateSubjectPropertyConfig()))));
+ EXPECT_THAT(
+ schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/3),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/3, /*expected_property_path=*/"subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/0, /*section_id=*/4),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/4, /*expected_property_path=*/"timestamp",
+ CreateTimestampPropertyConfig()))));
// Conversation (section id -> section property path):
// 0 -> emails.recipientIds
// 1 -> emails.recipients
// 2 -> emails.subject
- // 3 -> emails.timestamp
- // 4 -> name
+ // 3 -> emails.subjectEmbedding
+ // 4 -> emails.timestamp
+ // 5 -> name
EXPECT_THAT(
schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/1, /*section_id=*/0),
@@ -497,16 +561,22 @@ TEST_F(SectionManagerTest, GetSectionMetadata) {
IsOkAndHolds(Pointee(EqualsSectionMetadata(
/*expected_id=*/2, /*expected_property_path=*/"emails.subject",
CreateSubjectPropertyConfig()))));
+ EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
+ /*schema_type_id=*/1, /*section_id=*/3),
+ IsOkAndHolds(Pointee(EqualsSectionMetadata(
+ /*expected_id=*/3,
+ /*expected_property_path=*/"emails.subjectEmbedding",
+ CreateSubjectEmbeddingPropertyConfig()))));
EXPECT_THAT(
schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/3),
+ /*schema_type_id=*/1, /*section_id=*/4),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
- /*expected_id=*/3, /*expected_property_path=*/"emails.timestamp",
+ /*expected_id=*/4, /*expected_property_path=*/"emails.timestamp",
CreateTimestampPropertyConfig()))));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/4),
+ /*schema_type_id=*/1, /*section_id=*/5),
IsOkAndHolds(Pointee(EqualsSectionMetadata(
- /*expected_id=*/4, /*expected_property_path=*/"name",
+ /*expected_id=*/5, /*expected_property_path=*/"name",
CreateNamePropertyConfig()))));
// Group (section id -> section property path):
@@ -615,25 +685,27 @@ TEST_F(SectionManagerTest, GetSectionMetadataInvalidSectionId) {
// 0 -> recipientIds
// 1 -> recipients
// 2 -> subject
- // 3 -> timestamp
+ // 3 -> subjectEmbedding
+ // 4 -> timestamp
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/0, /*section_id=*/-1),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/0, /*section_id=*/4),
+ /*schema_type_id=*/0, /*section_id=*/5),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// Conversation (section id -> section property path):
// 0 -> emails.recipientIds
// 1 -> emails.recipients
// 2 -> emails.subject
- // 3 -> emails.timestamp
- // 4 -> name
+ // 3 -> emails.subjectEmbedding
+ // 4 -> emails.timestamp
+ // 5 -> name
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
/*schema_type_id=*/1, /*section_id=*/-1),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata(
- /*schema_type_id=*/1, /*section_id=*/5),
+ /*schema_type_id=*/1, /*section_id=*/6),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
diff --git a/icing/schema/section.h b/icing/schema/section.h
index 3685a29..76e9e57 100644
--- a/icing/schema/section.h
+++ b/icing/schema/section.h
@@ -21,6 +21,7 @@
#include <utility>
#include <vector>
+#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/term.pb.h"
@@ -89,18 +90,31 @@ struct SectionMetadata {
// Contents will be matched by a range query.
IntegerIndexingConfig::NumericMatchType::Code numeric_match_type;
+ // How vectors in a vector section should be indexed.
+ //
+ // EmbeddingIndexingType::UNKNOWN:
+ // Contents will not be indexed. It is invalid for a vector section
+ // (data_type == 'VECTOR') to have embedding_indexing_type == 'UNKNOWN'.
+ //
+ // EmbeddingIndexingType::LINEAR_SEARCH:
+ // Contents will be indexed for linear search.
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code embedding_indexing_type;
+
explicit SectionMetadata(
SectionId id_in, PropertyConfigProto::DataType::Code data_type_in,
StringIndexingConfig::TokenizerType::Code tokenizer,
TermMatchType::Code term_match_type_in,
IntegerIndexingConfig::NumericMatchType::Code numeric_match_type_in,
+ EmbeddingIndexingConfig::EmbeddingIndexingType::Code
+ embedding_indexing_type_in,
std::string&& path_in)
: path(std::move(path_in)),
id(id_in),
data_type(data_type_in),
tokenizer(tokenizer),
term_match_type(term_match_type_in),
- numeric_match_type(numeric_match_type_in) {}
+ numeric_match_type(numeric_match_type_in),
+ embedding_indexing_type(embedding_indexing_type_in) {}
SectionMetadata(const SectionMetadata& other) = default;
SectionMetadata& operator=(const SectionMetadata& other) = default;
@@ -144,6 +158,7 @@ struct Section {
struct SectionGroup {
std::vector<Section<std::string_view>> string_sections;
std::vector<Section<int64_t>> integer_sections;
+ std::vector<Section<PropertyProto::VectorProto>> vector_sections;
};
} // namespace lib
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc
index 83c1519..e375a8e 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer.cc
@@ -14,14 +14,25 @@
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+#include <cstdint>
#include <memory>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/section-weights.h"
+#include "icing/store/document-store.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -29,11 +40,15 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>>
AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store, int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher) {
+ const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
Lexer lexer(scoring_spec.advanced_scoring_expression(),
Lexer::Language::SCORING);
@@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
std::unique_ptr<Bm25fCalculator> bm25f_calculator =
std::make_unique<Bm25fCalculator>(document_store, section_weights.get(),
current_time_ms);
- ScoringVisitor visitor(default_score, document_store, schema_store,
- section_weights.get(), bm25f_calculator.get(),
- join_children_fetcher, current_time_ms);
+ ScoringVisitor visitor(default_score, default_semantic_metric_type,
+ document_store, schema_store, section_weights.get(),
+ bm25f_calculator.get(), join_children_fetcher,
+ embedding_query_results, current_time_ms);
tree_root->Accept(&visitor);
ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression,
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h
index d69abad..00477a3 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.h
+++ b/icing/scoring/advanced_scoring/advanced-scorer.h
@@ -15,17 +15,26 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
#define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
+#include <cstdint>
#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/scorer.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-store.h"
+#include "icing/util/logging.h"
namespace icing {
namespace lib {
@@ -38,9 +47,11 @@ class AdvancedScorer : public Scorer {
// INVALID_ARGUMENT if fails to create an instance
static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
double GetScore(const DocHitInfo& hit_info,
const DocHitInfoIterator* query_it) override {
@@ -63,7 +74,7 @@ class AdvancedScorer : public Scorer {
bm25f_calculator_->PrepareToScore(query_term_iterators);
}
- bool is_constant() const { return score_expression_->is_constant_double(); }
+ bool is_constant() const { return score_expression_->is_constant(); }
private:
explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
index 3612359..ca081cd 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc
@@ -12,11 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <cstddef>
#include <cstdint>
#include <memory>
+#include <string>
#include <string_view>
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+#include "icing/store/document-store.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/tmp-directory.h"
@@ -29,6 +36,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
const std::string test_dir = GetTestTempDir() + "/icing";
const std::string doc_store_dir = test_dir + "/doc_store";
const std::string schema_store_dir = test_dir + "/schema_store";
+ EmbeddingQueryResults empty_embedding_query_results_;
filesystem.DeleteDirectoryRecursively(test_dir.c_str());
filesystem.CreateDirectoryRecursively(doc_store_dir.c_str());
filesystem.CreateDirectoryRecursively(schema_store_dir.c_str());
@@ -54,9 +62,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
scoring_spec.set_advanced_scoring_expression(text);
AdvancedScorer::Create(scoring_spec,
- /*default_score=*/10, document_store.get(),
- schema_store.get(),
- fake_clock.GetSystemTimeMilliseconds());
+ /*default_score=*/10,
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ fake_clock.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_);
// Not able to test the GetScore method of AdvancedScorer, since it will only
// be available after AdvancedScorer is successfully created. However, the
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
index cc1d413..4b9a46c 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
@@ -15,14 +15,22 @@
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
#include <cmath>
+#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/proto/document.pb.h"
@@ -31,6 +39,8 @@
#include "icing/proto/usage.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer.h"
#include "icing/store/document-id.h"
@@ -45,6 +55,7 @@ namespace lib {
namespace {
using ::testing::DoubleNear;
using ::testing::Eq;
+using ::testing::HasSubstr;
class AdvancedScorerTest : public testing::Test {
protected:
@@ -124,15 +135,19 @@ class AdvancedScorerTest : public testing::Test {
const std::string test_dir_;
const std::string doc_store_dir_;
const std::string schema_store_dir_;
+ EmbeddingQueryResults empty_embedding_query_results_;
Filesystem filesystem_;
std::unique_ptr<SchemaStore> schema_store_;
std::unique_ptr<DocumentStore> document_store_;
FakeClock fake_clock_;
};
-constexpr double kEps = 0.0000000001;
+constexpr double kEps = 0.0000001;
constexpr int kDefaultScore = 0;
constexpr int64_t kDefaultCreationTimestampMs = 1571100001111;
+constexpr SearchSpecProto::EmbeddingQueryMetricType::Code
+ kDefaultSemanticMetricType =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
DocumentProto CreateDocument(
const std::string& name_space, const std::string& uri,
@@ -193,8 +208,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) {
scoring_spec.set_rank_by(
ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ kDefaultSemanticMetricType,
document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// Non-empty scoring expression for normal scoring
@@ -202,8 +220,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) {
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
scoring_spec.set_advanced_scoring_expression("1");
EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ kDefaultSemanticMetricType,
document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -215,9 +236,11 @@ TEST_F(AdvancedScorerTest, SimpleExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("123"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -233,44 +256,61 @@ TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5));
}
@@ -283,103 +323,131 @@ TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("log(2.718281828459045)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sum(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10 + 11 + 12 + 13 + 14));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("avg(10, 11, 12, 13, 14)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq((10 + 11 + 12 + 13 + 14) / 5.));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sin(3.141592653589793)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("cos(3.141592653589793)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
}
@@ -394,17 +462,21 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("this.documentScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.creationTimestamp()"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(kDefaultCreationTimestampMs));
ICING_ASSERT_OK_AND_ASSIGN(
@@ -412,8 +484,10 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) {
AdvancedScorer::Create(
CreateAdvancedScoringSpec(
"this.documentScore() + this.creationTimestamp()"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo),
Eq(123 + kDefaultCreationTimestampMs));
}
@@ -429,8 +503,10 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) {
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageCount(1) + this.usageCount(2) "
"+ this.usageLastUsedTimestamp(3)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
ICING_ASSERT_OK(document_store_->ReportUsage(
CreateUsageReport("namespace", "uri", 100000, UsageReport::USAGE_TYPE1)));
@@ -446,22 +522,28 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(1)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(100000));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(2)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(200000));
ICING_ASSERT_OK_AND_ASSIGN(
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(3)"),
- /*default_score=*/10, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300000));
}
@@ -478,24 +560,29 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionOutOfRange) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("this.usageCount(4)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(4)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount(0)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(0)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount(1.5)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount(1.5)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score));
}
@@ -516,9 +603,11 @@ TEST_F(AdvancedScorerTest, RelevanceScoreFunctionScoreExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec("this.relevanceScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
scorer->PrepareToScore(/*query_term_iterators=*/{});
// Should get the default score.
@@ -566,8 +655,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(2));
// document_id_2 has one child.
@@ -579,8 +669,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("sum(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3));
// document_id_2 has one child with score 4.
@@ -592,8 +683,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
scorer,
AdvancedScorer::Create(
CreateAdvancedScoringSpec("avg(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.));
// document_id_2 has one child with score 4.
@@ -604,13 +696,15 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) {
Eq(default_score));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(
- CreateAdvancedScoringSpec(
- // Equivalent to "avg(this.childrenRankingSignals())"
- "sum(this.childrenRankingSignals()) / "
- "len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fetcher));
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ // Equivalent to "avg(this.childrenRankingSignals())"
+ "sum(this.childrenRankingSignals()) / "
+ "len(this.childrenRankingSignals())"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fetcher, &empty_embedding_query_results_));
// document_id_1 has two children with scores 1 and 2.
EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.));
// document_id_2 has one child with score 4.
@@ -670,9 +764,11 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(spec_proto,
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// min([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// min([0.5, 0.8]) = 0.5
@@ -682,10 +778,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// max([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// max([0.5, 0.8]) = 0.8
@@ -695,10 +794,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) {
spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// sum([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// sum([0.5, 0.8]) = 1.3
@@ -747,9 +849,11 @@ TEST_F(AdvancedScorerTest,
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(spec_proto,
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// min([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// min([0.5, 1, 0.5]) = 0.5
@@ -757,10 +861,13 @@ TEST_F(AdvancedScorerTest,
spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// max([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// max([0.5, 1, 0.5]) = 1
@@ -768,10 +875,13 @@ TEST_F(AdvancedScorerTest,
spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())");
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, AdvancedScorer::Create(spec_proto,
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer,
+ AdvancedScorer::Create(spec_proto,
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
// sum([1]) = 1
EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1));
// sum([0.5, 1, 0.5]) = 2
@@ -786,20 +896,22 @@ TEST_F(AdvancedScorerTest, InvalidChildrenScoresFunctionScoreExpression) {
EXPECT_THAT(
AdvancedScorer::Create(
CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(),
- /*join_children_fetcher=*/nullptr),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
// The root expression can only be of double type, but here it is of list
// type.
JoinChildrenFetcher fake_fetcher(JoinSpecProto::default_instance(),
/*map_joinable_qualified_id=*/{});
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.childrenRankingSignals()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds(), &fake_fetcher),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.childrenRankingSignals()"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ &fake_fetcher, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, ComplexExpression) {
@@ -821,9 +933,11 @@ TEST_F(AdvancedScorerTest, ComplexExpression) {
"+ 10 * (2 + 10 + this.creationTimestamp()))"
// This should evaluate to default score.
"+ this.relevanceScore()"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_FALSE(scorer->is_constant());
scorer->PrepareToScore(/*query_term_iterators=*/{});
@@ -847,19 +961,24 @@ TEST_F(AdvancedScorerTest, ConstantExpression) {
"pow(sin(2), 2)"
"+ log(2, 122) / 12.34"
"* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"),
- /*default_score=*/10, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_TRUE(scorer->is_constant());
}
// Should be a parsing Error
TEST_F(AdvancedScorerTest, EmptyExpression) {
- EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec(""),
- /*default_score=*/10,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec(""),
+ /*default_score=*/10, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
@@ -872,30 +991,38 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("log(0)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer,
- AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
ICING_ASSERT_OK_AND_ASSIGN(
scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()));
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_));
EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
}
@@ -904,133 +1031,402 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
TEST_F(AdvancedScorerTest, MathTypeError) {
const double default_score = 0;
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("test"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(
- AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(AdvancedScorerTest, DocumentFunctionTypeError) {
const double default_score = 0;
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("documentScore(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.creationTimestamp(1)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("documentScore(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.creationTimestamp(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.usageCount()"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("relevanceScore(1)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("documentScore(this)"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("that.documentScore()"), default_score,
+ kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("this.this.creationTimestamp()"),
+ default_score, kDefaultSemanticMetricType, document_store_.get(),
+ schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"),
+ default_score, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results_),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.usageCount()"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+}
+
+TEST_F(AdvancedScorerTest,
+ MatchedSemanticScoresFunctionScoreExpressionTypeError) {
+ libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> scorer_or =
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(matchedSemanticScores(getSearchSpecEmbedding(0)))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("not called with \"this\""));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("relevanceScore(1)"), default_score,
- document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("got invalid argument type for embedding vector"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0), 0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("documentScore(this)"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("Embedding metric can only be given as a string"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSINE\", 0))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("that.documentScore()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("got invalid number of arguments"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSIGN\"))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(
- CreateAdvancedScoringSpec("this.this.creationTimestamp()"),
- default_score, document_store_.get(), schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("Unknown metric type: COSIGN"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(\"0\")))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
- EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"),
- default_score, document_store_.get(),
- schema_store_.get(),
- fake_clock_.GetSystemTimeMilliseconds()),
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("getSearchSpecEmbedding got invalid argument type"));
+
+ scorer_or = AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding()))"),
+ kDefaultSemanticMetricType, kDefaultSemanticMetricType,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_);
+ EXPECT_THAT(scorer_or,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(scorer_or.status().error_message(),
+ HasSubstr("getSearchSpecEmbedding must have 1 argument"));
+}
+
+void AddEntryToEmbeddingQueryScoreMap(
+ EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map,
+ double semantic_score, DocumentId document_id) {
+ score_map[document_id].push_back(semantic_score);
+}
+
+TEST_F(AdvancedScorerTest, MatchedSemanticScoresFunctionScoreExpression) {
+ DocumentId document_id_0 = 0;
+ DocumentId document_id_1 = 1;
+ DocHitInfo doc_hit_info_0(document_id_0);
+ DocHitInfo doc_hit_info_1(document_id_1);
+ EmbeddingQueryResults embedding_query_results;
+
+ // Let the first query assign the following semantic scores:
+ // COSINE:
+ // Document 0: 0.1, 0.2
+ // Document 1: 0.3, 0.4
+ // DOT_PRODUCT:
+ // Document 0: 0.5
+ // Document 1: 0.6
+ // EUCLIDEAN:
+ // Document 0: 0.7
+ // Document 1: 0.8
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map =
+ &embedding_query_results
+ .result_scores[0][SearchSpecProto::EmbeddingQueryMetricType::COSINE];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.1, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.2, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.3, document_id_1);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.4, document_id_1);
+ score_map = &embedding_query_results.result_scores
+ [0][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.5, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.6, document_id_1);
+ score_map =
+ &embedding_query_results
+ .result_scores[0]
+ [SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.7, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.8, document_id_1);
+
+ // Let the second query only assign DOT_PRODUCT scores:
+ // DOT_PRODUCT:
+ // Document 0: 0.1
+ // Document 1: 0.2
+ score_map = &embedding_query_results.result_scores
+ [1][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT];
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.1, document_id_0);
+ AddEntryToEmbeddingQueryScoreMap(*score_map,
+ /*semantic_score=*/0.2, document_id_1);
+
+ // Get semantic scores for default metric (DOT_PRODUCT) for the first query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.5, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.6, kEps));
+
+ // Get semantic scores for a metric overriding the default one for the first
+ // query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(0), \"COSINE\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1 + 0.2, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.3 + 0.4, kEps));
+
+ // Get semantic scores for multiple metrics for the first query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer, AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"COSINE\")) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"DOT_PRODUCT\")) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)"
+ ", \"EUCLIDEAN\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0),
+ DoubleNear(0.1 + 0.2 + 0.5 + 0.7, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1),
+ DoubleNear(0.3 + 0.4 + 0.6 + 0.8, kEps));
+
+ // Get semantic scores for the second query.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.2, kEps));
+
+ // The second query does not contain cosine scores.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sum(this.matchedSemanticScores("
+ "getSearchSpecEmbedding(1), \"COSINE\"))"),
+ kDefaultScore, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store_.get(), schema_store_.get(),
+ fake_clock_.GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &embedding_query_results));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0, kEps));
+ EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0, kEps));
}
} // namespace
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc
index e8a2a89..687180a 100644
--- a/icing/scoring/advanced_scoring/score-expression.cc
+++ b/icing/scoring/advanced_scoring/score-expression.cc
@@ -14,10 +14,39 @@
#include "icing/scoring/advanced_scoring/score-expression.h"
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
#include <numeric>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-associated-score-data.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/util/embedding-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -49,7 +78,7 @@ OperatorScoreExpression::Create(
return absl_ports::InvalidArgumentError(
"Operators are only supported for double type.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
children_all_constant_double = false;
}
}
@@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create(
"Got an invalid type for the math function. Should expect a double "
"type argument.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
args_all_constant_double = false;
}
}
@@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId(
return filter_data_optional.value().schema_type_id();
}
+libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
+GetSearchSpecEmbeddingFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args) {
+ if (args.size() != 1) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " must have 1 argument."));
+ }
+ if (args[0]->type() != ScoreExpressionType::kDouble) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid argument type."));
+ }
+ bool is_constant = args[0]->is_constant();
+ std::unique_ptr<ScoreExpression> expression =
+ std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>(
+ new GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::move(args[0])));
+ if (is_constant) {
+ return ConstantScoreExpression::Create(
+ expression->eval(DocHitInfo(), /*query_it=*/nullptr),
+ expression->type());
+ }
+ return expression;
+}
+
+libtextclassifier3::StatusOr<double>
+GetSearchSpecEmbeddingFunctionScoreExpression::eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ arg_->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ if (query_index != raw_query_index) {
+ return absl_ports::InvalidArgumentError(
+ "The index of an embedding query must be an integer.");
+ }
+ return query_index;
+}
+
+libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results) {
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
+ ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
+
+ if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
+ }
+ if (args.size() != 2 && args.size() != 3) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
+ }
+ if (args[1]->type() != ScoreExpressionType::kVectorIndex) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ kFunctionName, " got invalid argument type for embedding vector."));
+ }
+ if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a string.");
+ }
+
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ default_metric_type;
+ if (args.size() == 3) {
+ if (!args[2]->is_constant()) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a constant string.");
+ }
+ ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string());
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+ return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
+ new MatchedSemanticScoresFunctionScoreExpression(
+ std::move(args), metric_type, *embedding_query_results));
+}
+
+libtextclassifier3::StatusOr<std::vector<double>>
+MatchedSemanticScoresFunctionScoreExpression::eval_list(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ args_[1]->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ const std::vector<double>* scores =
+ embedding_query_results_.GetMatchedScoresForDocument(
+ query_index, metric_type_, hit_info.document_id());
+ if (scores == nullptr) {
+ return std::vector<double>();
+ }
+ return *scores;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h
index 08d7997..e28fcd7 100644
--- a/icing/scoring/advanced_scoring/score-expression.h
+++ b/icing/scoring/advanced_scoring/score-expression.h
@@ -15,20 +15,26 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
#define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
-#include <algorithm>
-#include <cmath>
+#include <cstdint>
#include <memory>
+#include <string>
+#include <string_view>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
-#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -36,7 +42,11 @@ namespace lib {
enum class ScoreExpressionType {
kDouble,
kDoubleList,
- kDocument // Only "this" is considered as document type.
+ kDocument, // Only "this" is considered as document type.
+ // TODO(b/326656531): Instead of creating a vector index type, consider
+ // changing it to vector type so that the data is the vector directly.
+ kVectorIndex,
+ kString,
};
class ScoreExpression {
@@ -75,12 +85,24 @@ class ScoreExpression {
"checking.");
}
+ virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const {
+ if (type() == ScoreExpressionType::kString) {
+ return absl_ports::UnimplementedError(
+ "All ScoreExpressions of type string must provide their own "
+ "implementation of eval_string!");
+ }
+ return absl_ports::InternalError(
+ "Runtime type error: the expression should never be evaluated to a "
+ "string. There must be inconsistencies in the static type checking.");
+ }
+
// Indicate the type to which the current expression will be evaluated.
virtual ScoreExpressionType type() const = 0;
- // Indicate whether the current expression is a constant double.
- // Returns true if and only if the object is of ConstantScoreExpression type.
- virtual bool is_constant_double() const { return false; }
+ // Indicate whether the current expression is a constant.
+ // Returns true if and only if the object is of ConstantScoreExpression or
+ // StringExpression type.
+ virtual bool is_constant() const { return false; }
};
class ThisExpression : public ScoreExpression {
@@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression {
class ConstantScoreExpression : public ScoreExpression {
public:
static std::unique_ptr<ConstantScoreExpression> Create(
- libtextclassifier3::StatusOr<double> c) {
+ libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type = ScoreExpressionType::kDouble) {
return std::unique_ptr<ConstantScoreExpression>(
- new ConstantScoreExpression(c));
+ new ConstantScoreExpression(c, type));
}
libtextclassifier3::StatusOr<double> eval(
@@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression {
return c_;
}
- ScoreExpressionType type() const override {
- return ScoreExpressionType::kDouble;
- }
+ ScoreExpressionType type() const override { return type_; }
- bool is_constant_double() const override { return true; }
+ bool is_constant() const override { return true; }
private:
- explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c)
- : c_(c) {}
+ explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type)
+ : c_(c), type_(type) {}
libtextclassifier3::StatusOr<double> c_;
+ ScoreExpressionType type_;
+};
+
+class StringExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<StringExpression> Create(std::string str) {
+ return std::unique_ptr<StringExpression>(
+ new StringExpression(std::move(str)));
+ }
+
+ libtextclassifier3::StatusOr<std::string_view> eval_string() const override {
+ return str_;
+ }
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kString;
+ }
+
+ bool is_constant() const override { return true; }
+
+ private:
+ explicit StringExpression(std::string str) : str_(std::move(str)) {}
+ std::string str_;
};
class OperatorScoreExpression : public ScoreExpression {
@@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression {
int64_t current_time_ms_;
};
+class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding";
+
+ // RETURNS:
+ // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if
+ // not simplifiable.
+ // - A ConstantScoreExpression instance on success if simplifiable.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args);
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kVectorIndex;
+ }
+
+ private:
+ explicit GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::unique_ptr<ScoreExpression> arg)
+ : arg_(std::move(arg)) {}
+ std::unique_ptr<ScoreExpression> arg_;
+};
+
+class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "matchedSemanticScores";
+
+ // RETURNS:
+ // - A MatchedSemanticScoresFunctionScoreExpression instance on success.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+ Create(std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results);
+
+ libtextclassifier3::StatusOr<std::vector<double>> eval_list(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kDoubleList;
+ }
+
+ private:
+ explicit MatchedSemanticScoresFunctionScoreExpression(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ const EmbeddingQueryResults& embedding_query_results)
+ : args_(std::move(args)),
+ metric_type_(metric_type),
+ embedding_query_results_(embedding_query_results) {}
+
+ std::vector<std::unique_ptr<ScoreExpression>> args_;
+ const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ const EmbeddingQueryResults& embedding_query_results_;
+};
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc
index 588090d..cd58366 100644
--- a/icing/scoring/advanced_scoring/score-expression_test.cc
+++ b/icing/scoring/advanced_scoring/score-expression_test.cc
@@ -14,15 +14,16 @@
#include "icing/scoring/advanced_scoring/score-expression.h"
-#include <cmath>
#include <memory>
-#include <string>
-#include <string_view>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/testing/common-matchers.h"
namespace icing {
@@ -47,7 +48,7 @@ class NonConstantScoreExpression : public ScoreExpression {
return ScoreExpressionType::kDouble;
}
- bool is_constant_double() const override { return false; }
+ bool is_constant() const override { return false; }
};
class ListScoreExpression : public ScoreExpression {
@@ -87,7 +88,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
OperatorScoreExpression::OperatorType::kPlus,
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(1))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
// 1 - 2 - 3 = -4
@@ -97,7 +98,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(3))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4)));
// 1 * 2 * 3 * 4 = 24
@@ -108,7 +109,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(3),
ConstantScoreExpression::Create(4))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24)));
// 1 / 2 / 4 = 0.125
@@ -118,7 +119,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
MakeChildren(ConstantScoreExpression::Create(1),
ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(4))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125)));
// -(2) = -2
@@ -126,7 +127,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) {
expression, OperatorScoreExpression::Create(
OperatorScoreExpression::OperatorType::kNegative,
MakeChildren(ConstantScoreExpression::Create(2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2)));
}
@@ -138,7 +139,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
MathFunctionScoreExpression::FunctionType::kPow,
MakeChildren(ConstantScoreExpression::Create(2),
ConstantScoreExpression::Create(2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4)));
// abs(-2) = 2
@@ -146,7 +147,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kAbs,
MakeChildren(ConstantScoreExpression::Create(-2))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
// log(e) = 1
@@ -154,7 +155,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) {
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kLog,
MakeChildren(ConstantScoreExpression::Create(M_E))));
- ASSERT_TRUE(expression->is_constant_double());
+ ASSERT_TRUE(expression->is_constant());
EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1)));
}
@@ -166,7 +167,7 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
OperatorScoreExpression::OperatorType::kPlus,
MakeChildren(ConstantScoreExpression::Create(1),
NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// non_constant * non_constant = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
@@ -174,14 +175,14 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
OperatorScoreExpression::OperatorType::kTimes,
MakeChildren(NonConstantScoreExpression::Create(),
NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// -(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, OperatorScoreExpression::Create(
OperatorScoreExpression::OperatorType::kNegative,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// pow(non_constant, 2) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
@@ -189,21 +190,21 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
MathFunctionScoreExpression::FunctionType::kPow,
MakeChildren(NonConstantScoreExpression::Create(),
ConstantScoreExpression::Create(2))));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// abs(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kAbs,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
// log(non_constant) = non_constant
ICING_ASSERT_OK_AND_ASSIGN(
expression, MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::FunctionType::kLog,
MakeChildren(NonConstantScoreExpression::Create())));
- ASSERT_FALSE(expression->is_constant_double());
+ ASSERT_FALSE(expression->is_constant());
}
TEST(ScoreExpressionTest, MathFunctionsWithListTypeArgument) {
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc
index e2b24a2..05240c0 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.cc
+++ b/icing/scoring/advanced_scoring/scoring-visitor.cc
@@ -14,7 +14,17 @@
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
namespace icing {
namespace lib {
@@ -25,8 +35,7 @@ void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) {
}
void ScoringVisitor::VisitString(const StringNode* node) {
- pending_error_ =
- absl_ports::InvalidArgumentError("Scoring does not support String!");
+ stack_.push_back(StringExpression::Create(node->value()));
}
void ScoringVisitor::VisitText(const TextNode* node) {
@@ -120,6 +129,17 @@ void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node,
expression = MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::kFunctionNames.at(function_name),
std::move(args));
+ } else if (function_name ==
+ GetSearchSpecEmbeddingFunctionScoreExpression::kFunctionName) {
+ // getSearchSpecEmbedding function
+ expression =
+ GetSearchSpecEmbeddingFunctionScoreExpression::Create(std::move(args));
+ } else if (function_name ==
+ MatchedSemanticScoresFunctionScoreExpression::kFunctionName) {
+ // matchedSemanticScores function
+ expression = MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::move(args), default_semantic_metric_type_,
+ &embedding_query_results_);
}
if (!expression.ok()) {
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h
index cfee25b..bb5e6ba 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.h
+++ b/icing/scoring/advanced_scoring/scoring-visitor.h
@@ -15,14 +15,22 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
#define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/legacy/core/icing-string-util.h"
-#include "icing/proto/scoring.pb.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-store.h"
namespace icing {
@@ -31,18 +39,23 @@ namespace lib {
class ScoringVisitor : public AbstractSyntaxTreeVisitor {
public:
explicit ScoringVisitor(double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store,
SectionWeights* section_weights,
Bm25fCalculator* bm25f_calculator,
const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results,
int64_t current_time_ms)
: default_score_(default_score),
+ default_semantic_metric_type_(default_semantic_metric_type),
document_store_(*document_store),
schema_store_(*schema_store),
section_weights_(*section_weights),
bm25f_calculator_(*bm25f_calculator),
join_children_fetcher_(join_children_fetcher),
+ embedding_query_results_(*embedding_query_results),
current_time_ms_(current_time_ms) {}
void VisitFunctionName(const FunctionNameNode* node) override;
@@ -90,12 +103,15 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor {
}
double default_score_;
+ const SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type_;
const DocumentStore& document_store_;
const SchemaStore& schema_store_;
SectionWeights& section_weights_;
Bm25fCalculator& bm25f_calculator_;
// A non-null join_children_fetcher_ indicates scoring in a join.
const JoinChildrenFetcher* join_children_fetcher_; // Does not own.
+ const EmbeddingQueryResults& embedding_query_results_;
libtextclassifier3::Status pending_error_;
std::vector<std::unique_ptr<ScoreExpression>> stack_;
diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc
index 7cb5a95..4da8d1f 100644
--- a/icing/scoring/score-and-rank_benchmark.cc
+++ b/icing/scoring/score-and-rank_benchmark.cc
@@ -12,18 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <algorithm>
#include <cstdint>
#include <limits>
#include <memory>
#include <random>
#include <string>
+#include <unordered_map>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "testing/base/public/benchmark.h"
+#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
@@ -31,6 +35,7 @@
#include "icing/proto/schema.pb.h"
#include "icing/proto/scoring.pb.h"
#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
#include "icing/scoring/ranker.h"
#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scoring-processor.h"
@@ -131,11 +136,15 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -237,11 +246,15 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(
ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -344,11 +357,15 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -446,11 +463,15 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) {
ScoringSpecProto scoring_spec;
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+ EmbeddingQueryResults empty_embedding_query_results;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get(),
- schema_store.get(),
- clock.GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ scoring_spec, /*default_semantic_metric_type=*/
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT,
+ document_store.get(), schema_store.get(),
+ clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr,
+ &empty_embedding_query_results));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc
index e56f10c..1d66d7f 100644
--- a/icing/scoring/scorer-factory.cc
+++ b/icing/scoring/scorer-factory.cc
@@ -14,19 +14,26 @@
#include "icing/scoring/scorer-factory.h"
+#include <cstdint>
#include <memory>
+#include <optional>
+#include <string>
#include <unordered_map>
+#include <utility>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
#include "icing/proto/scoring.pb.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/scorer.h"
#include "icing/scoring/section-weights.h"
-#include "icing/store/document-id.h"
+#include "icing/store/document-associated-score-data.h"
#include "icing/store/document-store.h"
#include "icing/util/status-macros.h"
@@ -173,10 +180,14 @@ namespace scorer_factory {
libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) {
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
if (!scoring_spec.advanced_scoring_expression().empty() &&
scoring_spec.rank_by() !=
@@ -223,9 +234,10 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
return absl_ports::InvalidArgumentError(
"Advanced scoring is enabled, but the expression is empty!");
}
- return AdvancedScorer::Create(scoring_spec, default_score, document_store,
- schema_store, current_time_ms,
- join_children_fetcher);
+ return AdvancedScorer::Create(
+ scoring_spec, default_score, default_semantic_metric_type,
+ document_store, schema_store, current_time_ms, join_children_fetcher,
+ embedding_query_results);
case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
// Use join aggregate score to rank. Since the aggregation score is
// calculated by child documents after joining (in JoinProcessor), we can
diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h
index 659bebd..f5766b3 100644
--- a/icing/scoring/scorer-factory.h
+++ b/icing/scoring/scorer-factory.h
@@ -15,8 +15,13 @@
#ifndef ICING_SCORING_SCORER_FACTORY_H_
#define ICING_SCORING_SCORER_FACTORY_H_
+#include <cstdint>
+#include <memory>
+
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/join/join-children-fetcher.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/scorer.h"
#include "icing/store/document-store.h"
@@ -37,9 +42,11 @@ namespace scorer_factory {
// INVALID_ARGUMENT if fails to create an instance
libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store, const SchemaStore* schema_store,
- int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
} // namespace scorer_factory
diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc
index 5194c7f..e22d5f4 100644
--- a/icing/scoring/scorer_test.cc
+++ b/icing/scoring/scorer_test.cc
@@ -14,13 +14,19 @@
#include "icing/scoring/scorer.h"
+#include <cstdint>
+#include <limits>
#include <memory>
#include <string>
+#include <utility>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
@@ -30,7 +36,6 @@
#include "icing/schema/schema-store.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer-test-utils.h"
-#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
@@ -107,6 +112,10 @@ class ScorerTest : public ::testing::TestWithParam<ScorerTestingMode> {
fake_clock1_.SetSystemTimeMilliseconds(new_time);
}
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ EmbeddingQueryResults empty_embedding_query_results;
+
private:
const std::string test_dir_;
const std::string doc_store_dir_;
@@ -134,8 +143,10 @@ TEST_P(ScorerTest, CreationWithNullDocumentStoreShouldFail) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, /*document_store=*/nullptr, schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()),
+ /*default_score=*/0, default_semantic_metric_type,
+ /*document_store=*/nullptr, schema_store(),
+ fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
@@ -144,8 +155,9 @@ TEST_P(ScorerTest, CreationWithNullSchemaStoreShouldFail) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, document_store(),
- /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
@@ -155,8 +167,9 @@ TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
// Non existent document id
DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1);
@@ -181,8 +194,9 @@ TEST_P(ScorerTest, ShouldGetDefaultDocumentScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
@@ -206,8 +220,9 @@ TEST_P(ScorerTest, ShouldGetCorrectDocumentScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
@@ -233,8 +248,9 @@ TEST_P(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()),
- /*default_score=*/10, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/10, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
@@ -268,8 +284,9 @@ TEST_P(ScorerTest, ShouldGetCorrectCreationTimestampScore) {
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP,
GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo1 = DocHitInfo(document_id1);
DocHitInfo docHitInfo2 = DocHitInfo(document_id2);
@@ -297,22 +314,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -347,22 +367,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -397,22 +420,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()),
- /*default_score=*/0, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -444,31 +470,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -516,31 +545,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -588,31 +620,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -651,8 +686,9 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
scorer_factory::Create(
CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::NONE, GetParam()),
- /*default_score=*/3, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ /*default_score=*/3, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0);
DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1);
@@ -662,11 +698,13 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, scorer_factory::Create(
- CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::NONE, GetParam()),
- /*default_score=*/111, document_store(), schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer,
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE, GetParam()),
+ /*default_score=*/111, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
docHitInfo1 = DocHitInfo(/*document_id_in=*/4);
docHitInfo2 = DocHitInfo(/*document_id_in=*/5);
@@ -690,13 +728,14 @@ TEST_P(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- scorer_factory::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP,
- GetParam()),
- /*default_score=*/0, document_store(),
- schema_store(),
- fake_clock1().GetSystemTimeMilliseconds()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP,
+ GetParam()),
+ /*default_score=*/0, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock1().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
DocHitInfo docHitInfo = DocHitInfo(document_id);
// Create usage report for the maximum allowable timestamp.
diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc
index b827bd8..bb62db9 100644
--- a/icing/scoring/scoring-processor.cc
+++ b/icing/scoring/scoring-processor.cc
@@ -14,6 +14,7 @@
#include "icing/scoring/scoring-processor.h"
+#include <cstdint>
#include <limits>
#include <memory>
#include <string>
@@ -23,10 +24,12 @@
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
#include "icing/proto/scoring.pb.h"
-#include "icing/scoring/ranker.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer.h"
@@ -44,24 +47,28 @@ constexpr double kDefaultScoreInAscendingOrder =
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
ScoringProcessor::Create(const ScoringSpecProto& scoring_spec,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store,
int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher) {
+ const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
bool is_descending_order =
scoring_spec.order_by() == ScoringSpecProto::Order::DESC;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<Scorer> scorer,
- scorer_factory::Create(scoring_spec,
- is_descending_order
- ? kDefaultScoreInDescendingOrder
- : kDefaultScoreInAscendingOrder,
- document_store, schema_store, current_time_ms,
- join_children_fetcher));
+ scorer_factory::Create(
+ scoring_spec,
+ is_descending_order ? kDefaultScoreInDescendingOrder
+ : kDefaultScoreInAscendingOrder,
+ default_semantic_metric_type, document_store, schema_store,
+ current_time_ms, join_children_fetcher, embedding_query_results));
// Using `new` to access a non-public constructor.
return std::unique_ptr<ScoringProcessor>(
new ScoringProcessor(std::move(scorer)));
diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h
index 8634a22..de0b95c 100644
--- a/icing/scoring/scoring-processor.h
+++ b/icing/scoring/scoring-processor.h
@@ -23,6 +23,7 @@
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/proto/logging.pb.h"
@@ -46,9 +47,12 @@ class ScoringProcessor {
// A ScoringProcessor on success
// FAILED_PRECONDITION on any null pointer input
static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create(
- const ScoringSpecProto& scoring_spec, const DocumentStore* document_store,
- const SchemaStore* schema_store, int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher = nullptr);
+ const ScoringSpecProto& scoring_spec,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
+ const DocumentStore* document_store, const SchemaStore* schema_store,
+ int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results);
// Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns
// a vector of ScoredDocumentHits. The size of results is no more than
diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc
index deddff8..e3b70a6 100644
--- a/icing/scoring/scoring-processor_test.cc
+++ b/icing/scoring/scoring-processor_test.cc
@@ -15,22 +15,39 @@
#include "icing/scoring/scoring-processor.h"
#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/portable-file-backed-proto-log.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/scoring.pb.h"
#include "icing/proto/term.pb.h"
#include "icing/proto/usage.pb.h"
#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/scorer-test-utils.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -111,6 +128,10 @@ class ScoringProcessorTest
const FakeClock& fake_clock() const { return fake_clock_; }
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type =
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ EmbeddingQueryResults empty_embedding_query_results;
+
private:
const std::string test_dir_;
const std::string doc_store_dir_;
@@ -187,27 +208,31 @@ PropertyWeight CreatePropertyWeight(std::string path, double weight) {
TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) {
ScoringSpecProto spec_proto;
- EXPECT_THAT(ScoringProcessor::Create(
- spec_proto, /*document_store=*/nullptr, schema_store(),
- fake_clock().GetSystemTimeMilliseconds()),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, /*document_store=*/nullptr,
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) {
ScoringSpecProto spec_proto;
EXPECT_THAT(
- ScoringProcessor::Create(spec_proto, document_store(),
- /*schema_store=*/nullptr,
- fake_clock().GetSystemTimeMilliseconds()),
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ /*schema_store=*/nullptr, fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_P(ScoringProcessorTest, ShouldCreateInstance) {
ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy(
ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam());
- ICING_EXPECT_OK(
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ICING_EXPECT_OK(ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
}
TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
@@ -222,8 +247,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/5),
@@ -249,8 +276,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/-1),
@@ -280,8 +309,10 @@ TEST_P(ScoringProcessorTest, ShouldRespectNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/2),
@@ -313,8 +344,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByDocumentScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -369,8 +402,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -439,8 +474,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -513,8 +550,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -563,8 +602,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -629,8 +670,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -700,8 +743,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -762,8 +807,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor with no explicit weights set.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
ScoringSpecProto spec_proto_with_weights =
CreateScoringSpecForRankingStrategy(
@@ -778,9 +825,11 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor with default weight set for "body" property.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor_with_weights,
- ScoringProcessor::Create(spec_proto_with_weights, document_store(),
- schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto_with_weights, default_semantic_metric_type,
+ document_store(), schema_store(),
+ fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -866,8 +915,10 @@ TEST_P(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -929,8 +980,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByCreationTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -990,8 +1043,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageCount) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -1051,8 +1106,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -1088,8 +1145,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNoScores) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/4),
ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default),
@@ -1138,8 +1197,10 @@ TEST_P(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store(), schema_store(),
- fake_clock().GetSystemTimeMilliseconds()));
+ ScoringProcessor::Create(
+ spec_proto, default_semantic_metric_type, document_store(),
+ schema_store(), fake_clock().GetSystemTimeMilliseconds(),
+ /*join_children_fetcher=*/nullptr, &empty_embedding_query_results));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
diff --git a/icing/testing/embedding-test-utils.h b/icing/testing/embedding-test-utils.h
new file mode 100644
index 0000000..931953e
--- /dev/null
+++ b/icing/testing/embedding-test-utils.h
@@ -0,0 +1,45 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_TESTING_EMBEDDING_TEST_UTILS_H_
+#define ICING_TESTING_EMBEDDING_TEST_UTILS_H_
+
+#include <initializer_list>
+#include <string>
+
+#include "icing/proto/document.pb.h"
+
+namespace icing {
+namespace lib {
+
+inline PropertyProto::VectorProto CreateVector(
+ const std::string& model_signature, std::initializer_list<float> values) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature(model_signature);
+ for (float value : values) {
+ vector.add_values(value);
+ }
+ return vector;
+}
+
+template <typename... V>
+inline PropertyProto::VectorProto CreateVector(
+ const std::string& model_signature, V&&... values) {
+ return CreateVector(model_signature, values...);
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_TESTING_EMBEDDING_TEST_UTILS_H_
diff --git a/icing/testing/hit-test-utils.cc b/icing/testing/hit-test-utils.cc
index 7ad8a64..c235e23 100644
--- a/icing/testing/hit-test-utils.cc
+++ b/icing/testing/hit-test-utils.cc
@@ -14,29 +14,45 @@
#include "icing/testing/hit-test-utils.h"
+#include <cstdint>
+#include <vector>
+
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/main/posting-list-hit-serializer.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
namespace icing {
namespace lib {
// Returns a hit that has a delta of desired_byte_length from last_hit.
-Hit CreateHit(Hit last_hit, int desired_byte_length) {
- Hit hit = (last_hit.section_id() == kMinSectionId)
- ? Hit(kMaxSectionId, last_hit.document_id() + 1,
- last_hit.term_frequency())
- : Hit(last_hit.section_id() - 1, last_hit.document_id(),
- last_hit.term_frequency());
+Hit CreateHit(const Hit& last_hit, int desired_byte_length) {
+ return CreateHit(last_hit, desired_byte_length, last_hit.term_frequency(),
+ /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false);
+}
+
+// Returns a hit that has a delta of desired_byte_length from last_hit, with
+// the desired term_frequency and flags
+Hit CreateHit(const Hit& last_hit, int desired_byte_length,
+ Hit::TermFrequency term_frequency, bool is_in_prefix_section,
+ bool is_prefix_hit) {
+ Hit hit = last_hit;
uint8_t buf[5];
- while (VarInt::Encode(last_hit.value() - hit.value(), buf) <
- desired_byte_length) {
+ do {
hit = (hit.section_id() == kMinSectionId)
- ? Hit(kMaxSectionId, hit.document_id() + 1, hit.term_frequency())
- : Hit(hit.section_id() - 1, hit.document_id(),
- hit.term_frequency());
- }
+ ? Hit(kMaxSectionId, hit.document_id() + 1, term_frequency,
+ is_in_prefix_section, is_prefix_hit)
+ : Hit(hit.section_id() - 1, hit.document_id(), term_frequency,
+ is_in_prefix_section, is_prefix_hit);
+ } while (PostingListHitSerializer::EncodeNextHitValue(
+ /*next_hit_value=*/hit.value(),
+ /*curr_hit_value=*/last_hit.value(), buf) < desired_byte_length);
return hit;
}
// Returns a vector of num_hits Hits with the first hit starting at start_docid
-// and with 1-byte deltas.
+// and with deltas of the desired byte length.
std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits,
int desired_byte_length) {
std::vector<Hit> hits;
@@ -44,16 +60,56 @@ std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits,
return hits;
}
hits.push_back(Hit(/*section_id=*/1, /*document_id=*/start_docid,
- Hit::kDefaultTermFrequency));
+ Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false,
+ /*is_prefix_hit=*/false));
while (hits.size() < num_hits) {
hits.push_back(CreateHit(hits.back(), desired_byte_length));
}
return hits;
}
+// Returns a vector of num_hits Hits with the first hit being the desired byte
+// length from last_hit, and with deltas of the same desired byte length.
+std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits,
+ int desired_byte_length) {
+ std::vector<Hit> hits;
+ if (num_hits < 1) {
+ return hits;
+ }
+ hits.reserve(num_hits);
+ for (int i = 0; i < num_hits; ++i) {
+ hits.push_back(
+ CreateHit(hits.empty() ? last_hit : hits.back(), desired_byte_length));
+ }
+ return hits;
+}
+
std::vector<Hit> CreateHits(int num_hits, int desired_byte_length) {
return CreateHits(/*start_docid=*/0, num_hits, desired_byte_length);
}
+EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit,
+ uint32_t desired_byte_length) {
+ // Create a delta that has (desired_byte_length - 1) * 7 + 1 bits, so that it
+ // can be encoded in desired_byte_length bytes.
+ uint64_t delta = UINT64_C(1) << ((desired_byte_length - 1) * 7);
+ return EmbeddingHit(last_hit.value() - delta);
+}
+
+std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits,
+ int desired_byte_length) {
+ std::vector<EmbeddingHit> hits;
+ if (num_hits == 0) {
+ return hits;
+ }
+ hits.reserve(num_hits);
+ hits.push_back(EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0),
+ /*location=*/0));
+ for (int i = 1; i < num_hits; ++i) {
+ hits.push_back(CreateEmbeddingHit(hits.back(), desired_byte_length));
+ }
+ return hits;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/testing/hit-test-utils.h b/icing/testing/hit-test-utils.h
index e236ec0..e041c22 100644
--- a/icing/testing/hit-test-utils.h
+++ b/icing/testing/hit-test-utils.h
@@ -15,28 +15,51 @@
#ifndef ICING_TESTING_HIT_TEST_UTILS_H_
#define ICING_TESTING_HIT_TEST_UTILS_H_
+#include <cstdint>
#include <vector>
+#include "icing/index/embed/embedding-hit.h"
#include "icing/index/hit/hit.h"
-#include "icing/legacy/index/icing-bit-util.h"
-#include "icing/schema/section.h"
#include "icing/store/document-id.h"
namespace icing {
namespace lib {
// Returns a hit that has a delta of desired_byte_length from last_hit.
-Hit CreateHit(Hit last_hit, int desired_byte_length);
+Hit CreateHit(const Hit& last_hit, int desired_byte_length);
+
+// Returns a hit that has a delta of desired_byte_length from last_hit, with
+// the desired term_frequency and flags
+Hit CreateHit(const Hit& last_hit, int desired_byte_length,
+ Hit::TermFrequency term_frequency, bool is_in_prefix_section,
+ bool is_prefix_hit);
// Returns a vector of num_hits Hits with the first hit starting at start_docid
// and with desired_byte_length deltas.
std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits,
int desired_byte_length);
+// Returns a vector of num_hits Hits with the first hit being the desired byte
+// length from last_hit, and with deltas of the same desired byte length.
+std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits,
+ int desired_byte_length);
+
// Returns a vector of num_hits Hits with the first hit starting at 0 and each
// with desired_byte_length deltas.
std::vector<Hit> CreateHits(int num_hits, int desired_byte_length);
+// Returns a hit that has a delta of desired_byte_length from last_hit after
+// VarInt encoding.
+// Requires that 0 < desired_byte_length <= VarInt::kMaxEncodedLen64.
+EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit,
+ uint32_t desired_byte_length);
+
+// Returns a vector of num_hits Hits with the first hit starting at document 0
+// and with a delta of desired_byte_length between each subsequent hit after
+// VarInt encoding.
+std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits,
+ int desired_byte_length);
+
} // namespace lib
} // namespace icing
diff --git a/icing/text_classifier/lib3/utils/base/logging.h b/icing/text_classifier/lib3/utils/base/logging.h
index 92d775e..39e09eb 100644
--- a/icing/text_classifier/lib3/utils/base/logging.h
+++ b/icing/text_classifier/lib3/utils/base/logging.h
@@ -16,9 +16,9 @@
#define ICING_TEXT_CLASSIFIER_LIB3_UTILS_BASE_LOGGING_H_
#include <cassert>
+#include <cstdint>
#include <string>
-#include "icing/text_classifier/lib3/utils/base/integral_types.h"
#include "icing/text_classifier/lib3/utils/base/logging_levels.h"
#include "icing/text_classifier/lib3/utils/base/port.h"
@@ -45,7 +45,8 @@ inline LoggingStringStream& operator<<(LoggingStringStream& stream,
template <typename T>
inline LoggingStringStream& operator<<(LoggingStringStream& stream,
T* const entry) {
- stream.message.append(std::to_string(reinterpret_cast<const uint64>(entry)));
+ stream.message.append(
+ std::to_string(reinterpret_cast<const uint64_t>(entry)));
return stream;
}
diff --git a/icing/text_classifier/lib3/utils/java/jni-helper.cc b/icing/text_classifier/lib3/utils/java/jni-helper.cc
index 60a9dfb..cb0b899 100644
--- a/icing/text_classifier/lib3/utils/java/jni-helper.cc
+++ b/icing/text_classifier/lib3/utils/java/jni-helper.cc
@@ -14,6 +14,8 @@
#include "icing/text_classifier/lib3/utils/java/jni-helper.h"
+#include <cstdint>
+
#include "icing/text_classifier/lib3/utils/base/status_macros.h"
namespace libtextclassifier3 {
@@ -121,8 +123,8 @@ StatusOr<bool> JniHelper::CallBooleanMethod(JNIEnv* env, jobject object,
return result;
}
-StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object,
- jmethodID method_id, ...) {
+StatusOr<int32_t> JniHelper::CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
va_list args;
va_start(args, method_id);
jint result = env->CallIntMethodV(object, method_id, args);
@@ -132,8 +134,8 @@ StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object,
return result;
}
-StatusOr<int64> JniHelper::CallLongMethod(JNIEnv* env, jobject object,
- jmethodID method_id, ...) {
+StatusOr<int64_t> JniHelper::CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
va_list args;
va_start(args, method_id);
jlong result = env->CallLongMethodV(object, method_id, args);
diff --git a/icing/text_classifier/lib3/utils/java/jni-helper.h b/icing/text_classifier/lib3/utils/java/jni-helper.h
index 4e548ec..8b57c11 100644
--- a/icing/text_classifier/lib3/utils/java/jni-helper.h
+++ b/icing/text_classifier/lib3/utils/java/jni-helper.h
@@ -20,6 +20,7 @@
#include <jni.h>
+#include <cstdint>
#include <string>
#include "icing/text_classifier/lib3/utils/base/status.h"
@@ -140,10 +141,10 @@ class JniHelper {
...);
static StatusOr<bool> CallBooleanMethod(JNIEnv* env, jobject object,
jmethodID method_id, ...);
- static StatusOr<int32> CallIntMethod(JNIEnv* env, jobject object,
- jmethodID method_id, ...);
- static StatusOr<int64> CallLongMethod(JNIEnv* env, jobject object,
- jmethodID method_id, ...);
+ static StatusOr<int32_t> CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int64_t> CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
static StatusOr<float> CallFloatMethod(JNIEnv* env, jobject object,
jmethodID method_id, ...);
static StatusOr<double> CallDoubleMethod(JNIEnv* env, jobject object,
diff --git a/icing/util/document-validator.cc b/icing/util/document-validator.cc
index e0880ea..bc75334 100644
--- a/icing/util/document-validator.cc
+++ b/icing/util/document-validator.cc
@@ -15,13 +15,23 @@
#include "icing/util/document-validator.h"
#include <cstdint>
+#include <string>
+#include <string_view>
+#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/legacy/core/icing-string-util.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
+#include "icing/schema/schema-store.h"
#include "icing/schema/schema-util.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/util/logging.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -123,6 +133,18 @@ libtextclassifier3::Status DocumentValidator::Validate(
} else if (property_config.data_type() ==
PropertyConfigProto::DataType::DOCUMENT) {
value_size = property.document_values_size();
+ } else if (property_config.data_type() ==
+ PropertyConfigProto::DataType::VECTOR) {
+ value_size = property.vector_values_size();
+ for (const PropertyProto::VectorProto& vector_value :
+ property.vector_values()) {
+ if (vector_value.values_size() == 0) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Property '%s' contains empty vectors for key: (%s, %s).",
+ property.name().c_str(), document.namespace_().c_str(),
+ document.uri().c_str()));
+ }
+ }
}
if (property_config.cardinality() ==
diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc
index 9d10b36..2c366fd 100644
--- a/icing/util/document-validator_test.cc
+++ b/icing/util/document-validator_test.cc
@@ -15,7 +15,11 @@
#include "icing/util/document-validator.h"
#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -42,6 +46,7 @@ constexpr char kPropertySubject[] = "subject";
constexpr char kPropertyText[] = "text";
constexpr char kPropertyRecipients[] = "recipients";
constexpr char kPropertyNote[] = "note";
+constexpr char kPropertyNoteEmbedding[] = "noteEmbedding";
// type and property names of Conversation
constexpr char kTypeConversation[] = "Conversation";
constexpr char kTypeConversationWithEmailNote[] = "ConversationWithEmailNote";
@@ -92,6 +97,10 @@ class DocumentValidatorTest : public ::testing::Test {
.AddProperty(PropertyConfigBuilder()
.SetName(kPropertyNote)
.SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kPropertyNoteEmbedding)
+ .SetDataType(TYPE_VECTOR)
.SetCardinality(CARDINALITY_OPTIONAL)))
.AddType(
SchemaTypeConfigBuilder()
@@ -146,6 +155,12 @@ class DocumentValidatorTest : public ::testing::Test {
}
DocumentBuilder SimpleEmailWithNoteBuilder() {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ vector.set_model_signature("my_model");
+
return DocumentBuilder()
.SetKey(kDefaultNamespace, "email_with_note/1")
.SetSchema(kTypeEmailWithNote)
@@ -153,7 +168,8 @@ class DocumentValidatorTest : public ::testing::Test {
.AddStringProperty(kPropertyText, kDefaultString)
.AddStringProperty(kPropertyRecipients, kDefaultString, kDefaultString,
kDefaultString)
- .AddStringProperty(kPropertyNote, kDefaultString);
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector);
}
DocumentBuilder SimpleConversationBuilder() {
@@ -597,6 +613,64 @@ TEST_F(DocumentValidatorTest, NegativeDocumentTtlMsInvalid) {
HasSubstr("is negative")));
}
+TEST_F(DocumentValidatorTest, ValidateEmbeddingZeroDimensionInvalid) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature("my_model");
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ EXPECT_THAT(document_validator_->Validate(email),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT,
+ HasSubstr("contains empty vectors")));
+}
+
+TEST_F(DocumentValidatorTest, ValidateEmbeddingEmptySignatureOk) {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ vector.set_model_signature("");
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ ICING_EXPECT_OK(document_validator_->Validate(email));
+}
+
+TEST_F(DocumentValidatorTest, ValidateEmbeddingNoSignatureOk) {
+ PropertyProto::VectorProto vector;
+ vector.add_values(0.1);
+ vector.add_values(0.2);
+ vector.add_values(0.3);
+ DocumentProto email =
+ DocumentBuilder()
+ .SetKey(kDefaultNamespace, "email_with_note/1")
+ .SetSchema(kTypeEmailWithNote)
+ .AddStringProperty(kPropertySubject, kDefaultString)
+ .AddStringProperty(kPropertyText, kDefaultString)
+ .AddStringProperty(kPropertyRecipients, kDefaultString,
+ kDefaultString, kDefaultString)
+ .AddStringProperty(kPropertyNote, kDefaultString)
+ .AddVectorProperty(kPropertyNoteEmbedding, vector)
+ .Build();
+ ICING_EXPECT_OK(document_validator_->Validate(email));
+}
+
} // namespace
} // namespace lib
diff --git a/icing/util/embedding-util.h b/icing/util/embedding-util.h
new file mode 100644
index 0000000..5026051
--- /dev/null
+++ b/icing/util/embedding-util.h
@@ -0,0 +1,49 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_UTIL_EMBEDDING_UTIL_H_
+#define ICING_UTIL_EMBEDDING_UTIL_H_
+
+#include <string_view>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/proto/search.pb.h"
+
+namespace icing {
+namespace lib {
+
+namespace embedding_util {
+
+inline libtextclassifier3::StatusOr<
+ SearchSpecProto::EmbeddingQueryMetricType::Code>
+GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name) {
+ if (metric_name == "COSINE") {
+ return SearchSpecProto::EmbeddingQueryMetricType::COSINE;
+ } else if (metric_name == "DOT_PRODUCT") {
+ return SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT;
+ } else if (metric_name == "EUCLIDEAN") {
+ return SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN;
+ }
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unknown metric type: ", metric_name));
+}
+
+} // namespace embedding_util
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_UTIL_EMBEDDING_UTIL_H_
diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc
index 19aaddf..e10fe25 100644
--- a/icing/util/tokenized-document.cc
+++ b/icing/util/tokenized-document.cc
@@ -14,16 +14,18 @@
#include "icing/util/tokenized-document.h"
-#include <string>
+#include <memory>
#include <string_view>
+#include <utility>
#include <vector>
-#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/proto/document.pb.h"
#include "icing/schema/joinable-property.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/tokenization/language-segmenter.h"
+#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer-factory.h"
#include "icing/tokenization/tokenizer.h"
#include "icing/util/document-validator.h"
@@ -85,6 +87,7 @@ TokenizedDocument::Create(const SchemaStore* schema_store,
return TokenizedDocument(std::move(document),
std::move(tokenized_string_sections),
std::move(section_group.integer_sections),
+ std::move(section_group.vector_sections),
std::move(joinable_property_group));
}
diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h
index 7cc34e3..0337083 100644
--- a/icing/util/tokenized-document.h
+++ b/icing/util/tokenized-document.h
@@ -16,7 +16,8 @@
#define ICING_STORE_TOKENIZED_DOCUMENT_H_
#include <cstdint>
-#include <string>
+#include <string_view>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
@@ -63,6 +64,11 @@ class TokenizedDocument {
return integer_sections_;
}
+ const std::vector<Section<PropertyProto::VectorProto>>& vector_sections()
+ const {
+ return vector_sections_;
+ }
+
const std::vector<JoinableProperty<std::string_view>>&
qualified_id_join_properties() const {
return joinable_property_group_.qualified_id_properties;
@@ -74,15 +80,18 @@ class TokenizedDocument {
DocumentProto&& document,
std::vector<TokenizedSection>&& tokenized_string_sections,
std::vector<Section<int64_t>>&& integer_sections,
+ std::vector<Section<PropertyProto::VectorProto>>&& vector_sections,
JoinablePropertyGroup&& joinable_property_group)
: document_(std::move(document)),
tokenized_string_sections_(std::move(tokenized_string_sections)),
integer_sections_(std::move(integer_sections)),
+ vector_sections_(std::move(vector_sections)),
joinable_property_group_(std::move(joinable_property_group)) {}
DocumentProto document_;
std::vector<TokenizedSection> tokenized_string_sections_;
std::vector<Section<int64_t>> integer_sections_;
+ std::vector<Section<PropertyProto::VectorProto>> vector_sections_;
JoinablePropertyGroup joinable_property_group_;
};
diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc
index 7c97776..ab7f4b9 100644
--- a/icing/util/tokenized-document_test.cc
+++ b/icing/util/tokenized-document_test.cc
@@ -16,6 +16,8 @@
#include <memory>
#include <string>
+#include <string_view>
+#include <utility>
#include <vector>
#include "gmock/gmock.h"
@@ -25,7 +27,6 @@
#include "icing/portable/platform.h"
#include "icing/proto/document.pb.h"
#include "icing/proto/schema.pb.h"
-#include "icing/proto/term.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/joinable-property.h"
#include "icing/schema/schema-store.h"
@@ -59,13 +60,19 @@ static constexpr std::string_view kIndexableIntegerProperty1 =
"indexableInteger1";
static constexpr std::string_view kIndexableIntegerProperty2 =
"indexableInteger2";
+static constexpr std::string_view kIndexableVectorProperty1 =
+ "indexableVector1";
+static constexpr std::string_view kIndexableVectorProperty2 =
+ "indexableVector2";
static constexpr std::string_view kStringExactProperty = "stringExact";
static constexpr std::string_view kStringPrefixProperty = "stringPrefix";
static constexpr SectionId kIndexableInteger1SectionId = 0;
static constexpr SectionId kIndexableInteger2SectionId = 1;
-static constexpr SectionId kStringExactSectionId = 2;
-static constexpr SectionId kStringPrefixSectionId = 3;
+static constexpr SectionId kIndexableVector1SectionId = 2;
+static constexpr SectionId kIndexableVector2SectionId = 3;
+static constexpr SectionId kStringExactSectionId = 4;
+static constexpr SectionId kStringPrefixSectionId = 5;
// Joinable properties and joinable property id. Joinable property id is
// determined by the lexicographical order of joinable property path.
@@ -77,19 +84,33 @@ static constexpr JoinablePropertyId kQualifiedId2JoinablePropertyId = 1;
const SectionMetadata kIndexableInteger1SectionMetadata(
kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1));
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kIndexableIntegerProperty1));
const SectionMetadata kIndexableInteger2SectionMetadata(
kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
- NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2));
+ NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kIndexableIntegerProperty2));
+
+const SectionMetadata kIndexableVector1SectionMetadata(
+ kIndexableVector1SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH,
+ std::string(kIndexableVectorProperty1));
+
+const SectionMetadata kIndexableVector2SectionMetadata(
+ kIndexableVector2SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH,
+ std::string(kIndexableVectorProperty2));
const SectionMetadata kStringExactSectionMetadata(
kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT,
- NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty));
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kStringExactProperty));
const SectionMetadata kStringPrefixSectionMetadata(
kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX,
- NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty));
+ NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN,
+ std::string(kStringPrefixProperty));
const JoinablePropertyMetadata kQualifiedId1JoinablePropertyMetadata(
kQualifiedId1JoinablePropertyId, TYPE_STRING,
@@ -102,6 +123,7 @@ const JoinablePropertyMetadata kQualifiedId2JoinablePropertyMetadata(
// Other non-indexable/joinable properties.
constexpr std::string_view kUnindexedStringProperty = "unindexedString";
constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger";
+constexpr std::string_view kUnindexedVectorProperty = "unindexedVector";
class TokenizedDocumentTest : public ::testing::Test {
protected:
@@ -140,6 +162,10 @@ class TokenizedDocumentTest : public ::testing::Test {
.SetDataType(TYPE_INT64)
.SetCardinality(CARDINALITY_OPTIONAL))
.AddProperty(PropertyConfigBuilder()
+ .SetName(kUnindexedVectorProperty)
+ .SetDataType(TYPE_VECTOR)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
.SetName(kIndexableIntegerProperty1)
.SetDataTypeInt64(NUMERIC_MATCH_RANGE)
.SetCardinality(CARDINALITY_REPEATED))
@@ -147,6 +173,16 @@ class TokenizedDocumentTest : public ::testing::Test {
.SetName(kIndexableIntegerProperty2)
.SetDataTypeInt64(NUMERIC_MATCH_RANGE)
.SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kIndexableVectorProperty1)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName(kIndexableVectorProperty2)
+ .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_OPTIONAL))
.AddProperty(PropertyConfigBuilder()
.SetName(kStringExactProperty)
.SetDataTypeString(TERM_MATCH_EXACT,
@@ -196,6 +232,16 @@ class TokenizedDocumentTest : public ::testing::Test {
};
TEST_F(TokenizedDocumentTest, CreateAll) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
DocumentProto document =
DocumentBuilder()
.SetKey("icing", "fake_type/1")
@@ -208,6 +254,10 @@ TEST_F(TokenizedDocumentTest, CreateAll) {
.AddInt64Property(std::string(kUnindexedIntegerProperty), 789)
.AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3)
.AddInt64Property(std::string(kIndexableIntegerProperty2), 456)
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1)
+ .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1,
+ vector2)
+ .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1)
.AddStringProperty(std::string(kQualifiedId1), "pkg$db/ns#uri1")
.AddStringProperty(std::string(kQualifiedId2), "pkg$db/ns#uri2")
.Build();
@@ -244,6 +294,17 @@ TEST_F(TokenizedDocumentTest, CreateAll) {
EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
ElementsAre(456));
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata,
+ Eq(kIndexableVector1SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).content,
+ ElementsAre(EqualsProto(vector1), EqualsProto(vector2)));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata,
+ Eq(kIndexableVector2SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).content,
+ ElementsAre(EqualsProto(vector1)));
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2));
EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata,
@@ -278,6 +339,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -314,6 +378,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) {
EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
ElementsAre(456));
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -341,6 +408,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -381,6 +451,92 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
+ // Qualified id join properties
+ EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateNoIndexableVectorProperties) {
+ PropertyProto::VectorProto vector;
+ vector.set_model_signature("my_model");
+ vector.add_values(1.0f);
+
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
+ // Qualified id join properties
+ EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateMultipleIndexableVectorProperties) {
+ PropertyProto::VectorProto vector1;
+ vector1.set_model_signature("my_model1");
+ vector1.add_values(1.0f);
+ vector1.add_values(2.0f);
+ PropertyProto::VectorProto vector2;
+ vector2.set_model_signature("my_model2");
+ vector2.add_values(-1.0f);
+ vector2.add_values(-2.0f);
+ vector2.add_values(-3.0f);
+
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1)
+ .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1,
+ vector2)
+ .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata,
+ Eq(kIndexableVector1SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(0).content,
+ ElementsAre(EqualsProto(vector1), EqualsProto(vector2)));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata,
+ Eq(kIndexableVector2SectionMetadata));
+ EXPECT_THAT(tokenized_document.vector_sections().at(1).content,
+ ElementsAre(EqualsProto(vector1)));
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -408,6 +564,9 @@ TEST_F(TokenizedDocumentTest, CreateNoJoinQualifiedIdProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty());
}
@@ -437,6 +596,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleJoinQualifiedIdProperties) {
// integer sections
EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+ // vector sections
+ EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty());
+
// Qualified id join properties
EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2));
EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata,
diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java
index 79fcdb8..a8a571e 100644
--- a/java/src/com/google/android/icing/IcingSearchEngine.java
+++ b/java/src/com/google/android/icing/IcingSearchEngine.java
@@ -77,13 +77,6 @@ public class IcingSearchEngine implements IcingSearchEngineInterface {
icingSearchEngineImpl.close();
}
- @SuppressWarnings("deprecation")
- @Override
- protected void finalize() throws Throwable {
- icingSearchEngineImpl.close();
- super.finalize();
- }
-
@NonNull
@Override
public InitializeResultProto initialize() {
diff --git a/java/src/com/google/android/icing/IcingSearchEngineImpl.java b/java/src/com/google/android/icing/IcingSearchEngineImpl.java
index 57744c4..7994162 100644
--- a/java/src/com/google/android/icing/IcingSearchEngineImpl.java
+++ b/java/src/com/google/android/icing/IcingSearchEngineImpl.java
@@ -71,13 +71,6 @@ public class IcingSearchEngineImpl implements Closeable {
closed = true;
}
- @SuppressWarnings("deprecation")
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
- }
-
@Nullable
public byte[] initialize() {
throwIfClosed();
diff --git a/java/src/com/google/android/icing/IcingSearchEngineInterface.java b/java/src/com/google/android/icing/IcingSearchEngineInterface.java
index 0bc58f1..67f60ed 100644
--- a/java/src/com/google/android/icing/IcingSearchEngineInterface.java
+++ b/java/src/com/google/android/icing/IcingSearchEngineInterface.java
@@ -32,7 +32,7 @@ import com.google.android.icing.proto.SuggestionSpecProto;
import com.google.android.icing.proto.UsageReport;
import java.io.Closeable;
-/** A common user-facing interface to expose the funcationalities provided by Icing Library. */
+/** A common user-facing interface to expose the functionalities provided by Icing Library. */
public interface IcingSearchEngineInterface extends Closeable {
/**
* Initializes the current IcingSearchEngine implementation.
diff --git a/proto/icing/proto/document.proto b/proto/icing/proto/document.proto
index 1a501e7..919769e 100644
--- a/proto/icing/proto/document.proto
+++ b/proto/icing/proto/document.proto
@@ -78,7 +78,7 @@ message DocumentProto {
}
// Holds a property field of the Document.
-// Next tag: 8
+// Next tag: 9
message PropertyProto {
// Name of the property.
// See icing.lib.PropertyConfigProto.property_name for details.
@@ -92,6 +92,17 @@ message PropertyProto {
repeated bool boolean_values = 5;
repeated bytes bytes_values = 6;
repeated DocumentProto document_values = 7;
+
+ message VectorProto {
+ // The values of the vector.
+ repeated float values = 1 [packed = true];
+ // The model signature of the vector, which can be any string used to
+ // identify the model, so that embedding searches can be restricted only to
+ // the vectors with the matching target signature.
+ // Eg: "universal-sentence-encoder_v0"
+ optional string model_signature = 2;
+ }
+ repeated VectorProto vector_values = 8;
}
// Result of a call to IcingSearchEngine.Put
diff --git a/proto/icing/proto/initialize.proto b/proto/icing/proto/initialize.proto
index 9dd9e88..7c58d2c 100644
--- a/proto/icing/proto/initialize.proto
+++ b/proto/icing/proto/initialize.proto
@@ -23,6 +23,56 @@ option java_package = "com.google.android.icing.proto";
option java_multiple_files = true;
option objc_class_prefix = "ICNG";
+// Next tag: 7
+message IcingSearchEngineFeatureInfoProto {
+ // REQUIRED: Enum representing an IcingLib feature flagged using
+ // IcingSearchEngineOptions
+ optional FlaggedFeatureType feature_type = 1;
+
+ enum FlaggedFeatureType {
+ // This value should never purposely be used. This is used for backwards
+ // compatibility reasons.
+ UNKNOWN = 0;
+
+ // Feature for flag
+ // IcingSearchEngineOptions::build_property_existence_metadata_hits.
+ //
+ // This feature covers the kHasPropertyFunctionFeature advanced query
+ // feature, and related metadata hits indexing used for property existence
+ // check.
+ FEATURE_HAS_PROPERTY_OPERATOR = 1;
+ }
+
+ // Whether the feature requires the document store to be rebuilt.
+ // The default value is false.
+ optional bool needs_document_store_rebuild = 2;
+
+ // Whether the feature requires the schema store to be rebuilt.
+ // The default value is false.
+ optional bool needs_schema_store_rebuild = 3;
+
+ // Whether the feature requires the term index to be rebuilt.
+ // The default value is false.
+ optional bool needs_term_index_rebuild = 4;
+
+ // Whether the feature requires the integer index to be rebuilt.
+ // The default value is false.
+ optional bool needs_integer_index_rebuild = 5;
+
+ // Whether the feature requires the qualified id join index to be rebuilt.
+ // The default value is false.
+ optional bool needs_qualified_id_join_index_rebuild = 6;
+}
+
+// Next tag: 4
+message IcingSearchEngineVersionProto {
+ // version and max_version are from the original version file.
+ optional int32 version = 1;
+ optional int32 max_version = 2;
+ // Features that are enabled in an icing version at initialization.
+ repeated IcingSearchEngineFeatureInfoProto enabled_features = 3;
+}
+
// Next tag: 16
message IcingSearchEngineOptions {
// Directory to persist files for Icing. Required.
@@ -131,9 +181,6 @@ message IcingSearchEngineOptions {
// Whether to build the metadata hits used for property existence check, which
// is required to support the hasProperty function in advanced query.
- //
- // TODO(b/309826655): Implement the feature flag derived files rebuild
- // mechanism to handle index rebuild, instead of using index's magic value.
optional bool build_property_existence_metadata_hits = 15;
reserved 2;
diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto
index fcedeed..46e988e 100644
--- a/proto/icing/proto/logging.proto
+++ b/proto/icing/proto/logging.proto
@@ -23,7 +23,7 @@ option java_multiple_files = true;
option objc_class_prefix = "ICNG";
// Stats of the top-level function IcingSearchEngine::Initialize().
-// Next tag: 14
+// Next tag: 15
message InitializeStatsProto {
// Overall time used for the function call.
optional int32 latency_ms = 1;
@@ -55,6 +55,10 @@ message InitializeStatsProto {
// Any dependencies have changed.
DEPENDENCIES_CHANGED = 7;
+
+ // Change detected in Icing's feature flags since last initialization that
+ // requires recovery.
+ FEATURE_FLAG_CHANGED = 8;
}
// Possible recovery causes for document store:
@@ -117,10 +121,16 @@ message InitializeStatsProto {
// - SCHEMA_CHANGES_OUT_OF_SYNC
// - IO_ERROR
optional RecoveryCause qualified_id_join_index_restoration_cause = 13;
+
+ // Possible recovery causes for embedding index:
+ // - INCONSISTENT_WITH_GROUND_TRUTH
+ // - SCHEMA_CHANGES_OUT_OF_SYNC
+ // - IO_ERROR
+ optional RecoveryCause embedding_index_restoration_cause = 14;
}
// Stats of the top-level function IcingSearchEngine::Put().
-// Next tag: 12
+// Next tag: 13
message PutDocumentStatsProto {
// Overall time used for the function call.
optional int32 latency_ms = 1;
@@ -166,11 +176,14 @@ message PutDocumentStatsProto {
// Time used to index all metadata terms in the document, which can only be
// added by PropertyExistenceIndexingHandler currently.
optional int32 metadata_term_index_latency_ms = 11;
+
+ // Time used to index all embeddings in the document.
+ optional int32 embedding_index_latency_ms = 12;
}
// Stats of the top-level function IcingSearchEngine::Search() and
// IcingSearchEngine::GetNextPage().
-// Next tag: 26
+// Next tag: 28
message QueryStatsProto {
// TODO(b/305098009): deprecate. Use parent_search_stats instead.
// The UTF-8 length of the query string
@@ -252,7 +265,7 @@ message QueryStatsProto {
optional bool is_join_query = 23;
// Stats of the search. Only valid for first page.
- // Next tag: 13
+ // Next tag: 16
message SearchStats {
// The UTF-8 length of the query string
optional int32 query_length = 1;
@@ -290,6 +303,15 @@ message QueryStatsProto {
// Number of hits fetched by integer index before applying any filters.
optional int32 num_fetched_hits_integer_index = 12;
+
+ // Time used in Lexer to extract lexer tokens from the query.
+ optional int32 query_processor_lexer_extract_token_latency_ms = 13;
+
+ // Time used in Parser to consume lexer tokens extracted from the query.
+ optional int32 query_processor_parser_consume_query_latency_ms = 14;
+
+ // Time used in QueryVisitor to visit and build (nested) DocHitInfoIterator.
+ optional int32 query_processor_query_visitor_latency_ms = 15;
}
// Search stats for parent. Only valid for first page.
@@ -298,6 +320,12 @@ message QueryStatsProto {
// Search stats for child.
optional SearchStats child_search_stats = 25;
+ // Byte size of the lite index hit buffer.
+ optional int64 lite_index_hit_buffer_byte_size = 26;
+
+ // Byte size of the unsorted tail of the lite index hit buffer.
+ optional int64 lite_index_hit_buffer_unsorted_byte_size = 27;
+
reserved 9;
}
diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto
index c716dba..99439bb 100644
--- a/proto/icing/proto/schema.proto
+++ b/proto/icing/proto/schema.proto
@@ -34,7 +34,7 @@ option objc_class_prefix = "ICNG";
// TODO(cassiewang) Define a sample proto file that can be used by tests and for
// documentation.
//
-// Next tag: 7
+// Next tag: 8
message SchemaTypeConfigProto {
// REQUIRED: Named type that uniquely identifies the structured, logical
// schema being defined.
@@ -43,6 +43,13 @@ message SchemaTypeConfigProto {
// in http://schema.org. Eg: DigitalDocument, Message, Person, etc.
optional string schema_type = 1;
+ // OPTIONAL: A natural language description of the SchemaTypeConfigProto.
+ //
+ // This string is not used by Icing in any way. It simply exists to allow
+ // users to store semantic information about the SchemaTypeConfigProto for
+ // future retrieval.
+ optional string description = 7;
+
// List of all properties that are supported by Documents of this type.
// An Document should never have properties that are not listed here.
//
@@ -177,6 +184,26 @@ message IntegerIndexingConfig {
optional NumericMatchType.Code numeric_match_type = 1;
}
+// Describes how a vector property should be indexed.
+// Next tag: 3
+message EmbeddingIndexingConfig {
+ // OPTIONAL: Indicates how the vector contents of this property should be
+ // matched.
+ //
+ // The default value is UNKNOWN.
+ message EmbeddingIndexingType {
+ enum Code {
+ // Contents in this property will not be indexed. Useful if the vector
+ // property type is not indexable.
+ UNKNOWN = 0;
+
+ // Contents in this property will be indexed for linear search.
+ LINEAR_SEARCH = 1;
+ }
+ }
+ optional EmbeddingIndexingType.Code embedding_indexing_type = 1;
+}
+
// Describes how a property can be used to join this document with another
// document. See JoinSpecProto (in search.proto) for more details.
// Next tag: 3
@@ -208,7 +235,7 @@ message JoinableConfig {
// Describes the schema of a single property of Documents that belong to a
// specific SchemaTypeConfigProto. These can be considered as a rich, structured
// type for each property of Documents accepted by IcingSearchEngine.
-// Next tag: 9
+// Next tag: 11
message PropertyConfigProto {
// REQUIRED: Name that uniquely identifies a property within an Document of
// a specific SchemaTypeConfigProto.
@@ -219,6 +246,13 @@ message PropertyConfigProto {
// Eg: 'address' for http://schema.org/Place.
optional string property_name = 1;
+ // OPTIONAL: A natural language description of the property.
+ //
+ // This string is not used by Icing in any way. It simply exists to allow
+ // users to store semantic information about the PropertyConfigProto for
+ // future retrieval.
+ optional string description = 9;
+
// REQUIRED: Physical data-types of the contents of the property.
message DataType {
enum Code {
@@ -238,6 +272,9 @@ message PropertyConfigProto {
// a hierarchical Document schema. Any property using this data_type
// MUST have a valid 'schema_type'.
DOCUMENT = 6;
+
+ // A list of floats. Vector type is used for embedding searches.
+ VECTOR = 7;
}
}
optional DataType.Code data_type = 2;
@@ -299,6 +336,10 @@ message PropertyConfigProto {
// - The property itself and any upper-level (nested doc) property should
// contain at most one element (i.e. Cardinality is OPTIONAL or REQUIRED).
optional JoinableConfig joinable_config = 8;
+
+ // OPTIONAL: Describes how vector properties should be indexed. Vector
+ // properties that do not set the indexing config will not be indexed.
+ optional EmbeddingIndexingConfig embedding_indexing_config = 10;
}
// List of all supported types constitutes the schema used by Icing.
diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto
index 7f4fb3e..7e145ce 100644
--- a/proto/icing/proto/search.proto
+++ b/proto/icing/proto/search.proto
@@ -27,7 +27,7 @@ option java_multiple_files = true;
option objc_class_prefix = "ICNG";
// Client-supplied specifications on what documents to retrieve.
-// Next tag: 11
+// Next tag: 13
message SearchSpecProto {
// REQUIRED: The "raw" query string that users may type. For example, "cat"
// will search for documents with the term cat in it.
@@ -112,6 +112,22 @@ message SearchSpecProto {
// (TypePropertyMask for the given schema type has empty paths field), no
// properties of that schema type will be searched.
repeated TypePropertyMask type_property_filters = 10;
+
+ // The vectors to be used in embedding queries.
+ repeated PropertyProto.VectorProto embedding_query_vectors = 11;
+
+ message EmbeddingQueryMetricType {
+ enum Code {
+ UNKNOWN = 0;
+ COSINE = 1;
+ DOT_PRODUCT = 2;
+ EUCLIDEAN = 3;
+ }
+ }
+
+ // The default metric type used to calculate the scores for embedding
+ // queries.
+ optional EmbeddingQueryMetricType.Code embedding_query_metric_type = 12;
}
// Client-supplied specifications on what to include/how to format the search
diff --git a/proto/icing/proto/storage.proto b/proto/icing/proto/storage.proto
index 39dab6b..e0323a1 100644
--- a/proto/icing/proto/storage.proto
+++ b/proto/icing/proto/storage.proto
@@ -22,6 +22,8 @@ option java_package = "com.google.android.icing.proto";
option java_multiple_files = true;
option objc_class_prefix = "ICNG";
+// TODO(b/305098009): fix byte size vs size naming issue.
+
// Next tag: 10
message NamespaceStorageInfoProto {
// Name of the namespace
diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt
index dd08fd1..55c4647 100644
--- a/synced_AOSP_CL_number.txt
+++ b/synced_AOSP_CL_number.txt
@@ -1 +1 @@
-set(synced_AOSP_CL_number=587883838)
+set(synced_AOSP_CL_number=616925123)