diff options
author | Krzysztof KosiĆski <krzysio@google.com> | 2024-03-21 15:33:55 +0000 |
---|---|---|
committer | Alexander Dorokhine <adorokhine@google.com> | 2024-03-21 16:14:28 +0000 |
commit | 15170523d0b603a3fc2729695ce4d9740ce5a85e (patch) | |
tree | 693704d89a2b49ffda814b0d81b3e9541f48a740 | |
parent | 42996c97b96f0da75543f0fee670f9e8cc595744 (diff) | |
parent | 555cb6e3295cf525baf46358235389ff52c9dcc2 (diff) | |
download | icing-15170523d0b603a3fc2729695ce4d9740ce5a85e.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master'
Change descriptions:
========================================================================
Integration test for embedding search
========================================================================
Support embedding search in the advanced scoring language
========================================================================
Support embedding search in the advanced query language
========================================================================
Add missing header inclusions.
========================================================================
Assign default values to MemoryMappedFile members to avoid temporary object error during std::move
========================================================================
Use GetEmbeddingVector in EmbeddingIndex::TransferIndex
========================================================================
Refactor SectionRestrictData
========================================================================
Add EmbeddingIndex to IcingSearchEngine and create EmbeddingIndexingHandler
========================================================================
Support optimize for embedding index
========================================================================
Implement embedding search index
========================================================================
Branch posting list accessor and serializer to store embedding hits
========================================================================
Introduce embedding hit
========================================================================
Update schema and document definition
========================================================================
Add tests verifying that Icing will correctly save updated schema description fields.
========================================================================
Allow adding duplicate hits with the same value into the posting list.
========================================================================
Support an extension to the encoded Hit in Icing's posting list.
========================================================================
Adds description fields to SchemaTypeConfigProto and PropertyConfigProto
========================================================================
Support polymorphism in type property filters
========================================================================
Fix posting list GetMinPostingListSizeToFit size calculation bug
========================================================================
Add instructions to error message for advanced query backward compat.
========================================================================
[Trunk Stable Flags and Files Rebuild #1] Implement v2 version file in version-util
========================================================================
[Trunk Stable Flags and Files Rebuild #2] Integrate v2 version file with IcingSearchEngine
========================================================================
Remove dead code in index initialization for HasPropertyOperator after introducing v2 version file
========================================================================
[Icing version 4] Bump Icing kVersion to 4 for Android V
========================================================================
[Icing][Expand QueryStats][4/x] Add lite index hit buffer info into QueryStats
========================================================================
[Icing][Expand QueryStats][5/x] Add query processor advanced query
components latencies into QueryStats::SearchStats
========================================================================
Add instructions to error message for AppSearch advanced query features backward compatibility
========================================================================
BUG: 326987971
BUG: 326656531
BUG: 329747255
BUG: 294274922
BUG: 321107391
BUG: 324908653
BUG: 326987971
Bug: 314816301
Bug: 309826655
Bug: 305098009
Bug: 329747255
Test: unit test
Change-Id: I6f35b52114181d9dfded41cfb3949f062876a2e2
127 files changed, 13123 insertions, 1871 deletions
@@ -12,6 +12,6 @@ third_party { type: PIPER value: "http://google3/third_party/icing/" } - last_upgrade_date { year: 2019 month: 12 day: 20 } + last_upgrade_date { year: 2024 month: 3 day: 18 } license_type: NOTICE } diff --git a/icing/document-builder.h b/icing/document-builder.h index 44500f9..5d6f14d 100644 --- a/icing/document-builder.h +++ b/icing/document-builder.h @@ -126,6 +126,13 @@ class DocumentBuilder { return AddDocumentProperty(std::move(property_name), {document_values...}); } + // Takes a property name and any number of vector values. + template <typename... V> + DocumentBuilder& AddVectorProperty(std::string property_name, + V... vector_values) { + return AddVectorProperty(std::move(property_name), {vector_values...}); + } + DocumentProto Build() const { return document_; } private: @@ -183,6 +190,17 @@ class DocumentBuilder { } return *this; } + + DocumentBuilder& AddVectorProperty( + std::string property_name, + std::initializer_list<PropertyProto::VectorProto> vector_values) { + auto property = document_.add_properties(); + property->set_name(std::move(property_name)); + for (PropertyProto::VectorProto vector_value : vector_values) { + property->mutable_vector_values()->Add(std::move(vector_value)); + } + return *this; + } }; } // namespace lib diff --git a/icing/file/memory-mapped-file.h b/icing/file/memory-mapped-file.h index 54507af..185d940 100644 --- a/icing/file/memory-mapped-file.h +++ b/icing/file/memory-mapped-file.h @@ -79,7 +79,6 @@ #include <algorithm> #include <cstdint> -#include <memory> #include <string> #include <string_view> @@ -314,7 +313,7 @@ class MemoryMappedFile { // Cached constructor params. const Filesystem* filesystem_; std::string file_path_; - Strategy strategy_; + Strategy strategy_ = Strategy::READ_WRITE_AUTO_SYNC; // Raw file related fields: // - max_file_size_ @@ -327,7 +326,7 @@ class MemoryMappedFile { // // Note: max_file_size_ will be specified in runtime and the caller should // make sure its value is correct and reasonable. - int64_t max_file_size_; + int64_t max_file_size_ = 0; // Cached file size to avoid calling system call too frequently. It is only // used in GrowAndRemapIfNecessary(), the new API that handles underlying file @@ -336,7 +335,7 @@ class MemoryMappedFile { // Note: it is guaranteed that file_size_ is smaller or equal to the actual // file size as long as the underlying file hasn't been truncated or deleted // externally. See GrowFileSize() for more details. - int64_t file_size_; + int64_t file_size_ = 0; // Memory mapped related fields: // - mmap_result_ @@ -345,22 +344,22 @@ class MemoryMappedFile { // - mmap_size_ // Raw pointer (or error) returned by calls to mmap(). - void* mmap_result_; + void* mmap_result_ = nullptr; // Offset within the file at which the current memory-mapped region starts. - int64_t file_offset_; + int64_t file_offset_ = 0; // Size that is currently memory-mapped. // Note that the mmapped size can be larger than the underlying file size. We // can reduce remapping by pre-mmapping a large memory and grow the file size // later. See GrowAndRemapIfNecessary(). - int64_t mmap_size_; + int64_t mmap_size_ = 0; // The difference between file_offset_ and the actual adjusted (aligned) // offset. // Since mmap requires the offset to be a multiple of system page size, we // have to align file_offset_ to the last multiple of system page size. - int64_t alignment_adjustment_; + int64_t alignment_adjustment_ = 0; // E.g. system_page_size = 5, RemapImpl(/*new_file_offset=*/8, mmap_size) // diff --git a/icing/file/posting_list/flash-index-storage-header.h b/icing/file/posting_list/flash-index-storage-header.h index 6bbf1ba..f7b331c 100644 --- a/icing/file/posting_list/flash-index-storage-header.h +++ b/icing/file/posting_list/flash-index-storage-header.h @@ -33,7 +33,7 @@ class HeaderBlock { // The class used to access the actual header. struct Header { // A magic used to mark the beginning of a valid header. - static constexpr int kMagic = 0xb0780cf4; + static constexpr int kMagic = 0xd1b7b293; int magic; int block_size; int last_indexed_docid; diff --git a/icing/file/posting_list/flash-index-storage_test.cc b/icing/file/posting_list/flash-index-storage_test.cc index ef60037..203041e 100644 --- a/icing/file/posting_list/flash-index-storage_test.cc +++ b/icing/file/posting_list/flash-index-storage_test.cc @@ -16,9 +16,9 @@ #include <unistd.h> -#include <algorithm> -#include <cstdlib> -#include <limits> +#include <cstdint> +#include <memory> +#include <string> #include <utility> #include <vector> @@ -27,6 +27,7 @@ #include "gtest/gtest.h" #include "icing/file/filesystem.h" #include "icing/file/posting_list/flash-index-storage-header.h" +#include "icing/file/posting_list/posting-list-identifier.h" #include "icing/index/hit/hit.h" #include "icing/index/main/posting-list-hit-serializer.h" #include "icing/store/document-id.h" @@ -213,10 +214,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits1 = { - Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19), - Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100), - Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)}; + Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits1) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder1.posting_list, hit)); @@ -237,10 +242,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits2 = { - Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19), - Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100), - Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)}; + Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits2) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder2.posting_list, hit)); @@ -273,10 +282,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemory) { EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list), IsOkAndHolds(IsEmpty())); std::vector<Hit> hits3 = { - Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62), - Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45), - Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12), - Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)}; + Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits3) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder3.posting_list, hit)); @@ -314,10 +327,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits1 = { - Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19), - Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100), - Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)}; + Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits1) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder1.posting_list, hit)); @@ -338,10 +355,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits2 = { - Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19), - Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100), - Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)}; + Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits2) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder2.posting_list, hit)); @@ -374,10 +395,14 @@ TEST_F(FlashIndexStorageTest, FreeListNotInMemory) { EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list), IsOkAndHolds(IsEmpty())); std::vector<Hit> hits3 = { - Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62), - Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45), - Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12), - Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)}; + Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits3) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder3.posting_list, hit)); @@ -417,10 +442,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits1 = { - Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19), - Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100), - Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)}; + Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits1) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder1.posting_list, hit)); @@ -441,10 +470,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits2 = { - Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19), - Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100), - Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)}; + Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits2) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder2.posting_list, hit)); @@ -492,10 +525,14 @@ TEST_F(FlashIndexStorageTest, FreeListInMemoryPersistence) { EXPECT_THAT(serializer_->GetHits(&posting_list_holder3.posting_list), IsOkAndHolds(IsEmpty())); std::vector<Hit> hits3 = { - Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62), - Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45), - Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12), - Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74)}; + Hit(/*section_id=*/7, /*document_id=*/1, /*term_frequency=*/62, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/12, /*document_id=*/3, /*term_frequency=*/45, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/11, /*document_id=*/18, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/7, /*document_id=*/100, /*term_frequency=*/74, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits3) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder3.posting_list, hit)); @@ -534,10 +571,14 @@ TEST_F(FlashIndexStorageTest, DifferentSizedPostingLists) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits1 = { - Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19), - Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100), - Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197)}; + Hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/2, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/2, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/5, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits1) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder1.posting_list, hit)); @@ -561,10 +602,14 @@ TEST_F(FlashIndexStorageTest, DifferentSizedPostingLists) { EXPECT_THAT(flash_index_storage.empty(), IsFalse()); std::vector<Hit> hits2 = { - Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12), - Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19), - Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100), - Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197)}; + Hit(/*section_id=*/4, /*document_id=*/0, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/8, /*document_id=*/4, /*term_frequency=*/19, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/9, /*document_id=*/7, /*term_frequency=*/100, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/6, /*document_id=*/7, /*term_frequency=*/197, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false)}; for (const Hit& hit : hits2) { ICING_ASSERT_OK( serializer_->PrependHit(&posting_list_holder2.posting_list, hit)); diff --git a/icing/file/posting_list/index-block_test.cc b/icing/file/posting_list/index-block_test.cc index ebc9ba4..d841e79 100644 --- a/icing/file/posting_list/index-block_test.cc +++ b/icing/file/posting_list/index-block_test.cc @@ -14,11 +14,16 @@ #include "icing/file/posting_list/index-block.h" +#include <memory> +#include <string> +#include <vector> + #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/file/filesystem.h" -#include "icing/file/posting_list/posting-list-used.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/index/hit/hit.h" #include "icing/index/main/posting-list-hit-serializer.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -67,7 +72,7 @@ class IndexBlockTest : public ::testing::Test { }; TEST_F(IndexBlockTest, CreateFromUninitializedRegionProducesEmptyBlock) { - constexpr int kPostingListBytes = 20; + constexpr int kPostingListBytes = 24; { // Create an IndexBlock from this newly allocated file block. @@ -80,7 +85,7 @@ TEST_F(IndexBlockTest, CreateFromUninitializedRegionProducesEmptyBlock) { } TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) { - constexpr int kPostingListBytes1 = 20; + constexpr int kPostingListBytes1 = 24; // Create an IndexBlock from this newly allocated file block. ICING_ASSERT_OK_AND_ASSIGN(IndexBlock block, @@ -88,13 +93,13 @@ TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) { &filesystem_, serializer_.get(), sfd_->get(), /*offset=*/0, kBlockSize, kPostingListBytes1)); EXPECT_THAT(block.posting_list_bytes(), Eq(kPostingListBytes1)); - // There should be (4096 - 12) / 20 = 204 posting lists - // (sizeof(BlockHeader)==12). We can store a PostingListIndex of 203 in only 8 + // There should be (4096 - 12) / 24 = 170 posting lists + // (sizeof(BlockHeader)==12). We can store a PostingListIndex of 170 in only 8 // bits. - EXPECT_THAT(block.max_num_posting_lists(), Eq(204)); + EXPECT_THAT(block.max_num_posting_lists(), Eq(170)); EXPECT_THAT(block.posting_list_index_bits(), Eq(8)); - constexpr int kPostingListBytes2 = 200; + constexpr int kPostingListBytes2 = 240; // Create an IndexBlock from this newly allocated file block. ICING_ASSERT_OK_AND_ASSIGN( @@ -102,22 +107,27 @@ TEST_F(IndexBlockTest, SizeAccessorsWorkCorrectly) { &filesystem_, serializer_.get(), sfd_->get(), /*offset=*/0, kBlockSize, kPostingListBytes2)); EXPECT_THAT(block.posting_list_bytes(), Eq(kPostingListBytes2)); - // There should be (4096 - 12) / 200 = 20 posting lists + // There should be (4096 - 12) / 240 = 17 posting lists // (sizeof(BlockHeader)==12). We can store a PostingListIndex of 19 in only 5 // bits. - EXPECT_THAT(block.max_num_posting_lists(), Eq(20)); + EXPECT_THAT(block.max_num_posting_lists(), Eq(17)); EXPECT_THAT(block.posting_list_index_bits(), Eq(5)); } TEST_F(IndexBlockTest, IndexBlockChangesPersistAcrossInstances) { - constexpr int kPostingListBytes = 2000; + constexpr int kPostingListBytes = 2004; std::vector<Hit> test_hits{ - Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99), - Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17), - Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency), + Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; PostingListIndex allocated_index; { @@ -158,21 +168,31 @@ TEST_F(IndexBlockTest, IndexBlockChangesPersistAcrossInstances) { } TEST_F(IndexBlockTest, IndexBlockMultiplePostingLists) { - constexpr int kPostingListBytes = 2000; + constexpr int kPostingListBytes = 2004; std::vector<Hit> hits_in_posting_list1{ - Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99), - Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17), - Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency), + Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; std::vector<Hit> hits_in_posting_list2{ - Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88), - Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2), - Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12), - Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency), + Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; PostingListIndex allocated_index_1; PostingListIndex allocated_index_2; @@ -242,7 +262,7 @@ TEST_F(IndexBlockTest, IndexBlockMultiplePostingLists) { } TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) { - constexpr int kPostingListBytes = 2000; + constexpr int kPostingListBytes = 2004; // Create an IndexBlock from this newly allocated file block. ICING_ASSERT_OK_AND_ASSIGN(IndexBlock block, @@ -252,11 +272,16 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) { // Add hits to the first posting list. std::vector<Hit> hits_in_posting_list1{ - Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99), - Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17), - Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency), + Hit(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/5, /*document_id=*/1, /*term_frequency=*/99, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/3, /*document_id=*/3, /*term_frequency=*/17, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/10, /*document_id=*/10, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_1, block.AllocatePostingList()); @@ -270,11 +295,16 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) { // Add hits to the second posting list. std::vector<Hit> hits_in_posting_list2{ - Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88), - Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2), - Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12), - Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency), + Hit(/*section_id=*/12, /*document_id=*/220, /*term_frequency=*/88, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/17, /*document_id=*/265, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/0, /*document_id=*/287, /*term_frequency=*/2, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/11, /*document_id=*/306, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/10, /*document_id=*/306, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_2, block.AllocatePostingList()); @@ -296,9 +326,12 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) { EXPECT_THAT(block.HasFreePostingLists(), IsOkAndHolds(IsTrue())); std::vector<Hit> hits_in_posting_list3{ - Hit(/*section_id=*/12, /*document_id=*/0, /*term_frequency=*/88), - Hit(/*section_id=*/17, /*document_id=*/1, Hit::kDefaultTermFrequency), - Hit(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/2), + Hit(/*section_id=*/12, /*document_id=*/0, /*term_frequency=*/88, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/17, /*document_id=*/1, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), + Hit(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/2, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false), }; ICING_ASSERT_OK_AND_ASSIGN(IndexBlock::PostingListAndBlockInfo alloc_info_3, block.AllocatePostingList()); @@ -317,7 +350,7 @@ TEST_F(IndexBlockTest, IndexBlockReallocatingPostingLists) { } TEST_F(IndexBlockTest, IndexBlockNextBlockIndex) { - constexpr int kPostingListBytes = 2000; + constexpr int kPostingListBytes = 2004; constexpr int kSomeBlockIndex = 22; { diff --git a/icing/file/version-util.cc b/icing/file/version-util.cc index dd233e0..e750b0c 100644 --- a/icing/file/version-util.cc +++ b/icing/file/version-util.cc @@ -15,36 +15,47 @@ #include "icing/file/version-util.h" #include <cstdint> +#include <memory> #include <string> +#include <string_view> +#include <unordered_set> #include <utility> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/file-backed-proto.h" #include "icing/file/filesystem.h" #include "icing/index/index.h" +#include "icing/proto/initialize.pb.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { namespace version_util { -libtextclassifier3::StatusOr<VersionInfo> ReadVersion( - const Filesystem& filesystem, const std::string& version_file_path, +namespace { + +libtextclassifier3::StatusOr<VersionInfo> ReadV1VersionInfo( + const Filesystem& filesystem, const std::string& version_file_dir, const std::string& index_base_dir) { // 1. Read the version info. + const std::string v1_version_filepath = + MakeVersionFilePath(version_file_dir, kVersionFilenameV1); VersionInfo existing_version_info(-1, -1); - if (filesystem.FileExists(version_file_path.c_str()) && - !filesystem.PRead(version_file_path.c_str(), &existing_version_info, + if (filesystem.FileExists(v1_version_filepath.c_str()) && + !filesystem.PRead(v1_version_filepath.c_str(), &existing_version_info, sizeof(VersionInfo), /*offset=*/0)) { - return absl_ports::InternalError("Fail to read version"); + return absl_ports::InternalError("Failed to read v1 version file"); } // 2. Check the Index magic to see if we're actually on version 0. - libtextclassifier3::StatusOr<int> existing_flash_index_magic_or = + libtextclassifier3::StatusOr<int> existing_flash_index_magic = Index::ReadFlashIndexMagic(&filesystem, index_base_dir); - if (!existing_flash_index_magic_or.ok()) { - if (absl_ports::IsNotFound(existing_flash_index_magic_or.status())) { + if (!existing_flash_index_magic.ok()) { + if (absl_ports::IsNotFound(existing_flash_index_magic.status())) { // Flash index magic doesn't exist. In this case, we're unable to // determine the version change state correctly (regardless of the // existence of the version file), so invalidate VersionInfo by setting @@ -53,9 +64,9 @@ libtextclassifier3::StatusOr<VersionInfo> ReadVersion( return existing_version_info; } // Real error. - return std::move(existing_flash_index_magic_or).status(); + return std::move(existing_flash_index_magic).status(); } - if (existing_flash_index_magic_or.ValueOrDie() == + if (existing_flash_index_magic.ValueOrDie() == kVersionZeroFlashIndexMagic) { existing_version_info.version = 0; if (existing_version_info.max_version == -1) { @@ -66,15 +77,124 @@ libtextclassifier3::StatusOr<VersionInfo> ReadVersion( return existing_version_info; } -libtextclassifier3::Status WriteVersion(const Filesystem& filesystem, - const std::string& version_file_path, - const VersionInfo& version_info) { - ScopedFd scoped_fd(filesystem.OpenForWrite(version_file_path.c_str())); +libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadV2VersionInfo( + const Filesystem& filesystem, const std::string& version_file_dir) { + // Read the v2 version file. V2 version file stores the + // IcingSearchEngineVersionProto as a file-backed proto. + const std::string v2_version_filepath = + MakeVersionFilePath(version_file_dir, kVersionFilenameV2); + FileBackedProto<IcingSearchEngineVersionProto> v2_version_file( + filesystem, v2_version_filepath); + ICING_ASSIGN_OR_RETURN(const IcingSearchEngineVersionProto* v2_version_proto, + v2_version_file.Read()); + + return *v2_version_proto; +} + +} // namespace + +libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadVersion( + const Filesystem& filesystem, const std::string& version_file_dir, + const std::string& index_base_dir) { + // 1. Read the v1 version file + ICING_ASSIGN_OR_RETURN( + VersionInfo v1_version_info, + ReadV1VersionInfo(filesystem, version_file_dir, index_base_dir)); + if (!v1_version_info.IsValid()) { + // This happens if IcingLib's state is invalid (e.g. flash index header file + // is missing). Return the invalid version numbers in this case. + IcingSearchEngineVersionProto version_proto; + version_proto.set_version(v1_version_info.version); + version_proto.set_max_version(v1_version_info.max_version); + return version_proto; + } + + // 2. Read the v2 version file + auto v2_version_proto = ReadV2VersionInfo(filesystem, version_file_dir); + if (!v2_version_proto.ok()) { + if (!absl_ports::IsNotFound(v2_version_proto.status())) { + // Real error. + return std::move(v2_version_proto).status(); + } + // The v2 version file has not been written + IcingSearchEngineVersionProto version_proto; + if (v1_version_info.version < kFirstV2Version) { + // There are two scenarios for this case: + // 1. It's the first time that we're upgrading from a lower version to a + // version >= kFirstV2Version. + // - It's expected that the v2 version file has not been written yet in + // this case and we return the v1 version numbers instead. + // 2. We're rolling forward from a version < kFirstV2Version, after + // rolling back from a previous version >= kFirstV2Version, and for + // some unknown reason we lost the v2 version file in the previous + // version. + // - e.g. version #4 -> version #1 -> version #4, but we lost the v2 + // file during version #1. + // - This is a rollforward case, but it's still fine to return the v1 + // version number here as ShouldRebuildDerivedFiles can handle + // rollforwards correctly. + version_proto.set_version(v1_version_info.version); + version_proto.set_max_version(v1_version_info.max_version); + } else { + // Something weird has happened. During last initialization we were + // already on a version >= kFirstV2Version, so the v2 version file + // should have been written. + // Return an invalid version number in this case and trigger rebuilding + // everything. + version_proto.set_version(-1); + version_proto.set_max_version(v1_version_info.max_version); + } + return version_proto; + } + + // 3. Check if versions match. If not, it means that we're rolling forward + // from a version < kFirstV2Version. In order to trigger rebuilding + // everything, we return an invalid version number in this case. + IcingSearchEngineVersionProto v2_version_proto_value = + std::move(v2_version_proto).ValueOrDie(); + if (v1_version_info.version != v2_version_proto_value.version()) { + v2_version_proto_value.set_version(-1); + v2_version_proto_value.mutable_enabled_features()->Clear(); + } + + return v2_version_proto_value; +} + +libtextclassifier3::Status WriteV1Version(const Filesystem& filesystem, + const std::string& version_file_dir, + const VersionInfo& version_info) { + ScopedFd scoped_fd(filesystem.OpenForWrite( + MakeVersionFilePath(version_file_dir, kVersionFilenameV1).c_str())); if (!scoped_fd.is_valid() || !filesystem.PWrite(scoped_fd.get(), /*offset=*/0, &version_info, sizeof(VersionInfo)) || !filesystem.DataSync(scoped_fd.get())) { - return absl_ports::InternalError("Fail to write version"); + return absl_ports::InternalError("Failed to write v1 version file"); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status WriteV2Version( + const Filesystem& filesystem, const std::string& version_file_dir, + std::unique_ptr<IcingSearchEngineVersionProto> version_proto) { + FileBackedProto<IcingSearchEngineVersionProto> v2_version_file( + filesystem, MakeVersionFilePath(version_file_dir, kVersionFilenameV2)); + libtextclassifier3::Status v2_write_status = + v2_version_file.Write(std::move(version_proto)); + if (!v2_write_status.ok()) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to write v2 version file: ", v2_write_status.error_message())); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status DiscardVersionFiles( + const Filesystem& filesystem, std::string_view version_file_dir) { + if (!filesystem.DeleteFile( + MakeVersionFilePath(version_file_dir, kVersionFilenameV1).c_str()) || + !filesystem.DeleteFile( + MakeVersionFilePath(version_file_dir, kVersionFilenameV2).c_str())) { + return absl_ports::InternalError("Failed to discard version files"); } return libtextclassifier3::Status::OK; } @@ -102,6 +222,65 @@ StateChange GetVersionStateChange(const VersionInfo& existing_version_info, } } +DerivedFilesRebuildResult CalculateRequiredDerivedFilesRebuild( + const IcingSearchEngineVersionProto& prev_version_proto, + const IcingSearchEngineVersionProto& curr_version_proto) { + // 1. Do version check using version and max_version numbers + if (ShouldRebuildDerivedFiles(GetVersionInfoFromProto(prev_version_proto), + curr_version_proto.version())) { + return DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true); + } + + // 2. Compare the previous enabled features with the current enabled features + // and rebuild if there are differences. + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + prev_features; + for (const auto& feature : prev_version_proto.enabled_features()) { + prev_features.insert(feature.feature_type()); + } + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + curr_features; + for (const auto& feature : curr_version_proto.enabled_features()) { + curr_features.insert(feature.feature_type()); + } + DerivedFilesRebuildResult result; + for (const auto& prev_feature : prev_features) { + // If there is an UNKNOWN feature in the previous feature set (note that we + // never use UNKNOWN when writing the version proto), it means that: + // - The previous version proto contains a feature enum that is only defined + // in a newer version. + // - We've now rolled back to an old version that doesn't understand this + // new enum value, and proto serialization defaults it to 0 (UNKNOWN). + // - In this case we need to rebuild everything. + if (prev_feature == IcingSearchEngineFeatureInfoProto::UNKNOWN) { + return DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true); + } + if (curr_features.find(prev_feature) == curr_features.end()) { + DerivedFilesRebuildResult required_rebuilds = + GetFeatureDerivedFilesRebuildResult(prev_feature); + result.CombineWithOtherRebuildResultOr(required_rebuilds); + } + } + for (const auto& curr_feature : curr_features) { + if (prev_features.find(curr_feature) == prev_features.end()) { + DerivedFilesRebuildResult required_rebuilds = + GetFeatureDerivedFilesRebuildResult(curr_feature); + result.CombineWithOtherRebuildResultOr(required_rebuilds); + } + } + return result; +} + bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info, int32_t curr_version) { StateChange state_change = @@ -135,6 +314,10 @@ bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info, // version 2 -> version 3 upgrade, no need to rebuild break; } + case 3: { + // version 3 -> version 4 upgrade, no need to rebuild + break; + } default: // This should not happen. Rebuild anyway if unsure. should_rebuild |= true; @@ -144,6 +327,56 @@ bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info, return should_rebuild; } +DerivedFilesRebuildResult GetFeatureDerivedFilesRebuildResult( + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature) { + switch (feature) { + case IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR: { + return DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false); + } + case IcingSearchEngineFeatureInfoProto::UNKNOWN: + return DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true); + } +} + +IcingSearchEngineFeatureInfoProto GetFeatureInfoProto( + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature) { + IcingSearchEngineFeatureInfoProto info; + info.set_feature_type(feature); + + DerivedFilesRebuildResult result = + GetFeatureDerivedFilesRebuildResult(feature); + info.set_needs_document_store_rebuild( + result.needs_document_store_derived_files_rebuild); + info.set_needs_schema_store_rebuild( + result.needs_schema_store_derived_files_rebuild); + info.set_needs_term_index_rebuild(result.needs_term_index_rebuild); + info.set_needs_integer_index_rebuild(result.needs_integer_index_rebuild); + info.set_needs_qualified_id_join_index_rebuild( + result.needs_qualified_id_join_index_rebuild); + + return info; +} + +void AddEnabledFeatures(const IcingSearchEngineOptions& options, + IcingSearchEngineVersionProto* version_proto) { + auto* enabled_features = version_proto->mutable_enabled_features(); + // HasPropertyOperator feature + if (options.build_property_existence_metadata_hits()) { + enabled_features->Add(GetFeatureInfoProto( + IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR)); + } +} + } // namespace version_util } // namespace lib diff --git a/icing/file/version-util.h b/icing/file/version-util.h index b2d51df..feadaf6 100644 --- a/icing/file/version-util.h +++ b/icing/file/version-util.h @@ -16,11 +16,15 @@ #define ICING_FILE_VERSION_UTIL_H_ #include <cstdint> +#include <memory> #include <string> +#include <string_view> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/str_cat.h" #include "icing/file/filesystem.h" +#include "icing/proto/initialize.pb.h" namespace icing { namespace lib { @@ -32,16 +36,26 @@ namespace version_util { // - Version 2: M-2023-09, M-2023-11, M-2024-01. Schema is compatible with v1. // (There were no M-2023-10, M-2023-12). // - Version 3: M-2024-02. Schema is compatible with v1 and v2. +// - Version 4: Android V base. Schema is compatible with v1, v2 and v3. // +// TODO(b/314816301): Bump kVersion to 4 for Android V rollout with v2 version +// detection // LINT.IfChange(kVersion) -inline static constexpr int32_t kVersion = 3; +inline static constexpr int32_t kVersion = 4; // LINT.ThenChange(//depot/google3/icing/schema/schema-store.cc:min_overlay_version_compatibility) inline static constexpr int32_t kVersionOne = 1; inline static constexpr int32_t kVersionTwo = 2; inline static constexpr int32_t kVersionThree = 3; +inline static constexpr int32_t kVersionFour = 4; + +// Version at which v2 version file is introduced. +inline static constexpr int32_t kFirstV2Version = kVersionFour; inline static constexpr int kVersionZeroFlashIndexMagic = 0x6dfba6ae; +inline static constexpr std::string_view kVersionFilenameV1 = "version"; +inline static constexpr std::string_view kVersionFilenameV2 = "version2"; + struct VersionInfo { int32_t version; int32_t max_version; @@ -67,28 +81,144 @@ enum class StateChange { kVersionZeroRollForward, }; -// Helper method to read version info (using version file and flash index header -// magic) from the existing data. If the state is invalid (e.g. flash index -// header file is missing), then return an invalid VersionInfo. +// Contains information about which derived files need to be rebuild. +// +// These flags only reflect whether each component should be rebuilt, but do not +// handle any dependencies. The caller should handle the dependencies by +// themselves. +// e.g. - qualified id join index depends on document store derived files, but +// it's possible to have needs_document_store_derived_files_rebuild = +// true and needs_qualified_id_join_index_rebuild = false. +// - The caller should know that join index should also be rebuilt in this +// case even though needs_qualified_id_join_index_rebuild = false. +struct DerivedFilesRebuildResult { + bool needs_document_store_derived_files_rebuild = false; + bool needs_schema_store_derived_files_rebuild = false; + bool needs_term_index_rebuild = false; + bool needs_integer_index_rebuild = false; + bool needs_qualified_id_join_index_rebuild = false; + + DerivedFilesRebuildResult() = default; + + explicit DerivedFilesRebuildResult( + bool needs_document_store_derived_files_rebuild_in, + bool needs_schema_store_derived_files_rebuild_in, + bool needs_term_index_rebuild_in, bool needs_integer_index_rebuild_in, + bool needs_qualified_id_join_index_rebuild_in) + : needs_document_store_derived_files_rebuild( + needs_document_store_derived_files_rebuild_in), + needs_schema_store_derived_files_rebuild( + needs_schema_store_derived_files_rebuild_in), + needs_term_index_rebuild(needs_term_index_rebuild_in), + needs_integer_index_rebuild(needs_integer_index_rebuild_in), + needs_qualified_id_join_index_rebuild( + needs_qualified_id_join_index_rebuild_in) {} + + bool IsRebuildNeeded() const { + return needs_document_store_derived_files_rebuild || + needs_schema_store_derived_files_rebuild || + needs_term_index_rebuild || needs_integer_index_rebuild || + needs_qualified_id_join_index_rebuild; + } + + bool operator==(const DerivedFilesRebuildResult& other) const { + return needs_document_store_derived_files_rebuild == + other.needs_document_store_derived_files_rebuild && + needs_schema_store_derived_files_rebuild == + other.needs_schema_store_derived_files_rebuild && + needs_term_index_rebuild == other.needs_term_index_rebuild && + needs_integer_index_rebuild == other.needs_integer_index_rebuild && + needs_qualified_id_join_index_rebuild == + other.needs_qualified_id_join_index_rebuild; + } + + void CombineWithOtherRebuildResultOr(const DerivedFilesRebuildResult& other) { + needs_document_store_derived_files_rebuild = + needs_document_store_derived_files_rebuild || + other.needs_document_store_derived_files_rebuild; + needs_schema_store_derived_files_rebuild = + needs_schema_store_derived_files_rebuild || + other.needs_schema_store_derived_files_rebuild; + needs_term_index_rebuild = + needs_term_index_rebuild || other.needs_term_index_rebuild; + needs_integer_index_rebuild = + needs_integer_index_rebuild || other.needs_integer_index_rebuild; + needs_qualified_id_join_index_rebuild = + needs_qualified_id_join_index_rebuild || + other.needs_qualified_id_join_index_rebuild; + } +}; + +// There are two icing version files: +// 1. V1 version file contains version and max_version info of the existing +// data. +// 2. V2 version file writes the version information using +// FileBackedProto<IcingSearchEngineVersionProto>. This contains information +// about the version's enabled trunk stable features in addition to the +// version numbers written for V1. +// +// Both version files must be written to maintain backwards compatibility. +inline std::string MakeVersionFilePath(std::string_view version_file_dir, + std::string_view version_file_name) { + return absl_ports::StrCat(version_file_dir, "/", version_file_name); +} + +// Returns a VersionInfo from a given IcingSearchEngineVersionProto. +inline VersionInfo GetVersionInfoFromProto( + const IcingSearchEngineVersionProto& version_proto) { + return VersionInfo(version_proto.version(), version_proto.max_version()); +} + + +// Reads the IcingSearchEngineVersionProto from the version files of the +// existing data. +// +// This method reads both the v1 and v2 version files, and returns the v1 +// version numbers in the absence of the v2 version file. If there is a mismatch +// between the v1 and v2 version numbers, or if the state is invalid (e.g. flash +// index header file is missing), then an invalid VersionInfo is returned. // // RETURNS: -// - Existing data's VersionInfo on success +// - Existing data's IcingSearchEngineVersionProto on success // - INTERNAL_ERROR on I/O errors -libtextclassifier3::StatusOr<VersionInfo> ReadVersion( - const Filesystem& filesystem, const std::string& version_file_path, +libtextclassifier3::StatusOr<IcingSearchEngineVersionProto> ReadVersion( + const Filesystem& filesystem, const std::string& version_file_dir, const std::string& index_base_dir); -// Helper method to write version file. +// Writes the v1 version file. V1 version file is written for all versions and +// contains only Icing's VersionInfo (version number and max_version) // // RETURNS: // - OK on success // - INTERNAL_ERROR on I/O errors -libtextclassifier3::Status WriteVersion(const Filesystem& filesystem, - const std::string& version_file_path, - const VersionInfo& version_info); +libtextclassifier3::Status WriteV1Version(const Filesystem& filesystem, + const std::string& version_file_dir, + const VersionInfo& version_info); -// Helper method to determine the change state between the existing data version -// and the current code version. +// Writes the v2 version file. V2 version file writes the version information +// using FileBackedProto<IcingSearchEngineVersionProto>. +// +// REQUIRES: version_proto.version >= kFirstV2Version. We implement v2 version +// checking in kFirstV2Version, so callers will always use a version # greater +// than this. +// +// RETURNS: +// - OK on success +// - INTERNAL_ERROR on I/O errors +libtextclassifier3::Status WriteV2Version( + const Filesystem& filesystem, const std::string& version_file_dir, + std::unique_ptr<IcingSearchEngineVersionProto> version_proto); + +// Deletes Icing's version files from version_file_dir. +// +// Returns: +// - OK on success +// - INTERNAL_ERROR on I/O error +libtextclassifier3::Status DiscardVersionFiles( + const Filesystem& filesystem, std::string_view version_file_dir); + +// Determines the change state between the existing data version and the current +// code version. // // REQUIRES: curr_version > 0. We implement version checking in version 1, so // the callers (except unit tests) will always use a version # greater than 0. @@ -97,7 +227,19 @@ libtextclassifier3::Status WriteVersion(const Filesystem& filesystem, StateChange GetVersionStateChange(const VersionInfo& existing_version_info, int32_t curr_version = kVersion); -// Helper method to determine whether Icing should rebuild all derived files. +// Determines the derived files that need to be rebuilt between Icing's existing +// data based on previous data version and enabled features, and the current +// code version and enabled features. +// +// REQUIRES: curr_version >= kFirstV2Version. We implement v2 version checking +// in kFirstV2Version, so callers will always use a version # greater than this. +// +// RETURNS: DerivedFilesRebuildResult +DerivedFilesRebuildResult CalculateRequiredDerivedFilesRebuild( + const IcingSearchEngineVersionProto& prev_version_proto, + const IcingSearchEngineVersionProto& curr_version_proto); + +// Determines whether Icing should rebuild all derived files. // Sometimes it is not required to rebuild derived files when // roll-forward/upgrading. This function "encodes" upgrade paths and checks if // the roll-forward/upgrading requires derived files to be rebuilt or not. @@ -107,6 +249,24 @@ StateChange GetVersionStateChange(const VersionInfo& existing_version_info, bool ShouldRebuildDerivedFiles(const VersionInfo& existing_version_info, int32_t curr_version = kVersion); +// Returns the derived files rebuilds required for a given feature. +DerivedFilesRebuildResult GetFeatureDerivedFilesRebuildResult( + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature); + +// Constructs the IcingSearchEngineFeatureInfoProto for a given feature. +IcingSearchEngineFeatureInfoProto GetFeatureInfoProto( + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature); + +// Populates the enabled_features field for an IcingSearchEngineFeatureInfoProto +// based on icing's initialization flag options. +// +// All enabled features are converted into an IcingSearchEngineFeatureInfoProto +// and returned in IcingSearchEngineVersionProto::enabled_features. A conversion +// should be added for each trunk stable feature flag defined in +// IcingSearchEngineOptions. +void AddEnabledFeatures(const IcingSearchEngineOptions& options, + IcingSearchEngineVersionProto* version_proto); + } // namespace version_util } // namespace lib diff --git a/icing/file/version-util_test.cc b/icing/file/version-util_test.cc index 9dedb1d..f2922e2 100644 --- a/icing/file/version-util_test.cc +++ b/icing/file/version-util_test.cc @@ -14,14 +14,19 @@ #include "icing/file/version-util.h" +#include <cstdint> +#include <memory> #include <optional> #include <string> +#include <unordered_set> #include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/file/filesystem.h" #include "icing/file/posting_list/flash-index-storage-header.h" +#include "icing/portable/equals-proto.h" +#include "icing/proto/initialize.pb.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -32,19 +37,38 @@ namespace version_util { namespace { using ::testing::Eq; +using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; +IcingSearchEngineVersionProto MakeTestVersionProto( + const VersionInfo& version_info, + const std::unordered_set< + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType>& features_set) { + IcingSearchEngineVersionProto version_proto; + version_proto.set_version(version_info.version); + version_proto.set_max_version(version_info.max_version); + + auto* enabled_features = version_proto.mutable_enabled_features(); + for (const auto& feature : features_set) { + enabled_features->Add(GetFeatureInfoProto(feature)); + } + return version_proto; +} + struct VersionUtilReadVersionTestParam { - std::optional<VersionInfo> existing_version_info; + std::optional<VersionInfo> existing_v1_version_info; + std::optional<VersionInfo> existing_v2_version_info; std::optional<int> existing_flash_index_magic; VersionInfo expected_version_info; explicit VersionUtilReadVersionTestParam( - std::optional<VersionInfo> existing_version_info_in, + std::optional<VersionInfo> existing_v1_version_info_in, + std::optional<VersionInfo> existing_v2_version_info_in, std::optional<int> existing_flash_index_magic_in, VersionInfo expected_version_info_in) - : existing_version_info(std::move(existing_version_info_in)), + : existing_v1_version_info(std::move(existing_v1_version_info_in)), + existing_v2_version_info(std::move(existing_v2_version_info_in)), existing_flash_index_magic(std::move(existing_flash_index_magic_in)), expected_version_info(std::move(expected_version_info_in)) {} }; @@ -54,7 +78,6 @@ class VersionUtilReadVersionTest protected: void SetUp() override { base_dir_ = GetTestTempDir() + "/version_util_test"; - version_file_path_ = base_dir_ + "/version"; index_path_ = base_dir_ + "/index"; ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(base_dir_.c_str())); @@ -68,19 +91,28 @@ class VersionUtilReadVersionTest Filesystem filesystem_; std::string base_dir_; - std::string version_file_path_; std::string index_path_; }; TEST_P(VersionUtilReadVersionTest, ReadVersion) { const VersionUtilReadVersionTestParam& param = GetParam(); + IcingSearchEngineVersionProto dummy_version_proto; - // Prepare version file and flash index file. - if (param.existing_version_info.has_value()) { - ICING_ASSERT_OK(WriteVersion(filesystem_, version_file_path_, - param.existing_version_info.value())); + if (param.existing_v1_version_info.has_value()) { + ICING_ASSERT_OK(WriteV1Version(filesystem_, base_dir_, + param.existing_v1_version_info.value())); + } + if (param.existing_v2_version_info.has_value()) { + dummy_version_proto = MakeTestVersionProto( + param.existing_v2_version_info.value(), + {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR, + IcingSearchEngineFeatureInfoProto::UNKNOWN}); + ICING_ASSERT_OK(WriteV2Version( + filesystem_, base_dir_, + std::make_unique<IcingSearchEngineVersionProto>(dummy_version_proto))); } + // Prepare flash index file. if (param.existing_flash_index_magic.has_value()) { HeaderBlock header_block(&filesystem_, /*block_size=*/4096); header_block.header()->magic = param.existing_flash_index_magic.value(); @@ -94,10 +126,22 @@ TEST_P(VersionUtilReadVersionTest, ReadVersion) { ASSERT_TRUE(header_block.Write(sfd.get())); } - ICING_ASSERT_OK_AND_ASSIGN( - VersionInfo version_info, - ReadVersion(filesystem_, version_file_path_, index_path_)); - EXPECT_THAT(version_info, Eq(param.expected_version_info)); + ICING_ASSERT_OK_AND_ASSIGN(IcingSearchEngineVersionProto version_proto, + ReadVersion(filesystem_, base_dir_, index_path_)); + if (param.existing_v2_version_info.has_value() && + param.expected_version_info.version == + param.existing_v2_version_info.value().version) { + EXPECT_THAT(version_proto, + portable_equals_proto::EqualsProto(dummy_version_proto)); + } else { + // We're returning the version from v1 version file, or an invalid version. + // version_proto.enabled_features should be empty in this case. + EXPECT_THAT(version_proto.version(), + Eq(param.expected_version_info.version)); + EXPECT_THAT(version_proto.max_version(), + Eq(param.expected_version_info.max_version)); + EXPECT_THAT(version_proto.enabled_features(), IsEmpty()); + } } INSTANTIATE_TEST_SUITE_P( @@ -107,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P( // - Flash index doesn't exist // - Result: version -1, max_version -1 (invalid) VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::nullopt, + /*existing_v1_version_info_in=*/std::nullopt, + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/std::nullopt, /*expected_version_info_in=*/ VersionInfo(/*version_in=*/-1, /*max_version=*/-1)), @@ -116,7 +161,8 @@ INSTANTIATE_TEST_SUITE_P( // - Flash index exists with version 0 magic // - Result: version 0, max_version 0 VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::nullopt, + /*existing_v1_version_info_in=*/std::nullopt, + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic), /*expected_version_info_in=*/ @@ -126,65 +172,157 @@ INSTANTIATE_TEST_SUITE_P( // - Flash index exists with non version 0 magic // - Result: version -1, max_version -1 (invalid) VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::nullopt, + /*existing_v1_version_info_in=*/std::nullopt, + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), /*expected_version_info_in=*/ VersionInfo(/*version_in=*/-1, /*max_version=*/-1)), - // - Version file exists + // - Version file v1 exists // - Flash index doesn't exist // - Result: version -1, max_version 1 (invalid) VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::make_optional<VersionInfo>( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( /*version_in=*/1, /*max_version=*/1), + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/std::nullopt, /*expected_version_info_in=*/ VersionInfo(/*version_in=*/-1, /*max_version=*/1)), - // - Version file exists: version 1, max_version 1 + // - Version file v1 exists: version 1, max_version 1 // - Flash index exists with version 0 magic // - Result: version 0, max_version 1 VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::make_optional<VersionInfo>( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( /*version_in=*/1, /*max_version=*/1), + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic), /*expected_version_info_in=*/ VersionInfo(/*version_in=*/0, /*max_version=*/1)), - // - Version file exists: version 2, max_version 3 + // - Version file v1 exists: version 2, max_version 3 // - Flash index exists with version 0 magic // - Result: version 0, max_version 3 VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::make_optional<VersionInfo>( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( /*version_in=*/2, /*max_version=*/3), + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic), /*expected_version_info_in=*/ VersionInfo(/*version_in=*/0, /*max_version=*/3)), - // - Version file exists: version 1, max_version 1 + // - Version file v1 exists: version 1, max_version 1 // - Flash index exists with non version 0 magic // - Result: version 1, max_version 1 VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::make_optional<VersionInfo>( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( /*version_in=*/1, /*max_version=*/1), + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), /*expected_version_info_in=*/ VersionInfo(/*version_in=*/1, /*max_version=*/1)), - // - Version file exists: version 2, max_version 3 + // - Version file v1 exists: version 2, max_version 3 // - Flash index exists with non version 0 magic // - Result: version 2, max_version 3 VersionUtilReadVersionTestParam( - /*existing_version_info_in=*/std::make_optional<VersionInfo>( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( /*version_in=*/2, /*max_version=*/3), + /*existing_v2_version_info_in=*/std::nullopt, + /*existing_flash_index_magic_in=*/ + std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/2, /*max_version=*/3)), + + // - Version file v1 exists: version 2, max_version 4 + // - Version file v2 exists: version 4, max_version 4 + // - Flash index exists with non version 0 magic + // - Result: version -1, max_version 4 + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( + /*version_in=*/2, /*max_version=*/4), + /*existing_v2_version_info_in=*/ + std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_flash_index_magic_in=*/ + std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/-1, /*max_version=*/4)), + + // - Version file v1 exists: version 4, max_version 4 + // - Version file v2 exists: version 4, max_version 4 + // - Flash index exists with version 0 magic + // - Result: version -1, max_version 4 + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_v2_version_info_in=*/ + std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_flash_index_magic_in=*/ + std::make_optional<int>(kVersionZeroFlashIndexMagic), + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/-1, /*max_version=*/4)), + + // - Version file v1 exists: version 4, max_version 4 + // - Version file v2 exists: version 4, max_version 4 + // - Flash index exists with non version 0 magic + // - Result: version 4, max_version 4 + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_v2_version_info_in=*/ + std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_flash_index_magic_in=*/ + std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/4, /*max_version=*/4)), + + // - Version file v1 exists: version 4, max_version 4 + // - Version file v2 does not exist + // - Flash index exists with non version 0 magic + // - Result: version -1, max_version 4 + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_v2_version_info_in=*/std::nullopt, /*existing_flash_index_magic_in=*/ std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), /*expected_version_info_in=*/ - VersionInfo(/*version_in=*/2, /*max_version=*/3)))); + VersionInfo(/*version_in=*/-1, /*max_version=*/4)), + + // - Version file v1 does not exist + // - Version file v2 exists: version 4, max_version 4 + // - Flash index exists with non version 0 magic + // - Result: version -1, max_version -1 + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::nullopt, + /*existing_v2_version_info_in=*/ + std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_flash_index_magic_in=*/ + std::make_optional<int>(kVersionZeroFlashIndexMagic + 1), + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/-1, /*max_version=*/-1)), + + // - Version file v1 doesn't exist + // - Version file v2 exists: version 4, max_version 4 + // - Flash index doesn't exist + // - Result: version -1, max_version -1 (invalid since flash index + // doesn't exist) + VersionUtilReadVersionTestParam( + /*existing_v1_version_info_in=*/std::nullopt, + /*existing_v2_version_info_in=*/ + std::make_optional<VersionInfo>( + /*version_in=*/4, /*max_version=*/4), + /*existing_flash_index_magic_in=*/std::nullopt, + /*expected_version_info_in=*/ + VersionInfo(/*version_in=*/-1, /*max_version=*/-1)))); struct VersionUtilStateChangeTestParam { VersionInfo existing_version_info; @@ -389,6 +527,240 @@ INSTANTIATE_TEST_SUITE_P( /*curr_version_in=*/2, /*expected_state_change_in=*/StateChange::kRollBack))); +struct VersionUtilDerivedFilesRebuildTestParam { + int32_t existing_version; + int32_t max_version; + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + existing_enabled_features; + int32_t curr_version; + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + curr_enabled_features; + DerivedFilesRebuildResult expected_derived_files_rebuild_result; + + explicit VersionUtilDerivedFilesRebuildTestParam( + int32_t existing_version_in, int32_t max_version_in, + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + existing_enabled_features_in, + int32_t curr_version_in, + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + curr_enabled_features_in, + DerivedFilesRebuildResult expected_derived_files_rebuild_result_in) + : existing_version(existing_version_in), + max_version(max_version_in), + existing_enabled_features(std::move(existing_enabled_features_in)), + curr_version(curr_version_in), + curr_enabled_features(std::move(curr_enabled_features_in)), + expected_derived_files_rebuild_result( + std::move(expected_derived_files_rebuild_result_in)) {} +}; + +class VersionUtilDerivedFilesRebuildTest + : public ::testing::TestWithParam<VersionUtilDerivedFilesRebuildTestParam> { +}; + +TEST_P(VersionUtilDerivedFilesRebuildTest, + CalculateRequiredDerivedFilesRebuild) { + const VersionUtilDerivedFilesRebuildTestParam& param = GetParam(); + + EXPECT_THAT(CalculateRequiredDerivedFilesRebuild( + /*prev_version_proto=*/MakeTestVersionProto( + VersionInfo(param.existing_version, param.max_version), + param.existing_enabled_features), + /*curr_version_proto=*/ + MakeTestVersionProto( + VersionInfo(param.curr_version, param.max_version), + param.curr_enabled_features)), + Eq(param.expected_derived_files_rebuild_result)); +} + +INSTANTIATE_TEST_SUITE_P( + VersionUtilDerivedFilesRebuildTest, VersionUtilDerivedFilesRebuildTest, + testing::Values( + // - Existing version -1, max_version -1 (invalid) + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: rebuild everything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/-1, /*max_version_in=*/-1, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true)), + + // - Existing version -1, max_version 2 (invalid) + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: rebuild everything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/-1, /*max_version_in=*/-1, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true)), + + // - Existing version 3, max_version 3 (pre v2 version check) + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: don't rebuild anything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/3, /*max_version_in=*/3, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/false, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false)), + + // - Existing version 3, max_version 3 (pre v2 version check) + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR} + // + // - Result: rebuild term index + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/3, /*max_version_in=*/3, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/ + {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false)), + + // - Existing version 4, max_version 4 + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: don't rebuild anything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/4, /*max_version_in=*/4, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/false, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false)), + + // - Existing version 4, max_version 5 + // - Existing enabled features = {} + // - Current version = 5 + // - Current enabled features = {} + // + // - Result: Rollforward -- rebuild everything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/4, /*max_version_in=*/5, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/5, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true)), + + // - Existing version 5, max_version 5 + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: Rollback -- rebuild everything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/5, /*max_version_in=*/5, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true)), + + // - Existing version 4, max_version 4 + // - Existing enabled features = {} + // - Current version = 4 + // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR} + // + // - Result: rebuild term index + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/4, /*max_version_in=*/4, + /*existing_enabled_features_in=*/{}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/ + {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false)), + + // - Existing version 4, max_version 4 + // - Existing enabled features = {FEATURE_HAS_PROPERTY_OPERATOR} + // - Current version = 4 + // - Current enabled features = {} + // + // - Result: rebuild term index + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/4, /*max_version_in=*/4, + /*existing_enabled_features_in=*/ + {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR}, + /*curr_version_in=*/4, /*curr_enabled_features_in=*/{}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false)), + + // - Existing version 4, max_version 4 + // - Existing enabled features = {UNKNOWN} + // - Current version = 4 + // - Current enabled features = {FEATURE_HAS_PROPERTY_OPERATOR} + // + // - Result: rebuild everything + VersionUtilDerivedFilesRebuildTestParam( + /*existing_version_in=*/4, /*max_version_in=*/4, + /*existing_enabled_features_in=*/ + {IcingSearchEngineFeatureInfoProto::UNKNOWN}, /*curr_version_in=*/4, + /*curr_enabled_features_in=*/ + {IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR}, + /*expected_derived_files_rebuild_result_in=*/ + DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true)))); + TEST(VersionUtilTest, ShouldRebuildDerivedFilesUndeterminedVersion) { EXPECT_THAT( ShouldRebuildDerivedFiles(VersionInfo(-1, -1), /*curr_version=*/1), @@ -475,8 +847,80 @@ TEST(VersionUtilTest, Upgrade) { EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionOne, kVersionOne), /*curr_version=*/kVersionThree), IsFalse()); + + // kVersionThree -> kVersionFour. + EXPECT_THAT( + ShouldRebuildDerivedFiles(VersionInfo(kVersionThree, kVersionThree), + /*curr_version=*/kVersionFour), + IsFalse()); + + // kVersionTwo -> kVersionFour + EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionTwo, kVersionTwo), + /*curr_version=*/kVersionFour), + IsFalse()); + + // kVersionOne -> kVersionFour. + EXPECT_THAT(ShouldRebuildDerivedFiles(VersionInfo(kVersionOne, kVersionOne), + /*curr_version=*/kVersionFour), + IsFalse()); +} + +TEST(VersionUtilTest, GetFeatureDerivedFilesRebuildResult_unknown) { + EXPECT_THAT(GetFeatureDerivedFilesRebuildResult( + IcingSearchEngineFeatureInfoProto::UNKNOWN), + Eq(DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/true, + /*needs_schema_store_derived_files_rebuild=*/true, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/true, + /*needs_qualified_id_join_index_rebuild=*/true))); } +TEST(VersionUtilTest, + GetFeatureDerivedFilesRebuildResult_featureHasPropertyOperator) { + EXPECT_THAT( + GetFeatureDerivedFilesRebuildResult( + IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR), + Eq(DerivedFilesRebuildResult( + /*needs_document_store_derived_files_rebuild=*/false, + /*needs_schema_store_derived_files_rebuild=*/false, + /*needs_term_index_rebuild=*/true, + /*needs_integer_index_rebuild=*/false, + /*needs_qualified_id_join_index_rebuild=*/false))); +} + +class VersionUtilFeatureProtoTest + : public ::testing::TestWithParam< + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> {}; + +TEST_P(VersionUtilFeatureProtoTest, GetFeatureInfoProto) { + IcingSearchEngineFeatureInfoProto::FlaggedFeatureType feature_type = + GetParam(); + DerivedFilesRebuildResult rebuild_result = + GetFeatureDerivedFilesRebuildResult(feature_type); + + IcingSearchEngineFeatureInfoProto feature_info = + GetFeatureInfoProto(feature_type); + EXPECT_THAT(feature_info.feature_type(), Eq(feature_type)); + + EXPECT_THAT(feature_info.needs_document_store_rebuild(), + Eq(rebuild_result.needs_document_store_derived_files_rebuild)); + EXPECT_THAT(feature_info.needs_schema_store_rebuild(), + Eq(rebuild_result.needs_schema_store_derived_files_rebuild)); + EXPECT_THAT(feature_info.needs_term_index_rebuild(), + Eq(rebuild_result.needs_term_index_rebuild)); + EXPECT_THAT(feature_info.needs_integer_index_rebuild(), + Eq(rebuild_result.needs_integer_index_rebuild)); + EXPECT_THAT(feature_info.needs_qualified_id_join_index_rebuild(), + Eq(rebuild_result.needs_qualified_id_join_index_rebuild)); +} + +INSTANTIATE_TEST_SUITE_P( + VersionUtilFeatureProtoTest, VersionUtilFeatureProtoTest, + testing::Values( + IcingSearchEngineFeatureInfoProto::UNKNOWN, + IcingSearchEngineFeatureInfoProto::FEATURE_HAS_PROPERTY_OPERATOR)); + } // namespace } // namespace version_util diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 72be4e9..89caaf1 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -14,7 +14,10 @@ #include "icing/icing-search-engine.h" +#include <algorithm> +#include <cstddef> #include <cstdint> +#include <functional> #include <memory> #include <string> #include <string_view> @@ -34,6 +37,8 @@ #include "icing/file/filesystem.h" #include "icing/file/version-util.h" #include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embedding-indexing-handler.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/index-processor.h" #include "icing/index/index.h" @@ -41,12 +46,16 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/numeric/integer-index.h" #include "icing/index/term-indexing-handler.h" +#include "icing/index/term-metadata.h" +#include "icing/jni/jni-cache.h" +#include "icing/join/join-children-fetcher.h" #include "icing/join/join-processor.h" #include "icing/join/qualified-id-join-index-impl-v1.h" #include "icing/join/qualified-id-join-index-impl-v2.h" #include "icing/join/qualified-id-join-index.h" #include "icing/join/qualified-id-join-indexing-handler.h" #include "icing/legacy/index/icing-filesystem.h" +#include "icing/performance-configuration.h" #include "icing/portable/endian.h" #include "icing/proto/debug.pb.h" #include "icing/proto/document.pb.h" @@ -73,9 +82,8 @@ #include "icing/result/projector.h" #include "icing/result/result-adjustment-info.h" #include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-manager.h" #include "icing/schema/schema-store.h" -#include "icing/schema/schema-util.h" -#include "icing/schema/section.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/priority-queue-scored-document-hits-ranker.h" #include "icing/scoring/scored-document-hit.h" @@ -84,11 +92,8 @@ #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter-factory.h" -#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" -#include "icing/transform/normalizer.h" #include "icing/util/clock.h" -#include "icing/util/crc32.h" #include "icing/util/data-loss.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -100,12 +105,12 @@ namespace lib { namespace { -constexpr std::string_view kVersionFilename = "version"; constexpr std::string_view kDocumentSubfolderName = "document_dir"; constexpr std::string_view kIndexSubfolderName = "index_dir"; constexpr std::string_view kIntegerIndexSubfolderName = "integer_index_dir"; constexpr std::string_view kQualifiedIdJoinIndexSubfolderName = "qualified_id_join_index_dir"; +constexpr std::string_view kEmbeddingIndexSubfolderName = "embedding_index_dir"; constexpr std::string_view kSchemaSubfolderName = "schema_dir"; constexpr std::string_view kSetSchemaMarkerFilename = "set_schema_marker"; constexpr std::string_view kInitMarkerFilename = "init_marker"; @@ -253,12 +258,6 @@ CreateQualifiedIdJoinIndex(const Filesystem& filesystem, } } -// Version file is a single file under base_dir containing version info of the -// existing data. -std::string MakeVersionFilePath(const std::string& base_dir) { - return absl_ports::StrCat(base_dir, "/", kVersionFilename); -} - // Document store files are in a standalone subfolder for easier file // management. We can delete and recreate the subfolder and not touch/affect // anything else. @@ -296,6 +295,11 @@ std::string MakeQualifiedIdJoinIndexWorkingPath(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kQualifiedIdJoinIndexSubfolderName); } +// Working path for embedding index. +std::string MakeEmbeddingIndexWorkingPath(const std::string& base_dir) { + return absl_ports::StrCat(base_dir, "/", kEmbeddingIndexSubfolderName); +} + // SchemaStore files are in a standalone subfolder for easier file management. // We can delete and recreate the subfolder and not touch/affect anything // else. @@ -482,6 +486,7 @@ void IcingSearchEngine::ResetMembers() { index_.reset(); integer_index_.reset(); qualified_id_join_index_.reset(); + embedding_index_.reset(); } libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile( @@ -604,32 +609,49 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( return status; } - // Read version file and determine the state change. - const std::string version_filepath = MakeVersionFilePath(options_.base_dir()); + // Do version and flags compatibility check + // Read version file, determine the state change and rebuild derived files if + // needed. const std::string index_dir = MakeIndexDirectoryPath(options_.base_dir()); ICING_ASSIGN_OR_RETURN( - version_util::VersionInfo version_info, - version_util::ReadVersion(*filesystem_, version_filepath, index_dir)); + IcingSearchEngineVersionProto stored_version_proto, + version_util::ReadVersion( + *filesystem_, /*version_file_dir=*/options_.base_dir(), index_dir)); + version_util::VersionInfo stored_version_info = + version_util::GetVersionInfoFromProto(stored_version_proto); version_util::StateChange version_state_change = - version_util::GetVersionStateChange(version_info); + version_util::GetVersionStateChange(stored_version_info); + + // Construct icing's current version proto based on the current code version + IcingSearchEngineVersionProto current_version_proto; + current_version_proto.set_version(version_util::kVersion); + current_version_proto.set_max_version( + std::max(stored_version_info.max_version, version_util::kVersion)); + version_util::AddEnabledFeatures(options_, ¤t_version_proto); + + // Step 1: If versions are incompatible, migrate schema according to the + // version state change. if (version_state_change != version_util::StateChange::kCompatible) { - // Step 1: migrate schema according to the version state change. ICING_RETURN_IF_ERROR(SchemaStore::MigrateSchema( filesystem_.get(), MakeSchemaDirectoryPath(options_.base_dir()), version_state_change, version_util::kVersion)); + } - // Step 2: discard all derived data if needed rebuild. - if (version_util::ShouldRebuildDerivedFiles(version_info)) { - ICING_RETURN_IF_ERROR(DiscardDerivedFiles()); - } + // Step 2: Discard derived files that need to be rebuilt + version_util::DerivedFilesRebuildResult required_derived_files_rebuild = + version_util::CalculateRequiredDerivedFilesRebuild(stored_version_proto, + current_version_proto); + ICING_RETURN_IF_ERROR(DiscardDerivedFiles(required_derived_files_rebuild)); - // Step 3: update version file - version_util::VersionInfo new_version_info( - version_util::kVersion, - std::max(version_info.max_version, version_util::kVersion)); - ICING_RETURN_IF_ERROR(version_util::WriteVersion( - *filesystem_, version_filepath, new_version_info)); - } + // Step 3: update version files. We need to update both the V1 and V2 + // version files. + ICING_RETURN_IF_ERROR(version_util::WriteV1Version( + *filesystem_, /*version_file_dir=*/options_.base_dir(), + version_util::GetVersionInfoFromProto(current_version_proto))); + ICING_RETURN_IF_ERROR(version_util::WriteV2Version( + *filesystem_, /*version_file_dir=*/options_.base_dir(), + std::make_unique<IcingSearchEngineVersionProto>( + std::move(current_version_proto)))); ICING_RETURN_IF_ERROR(InitializeSchemaStore(initialize_stats)); @@ -655,15 +677,19 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( MakeIntegerIndexWorkingPath(options_.base_dir()); const std::string qualified_id_join_index_dir = MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()); + const std::string embedding_index_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); if (!filesystem_->DeleteDirectoryRecursively(doc_store_dir.c_str()) || !filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) || !IntegerIndex::Discard(*filesystem_, integer_index_dir).ok() || !QualifiedIdJoinIndex::Discard(*filesystem_, qualified_id_join_index_dir) - .ok()) { + .ok() || + !EmbeddingIndex::Discard(*filesystem_, embedding_index_dir).ok()) { return absl_ports::InternalError(absl_ports::StrCat( "Could not delete directories: ", index_dir, ", ", integer_index_dir, - ", ", qualified_id_join_index_dir, " and ", doc_store_dir)); + ", ", qualified_id_join_index_dir, ", ", embedding_index_dir, " and ", + doc_store_dir)); } ICING_ASSIGN_OR_RETURN( bool document_store_derived_files_regenerated, @@ -688,10 +714,9 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( // We're going to need to build the index from scratch. So just delete its // directory now. // Discard index directory and instantiate a new one. - Index::Options index_options( - index_dir, options_.index_merge_size(), - options_.lite_index_sort_at_indexing(), options_.lite_index_sort_size(), - options_.build_property_existence_metadata_hits()); + Index::Options index_options(index_dir, options_.index_merge_size(), + options_.lite_index_sort_at_indexing(), + options_.lite_index_sort_size()); if (!filesystem_->DeleteDirectoryRecursively(index_dir.c_str()) || !filesystem_->CreateDirectoryRecursively(index_dir.c_str())) { return absl_ports::InternalError( @@ -722,6 +747,15 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( CreateQualifiedIdJoinIndex( *filesystem_, std::move(qualified_id_join_index_dir), options_)); + // Discard embedding index directory and instantiate a new one. + std::string embedding_index_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); + ICING_RETURN_IF_ERROR( + EmbeddingIndex::Discard(*filesystem_, embedding_index_dir)); + ICING_ASSIGN_OR_RETURN( + embedding_index_, + EmbeddingIndex::Create(filesystem_.get(), embedding_index_dir)); + std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer(); IndexRestorationResult restore_result = RestoreIndexIfNeeded(); index_init_status = std::move(restore_result.status); @@ -744,6 +778,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); initialize_stats->set_qualified_id_join_index_restoration_cause( InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); + initialize_stats->set_embedding_index_restoration_cause( + InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC); } else if (version_state_change != version_util::StateChange::kCompatible) { ICING_ASSIGN_OR_RETURN(bool document_store_derived_files_regenerated, InitializeDocumentStore( @@ -765,6 +801,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( InitializeStatsProto::VERSION_CHANGED); initialize_stats->set_qualified_id_join_index_restoration_cause( InitializeStatsProto::VERSION_CHANGED); + initialize_stats->set_embedding_index_restoration_cause( + InitializeStatsProto::VERSION_CHANGED); } else { ICING_ASSIGN_OR_RETURN( bool document_store_derived_files_regenerated, @@ -776,6 +814,32 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( if (!index_init_status.ok() && !absl_ports::IsDataLoss(index_init_status)) { return index_init_status; } + + // Set recovery cause to FEATURE_FLAG_CHANGED according to the calculated + // required_derived_files_rebuild + if (required_derived_files_rebuild + .needs_document_store_derived_files_rebuild) { + initialize_stats->set_document_store_recovery_cause( + InitializeStatsProto::FEATURE_FLAG_CHANGED); + } + if (required_derived_files_rebuild + .needs_schema_store_derived_files_rebuild) { + initialize_stats->set_schema_store_recovery_cause( + InitializeStatsProto::FEATURE_FLAG_CHANGED); + } + if (required_derived_files_rebuild.needs_term_index_rebuild) { + initialize_stats->set_index_restoration_cause( + InitializeStatsProto::FEATURE_FLAG_CHANGED); + } + if (required_derived_files_rebuild.needs_integer_index_rebuild) { + initialize_stats->set_integer_index_restoration_cause( + InitializeStatsProto::FEATURE_FLAG_CHANGED); + } + if (required_derived_files_rebuild.needs_qualified_id_join_index_rebuild) { + initialize_stats->set_qualified_id_join_index_restoration_cause( + InitializeStatsProto::FEATURE_FLAG_CHANGED); + } + // TODO(b/326656531): Update version-util to consider embedding index. } if (status.ok()) { @@ -842,10 +906,9 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex( return absl_ports::InternalError( absl_ports::StrCat("Could not create directory: ", index_dir)); } - Index::Options index_options( - index_dir, options_.index_merge_size(), - options_.lite_index_sort_at_indexing(), options_.lite_index_sort_size(), - options_.build_property_existence_metadata_hits()); + Index::Options index_options(index_dir, options_.index_merge_size(), + options_.lite_index_sort_at_indexing(), + options_.lite_index_sort_size()); // Term index InitializeStatsProto::RecoveryCause index_recovery_cause; @@ -945,6 +1008,30 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex( } } + // Embedding index + const std::string embedding_dir = + MakeEmbeddingIndexWorkingPath(options_.base_dir()); + InitializeStatsProto::RecoveryCause embedding_index_recovery_cause; + auto embedding_index_or = + EmbeddingIndex::Create(filesystem_.get(), embedding_dir); + if (!embedding_index_or.ok()) { + ICING_RETURN_IF_ERROR(EmbeddingIndex::Discard(*filesystem_, embedding_dir)); + + embedding_index_recovery_cause = InitializeStatsProto::IO_ERROR; + + // Try recreating it from scratch and re-indexing everything. + ICING_ASSIGN_OR_RETURN( + embedding_index_, + EmbeddingIndex::Create(filesystem_.get(), embedding_dir)); + } else { + // Embedding index was created fine. + embedding_index_ = std::move(embedding_index_or).ValueOrDie(); + // If a recover does have to happen, then it must be because the index is + // out of sync with the document store. + embedding_index_recovery_cause = + InitializeStatsProto::INCONSISTENT_WITH_GROUND_TRUTH; + } + std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer(); IndexRestorationResult restore_result = RestoreIndexIfNeeded(); if (restore_result.index_needed_restoration || @@ -964,6 +1051,10 @@ libtextclassifier3::Status IcingSearchEngine::InitializeIndex( initialize_stats->set_qualified_id_join_index_restoration_cause( qualified_id_join_index_recovery_cause); } + if (restore_result.embedding_index_needed_restoration) { + initialize_stats->set_embedding_index_restoration_cause( + embedding_index_recovery_cause); + } } return restore_result.status; } @@ -1479,8 +1570,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); // Gets unordered results from query processor auto query_processor_or = QueryProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); delete_stats->set_parse_query_latency_ms( @@ -1676,6 +1768,15 @@ OptimizeResultProto IcingSearchEngine::Optimize() { << qualified_id_join_index_optimize_status.error_message(); should_rebuild_index = true; } + + libtextclassifier3::Status embedding_index_optimize_status = + embedding_index_->Optimize(optimize_result.document_id_old_to_new, + document_store_->last_added_document_id()); + if (!embedding_index_optimize_status.ok()) { + ICING_LOG(WARNING) << "Failed to optimize embedding index. Error: " + << embedding_index_optimize_status.error_message(); + should_rebuild_index = true; + } } // If we received a DATA_LOSS error from OptimizeDocumentStore, we have a // valid document store, but it might be the old one or the new one. So throw @@ -1902,6 +2003,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk( ICING_RETURN_IF_ERROR(index_->PersistToDisk()); ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk()); ICING_RETURN_IF_ERROR(qualified_id_join_index_->PersistToDisk()); + ICING_RETURN_IF_ERROR(embedding_index_->PersistToDisk()); return libtextclassifier3::Status::OK; } @@ -1979,6 +2081,7 @@ SearchResultProto IcingSearchEngine::InternalSearch( result_status->set_message("IcingSearchEngine has not been initialized!"); return result_proto; } + index_->PublishQueryStats(query_stats); libtextclassifier3::Status status = ValidateResultSpec(document_store_.get(), result_spec); @@ -2182,8 +2285,9 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( // Gets unordered results from query processor auto query_processor_or = QueryProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!query_processor_or.ok()) { search_stats->set_parse_query_latency_ms( component_timer->GetElapsedMilliseconds()); @@ -2198,7 +2302,8 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( libtextclassifier3::StatusOr<QueryResults> query_results_or; if (ranking_strategy_or.ok()) { query_results_or = query_processor->ParseSearch( - search_spec, ranking_strategy_or.ValueOrDie(), current_time_ms); + search_spec, ranking_strategy_or.ValueOrDie(), current_time_ms, + search_stats); } else { query_results_or = ranking_strategy_or.status(); } @@ -2226,8 +2331,10 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( // Scores but does not rank the results. libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> scoring_processor_or = ScoringProcessor::Create( - scoring_spec, document_store_.get(), schema_store_.get(), - current_time_ms, join_children_fetcher); + scoring_spec, /*default_semantic_metric_type=*/ + search_spec.embedding_query_metric_type(), document_store_.get(), + schema_store_.get(), current_time_ms, join_children_fetcher, + &query_results.embedding_query_results); if (!scoring_processor_or.ok()) { return QueryScoringResults(std::move(scoring_processor_or).status(), std::move(query_results.query_terms), @@ -2464,27 +2571,28 @@ IcingSearchEngine::RestoreIndexIfNeeded() { if (last_stored_document_id == index_->last_added_document_id() && last_stored_document_id == integer_index_->last_added_document_id() && last_stored_document_id == - qualified_id_join_index_->last_added_document_id()) { + qualified_id_join_index_->last_added_document_id() && + last_stored_document_id == embedding_index_->last_added_document_id()) { // No need to recover. - return {libtextclassifier3::Status::OK, false, false, false}; + return {libtextclassifier3::Status::OK, false, false, false, false}; } if (last_stored_document_id == kInvalidDocumentId) { // Document store is empty but index is not. Clear the index. - return {ClearAllIndices(), false, false, false}; + return {ClearAllIndices(), false, false, false, false}; } // Truncate indices first. auto truncate_result_or = TruncateIndicesTo(last_stored_document_id); if (!truncate_result_or.ok()) { - return {std::move(truncate_result_or).status(), false, false, false}; + return {std::move(truncate_result_or).status(), false, false, false, false}; } TruncateIndexResult truncate_result = std::move(truncate_result_or).ValueOrDie(); if (truncate_result.first_document_to_reindex > last_stored_document_id) { // Nothing to restore. Just return. - return {libtextclassifier3::Status::OK, false, false, false}; + return {libtextclassifier3::Status::OK, false, false, false, false}; } auto data_indexing_handlers_or = CreateDataIndexingHandlers(); @@ -2492,7 +2600,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {data_indexing_handlers_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } // By using recovery_mode for IndexProcessor, we're able to replay documents // from smaller document id and it will skip documents that are already been @@ -2519,7 +2628,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { // Returns other errors return {document_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } } DocumentProto document(std::move(document_or).ValueOrDie()); @@ -2532,7 +2642,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {tokenized_document_or.status(), truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } TokenizedDocument tokenized_document( std::move(tokenized_document_or).ValueOrDie()); @@ -2544,7 +2655,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { // Real error. Stop recovering and pass it up. return {status, truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } // FIXME: why can we skip data loss error here? // Just a data loss. Keep trying to add the remaining docs, but report the @@ -2555,7 +2667,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {overall_status, truncate_result.index_needed_restoration, truncate_result.integer_index_needed_restoration, - truncate_result.qualified_id_join_index_needed_restoration}; + truncate_result.qualified_id_join_index_needed_restoration, + truncate_result.embedding_index_needed_restoration}; } libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() { @@ -2608,6 +2721,11 @@ IcingSearchEngine::CreateDataIndexingHandlers() { clock_.get(), document_store_.get(), qualified_id_join_index_.get())); handlers.push_back(std::move(qualified_id_join_indexing_handler)); + // Embedding index handler + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<EmbeddingIndexingHandler> embedding_indexing_handler, + EmbeddingIndexingHandler::Create(clock_.get(), embedding_index_.get())); + handlers.push_back(std::move(embedding_indexing_handler)); return handlers; } @@ -2696,53 +2814,104 @@ IcingSearchEngine::TruncateIndicesTo(DocumentId last_stored_document_id) { first_document_to_reindex = kMinDocumentId; } + // Attempt to truncate embedding index + bool embedding_index_needed_restoration = false; + DocumentId embedding_index_last_added_document_id = + embedding_index_->last_added_document_id(); + if (embedding_index_last_added_document_id == kInvalidDocumentId || + last_stored_document_id > embedding_index_last_added_document_id) { + // If last_stored_document_id is greater than + // embedding_index_last_added_document_id, then we only have to replay docs + // starting from (embedding_index_last_added_document_id + 1). Also use + // std::min since we might need to replay even smaller doc ids for other + // components. + embedding_index_needed_restoration = true; + if (embedding_index_last_added_document_id != kInvalidDocumentId) { + first_document_to_reindex = + std::min(first_document_to_reindex, + embedding_index_last_added_document_id + 1); + } else { + first_document_to_reindex = kMinDocumentId; + } + } else if (last_stored_document_id < embedding_index_last_added_document_id) { + // Clear the entire embedding index if last_stored_document_id is + // smaller than embedding_index_last_added_document_id, because + // there is no way to remove data with doc_id > last_stored_document_id from + // embedding index efficiently and we have to rebuild. + ICING_RETURN_IF_ERROR(embedding_index_->Clear()); + + // Since the entire embedding index is discarded, we start to + // rebuild it by setting first_document_to_reindex to kMinDocumentId. + embedding_index_needed_restoration = true; + first_document_to_reindex = kMinDocumentId; + } + return TruncateIndexResult(first_document_to_reindex, index_needed_restoration, integer_index_needed_restoration, - qualified_id_join_index_needed_restoration); + qualified_id_join_index_needed_restoration, + embedding_index_needed_restoration); } -libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles() { +libtextclassifier3::Status IcingSearchEngine::DiscardDerivedFiles( + const version_util::DerivedFilesRebuildResult& rebuild_result) { + if (!rebuild_result.IsRebuildNeeded()) { + return libtextclassifier3::Status::OK; + } + if (schema_store_ != nullptr || document_store_ != nullptr || index_ != nullptr || integer_index_ != nullptr || - qualified_id_join_index_ != nullptr) { + qualified_id_join_index_ != nullptr || embedding_index_ != nullptr) { return absl_ports::FailedPreconditionError( "Cannot discard derived files while having valid instances"); } // Schema store - ICING_RETURN_IF_ERROR( - SchemaStore::DiscardDerivedFiles(filesystem_.get(), options_.base_dir())); + if (rebuild_result.needs_schema_store_derived_files_rebuild) { + ICING_RETURN_IF_ERROR(SchemaStore::DiscardDerivedFiles( + filesystem_.get(), options_.base_dir())); + } // Document store - ICING_RETURN_IF_ERROR(DocumentStore::DiscardDerivedFiles( - filesystem_.get(), options_.base_dir())); + if (rebuild_result.needs_document_store_derived_files_rebuild) { + ICING_RETURN_IF_ERROR(DocumentStore::DiscardDerivedFiles( + filesystem_.get(), options_.base_dir())); + } // Term index - if (!filesystem_->DeleteDirectoryRecursively( - MakeIndexDirectoryPath(options_.base_dir()).c_str())) { - return absl_ports::InternalError("Failed to discard index"); + if (rebuild_result.needs_term_index_rebuild) { + if (!filesystem_->DeleteDirectoryRecursively( + MakeIndexDirectoryPath(options_.base_dir()).c_str())) { + return absl_ports::InternalError("Failed to discard index"); + } } // Integer index - if (!filesystem_->DeleteDirectoryRecursively( - MakeIntegerIndexWorkingPath(options_.base_dir()).c_str())) { - return absl_ports::InternalError("Failed to discard integer index"); + if (rebuild_result.needs_integer_index_rebuild) { + if (!filesystem_->DeleteDirectoryRecursively( + MakeIntegerIndexWorkingPath(options_.base_dir()).c_str())) { + return absl_ports::InternalError("Failed to discard integer index"); + } } // Qualified id join index - if (!filesystem_->DeleteDirectoryRecursively( - MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()).c_str())) { - return absl_ports::InternalError( - "Failed to discard qualified id join index"); + if (rebuild_result.needs_qualified_id_join_index_rebuild) { + if (!filesystem_->DeleteDirectoryRecursively( + MakeQualifiedIdJoinIndexWorkingPath(options_.base_dir()).c_str())) { + return absl_ports::InternalError( + "Failed to discard qualified id join index"); + } } + // TODO(b/326656531): Update version-util to consider embedding index. + return libtextclassifier3::Status::OK; } libtextclassifier3::Status IcingSearchEngine::ClearSearchIndices() { ICING_RETURN_IF_ERROR(index_->Reset()); ICING_RETURN_IF_ERROR(integer_index_->Clear()); + ICING_RETURN_IF_ERROR(embedding_index_->Clear()); return libtextclassifier3::Status::OK; } @@ -2815,8 +2984,9 @@ SuggestionResponse IcingSearchEngine::SearchSuggestions( // Create the suggestion processor. auto suggestion_processor_or = SuggestionProcessor::Create( - index_.get(), integer_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get()); + index_.get(), integer_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), clock_.get()); if (!suggestion_processor_or.ok()) { TransformStatus(suggestion_processor_or.status(), response_status); return response; diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index d316350..57f0f28 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -17,7 +17,6 @@ #include <cstdint> #include <memory> -#include <string> #include <string_view> #include <utility> #include <vector> @@ -27,7 +26,9 @@ #include "icing/absl_ports/mutex.h" #include "icing/absl_ports/thread_annotations.h" #include "icing/file/filesystem.h" +#include "icing/file/version-util.h" #include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/numeric-index.h" #include "icing/jni/jni-cache.h" @@ -51,11 +52,11 @@ #include "icing/result/result-state-manager.h" #include "icing/schema/schema-store.h" #include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer.h" #include "icing/util/clock.h" -#include "icing/util/crc32.h" namespace icing { namespace lib { @@ -483,6 +484,9 @@ class IcingSearchEngine { std::unique_ptr<QualifiedIdJoinIndex> qualified_id_join_index_ ICING_GUARDED_BY(mutex_); + // Storage for all hits of embedding contents from the document store. + std::unique_ptr<EmbeddingIndex> embedding_index_ ICING_GUARDED_BY(mutex_); + // Pointer to JNI class references const std::unique_ptr<const JniCache> jni_cache_; @@ -578,8 +582,8 @@ class IcingSearchEngine { // read-lock, allowing for parallel non-exclusive operations. // This implementation is used if search_spec.use_read_only_search is true. SearchResultProto SearchLockedShared(const SearchSpecProto& search_spec, - const ScoringSpecProto& scoring_spec, - const ResultSpecProto& result_spec) + const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec) ICING_LOCKS_EXCLUDED(mutex_); // Implementation of IcingSearchEngine::Search that requires the overall @@ -587,8 +591,8 @@ class IcingSearchEngine { // this version is used. // This implementation is used if search_spec.use_read_only_search is false. SearchResultProto SearchLockedExclusive(const SearchSpecProto& search_spec, - const ScoringSpecProto& scoring_spec, - const ResultSpecProto& result_spec) + const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec) ICING_LOCKS_EXCLUDED(mutex_); // Helper method for the actual work to Search. We need this separate @@ -641,13 +645,14 @@ class IcingSearchEngine { libtextclassifier3::Status CheckConsistency() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Discards all derived data. + // Discards derived data that requires rebuild based on rebuild_result. // // Returns: // OK on success // FAILED_PRECONDITION_ERROR if those instances are valid (non nullptr) // INTERNAL_ERROR on any I/O errors - libtextclassifier3::Status DiscardDerivedFiles() + libtextclassifier3::Status DiscardDerivedFiles( + const version_util::DerivedFilesRebuildResult& rebuild_result) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Repopulates derived data off our ground truths. @@ -700,6 +705,7 @@ class IcingSearchEngine { bool index_needed_restoration; bool integer_index_needed_restoration; bool qualified_id_join_index_needed_restoration; + bool embedding_index_needed_restoration; }; IndexRestorationResult RestoreIndexIfNeeded() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -740,17 +746,21 @@ class IcingSearchEngine { bool index_needed_restoration; bool integer_index_needed_restoration; bool qualified_id_join_index_needed_restoration; + bool embedding_index_needed_restoration; explicit TruncateIndexResult( DocumentId first_document_to_reindex_in, bool index_needed_restoration_in, bool integer_index_needed_restoration_in, - bool qualified_id_join_index_needed_restoration_in) + bool qualified_id_join_index_needed_restoration_in, + bool embedding_index_needed_restoration_in) : first_document_to_reindex(first_document_to_reindex_in), index_needed_restoration(index_needed_restoration_in), integer_index_needed_restoration(integer_index_needed_restoration_in), qualified_id_join_index_needed_restoration( - qualified_id_join_index_needed_restoration_in) {} + qualified_id_join_index_needed_restoration_in), + embedding_index_needed_restoration( + embedding_index_needed_restoration_in) {} }; libtextclassifier3::StatusOr<TruncateIndexResult> TruncateIndicesTo( DocumentId last_stored_document_id) diff --git a/icing/icing-search-engine_initialization_test.cc b/icing/icing-search-engine_initialization_test.cc index 122e4af..d6316d4 100644 --- a/icing/icing-search-engine_initialization_test.cc +++ b/icing/icing-search-engine_initialization_test.cc @@ -19,6 +19,7 @@ #include <string> #include <string_view> #include <tuple> +#include <unordered_set> #include <utility> #include <vector> @@ -204,7 +205,7 @@ class IcingSearchEngineInitializationTest : public testing::Test { // Non-zero value so we don't override it to be the current time constexpr int64_t kDefaultCreationTimestampMs = 1575492852000; -std::string GetVersionFilename() { return GetTestBaseDir() + "/version"; } +std::string GetVersionFileDir() { return GetTestBaseDir(); } std::string GetDocumentDir() { return GetTestBaseDir() + "/document_dir"; } @@ -5566,12 +5567,26 @@ INSTANTIATE_TEST_SUITE_P(IcingSearchEngineInitializationSwitchJoinIndexTest, IcingSearchEngineInitializationSwitchJoinIndexTest, testing::Values(true, false)); +struct IcingSearchEngineInitializationVersionChangeTestParam { + version_util::VersionInfo existing_version_info; + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + existing_enabled_features; + + explicit IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo version_info_in, + std::unordered_set<IcingSearchEngineFeatureInfoProto::FlaggedFeatureType> + existing_enabled_features_in) + : existing_version_info(std::move(version_info_in)), + existing_enabled_features(std::move(existing_enabled_features_in)) {} +}; + class IcingSearchEngineInitializationVersionChangeTest : public IcingSearchEngineInitializationTest, - public ::testing::WithParamInterface<version_util::VersionInfo> {}; + public ::testing::WithParamInterface< + IcingSearchEngineInitializationVersionChangeTestParam> {}; TEST_P(IcingSearchEngineInitializationVersionChangeTest, - RecoverFromVersionChange) { + RecoverFromVersionChangeOrUnknownFlagChange) { // TODO(b/280697513): test backup schema migration // Test the following scenario: version change. All derived data should be // rebuilt. We test this by manually adding some invalid derived data and @@ -5725,10 +5740,27 @@ TEST_P(IcingSearchEngineInitializationVersionChangeTest, std::move(incorrect_message))); ICING_ASSERT_OK(index_processor.IndexDocument(tokenized_document, doc_id)); - // Change existing data's version file - const version_util::VersionInfo& existing_version_info = GetParam(); - ICING_ASSERT_OK(version_util::WriteVersion( - *filesystem(), GetVersionFilename(), existing_version_info)); + // Rewrite existing data's version files + ICING_ASSERT_OK( + version_util::DiscardVersionFiles(*filesystem(), GetVersionFileDir())); + const version_util::VersionInfo& existing_version_info = + GetParam().existing_version_info; + ICING_ASSERT_OK(version_util::WriteV1Version( + *filesystem(), GetVersionFileDir(), existing_version_info)); + + if (existing_version_info.version >= version_util::kFirstV2Version) { + IcingSearchEngineVersionProto version_proto; + version_proto.set_version(existing_version_info.version); + version_proto.set_max_version(existing_version_info.max_version); + auto* enabled_features = version_proto.mutable_enabled_features(); + for (const auto& feature : GetParam().existing_enabled_features) { + enabled_features->Add(version_util::GetFeatureInfoProto(feature)); + } + version_util::WriteV2Version( + *filesystem(), GetVersionFileDir(), + std::make_unique<IcingSearchEngineVersionProto>( + std::move(version_proto))); + } } // Mock filesystem to observe and check the behavior of all indices. @@ -5738,28 +5770,48 @@ TEST_P(IcingSearchEngineInitializationVersionChangeTest, std::make_unique<FakeClock>(), GetTestJniCache()); InitializeResultProto initialize_result = icing.Initialize(); EXPECT_THAT(initialize_result.status(), ProtoIsOk()); - // Index Restoration should be triggered here. Incorrect data should be - // deleted and correct data of message should be indexed. + + // Derived files restoration should be triggered here. Incorrect data should + // be deleted and correct data of message should be indexed. + // Here we're recovering from a version change or a flag change that requires + // rebuilding all derived files. + // + // TODO(b/314816301): test individual derived files rebuilds due to change + // in trunk stable feature flags. + // i.e. Test individual rebuilding for each of: + // - document store + // - schema store + // - term index + // - numeric index + // - qualified id join index + InitializeStatsProto::RecoveryCause expected_recovery_cause = + GetParam().existing_version_info.version != version_util::kVersion + ? InitializeStatsProto::VERSION_CHANGED + : InitializeStatsProto::FEATURE_FLAG_CHANGED; EXPECT_THAT( initialize_result.initialize_stats().document_store_recovery_cause(), - Eq(InitializeStatsProto::VERSION_CHANGED)); + Eq(expected_recovery_cause)); + EXPECT_THAT( + initialize_result.initialize_stats().schema_store_recovery_cause(), + Eq(expected_recovery_cause)); EXPECT_THAT(initialize_result.initialize_stats().index_restoration_cause(), - Eq(InitializeStatsProto::VERSION_CHANGED)); + Eq(expected_recovery_cause)); EXPECT_THAT( initialize_result.initialize_stats().integer_index_restoration_cause(), - Eq(InitializeStatsProto::VERSION_CHANGED)); + Eq(expected_recovery_cause)); EXPECT_THAT(initialize_result.initialize_stats() .qualified_id_join_index_restoration_cause(), - Eq(InitializeStatsProto::VERSION_CHANGED)); + Eq(expected_recovery_cause)); // Manually check version file ICING_ASSERT_OK_AND_ASSIGN( - version_util::VersionInfo version_info_after_init, - version_util::ReadVersion(*filesystem(), GetVersionFilename(), + IcingSearchEngineVersionProto version_proto_after_init, + version_util::ReadVersion(*filesystem(), GetVersionFileDir(), GetIndexDir())); - EXPECT_THAT(version_info_after_init.version, Eq(version_util::kVersion)); - EXPECT_THAT(version_info_after_init.max_version, - Eq(std::max(version_util::kVersion, GetParam().max_version))); + EXPECT_THAT(version_proto_after_init.version(), Eq(version_util::kVersion)); + EXPECT_THAT(version_proto_after_init.max_version(), + Eq(std::max(version_util::kVersion, + GetParam().existing_version_info.max_version))); SearchResultProto expected_search_result_proto; expected_search_result_proto.mutable_status()->set_code(StatusProto::OK); @@ -5836,9 +5888,11 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( // Manually change existing data set's version to kVersion + 1. When // initializing, it will detect "rollback". - version_util::VersionInfo( - /*version_in=*/version_util::kVersion + 1, - /*max_version_in=*/version_util::kVersion + 1), + IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo( + /*version_in=*/version_util::kVersion + 1, + /*max_version_in=*/version_util::kVersion + 1), + /*existing_enabled_features_in=*/{}), // Currently we don't have any "upgrade" that requires rebuild derived // files, so skip this case until we have a case for it. @@ -5846,27 +5900,45 @@ INSTANTIATE_TEST_SUITE_P( // Manually change existing data set's version to kVersion - 1 and // max_version to kVersion. When initializing, it will detect "roll // forward". - version_util::VersionInfo( - /*version_in=*/version_util::kVersion - 1, - /*max_version_in=*/version_util::kVersion), + IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo( + /*version_in=*/version_util::kVersion - 1, + /*max_version_in=*/version_util::kVersion), + /*existing_enabled_features_in=*/{}), // Manually change existing data set's version to 0 and max_version to // 0. When initializing, it will detect "version 0 upgrade". // // Note: in reality, version 0 won't be written into version file, but // it is ok here since it is hack to simulate version 0 situation. - version_util::VersionInfo( - /*version_in=*/0, - /*max_version_in=*/0), + IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo( + /*version_in=*/0, + /*max_version_in=*/0), + /*existing_enabled_features_in=*/{}), // Manually change existing data set's version to 0 and max_version to // kVersion. When initializing, it will detect "version 0 roll forward". // // Note: in reality, version 0 won't be written into version file, but // it is ok here since it is hack to simulate version 0 situation. - version_util::VersionInfo( - /*version_in=*/0, - /*max_version_in=*/version_util::kVersion))); + IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo( + /*version_in=*/0, + /*max_version_in=*/version_util::kVersion), + /*existing_enabled_features_in=*/{}), + + // Manually write an unknown feature in the version proto while keeping + // version the same as kVersion. + // + // Result: this will rebuild all derived files with restoration cause + // FEATURE_FLAG_CHANGED + IcingSearchEngineInitializationVersionChangeTestParam( + version_util::VersionInfo( + /*version_in=*/version_util::kVersion, + /*max_version_in=*/version_util::kVersion), + /*existing_enabled_features_in=*/{ + IcingSearchEngineFeatureInfoProto::UNKNOWN}))); class IcingSearchEngineInitializationChangePropertyExistenceHitsFlagTest : public IcingSearchEngineInitializationTest, @@ -5962,7 +6034,7 @@ TEST_P(IcingSearchEngineInitializationChangePropertyExistenceHitsFlagTest, ASSERT_THAT(initialize_result.status(), ProtoIsOk()); // Ensure that the term index is rebuilt if the flag is changed. EXPECT_THAT(initialize_result.initialize_stats().index_restoration_cause(), - Eq(flag_changed ? InitializeStatsProto::IO_ERROR + Eq(flag_changed ? InitializeStatsProto::FEATURE_FLAG_CHANGED : InitializeStatsProto::NONE)); EXPECT_THAT( initialize_result.initialize_stats().integer_index_restoration_cause(), diff --git a/icing/icing-search-engine_schema_test.cc b/icing/icing-search-engine_schema_test.cc index 49c024e..34081e9 100644 --- a/icing/icing-search-engine_schema_test.cc +++ b/icing/icing-search-engine_schema_test.cc @@ -36,7 +36,6 @@ #include "icing/proto/persist.pb.h" #include "icing/proto/reset.pb.h" #include "icing/proto/schema.pb.h" -#include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/status.pb.h" #include "icing/proto/storage.pb.h" @@ -60,6 +59,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::Eq; using ::testing::HasSubstr; +using ::testing::Not; using ::testing::Return; // For mocking purpose, we allow tests to provide a custom Filesystem. @@ -3154,6 +3154,86 @@ TEST_F(IcingSearchEngineSchemaTest, IcingShouldReturnErrorForExtraSections) { HasSubstr("Too many properties to be indexed")); } +TEST_F(IcingSearchEngineSchemaTest, UpdatedTypeDescriptionIsSaved) { + // Create a schema with more sections than allowed. + PropertyConfigProto old_property = + PropertyConfigBuilder() + .SetName("prop0") + .SetDescription("old property description") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); + SchemaTypeConfigProto old_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(old_property) + .Build(); + SchemaProto old_schema = + SchemaBuilder().AddType(old_schema_type_config).Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk()); + + // Update the type description + SchemaTypeConfigProto new_schema_type_config = + SchemaTypeConfigBuilder(old_schema_type_config) + .SetDescription("new description") + .Build(); + SchemaProto new_schema = + SchemaBuilder().AddType(new_schema_type_config).Build(); + ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk()); + + GetSchemaResultProto get_result = icing.GetSchema(); + ASSERT_THAT(get_result.status(), ProtoIsOk()); + ASSERT_THAT(get_result.schema(), EqualsProto(new_schema)); + ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema))); +} + +TEST_F(IcingSearchEngineSchemaTest, UpdatedPropertyDescriptionIsSaved) { + // Create a schema with more sections than allowed. + PropertyConfigProto old_property = + PropertyConfigBuilder() + .SetName("prop0") + .SetDescription("old property description") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); + SchemaTypeConfigProto old_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(old_property) + .Build(); + SchemaProto old_schema = + SchemaBuilder().AddType(old_schema_type_config).Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(old_schema).status(), ProtoIsOk()); + + // Update the property description + PropertyConfigProto new_property = + PropertyConfigBuilder(old_property) + .SetDescription("new property description") + .Build(); + SchemaTypeConfigProto new_schema_type_config = + SchemaTypeConfigBuilder() + .SetType("type") + .SetDescription("old description") + .AddProperty(new_property) + .Build(); + SchemaProto new_schema = + SchemaBuilder().AddType(new_schema_type_config).Build(); + ASSERT_THAT(icing.SetSchema(new_schema).status(), ProtoIsOk()); + + GetSchemaResultProto get_result = icing.GetSchema(); + ASSERT_THAT(get_result.status(), ProtoIsOk()); + ASSERT_THAT(get_result.schema(), EqualsProto(new_schema)); + ASSERT_THAT(get_result.schema(), Not(EqualsProto(old_schema))); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/icing-search-engine_search_test.cc b/icing/icing-search-engine_search_test.cc index 21512c6..a58dbc8 100644 --- a/icing/icing-search-engine_search_test.cc +++ b/icing/icing-search-engine_search_test.cc @@ -13,12 +13,14 @@ // limitations under the License. #include <cstdint> +#include <initializer_list> #include <limits> #include <memory> #include <string> +#include <string_view> #include <utility> +#include <vector> -#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -27,7 +29,7 @@ #include "icing/index/lite/term-id-hit-pair.h" #include "icing/jni/jni-cache.h" #include "icing/join/join-processor.h" -#include "icing/portable/endian.h" +#include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/equals-proto.h" #include "icing/portable/platform.h" #include "icing/proto/debug.pb.h" @@ -49,11 +51,13 @@ #include "icing/result/result-state-manager.h" #include "icing/schema-builder.h" #include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" #include "icing/testing/fake-clock.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/clock.h" #include "icing/util/snippet-helpers.h" namespace icing { @@ -63,6 +67,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::DoubleEq; +using ::testing::DoubleNear; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Gt; @@ -120,6 +125,8 @@ class IcingSearchEngineSearchTest // Non-zero value so we don't override it to be the current time constexpr int64_t kDefaultCreationTimestampMs = 1575492852000; +constexpr double kEps = 0.000001; + IcingSearchEngineOptions GetDefaultIcingOptions() { IcingSearchEngineOptions icing_options; icing_options.set_base_dir(GetTestBaseDir()); @@ -3716,6 +3723,109 @@ TEST_P(IcingSearchEngineSearchTest, SearchWithPropertyFilters) { EXPECT_THAT(results.results(0).document(), EqualsProto(document_one)); } +TEST_P(IcingSearchEngineSearchTest, SearchWithPropertyFiltersPolymorphism) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + .SetType("Artist") + .AddParentType("Person") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("company") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + // Add a person document and an artist document + DocumentProto document_person = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build(); + DocumentProto document_artist = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Artist") + .AddStringProperty("name", "Meg Artist") + .AddStringProperty("emailAddress", "artist@aol.com") + .AddStringProperty("company", "company") + .Build(); + ASSERT_THAT(icing.Put(document_person).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document_artist).status(), ProtoIsOk()); + + // Set a query with property filters of "name" in Person and "emailAddress" + // in Artist. By polymorphism, "name" should also apply to Artist. + auto search_spec = std::make_unique<SearchSpecProto>(); + search_spec->set_term_match_type(TermMatchType::PREFIX); + search_spec->set_search_type(GetParam()); + TypePropertyMask* person_type_property_mask = + search_spec->add_type_property_filters(); + person_type_property_mask->set_schema_type("Person"); + person_type_property_mask->add_paths("name"); + TypePropertyMask* artist_type_property_mask = + search_spec->add_type_property_filters(); + artist_type_property_mask->set_schema_type("Artist"); + artist_type_property_mask->add_paths("emailAddress"); + + auto result_spec = std::make_unique<ResultSpecProto>(); + auto scoring_spec = std::make_unique<ScoringSpecProto>(); + *scoring_spec = GetDefaultScoringSpec(); + + // Verify that the property filter for "name" in Person is also applied to + // Artist. + search_spec->set_query("Meg"); + SearchResultProto results = + icing.Search(*search_spec, *scoring_spec, *result_spec); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document_person)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_artist)); + + // Verify that the property filter for "emailAddress" in Artist is only + // applied to Artist. + search_spec->set_query("aol"); + results = icing.Search(*search_spec, *scoring_spec, *result_spec); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_artist)); + + // Verify that the "company" property is filtered out, since it is not + // specified in the property filter. + search_spec->set_query("company"); + results = icing.Search(*search_spec, *scoring_spec, *result_spec); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), IsEmpty()); +} + TEST_P(IcingSearchEngineSearchTest, EmptySearchWithPropertyFilter) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); @@ -4496,6 +4606,11 @@ TEST_P(IcingSearchEngineSearchTest, QueryStatsProtoTest) { exp_stats.set_document_retrieval_latency_ms(5); exp_stats.set_lock_acquisition_latency_ms(5); exp_stats.set_num_joined_results_returned_current_page(0); + // document4, document5's hits will remain in the lite index (# of hits: 4). + exp_stats.set_lite_index_hit_buffer_byte_size(4 * + sizeof(TermIdHitPair::Value)); + exp_stats.set_lite_index_hit_buffer_unsorted_byte_size( + 4 * sizeof(TermIdHitPair::Value)); QueryStatsProto::SearchStats* exp_parent_search_stats = exp_stats.mutable_parent_search_stats(); @@ -4511,6 +4626,14 @@ TEST_P(IcingSearchEngineSearchTest, QueryStatsProtoTest) { exp_parent_search_stats->set_num_fetched_hits_lite_index(2); exp_parent_search_stats->set_num_fetched_hits_main_index(3); exp_parent_search_stats->set_num_fetched_hits_integer_index(0); + if (GetParam() == + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms( + 5); + exp_parent_search_stats + ->set_query_processor_parser_consume_query_latency_ms(5); + exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5); + } EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats)); @@ -4769,6 +4892,11 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) { exp_stats.set_num_joined_results_returned_current_page(3); exp_stats.set_join_latency_ms(5); exp_stats.set_is_join_query(true); + // person3, email4's hits will remain in the lite index (# of hits: 5). + exp_stats.set_lite_index_hit_buffer_byte_size(5 * + sizeof(TermIdHitPair::Value)); + exp_stats.set_lite_index_hit_buffer_unsorted_byte_size( + 5 * sizeof(TermIdHitPair::Value)); QueryStatsProto::SearchStats* exp_parent_search_stats = exp_stats.mutable_parent_search_stats(); @@ -4784,6 +4912,14 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) { exp_parent_search_stats->set_num_fetched_hits_lite_index(1); exp_parent_search_stats->set_num_fetched_hits_main_index(2); exp_parent_search_stats->set_num_fetched_hits_integer_index(0); + if (GetParam() == + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms( + 5); + exp_parent_search_stats + ->set_query_processor_parser_consume_query_latency_ms(5); + exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5); + } QueryStatsProto::SearchStats* exp_child_search_stats = exp_stats.mutable_child_search_stats(); @@ -4799,6 +4935,14 @@ TEST_P(IcingSearchEngineSearchTest, JoinQueryStatsProtoTest) { exp_child_search_stats->set_num_fetched_hits_lite_index(1); exp_child_search_stats->set_num_fetched_hits_main_index(3); exp_child_search_stats->set_num_fetched_hits_integer_index(0); + if (GetParam() == + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + exp_child_search_stats->set_query_processor_lexer_extract_token_latency_ms( + 5); + exp_child_search_stats->set_query_processor_parser_consume_query_latency_ms( + 5); + exp_child_search_stats->set_query_processor_query_visitor_latency_ms(5); + } EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats)); @@ -6638,6 +6782,8 @@ TEST_F(IcingSearchEngineSearchTest, NumericFilterQueryStatsProtoTest) { exp_stats.set_document_retrieval_latency_ms(5); exp_stats.set_lock_acquisition_latency_ms(5); exp_stats.set_num_joined_results_returned_current_page(0); + exp_stats.set_lite_index_hit_buffer_byte_size(0); + exp_stats.set_lite_index_hit_buffer_unsorted_byte_size(0); QueryStatsProto::SearchStats* exp_parent_search_stats = exp_stats.mutable_parent_search_stats(); @@ -6656,6 +6802,11 @@ TEST_F(IcingSearchEngineSearchTest, NumericFilterQueryStatsProtoTest) { // Since we will inspect 1 bucket from "price" in integer index and it // contains 3 hits, we will fetch 3 hits (but filter out one of them). exp_parent_search_stats->set_num_fetched_hits_integer_index(3); + exp_parent_search_stats->set_query_processor_lexer_extract_token_latency_ms( + 5); + exp_parent_search_stats->set_query_processor_parser_consume_query_latency_ms( + 5); + exp_parent_search_stats->set_query_processor_query_visitor_latency_ms(5); EXPECT_THAT(results.query_stats(), EqualsProto(exp_stats)); } @@ -7162,6 +7313,208 @@ TEST_P(IcingSearchEngineSearchTest, HasPropertyQueryNestedDocument) { EXPECT_THAT(results.results(), IsEmpty()); } +TEST_P(IcingSearchEngineSearchTest, EmbeddingSearch) { + if (GetParam() != + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + GTEST_SKIP() << "Embedding search is only supported in advanced query."; + } + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName("embedding1") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName("embedding2") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED))) + .Build(); + DocumentProto document0 = + DocumentBuilder() + .SetKey("icing", "uri0") + .SetSchema("Email") + .SetCreationTimestampMs(1) + .AddStringProperty("body", "foo") + .AddVectorProperty( + "embedding1", + CreateVector("my_model_v1", {0.1, 0.2, 0.3, 0.4, 0.5})) + .AddVectorProperty( + "embedding2", + CreateVector("my_model_v1", {-0.1, -0.2, -0.3, 0.4, 0.5}), + CreateVector("my_model_v2", {0.6, 0.7, 0.8})) + .Build(); + DocumentProto document1 = + DocumentBuilder() + .SetKey("icing", "uri1") + .SetSchema("Email") + .SetCreationTimestampMs(1) + .AddVectorProperty( + "embedding1", + CreateVector("my_model_v1", {-0.1, 0.2, -0.3, -0.4, 0.5})) + .AddVectorProperty("embedding2", + CreateVector("my_model_v2", {0.6, 0.7, -0.8})) + .Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document0).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_embedding_query_metric_type( + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT); + search_spec.add_enabled_features( + std::string(kListFilterQueryLanguageFeature)); + search_spec.add_enabled_features(std::string(kEmbeddingSearchFeature)); + search_spec.set_search_type(GetParam()); + // Add an embedding query with semantic scores: + // - document 0: -0.5 (embedding1), 0.3 (embedding2) + // - document 1: -0.9 (embedding1) + *search_spec.add_embedding_query_vectors() = + CreateVector("my_model_v1", {1, -1, -1, 1, -1}); + // Add an embedding query with semantic scores: + // - document 0: -0.5 (embedding2) + // - document 1: -2.1 (embedding2) + *search_spec.add_embedding_query_vectors() = + CreateVector("my_model_v2", {-1, -1, 1}); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + + // Match documents that have embeddings with a similarity closer to 0 that is + // greater than -1. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1), 0.3 (embedding2) + // - document 1: -0.9 (embedding1) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5, 0.3}) + sum({}) = -0.2 + // - document 1: sum({-0.9}) + sum({}) = -0.9 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(0), -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + SearchResultProto results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps)); + + // Create a query the same as above but with a section restriction, which + // still matches document 0 and document 1 but the semantic score 0.3 should + // be removed from document 0. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1) + // - document 1: -0.9 (embedding1) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5}) = -0.5 + // - document 1: sum({-0.9}) = -0.9 + search_spec.set_query( + "embedding1:semanticSearch(getSearchSpecEmbedding(0), -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps)); + + // Create a query that only matches document 0. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5}) = -0.5 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -1.5)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps)); + + // Create a query that only matches document 1. + // + // The matched embeddings for each doc are: + // - document 1: -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 1: sum({-2.1}) = -2.1 + search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -10, -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-2.1, kEps)); + + // Create a complex query that matches all hits from all documents. + // + // The matched embeddings for each doc are: + // - document 0: -0.5 (embedding1), 0.3 (embedding2), -0.5 (embedding2) + // - document 1: -0.9 (embedding1), -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({-0.5, 0.3}) + sum({-0.5}) = -0.7 + // - document 1: sum({-0.9}) + sum({-2.1}) = -3 + search_spec.set_query( + "semanticSearch(getSearchSpecEmbedding(0)) OR " + "semanticSearch(getSearchSpecEmbedding(1))"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3 - 0.5, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9 - 2.1, kEps)); + + // Create a hybrid query that matches document 0 because of term-based search + // and document 1 because of embedding-based search. + // + // The matched embeddings for each doc are: + // - document 1: -2.1 (embedding2) + // The scoring expression for each doc will be evaluated as: + // - document 0: sum({}) = 0 + // - document 1: sum({-2.1}) = -2.1 + search_spec.set_query( + "foo OR semanticSearch(getSearchSpecEmbedding(1), -10, -1)"); + scoring_spec.set_advanced_scoring_expression( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"); + results = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoIsOk()); + EXPECT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document0)); + // Document 0 has no matched embedding hit, so its score is 0. + EXPECT_THAT(results.results(0).score(), DoubleNear(0, kEps)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document1)); + EXPECT_THAT(results.results(1).score(), DoubleNear(-2.1, kEps)); +} + INSTANTIATE_TEST_SUITE_P( IcingSearchEngineSearchTest, IcingSearchEngineSearchTest, testing::Values( diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.cc b/icing/index/embed/doc-hit-info-iterator-embedding.cc new file mode 100644 index 0000000..8f2bc7d --- /dev/null +++ b/icing/index/embed/doc-hit-info-iterator-embedding.cc @@ -0,0 +1,164 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/doc-hit-info-iterator-embedding.h" + +#include <memory> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/embed/embedding-scorer.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" +#include "icing/index/iterator/section-restrict-data.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIteratorEmbedding>> +DocHitInfoIteratorEmbedding::Create( + const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + double score_low, double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index) { + ICING_RETURN_ERROR_IF_NULL(query); + ICING_RETURN_ERROR_IF_NULL(embedding_index); + ICING_RETURN_ERROR_IF_NULL(score_map); + + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = embedding_index->GetAccessorForVector(*query); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + // A not-found error should be fine, since that means there is no matching + // embedding hits in the index. + pl_accessor = nullptr; + } else { + // Otherwise, return the error as is. + return pl_accessor_or.status(); + } + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<EmbeddingScorer> embedding_scorer, + EmbeddingScorer::Create(metric_type)); + + return std::unique_ptr<DocHitInfoIteratorEmbedding>( + new DocHitInfoIteratorEmbedding(query, std::move(section_restrict_data), + metric_type, std::move(embedding_scorer), + score_low, score_high, score_map, + embedding_index, std::move(pl_accessor))); +} + +libtextclassifier3::StatusOr<const EmbeddingHit*> +DocHitInfoIteratorEmbedding::AdvanceToNextEmbeddingHit() { + if (cached_embedding_hits_idx_ == cached_embedding_hits_.size()) { + ICING_ASSIGN_OR_RETURN(cached_embedding_hits_, + posting_list_accessor_->GetNextHitsBatch()); + cached_embedding_hits_idx_ = 0; + if (cached_embedding_hits_.empty()) { + no_more_hit_ = true; + return nullptr; + } + } + const EmbeddingHit& embedding_hit = + cached_embedding_hits_[cached_embedding_hits_idx_]; + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + doc_hit_info_.set_document_id(embedding_hit.basic_hit().document_id()); + if (section_restrict_data_ != nullptr) { + current_allowed_sections_mask_ = + section_restrict_data_->ComputeAllowedSectionsMask( + doc_hit_info_.document_id()); + } + } else if (doc_hit_info_.document_id() != + embedding_hit.basic_hit().document_id()) { + return nullptr; + } + ++cached_embedding_hits_idx_; + return &embedding_hit; +} + +libtextclassifier3::Status DocHitInfoIteratorEmbedding::Advance() { + if (no_more_hit_ || posting_list_accessor_ == nullptr) { + return absl_ports::ResourceExhaustedError( + "No more DocHitInfos in iterator"); + } + + doc_hit_info_ = DocHitInfo(kInvalidDocumentId, kSectionIdMaskNone); + std::vector<double>* matched_scores = nullptr; + current_allowed_sections_mask_ = kSectionIdMaskAll; + while (true) { + ICING_ASSIGN_OR_RETURN(const EmbeddingHit* embedding_hit, + AdvanceToNextEmbeddingHit()); + if (embedding_hit == nullptr) { + // No more hits for the current document. + break; + } + + // Filter out the embedding hit according to the section restriction. + if (((UINT64_C(1) << embedding_hit->basic_hit().section_id()) & + current_allowed_sections_mask_) == 0) { + continue; + } + + // Calculate the semantic score. + int dimension = query_.values_size(); + ICING_ASSIGN_OR_RETURN( + const float* vector, + embedding_index_.GetEmbeddingVector(*embedding_hit, dimension)); + double semantic_score = + embedding_scorer_->Score(dimension, + /*v1=*/query_.values().data(), + /*v2=*/vector); + + // If the semantic score is within the desired score range, update + // doc_hit_info_ and score_map_. + if (score_low_ <= semantic_score && semantic_score <= score_high_) { + doc_hit_info_.UpdateSection(embedding_hit->basic_hit().section_id()); + if (matched_scores == nullptr) { + matched_scores = &(score_map_[doc_hit_info_.document_id()]); + } + matched_scores->push_back(semantic_score); + } + } + + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + return absl_ports::ResourceExhaustedError( + "No more DocHitInfos in iterator"); + } + + ++num_advance_calls_; + + // Skip the current document if it has no vector matched. + if (doc_hit_info_.hit_section_ids_mask() == kSectionIdMaskNone) { + return Advance(); + } + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.h b/icing/index/embed/doc-hit-info-iterator-embedding.h new file mode 100644 index 0000000..21180db --- /dev/null +++ b/icing/index/embed/doc-hit-info-iterator-embedding.h @@ -0,0 +1,161 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ +#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ + +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/embed/embedding-scorer.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/iterator/section-restrict-data.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" + +namespace icing { +namespace lib { + +class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator { + public: + // Create a DocHitInfoIterator for iterating through all docs which have an + // embedding matched with the provided query with a score in the range of + // [score_low, score_high], using the provided metric_type. + // + // The iterator will store the matched embedding scores in score_map to + // prepare for scoring. + // + // The iterator will handle the section restriction logic internally by the + // provided section_restrict_data. + // + // Returns: + // - a DocHitInfoIteratorEmbedding instance on success. + // - Any error from posting lists. + static libtextclassifier3::StatusOr< + std::unique_ptr<DocHitInfoIteratorEmbedding>> + Create(const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + double score_low, double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index); + + libtextclassifier3::Status Advance() override; + + // The iterator will internally handle the section restriction logic by itself + // to have better control, so that it is able to filter out embedding hits + // from unwanted sections to avoid retrieving unnecessary vectors and + // calculate scores for them. + bool full_section_restriction_applied() const override { return true; } + + libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override { + return absl_ports::InvalidArgumentError( + "Query suggestions for the semanticSearch function are not supported"); + } + + CallStats GetCallStats() const override { + return CallStats( + /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_, + /*num_leaf_advance_calls_main_index_in=*/0, + /*num_leaf_advance_calls_integer_index_in=*/0, + /*num_leaf_advance_calls_no_index_in=*/0, + /*num_blocks_inspected_in=*/0); + } + + std::string ToString() const override { return "embedding_iterator"; } + + // PopulateMatchedTermsStats is not applicable to embedding search. + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats, + SectionIdMask filtering_section_mask) const override {} + + private: + explicit DocHitInfoIteratorEmbedding( + const PropertyProto::VectorProto* query, + std::unique_ptr<SectionRestrictData> section_restrict_data, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low, + double score_high, + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, + const EmbeddingIndex* embedding_index, + std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor) + : query_(*query), + section_restrict_data_(std::move(section_restrict_data)), + metric_type_(metric_type), + embedding_scorer_(std::move(embedding_scorer)), + score_low_(score_low), + score_high_(score_high), + score_map_(*score_map), + embedding_index_(*embedding_index), + posting_list_accessor_(std::move(posting_list_accessor)), + cached_embedding_hits_idx_(0), + current_allowed_sections_mask_(kSectionIdMaskAll), + no_more_hit_(false), + num_advance_calls_(0) {} + + // Advance to the next embedding hit of the current document. If the current + // document id is kInvalidDocumentId, the method will advance to the first + // embedding hit of the next document and update doc_hit_info_. + // + // This method also properly updates cached_embedding_hits_, + // cached_embedding_hits_idx_, current_allowed_sections_mask_, and + // no_more_hit_ to reflect the current state. + // + // Returns: + // - a const pointer to the next embedding hit on success. + // - nullptr, if there is no more hit for the current document, or no more + // hit in general if the current document id is kInvalidDocumentId. + // - Any error from posting lists. + libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit(); + + // Query information + const PropertyProto::VectorProto& query_; // Does not own + std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable. + + // Scoring arguments + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; + std::unique_ptr<EmbeddingScorer> embedding_scorer_; + double score_low_; + double score_high_; + + // Score map + EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own + + // Access to embeddings index data + const EmbeddingIndex& embedding_index_; + std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_; + + // Cached data from the embeddings index + std::vector<EmbeddingHit> cached_embedding_hits_; + int cached_embedding_hits_idx_; + SectionIdMask current_allowed_sections_mask_; + bool no_more_hit_; + + int num_advance_calls_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ diff --git a/icing/index/embed/embedding-hit.h b/icing/index/embed/embedding-hit.h new file mode 100644 index 0000000..165a123 --- /dev/null +++ b/icing/index/embed/embedding-hit.h @@ -0,0 +1,67 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_HIT_H_ +#define ICING_INDEX_EMBED_EMBEDDING_HIT_H_ + +#include <cstdint> + +#include "icing/index/hit/hit.h" + +namespace icing { +namespace lib { + +class EmbeddingHit { + public: + // Order: basic_hit, location + // Value bits layout: 32 basic_hit + 32 location + using Value = uint64_t; + + // WARNING: Changing this value will invalidate any pre-existing posting lists + // on user devices. + // + // kInvalidValue contains: + // - BasicHit of value 0, which is invalid. + // - location of 0 (valid), which is ok because BasicHit is already invalid. + static_assert(BasicHit::kInvalidValue == 0); + static constexpr Value kInvalidValue = 0; + + explicit EmbeddingHit(BasicHit basic_hit, uint32_t location) { + value_ = (static_cast<uint64_t>(basic_hit.value()) << 32) | location; + } + + explicit EmbeddingHit(Value value) : value_(value) {} + + // BasicHit contains the document id and the section id. + BasicHit basic_hit() const { return BasicHit(value_ >> 32); } + // The location of the referred embedding vector in the vector storage. + uint32_t location() const { return value_ & 0xFFFFFFFF; }; + + bool is_valid() const { return basic_hit().is_valid(); } + Value value() const { return value_; } + + bool operator<(const EmbeddingHit& h2) const { return value_ < h2.value_; } + bool operator==(const EmbeddingHit& h2) const { return value_ == h2.value_; } + + private: + Value value_; +}; + +static_assert(sizeof(BasicHit) == 4); +static_assert(sizeof(EmbeddingHit) == 8); + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_HIT_H_ diff --git a/icing/index/embed/embedding-hit_test.cc b/icing/index/embed/embedding-hit_test.cc new file mode 100644 index 0000000..31a0b6c --- /dev/null +++ b/icing/index/embed/embedding-hit_test.cc @@ -0,0 +1,80 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-hit.h" + +#include <algorithm> +#include <cstdint> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/index/hit/hit.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsFalse; + +static constexpr DocumentId kSomeDocumentId = 24; +static constexpr SectionId kSomeSectionid = 5; +static constexpr uint32_t kSomeLocation = 123; + +TEST(EmbeddingHitTest, Accessors) { + BasicHit basic_hit(kSomeSectionid, kSomeDocumentId); + EmbeddingHit embedding_hit(basic_hit, kSomeLocation); + EXPECT_THAT(embedding_hit.basic_hit(), Eq(basic_hit)); + EXPECT_THAT(embedding_hit.location(), Eq(kSomeLocation)); +} + +TEST(EmbeddingHitTest, Invalid) { + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EXPECT_THAT(invalid_hit.is_valid(), IsFalse()); + + // Also make sure the invalid EmbeddingHit contains an invalid document id. + EXPECT_THAT(invalid_hit.basic_hit().document_id(), Eq(kInvalidDocumentId)); + EXPECT_THAT(invalid_hit.basic_hit().section_id(), Eq(kMinSectionId)); + EXPECT_THAT(invalid_hit.location(), Eq(0)); +} + +TEST(EmbeddingHitTest, Comparison) { + // Create basic hits with basic_hit1 < basic_hit2 < basic_hit3. + BasicHit basic_hit1(/*section_id=*/1, /*document_id=*/2409); + BasicHit basic_hit2(/*section_id=*/1, /*document_id=*/243); + BasicHit basic_hit3(/*section_id=*/15, /*document_id=*/243); + + // Embedding hits are sorted by BasicHit first, and then by location. + // So embedding_hit3 < embedding_hit4 < embedding_hit2 < embedding_hit1. + EmbeddingHit embedding_hit1(basic_hit3, /*location=*/10); + EmbeddingHit embedding_hit2(basic_hit3, /*location=*/0); + EmbeddingHit embedding_hit3(basic_hit1, /*location=*/100); + EmbeddingHit embedding_hit4(basic_hit2, /*location=*/0); + + std::vector<EmbeddingHit> embedding_hits{embedding_hit1, embedding_hit2, + embedding_hit3, embedding_hit4}; + std::sort(embedding_hits.begin(), embedding_hits.end()); + EXPECT_THAT(embedding_hits, ElementsAre(embedding_hit3, embedding_hit4, + embedding_hit2, embedding_hit1)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-index.cc b/icing/index/embed/embedding-index.cc new file mode 100644 index 0000000..2e70f0e --- /dev/null +++ b/icing/index/embed/embedding-index.cc @@ -0,0 +1,440 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-index.h" + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/store/document-id.h" +#include "icing/store/dynamic-trie-key-mapper.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" +#include "icing/util/encode-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +constexpr uint32_t kEmbeddingHitListMapperMaxSize = + 128 * 1024 * 1024; // 128 MiB; + +// The maximum length returned by encode_util::EncodeIntToCString is 5 for +// uint32_t. +constexpr uint32_t kEncodedDimensionLength = 5; + +std::string GetMetadataFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/metadata"); +} + +std::string GetFlashIndexStorageFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/flash_index_storage"); +} + +std::string GetEmbeddingHitListMapperPath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper"); +} + +std::string GetEmbeddingVectorsFilePath(std::string_view working_path) { + return absl_ports::StrCat(working_path, "/embedding_vectors"); +} + +// An injective function that maps the ordered pair (dimension, model_signature) +// to a string, which is used to form a key for embedding_posting_list_mapper_. +std::string GetPostingListKey(uint32_t dimension, + std::string_view model_signature) { + std::string encoded_dimension_str = + encode_util::EncodeIntToCString(dimension); + // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes. + while (encoded_dimension_str.size() < kEncodedDimensionLength) { + // C string cannot contain 0 bytes, so we append it using 1, just like what + // we do in encode_util::EncodeIntToCString. + // + // The reason that this works is because DecodeIntToString decodes a byte + // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded + // dimension that is less than 5 bytes, it means that the dimension contains + // unencoded leading 0x00. So here we're explicitly encoding those bytes as + // 0x01. + encoded_dimension_str.push_back(1); + } + return absl_ports::StrCat(encoded_dimension_str, model_signature); +} + +std::string GetPostingListKey(const PropertyProto::VectorProto& vector) { + return GetPostingListKey(vector.values_size(), vector.model_signature()); +} + +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> +EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) { + ICING_RETURN_ERROR_IF_NULL(filesystem); + + std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>( + new EmbeddingIndex(*filesystem, std::move(working_path))); + ICING_RETURN_IF_ERROR(index->Initialize()); + return index; +} + +libtextclassifier3::Status EmbeddingIndex::Initialize() { + bool is_new = false; + if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) { + // Create working directory. + if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) { + return absl_ports::InternalError( + absl_ports::StrCat("Failed to create directory: ", working_path_)); + } + is_new = true; + } + + ICING_ASSIGN_OR_RETURN( + MemoryMappedFile metadata_mmapped_file, + MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/kMetadataFileSize, + /*pre_mapping_file_offset=*/0, + /*pre_mapping_mmap_size=*/kMetadataFileSize)); + metadata_mmapped_file_ = + std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file)); + + ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage, + FlashIndexStorage::Create( + GetFlashIndexStorageFilePath(working_path_), + &filesystem_, posting_list_hit_serializer_.get())); + flash_index_storage_ = + std::make_unique<FlashIndexStorage>(std::move(flash_index_storage)); + + ICING_ASSIGN_OR_RETURN( + embedding_posting_list_mapper_, + DynamicTrieKeyMapper<PostingListIdentifier>::Create( + filesystem_, GetEmbeddingHitListMapperPath(working_path_), + kEmbeddingHitListMapperMaxSize)); + + ICING_ASSIGN_OR_RETURN( + embedding_vectors_, + FileBackedVector<float>::Create( + filesystem_, GetEmbeddingVectorsFilePath(working_path_), + MemoryMappedFile::READ_WRITE_AUTO_SYNC)); + + if (is_new) { + ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary( + /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize)); + info().magic = Info::kMagic; + info().last_added_document_id = kInvalidDocumentId; + ICING_RETURN_IF_ERROR(InitializeNewStorage()); + } else { + if (metadata_mmapped_file_->available_size() != kMetadataFileSize) { + return absl_ports::FailedPreconditionError( + "Incorrect metadata file size"); + } + if (info().magic != Info::kMagic) { + return absl_ports::FailedPreconditionError("Incorrect magic value"); + } + ICING_RETURN_IF_ERROR(InitializeExistingStorage()); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::Clear() { + pending_embedding_hits_.clear(); + metadata_mmapped_file_.reset(); + flash_index_storage_.reset(); + embedding_posting_list_mapper_.reset(); + embedding_vectors_.reset(); + if (filesystem_.DirectoryExists(working_path_.c_str())) { + ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_)); + } + is_initialized_ = false; + return Initialize(); +} + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +EmbeddingIndex::GetAccessor(uint32_t dimension, + std::string_view model_signature) const { + if (dimension == 0) { + return absl_ports::InvalidArgumentError("Dimension is 0"); + } + std::string key = GetPostingListKey(dimension, model_signature); + ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id, + embedding_posting_list_mapper_->Get(key)); + return PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + posting_list_id); +} + +libtextclassifier3::Status EmbeddingIndex::BufferEmbedding( + const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) { + if (vector.values_size() == 0) { + return absl_ports::InvalidArgumentError("Vector dimension is 0"); + } + + uint32_t location = embedding_vectors_->num_elements(); + uint32_t dimension = vector.values_size(); + std::string key = GetPostingListKey(vector); + + // Buffer the embedding hit. + pending_embedding_hits_.push_back( + {std::move(key), EmbeddingHit(basic_hit, location)}); + + // Put vector + ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr, + embedding_vectors_->Allocate(dimension)); + mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() { + std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end()); + auto iter_curr_key = pending_embedding_hits_.rbegin(); + while (iter_curr_key != pending_embedding_hits_.rend()) { + // In order to batch putting embedding hits with the same key (dimension, + // model_signature) to the same posting list, we find the range + // [iter_curr_key, iter_next_key) of embedding hits with the same key and + // put them into their corresponding posting list together. + auto iter_next_key = iter_curr_key; + while (iter_next_key != pending_embedding_hits_.rend() && + iter_next_key->first == iter_curr_key->first) { + iter_next_key++; + } + + const std::string& key = iter_curr_key->first; + libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or = + embedding_posting_list_mapper_->Get(key); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (posting_list_id_or.ok()) { + // Existing posting list. + ICING_ASSIGN_OR_RETURN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + posting_list_id_or.ValueOrDie())); + } else if (absl_ports::IsNotFound(posting_list_id_or.status())) { + // New posting list. + ICING_ASSIGN_OR_RETURN( + pl_accessor, + PostingListEmbeddingHitAccessor::Create( + flash_index_storage_.get(), posting_list_hit_serializer_.get())); + } else { + // Errors + return std::move(posting_list_id_or).status(); + } + + // Adding the embedding hits. + for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) { + ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second)); + } + + // Finalize this posting list and add the posting list id in + // embedding_posting_list_mapper_. + PostingListEmbeddingHitAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + if (!result.id.is_valid()) { + return absl_ports::InternalError("Failed to finalize posting list"); + } + ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id)); + + // Advance to the next key. + iter_curr_key = iter_next_key; + } + pending_embedding_hits_.clear(); + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + EmbeddingIndex* new_index) const { + std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr = + embedding_posting_list_mapper_->GetIterator(); + while (itr->Advance()) { + std::string_view key = itr->GetKey(); + // This should never happen unless there is an inconsistency, or the index + // is corrupted. + if (key.size() < kEncodedDimensionLength) { + return absl_ports::InternalError( + "Got invalid key from embedding posting list mapper."); + } + uint32_t dimension = encode_util::DecodeIntFromCString( + std::string_view(key.begin(), kEncodedDimensionLength)); + + // Transfer hits + std::vector<EmbeddingHit> new_hits; + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_hit_serializer_.get(), + /*existing_posting_list_id=*/itr->GetValue())); + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + old_pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + break; + } + for (EmbeddingHit& old_hit : batch) { + // Safety checks to add robustness to the codebase, so to make sure + // that we never access invalid memory, in case that hit from the + // posting list is corrupted. + ICING_ASSIGN_OR_RETURN(const float* old_vector, + GetEmbeddingVector(old_hit, dimension)); + if (old_hit.basic_hit().document_id() < 0 || + old_hit.basic_hit().document_id() >= + document_id_old_to_new.size()) { + return absl_ports::InternalError( + "Embedding hit document id is out of bound. The provided map is " + "too small, or the index may have been corrupted."); + } + + // Construct transferred hit + DocumentId new_document_id = + document_id_old_to_new[old_hit.basic_hit().document_id()]; + if (new_document_id == kInvalidDocumentId) { + continue; + } + uint32_t new_location = new_index->embedding_vectors_->num_elements(); + new_hits.push_back(EmbeddingHit( + BasicHit(old_hit.basic_hit().section_id(), new_document_id), + new_location)); + + // Copy the embedding vector of the hit to the new index. + ICING_ASSIGN_OR_RETURN( + FileBackedVector<float>::MutableArrayView mutable_arr, + new_index->embedding_vectors_->Allocate(dimension)); + mutable_arr.SetArray(/*idx=*/0, old_vector, dimension); + } + } + // No hit needs to be added to the new index. + if (new_hits.empty()) { + return libtextclassifier3::Status::OK; + } + // Add transferred hits to the new index. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum, + PostingListEmbeddingHitAccessor::Create( + new_index->flash_index_storage_.get(), + new_index->posting_list_hit_serializer_.get())); + for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend(); + ++new_hit_itr) { + ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr)); + } + PostingListEmbeddingHitAccessor::FinalizeResult result = + std::move(*hit_accum).Finalize(); + if (!result.id.is_valid()) { + return absl_ports::InternalError("Failed to finalize posting list"); + } + ICING_RETURN_IF_ERROR( + new_index->embedding_posting_list_mapper_->Put(key, result.id)); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status EmbeddingIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id) { + // This is just for completeness, but this should never be necessary, since we + // should never have pending hits at the time when Optimize is run. + ICING_RETURN_IF_ERROR(CommitBufferToIndex()); + + std::string temporary_index_working_path = working_path_ + "_temp"; + if (!filesystem_.DeleteDirectoryRecursively( + temporary_index_working_path.c_str())) { + ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path; + return absl_ports::InternalError( + "Unable to delete temp directory to prepare to build new index."); + } + + DestructibleDirectory temporary_index_dir( + &filesystem_, std::move(temporary_index_working_path)); + if (!temporary_index_dir.is_valid()) { + return absl_ports::InternalError( + "Unable to create temp directory to build new index."); + } + + { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<EmbeddingIndex> new_index, + EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir())); + ICING_RETURN_IF_ERROR( + TransferIndex(document_id_old_to_new, new_index.get())); + new_index->set_last_added_document_id(new_last_added_document_id); + ICING_RETURN_IF_ERROR(new_index->PersistToDisk()); + } + + // Destruct current storage instances to safely swap directories. + metadata_mmapped_file_.reset(); + flash_index_storage_.reset(); + embedding_posting_list_mapper_.reset(); + embedding_vectors_.reset(); + + if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(), + working_path_.c_str())) { + return absl_ports::InternalError( + "Unable to apply new index due to failed swap!"); + } + + // Reinitialize the index. + is_initialized_ = false; + return Initialize(); +} + +libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) { + return metadata_mmapped_file_->PersistToDisk(); +} + +libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) { + if (!flash_index_storage_->PersistToDisk()) { + return absl_ports::InternalError("Fail to persist flash index to disk"); + } + ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk()); + ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk()); + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum( + bool force) { + return info().ComputeChecksum(); +} + +libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum( + bool force) { + ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc, + embedding_posting_list_mapper_->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc, + embedding_vectors_->ComputeChecksum()); + return Crc32(embedding_posting_list_mapper_crc.Get() ^ + embedding_vectors_crc.Get()); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-index.h b/icing/index/embed/embedding-index.h new file mode 100644 index 0000000..7318871 --- /dev/null +++ b/icing/index/embed/embedding-index.h @@ -0,0 +1,274 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ +#define ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/persistent-storage.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/index/hit/hit.h" +#include "icing/store/document-id.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +class EmbeddingIndex : public PersistentStorage { + public: + struct Info { + static constexpr int32_t kMagic = 0xfbe13cbb; + + int32_t magic; + DocumentId last_added_document_id; + + Crc32 ComputeChecksum() const { + return Crc32( + std::string_view(reinterpret_cast<const char*>(this), sizeof(Info))); + } + } __attribute__((packed)); + static_assert(sizeof(Info) == 8, ""); + + // Metadata file layout: <Crcs><Info> + static constexpr int32_t kCrcsMetadataBufferOffset = 0; + static constexpr int32_t kInfoMetadataBufferOffset = + static_cast<int32_t>(sizeof(Crcs)); + static constexpr int32_t kMetadataFileSize = sizeof(Crcs) + sizeof(Info); + static_assert(kMetadataFileSize == 20, ""); + + static constexpr WorkingPathType kWorkingPathType = + WorkingPathType::kDirectory; + + EmbeddingIndex(const EmbeddingIndex&) = delete; + EmbeddingIndex& operator=(const EmbeddingIndex&) = delete; + + // Creates a new EmbeddingIndex instance to index embeddings. + // + // Returns: + // - FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored + // checksum. + // - INTERNAL_ERROR on I/O errors. + // - Any error from MemoryMappedFile, FlashIndexStorage, + // DynamicTrieKeyMapper, or FileBackedVector. + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> Create( + const Filesystem* filesystem, std::string working_path); + + static libtextclassifier3::Status Discard(const Filesystem& filesystem, + const std::string& working_path) { + return PersistentStorage::Discard(filesystem, working_path, + kWorkingPathType); + } + + libtextclassifier3::Status Clear(); + + // Buffer an embedding pending to be added to the index. This is required + // since EmbeddingHits added in posting lists must be decreasing, which means + // that section ids and location indexes for a single document must be added + // decreasingly. + // + // Returns: + // - OK on success + // - INVALID_ARGUMENT error if the dimension is 0. + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status BufferEmbedding( + const BasicHit& basic_hit, const PropertyProto::VectorProto& vector); + + // Commit the embedding hits in the buffer to the index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + // - Any error from posting lists + libtextclassifier3::Status CommitBufferToIndex(); + + // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match + // with the provided dimension and signature. + // + // Returns: + // - a PostingListEmbeddingHitAccessor instance on success. + // - INVALID_ARGUMENT error if the dimension is 0. + // - NOT_FOUND error if there is no matching embedding hit. + // - Any error from posting lists. + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + GetAccessor(uint32_t dimension, std::string_view model_signature) const; + + // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match + // with the provided vector's dimension and signature. + // + // Returns: + // - a PostingListEmbeddingHitAccessor instance on success. + // - INVALID_ARGUMENT error if the dimension is 0. + // - NOT_FOUND error if there is no matching embedding hit. + // - Any error from posting lists. + libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> + GetAccessorForVector(const PropertyProto::VectorProto& vector) const { + return GetAccessor(vector.values_size(), vector.model_signature()); + } + + // Reduces internal file sizes by reclaiming space of deleted documents. + // new_last_added_document_id will be used to update the last added document + // id in the lite index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id); + + libtextclassifier3::StatusOr<const float*> GetEmbeddingVector( + const EmbeddingHit& hit, uint32_t dimension) const { + if (static_cast<int64_t>(hit.location()) + dimension > + GetTotalVectorSize()) { + return absl_ports::InternalError( + "Got an embedding hit that refers to a vector out of range."); + } + return embedding_vectors_->array() + hit.location(); + } + + const float* GetRawEmbeddingData() const { + return embedding_vectors_->array(); + } + + int32_t GetTotalVectorSize() const { + return embedding_vectors_->num_elements(); + } + + DocumentId last_added_document_id() const { + return info().last_added_document_id; + } + + void set_last_added_document_id(DocumentId document_id) { + Info& info_ref = info(); + if (info_ref.last_added_document_id == kInvalidDocumentId || + document_id > info_ref.last_added_document_id) { + info_ref.last_added_document_id = document_id; + } + } + + private: + explicit EmbeddingIndex(const Filesystem& filesystem, + std::string working_path) + : PersistentStorage(filesystem, std::move(working_path), + kWorkingPathType) {} + + libtextclassifier3::Status Initialize(); + + // Transfers embedding data and hits from the current index to new_index. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error. This could potentially leave the storages + // in an invalid state and the caller should handle it properly (e.g. + // discard and rebuild) + libtextclassifier3::Status TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + EmbeddingIndex* new_index) const; + + // Flushes contents of metadata file. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistMetadataToDisk(bool force) override; + + // Flushes contents of all storages to underlying files. + // + // Returns: + // - OK on success + // - INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistStoragesToDisk(bool force) override; + + // Computes and returns Info checksum. + // + // Returns: + // - Crc of the Info on success + libtextclassifier3::StatusOr<Crc32> ComputeInfoChecksum(bool force) override; + + // Computes and returns all storages checksum. + // + // Returns: + // - Crc of all storages on success + // - INTERNAL_ERROR if any data inconsistency + libtextclassifier3::StatusOr<Crc32> ComputeStoragesChecksum( + bool force) override; + + Crcs& crcs() override { + return *reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() + + kCrcsMetadataBufferOffset); + } + + const Crcs& crcs() const override { + return *reinterpret_cast<const Crcs*>(metadata_mmapped_file_->region() + + kCrcsMetadataBufferOffset); + } + + Info& info() { + return *reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() + + kInfoMetadataBufferOffset); + } + + const Info& info() const { + return *reinterpret_cast<const Info*>(metadata_mmapped_file_->region() + + kInfoMetadataBufferOffset); + } + + // In memory data: + // Pending embedding hits with their embedding keys used for + // embedding_posting_list_mapper_. + std::vector<std::pair<std::string, EmbeddingHit>> pending_embedding_hits_; + + // Metadata + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; + + // Posting list storage + std::unique_ptr<PostingListEmbeddingHitSerializer> + posting_list_hit_serializer_ = + std::make_unique<PostingListEmbeddingHitSerializer>(); + std::unique_ptr<FlashIndexStorage> flash_index_storage_; + + // The mapper from embedding keys to the corresponding posting list identifier + // that stores all embedding hits with the same key. + // + // The key for an embedding hit is a one-to-one encoded string of the ordered + // pair (dimension, model_signature) corresponding to the embedding. + std::unique_ptr<KeyMapper<PostingListIdentifier>> + embedding_posting_list_mapper_; + + // A single FileBackedVector that holds all embedding vectors. + std::unique_ptr<FileBackedVector<float>> embedding_vectors_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_INDEX_H_ diff --git a/icing/index/embed/embedding-index_test.cc b/icing/index/embed/embedding-index_test.cc new file mode 100644 index 0000000..5980e82 --- /dev/null +++ b/icing/index/embed/embedding-index_test.cc @@ -0,0 +1,582 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-index.h" + +#include <unistd.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/filesystem.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/proto/document.pb.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Test; + +class EmbeddingIndexTest : public Test { + protected: + void SetUp() override { + embedding_index_dir_ = GetTestTempDir() + "/embedding_index_test"; + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); + } + + void TearDown() override { + embedding_index_.reset(); + filesystem_.DeleteDirectoryRecursively(embedding_index_dir_.c_str()); + } + + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + uint32_t dimension, std::string_view model_signature) { + std::vector<EmbeddingHit> hits; + + libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = + embedding_index_->GetAccessor(dimension, model_signature); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + return hits; + } else { + return std::move(pl_accessor_or).status(); + } + + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + return hits; + } + hits.insert(hits.end(), batch.begin(), batch.end()); + } + } + + std::vector<float> GetRawEmbeddingData() { + return std::vector<float>(embedding_index_->GetRawEmbeddingData(), + embedding_index_->GetRawEmbeddingData() + + embedding_index_->GetTotalVectorSize()); + } + + Filesystem filesystem_; + std::string embedding_index_dir_; + std::unique_ptr<EmbeddingIndex> embedding_index_; +}; + +TEST_F(EmbeddingIndexTest, AddSingleEmbedding) { + PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, AddMultipleEmbeddingsInTheSameSection) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, HitsWithLowerSectionIdReturnedFirst) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/5, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/2, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/5, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, HitsWithHigherDocumentIdReturnedFirst) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); +} + +TEST_F(EmbeddingIndexTest, AddEmbeddingsFromDifferentModels) { + PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = + CreateVector("model2", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT( + GetHits(/*dimension=*/5, /*model_signature=*/"non-existent-model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, + AddEmbeddingsWithSameSignatureButDifferentDimension) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(0); + + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, ClearIndex) { + // Loop the same logic twice to make sure that clear works as expected, and + // the index is still valid after clearing. + for (int i = 0; i < 2; i++) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/2, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/2, /*document_id=*/1), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Check that clear works as expected. + ICING_ASSERT_OK(embedding_index_->Clear()); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), kInvalidDocumentId); + } +} + +TEST_F(EmbeddingIndexTest, EmptyCommitIsOk) { + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexTest, MultipleCommits) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); +} + +TEST_F(EmbeddingIndexTest, + InvalidCommit_SectionIdCanOnlyDecreaseForSingleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector2)); + // Posting list with delta encoding can only allow decreasing values. + EXPECT_THAT(embedding_index_->CommitBufferToIndex(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(EmbeddingIndexTest, InvalidCommit_DocumentIdCanOnlyIncrease) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector2)); + // Posting list with delta encoding can only allow decreasing values, which + // means document ids must be committed increasingly, since document ids are + // inverted in hit values. + EXPECT_THAT(embedding_index_->CommitBufferToIndex(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(EmbeddingIndexTest, EmptyOptimizeIsOk) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{}, + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexTest, OptimizeSingleEmbeddingSingleDocument) { + PropertyProto::VectorProto vector = CreateVector("model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(2); + + // Before optimize + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize without deleting any documents, and check that the index is + // not changed. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/2), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize to map document id 2 to 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/1), /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2, 0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize to delete the document. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsSingleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/2), vector2)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(2); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize without deleting any documents, and check that the index is + // not changed. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 2); + + // Run optimize to map document id 2 to 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize to delete the document. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeMultipleEmbeddingsMultipleDocument) { + PropertyProto::VectorProto vector1 = CreateVector("model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector2 = CreateVector("model", {1, 2, 3}); + PropertyProto::VectorProto vector3 = + CreateVector("model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/0), vector2)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector3)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/6), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/3)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 0.3, 1, 2, 3, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize without deleting any documents. It is expected to see that the + // raw embedding data is rearranged, since during index transfer, embedding + // vectors from higher document ids are added first. + // + // Also keep in mind that once the raw data is rearranged, calling another + // Optimize subsequently will not change the raw data again. + for (int i = 0; i < 2; i++) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(-0.1, -0.2, -0.3, 0.1, 0.2, 0.3, 1, 2, 3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + } + + // Run optimize to delete document 0, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, 0}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(-0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +TEST_F(EmbeddingIndexTest, OptimizeEmbeddingsFromDifferentModels) { + PropertyProto::VectorProto vector1 = CreateVector("model1", {0.1, 0.2}); + PropertyProto::VectorProto vector2 = CreateVector("model1", {1, 2}); + PropertyProto::VectorProto vector3 = + CreateVector("model2", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/0), vector1)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/0, /*document_id=*/1), vector2)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(/*section_id=*/1, /*document_id=*/1), vector3)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + embedding_index_->set_last_added_document_id(1); + + // Before optimize + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/2), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1), + /*location=*/4)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.1, 0.2, 1, 2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + + // Run optimize without deleting any documents. It is expected to see that the + // raw embedding data is rearranged, since during index transfer: + // - Embedding vectors with lower keys, which are the string encoded ordered + // pairs (dimension, model_signature), are iterated first. + // - Embedding vectors from higher document ids are added first. + // + // Also keep in mind that once the raw data is rearranged, calling another + // Optimize subsequently will not change the raw data again. + for (int i = 0; i < 2; i++) { + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/0), + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/2)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/1, /*document_id=*/1), + /*location=*/4)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(1, 2, 0.1, 0.2, -0.1, -0.2, -0.3)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 1); + } + + // Run optimize to delete document 1, and check that the index is + // updated correctly. + ICING_ASSERT_OK(embedding_index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits(/*dimension=*/2, /*model_signature=*/"model1"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)))); + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model2"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.1, 0.2)); + EXPECT_EQ(embedding_index_->last_added_document_id(), 0); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-query-results.h b/icing/index/embed/embedding-query-results.h new file mode 100644 index 0000000..de85489 --- /dev/null +++ b/icing/index/embed/embedding-query-results.h @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ +#define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ + +#include <unordered_map> +#include <vector> + +#include "icing/proto/search.pb.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +// A class to store results generated from embedding queries. +struct EmbeddingQueryResults { + // Maps from DocumentId to the list of matched embedding scores for that + // document, which will be used in the advanced scoring language to + // determine the results for the "this.matchedSemanticScores(...)" function. + using EmbeddingQueryScoreMap = + std::unordered_map<DocumentId, std::vector<double>>; + + // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap. + std::unordered_map< + int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code, + EmbeddingQueryScoreMap>> + result_scores; + + // Returns the matched scores for the given query_vector_index, metric_type, + // and doc_id. Returns nullptr if (query_vector_index, metric_type) does not + // exist in the result_scores map. + const std::vector<double>* GetMatchedScoresForDocument( + int query_vector_index, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + DocumentId doc_id) const { + // Check if a mapping exists for the query_vector_index + auto outer_it = result_scores.find(query_vector_index); + if (outer_it == result_scores.end()) { + return nullptr; + } + // Check if a mapping exists for the metric_type + auto inner_it = outer_it->second.find(metric_type); + if (inner_it == outer_it->second.end()) { + return nullptr; + } + const EmbeddingQueryScoreMap& score_map = inner_it->second; + + // Check if the doc_id exists in the score_map + auto scores_it = score_map.find(doc_id); + if (scores_it == score_map.end()) { + return nullptr; + } + return &scores_it->second; + } +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ diff --git a/icing/index/embed/embedding-scorer.cc b/icing/index/embed/embedding-scorer.cc new file mode 100644 index 0000000..0d84e01 --- /dev/null +++ b/icing/index/embed/embedding-scorer.cc @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/embedding-scorer.h" + +#include <cmath> +#include <memory> +#include <string> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +namespace { + +float CalculateDotProduct(int dimension, const float* v1, const float* v2) { + float dot_product = 0.0; + for (int i = 0; i < dimension; ++i) { + dot_product += v1[i] * v2[i]; + } + return dot_product; +} + +float CalculateNorm2(int dimension, const float* v) { + return std::sqrt(CalculateDotProduct(dimension, v, v)); +} + +float CalculateCosine(int dimension, const float* v1, const float* v2) { + float divisor = CalculateNorm2(dimension, v1) * CalculateNorm2(dimension, v2); + if (divisor == 0.0) { + return 0.0; + } + return CalculateDotProduct(dimension, v1, v2) / divisor; +} + +float CalculateEuclideanDistance(int dimension, const float* v1, + const float* v2) { + float result = 0.0; + for (int i = 0; i < dimension; ++i) { + float diff = v1[i] - v2[i]; + result += diff * diff; + } + return std::sqrt(result); +} + +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> +EmbeddingScorer::Create( + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) { + switch (metric_type) { + case SearchSpecProto::EmbeddingQueryMetricType::COSINE: + return std::make_unique<CosineEmbeddingScorer>(); + case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT: + return std::make_unique<DotProductEmbeddingScorer>(); + case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN: + return std::make_unique<EuclideanDistanceEmbeddingScorer>(); + default: + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type))); + } +} + +float CosineEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateCosine(dimension, v1, v2); +} + +float DotProductEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateDotProduct(dimension, v1, v2); +} + +float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1, + const float* v2) const { + return CalculateEuclideanDistance(dimension, v1, v2); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/embedding-scorer.h b/icing/index/embed/embedding-scorer.h new file mode 100644 index 0000000..8caf0bc --- /dev/null +++ b/icing/index/embed/embedding-scorer.h @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ +#define ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +class EmbeddingScorer { + public: + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>> Create( + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type); + virtual float Score(int dimension, const float* v1, + const float* v2) const = 0; + + virtual ~EmbeddingScorer() = default; +}; + +class CosineEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +class DotProductEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +class EuclideanDistanceEmbeddingScorer : public EmbeddingScorer { + public: + float Score(int dimension, const float* v1, const float* v2) const override; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_EMBEDDING_SCORER_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.cc b/icing/index/embed/posting-list-embedding-hit-accessor.cc new file mode 100644 index 0000000..e154165 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor.cc @@ -0,0 +1,132 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" + +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +PostingListEmbeddingHitAccessor::Create( + FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer) { + uint32_t max_posting_list_bytes = storage->max_posting_list_bytes(); + ICING_ASSIGN_OR_RETURN(PostingListUsed in_memory_posting_list, + PostingListUsed::CreateFromUnitializedRegion( + serializer, max_posting_list_bytes)); + return std::unique_ptr<PostingListEmbeddingHitAccessor>( + new PostingListEmbeddingHitAccessor(storage, serializer, + std::move(in_memory_posting_list))); +} + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>> +PostingListEmbeddingHitAccessor::CreateFromExisting( + FlashIndexStorage *storage, PostingListEmbeddingHitSerializer *serializer, + PostingListIdentifier existing_posting_list_id) { + // Our in_memory_posting_list_ will start as empty. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + Create(storage, serializer)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage->GetPostingList(existing_posting_list_id)); + pl_accessor->preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + return pl_accessor; +} + +// Returns the next batch of hits for the provided posting list. +libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> +PostingListEmbeddingHitAccessor::GetNextHitsBatch() { + if (preexisting_posting_list_ == nullptr) { + if (has_reached_posting_list_chain_end_) { + return std::vector<EmbeddingHit>(); + } + return absl_ports::FailedPreconditionError( + "Cannot retrieve hits from a PostingListEmbeddingHitAccessor that was " + "not created from a preexisting posting list."); + } + ICING_ASSIGN_OR_RETURN( + std::vector<EmbeddingHit> batch, + serializer_->GetHits(&preexisting_posting_list_->posting_list)); + uint32_t next_block_index = kInvalidBlockIndex; + // Posting lists will only be chained when they are max-sized, in which case + // next_block_index will point to the next block for the next posting list. + // Otherwise, next_block_index can be kInvalidBlockIndex or be used to point + // to the next free list block, which is not relevant here. + if (preexisting_posting_list_->posting_list.size_in_bytes() == + storage_->max_posting_list_bytes()) { + next_block_index = preexisting_posting_list_->next_block_index; + } + + if (next_block_index != kInvalidBlockIndex) { + // Since we only have to deal with next block for max-sized posting list + // block, max_num_posting_lists is 1 and posting_list_index_bits is + // BitsToStore(1). + PostingListIdentifier next_posting_list_id( + next_block_index, /*posting_list_index=*/0, + /*posting_list_index_bits=*/BitsToStore(1)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage_->GetPostingList(next_posting_list_id)); + preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + } else { + has_reached_posting_list_chain_end_ = true; + preexisting_posting_list_.reset(); + } + return batch; +} + +libtextclassifier3::Status PostingListEmbeddingHitAccessor::PrependHit( + const EmbeddingHit &hit) { + PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr) + ? preexisting_posting_list_->posting_list + : in_memory_posting_list_; + libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit); + if (!absl_ports::IsResourceExhausted(status)) { + return status; + } + // There is no more room to add hits to this current posting list! Therefore, + // we need to either move those hits to a larger posting list or flush this + // posting list and create another max-sized posting list in the chain. + if (preexisting_posting_list_ != nullptr) { + ICING_RETURN_IF_ERROR(FlushPreexistingPostingList()); + } else { + ICING_RETURN_IF_ERROR(FlushInMemoryPostingList()); + } + + // Re-add hit. Should always fit since we just cleared + // in_memory_posting_list_. It's fine to explicitly reference + // in_memory_posting_list_ here because there's no way of reaching this line + // while preexisting_posting_list_ is still in use. + return serializer_->PrependHit(&in_memory_posting_list_, hit); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-accessor.h b/icing/index/embed/posting-list-embedding-hit-accessor.h new file mode 100644 index 0000000..4acb9a3 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor.h @@ -0,0 +1,106 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ +#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ + +#include <memory> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +namespace icing { +namespace lib { + +// This class is used to provide a simple abstraction for adding hits to posting +// lists. PostingListEmbeddingHitAccessor handles 1) selection of properly-sized +// posting lists for the accumulated hits during Finalize() and 2) chaining of +// max-sized posting lists. +class PostingListEmbeddingHitAccessor : public PostingListAccessor { + public: + // Creates an empty PostingListEmbeddingHitAccessor. + // + // RETURNS: + // - On success, a valid unique_ptr instance of + // PostingListEmbeddingHitAccessor + // - INVALID_ARGUMENT error if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + Create(FlashIndexStorage* storage, + PostingListEmbeddingHitSerializer* serializer); + + // Create a PostingListEmbeddingHitAccessor with an existing posting list + // identified by existing_posting_list_id. + // + // The PostingListEmbeddingHitAccessor will add hits to this posting list + // until it is necessary either to 1) chain the posting list (if it is + // max-sized) or 2) move its hits to a larger posting list. + // + // RETURNS: + // - On success, a valid unique_ptr instance of + // PostingListEmbeddingHitAccessor + // - INVALID_ARGUMENT if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + CreateFromExisting(FlashIndexStorage* storage, + PostingListEmbeddingHitSerializer* serializer, + PostingListIdentifier existing_posting_list_id); + + PostingListSerializer* GetSerializer() override { return serializer_; } + + // Retrieve the next batch of hits for the posting list chain + // + // RETURNS: + // - On success, a vector of hits in the posting list chain + // - INTERNAL if called on an instance of PostingListEmbeddingHitAccessor + // that was created via PostingListEmbeddingHitAccessor::Create, if unable + // to read the next posting list in the chain or if the posting list has + // been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetNextHitsBatch(); + + // Prepend one hit. This may result in flushing the posting list to disk (if + // the PostingListEmbeddingHitAccessor holds a max-sized posting list that is + // full) or freeing a pre-existing posting list if it is too small to fit all + // hits necessary. + // + // RETURNS: + // - OK, on success + // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the + // previously added hit. + // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new + // posting list. + libtextclassifier3::Status PrependHit(const EmbeddingHit& hit); + + private: + explicit PostingListEmbeddingHitAccessor( + FlashIndexStorage* storage, PostingListEmbeddingHitSerializer* serializer, + PostingListUsed in_memory_posting_list) + : PostingListAccessor(storage, std::move(in_memory_posting_list)), + serializer_(serializer) {} + + PostingListEmbeddingHitSerializer* serializer_; // Does not own. +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_ACCESSOR_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-accessor_test.cc b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc new file mode 100644 index 0000000..b9ebe87 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-accessor_test.cc @@ -0,0 +1,387 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" +#include "icing/index/hit/hit.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/hit-test-utils.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Lt; +using ::testing::SizeIs; + +class PostingListEmbeddingHitAccessorTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/test_dir"; + file_name_ = test_dir_ + "/test_file.idx.index"; + + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str())); + + serializer_ = std::make_unique<PostingListEmbeddingHitSerializer>(); + + ICING_ASSERT_OK_AND_ASSIGN( + FlashIndexStorage flash_index_storage, + FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get())); + flash_index_storage_ = + std::make_unique<FlashIndexStorage>(std::move(flash_index_storage)); + } + + void TearDown() override { + flash_index_storage_.reset(); + serializer_.reset(); + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + } + + Filesystem filesystem_; + std::string test_dir_; + std::string file_name_; + std::unique_ptr<PostingListEmbeddingHitSerializer> serializer_; + std::unique_ptr<FlashIndexStorage> flash_index_storage_; +}; + +TEST_F(PostingListEmbeddingHitAccessorTest, HitsAddAndRetrieveProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit& hit : hits1) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); + } + PostingListAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result.status); + EXPECT_THAT(result.id.block_index(), Eq(1)); + EXPECT_THAT(result.id.posting_list_index(), Eq(0)); + + // Retrieve some hits. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(pl_holder.next_block_index, Eq(kInvalidBlockIndex)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLKeepOnSameBlock) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add a single hit. This will fit in a min-sized posting list. + EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/1); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + // Should have been allocated to the first block. + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more hit. The minimum size for a posting list must be able to fit + // at least two hits, so this should NOT cause the previous pl to be + // reallocated. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + EmbeddingHit hit2(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit2)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + // Should have been allocated to the same posting list as the first hit. + EXPECT_THAT(result2.id, Eq(result1.id)); + + // The posting list at result2.id should hold all of the hits that have been + // added. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result2.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAre(hit2, hit1))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPLReallocateToLargerPL) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/11, /*desired_byte_length=*/1); + + // Add 8 hits with a small posting list of 24 bytes. The first 7 hits will + // be compressed to one byte each and will be able to fit in the 8 byte + // padded region. The last hit will fit in one of the special hits. The + // posting list will be ALMOST_FULL and can fit at most 2 more hits. + for (int i = 0; i < 8; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + // Should have been allocated to the first block. + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Now let's add some more hits! + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + // The current posting list can fit at most 2 more hits. + for (int i = 8; i < 10; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + // The 2 hits should still fit on the first block + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more hit + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result2.id)); + // The current posting list should be FULL. Adding more hits should result in + // these hits being moved to a larger posting list. + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[10])); + PostingListAccessor::FinalizeResult result3 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result3.status); + // Should have been allocated to the second (new) block because the posting + // list should have grown beyond the size that the first block maintains. + EXPECT_THAT(result3.id.block_index(), Eq(2)); + EXPECT_THAT(result3.id.posting_list_index(), Eq(0)); + + // The posting list at result3.id should hold all of the hits that have been + // added. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result3.id)); + EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, MultiBlockChainsBlocksProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5000, /*desired_byte_length=*/1); + for (const EmbeddingHit& hit : hits1) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + PostingListIdentifier second_block_id = result1.id; + // Should have been allocated to the second block, which holds a max-sized + // posting list. + EXPECT_THAT(second_block_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // Now let's retrieve them! + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_block_id)); + // This pl_holder will only hold a posting list with the hits that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits, + serializer_->GetHits(&pl_holder.posting_list)); + ASSERT_THAT(second_block_hits, SizeIs(Lt(hits1.size()))); + auto first_block_hits_start = hits1.rbegin() + second_block_hits.size(); + EXPECT_THAT(second_block_hits, + ElementsAreArray(hits1.rbegin(), first_block_hits_start)); + + // Now retrieve all of the hits that were on the first block. + uint32_t first_block_id = pl_holder.next_block_index; + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT( + serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, + PreexistingMultiBlockReusesBlocksProperly) { + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/5050, /*desired_byte_length=*/1); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some hits! Any hits! + for (int i = 0; i < 5000; ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result1.status); + PostingListIdentifier first_add_id = result1.id; + EXPECT_THAT(first_add_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // Now add a couple more hits. These should fit on the existing, not full + // second block. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), first_add_id)); + for (int i = 5000; i < hits.size(); ++i) { + ICING_ASSERT_OK(pl_accessor->PrependHit(hits[i])); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result2.status); + PostingListIdentifier second_add_id = result2.id; + EXPECT_THAT(second_add_id, Eq(first_add_id)); + + // We should be able to retrieve all 5050 hits. + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_add_id)); + // This pl_holder will only hold a posting list with the hits that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<EmbeddingHit> second_block_hits, + serializer_->GetHits(&pl_holder.posting_list)); + ASSERT_THAT(second_block_hits, SizeIs(Lt(hits.size()))); + auto first_block_hits_start = hits.rbegin() + second_block_hits.size(); + EXPECT_THAT(second_block_hits, + ElementsAreArray(hits.rbegin(), first_block_hits_start)); + + // Now retrieve all of the hits that were on the first block. + uint32_t first_block_id = pl_holder.next_block_index; + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT( + serializer_->GetHits(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits.rend()))); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, InvalidHitReturnsInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EXPECT_THAT(pl_accessor->PrependHit(invalid_hit), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, + HitsNotDecreasingReturnsInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/5); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + + EmbeddingHit hit2(BasicHit(/*section_id=*/6, /*document_id=*/1), + /*location=*/5); + EXPECT_THAT(pl_accessor->PrependHit(hit2), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EmbeddingHit hit3(BasicHit(/*section_id=*/2, /*document_id=*/0), + /*location=*/5); + EXPECT_THAT(pl_accessor->PrependHit(hit3), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EmbeddingHit hit4(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/6); + EXPECT_THAT(pl_accessor->PrependHit(hit4), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, NewPostingListNoHitsAdded) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + EXPECT_THAT(result1.status, + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListEmbeddingHitAccessorTest, PreexistingPostingListNoHitsAdded) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor, + PostingListEmbeddingHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + EmbeddingHit hit1(BasicHit(/*section_id=*/3, /*document_id=*/1), + /*location=*/5); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor2, + PostingListEmbeddingHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor2).Finalize(); + ICING_ASSERT_OK(result2.status); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.cc b/icing/index/embed/posting-list-embedding-hit-serializer.cc new file mode 100644 index 0000000..9247bcb --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer.cc @@ -0,0 +1,647 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +#include <cinttypes> +#include <cstdint> +#include <cstring> +#include <limits> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +uint32_t PostingListEmbeddingHitSerializer::GetBytesUsed( + const PostingListUsed* posting_list_used) const { + // The special hits will be included if they represent actual hits. If they + // represent the hit offset or the invalid hit sentinel, they are not + // included. + return posting_list_used->size_in_bytes() - + GetStartByteOffset(posting_list_used); +} + +uint32_t PostingListEmbeddingHitSerializer::GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) { + // If in either the FULL state or ALMOST_FULL state, this posting list *is* + // the minimum size posting list that can fit these hits. So just return the + // size of the posting list. + return posting_list_used->size_in_bytes(); + } + + // - In NOT_FULL status, BytesUsed contains no special hits. For a posting + // list in the NOT_FULL state with n hits, we would have n-1 compressed hits + // and 1 uncompressed hit. + // - The minimum sized posting list that would be guaranteed to fit these hits + // would be FULL, but calculating the size required for the FULL posting + // list would require deserializing the last two added hits, so instead we + // will calculate the size of an ALMOST_FULL posting list to fit. + // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the + // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in + // the compressed region. + // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed + // hits. + // - Therefore, fitting these hits into a posting list would require + // BytesUsed + one extra full hit. + return GetBytesUsed(posting_list_used) + sizeof(EmbeddingHit); +} + +void PostingListEmbeddingHitSerializer::Clear( + PostingListUsed* posting_list_used) const { + // Safe to ignore return value because posting_list_used->size_in_bytes() is + // a valid argument. + SetStartByteOffset(posting_list_used, + /*offset=*/posting_list_used->size_in_bytes()); +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::MoveFrom( + PostingListUsed* dst, PostingListUsed* src) const { + ICING_RETURN_ERROR_IF_NULL(dst); + ICING_RETURN_ERROR_IF_NULL(src); + if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "src MinPostingListSizeToFit %d must be larger than size %d.", + GetMinPostingListSizeToFit(src), dst->size_in_bytes())); + } + + if (!IsPostingListValid(dst)) { + return absl_ports::FailedPreconditionError( + "Dst posting list is in an invalid state and can't be used!"); + } + if (!IsPostingListValid(src)) { + return absl_ports::InvalidArgumentError( + "Cannot MoveFrom an invalid src posting list!"); + } + + // Pop just enough hits that all of src's compressed hits fit in + // dst posting_list's compressed area. Then we can memcpy that area. + std::vector<EmbeddingHit> hits; + while (IsFull(src) || IsAlmostFull(src) || + (dst->size_in_bytes() - kSpecialHitsSize < GetBytesUsed(src))) { + if (!GetHitsInternal(src, /*limit=*/1, /*pop=*/true, &hits).ok()) { + return absl_ports::AbortedError( + "Unable to retrieve hits from src posting list."); + } + } + + // memcpy the area and set up start byte offset. + Clear(dst); + memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src), + src->posting_list_buffer() + GetStartByteOffset(src), + GetBytesUsed(src)); + // Because we popped all hits from src outside of the compressed area and we + // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() - + // kSpecialHitSize. This is guaranteed to be a valid byte offset for the + // NOT_FULL state, so ignoring the value is safe. + SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src)); + + // Put back remaining hits. + for (size_t i = 0; i < hits.size(); i++) { + const EmbeddingHit& hit = hits[hits.size() - i - 1]; + // PrependHit can return either INVALID_ARGUMENT - if hit is invalid or not + // less than the previous hit - or RESOURCE_EXHAUSTED. RESOURCE_EXHAUSTED + // should be impossible because we've already assured that there is enough + // room above. + ICING_RETURN_IF_ERROR(PrependHit(dst, hit)); + } + + Clear(src); + return libtextclassifier3::Status::OK; +} + +uint32_t PostingListEmbeddingHitSerializer::GetPadEnd( + const PostingListUsed* posting_list_used, uint32_t offset) const { + EmbeddingHit::Value pad; + uint32_t pad_end = offset; + while (pad_end < posting_list_used->size_in_bytes()) { + size_t pad_len = VarInt::Decode( + posting_list_used->posting_list_buffer() + pad_end, &pad); + if (pad != 0) { + // No longer a pad. + break; + } + pad_end += pad_len; + } + return pad_end; +} + +bool PostingListEmbeddingHitSerializer::PadToEnd( + PostingListUsed* posting_list_used, uint32_t start, uint32_t end) const { + if (end > posting_list_used->size_in_bytes()) { + ICING_LOG(ERROR) << "Cannot pad a region that ends after size!"; + return false; + } + // In VarInt a value of 0 encodes to 0. + memset(posting_list_used->posting_list_buffer() + start, 0, end - start); + return true; +} + +libtextclassifier3::Status +PostingListEmbeddingHitSerializer::PrependHitToAlmostFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + // Get delta between first hit and the new hit. Try to fit delta + // in the padded area and put new hit at the special position 1. + // Calling ValueOrDie is safe here because 1 < kNumSpecialData. + EmbeddingHit cur = GetSpecialHit(posting_list_used, /*index=*/1); + if (cur.value() <= hit.value()) { + return absl_ports::InvalidArgumentError( + "Hit being prepended must be strictly less than the most recent Hit"); + } + uint64_t delta = cur.value() - hit.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + + uint32_t pad_end = GetPadEnd(posting_list_used, + /*offset=*/kSpecialHitsSize); + + if (pad_end >= kSpecialHitsSize + delta_len) { + // Pad area has enough space for delta of existing hit (cur). Write delta at + // pad_end - delta_len. + uint8_t* delta_offset = + posting_list_used->posting_list_buffer() + pad_end - delta_len; + memcpy(delta_offset, delta_buf, delta_len); + + // Now first hit is the new hit, at special position 1. Safe to ignore the + // return value because 1 < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, hit); + // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid + // argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // No space for delta. We put the new hit at special position 0 + // and go to the full state. Safe to ignore the return value because 1 < + // kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, hit); + } + return libtextclassifier3::Status::OK; +} + +void PostingListEmbeddingHitSerializer::PrependHitToEmpty( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + // First hit to be added. Just add verbatim, no compression. + if (posting_list_used->size_in_bytes() == kSpecialHitsSize) { + // Safe to ignore the return value because 1 < kNumSpecialData + SetSpecialHit(posting_list_used, /*index=*/1, hit); + // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid + // argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // Since this is the first hit, size != kSpecialHitsSize and + // size % sizeof(EmbeddingHit) == 0, we know that there is room to fit 'hit' + // into the compressed region, so ValueOrDie is safe. + uint32_t offset = + PrependHitUncompressed(posting_list_used, hit, + /*offset=*/posting_list_used->size_in_bytes()) + .ValueOrDie(); + // Safe to ignore the return value because PrependHitUncompressed is + // guaranteed to return a valid offset. + SetStartByteOffset(posting_list_used, offset); + } +} + +libtextclassifier3::Status +PostingListEmbeddingHitSerializer::PrependHitToNotFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const { + // First hit in compressed area. It is uncompressed. See if delta + // between the first hit and new hit will still fit in the + // compressed area. + if (offset + sizeof(EmbeddingHit::Value) > + posting_list_used->size_in_bytes()) { + // The first hit in the compressed region *should* be uncompressed, but + // somehow there isn't enough room between offset and the end of the + // compressed area to fit an uncompressed hit. This should NEVER happen. + return absl_ports::FailedPreconditionError( + "Posting list is in an invalid state."); + } + EmbeddingHit::Value cur_value; + memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset, + sizeof(EmbeddingHit::Value)); + if (cur_value <= hit.value()) { + return absl_ports::InvalidArgumentError( + IcingStringUtil::StringPrintf("EmbeddingHit %" PRId64 + " being prepended must be " + "strictly less than the most recent " + "EmbeddingHit %" PRId64, + hit.value(), cur_value)); + } + uint64_t delta = cur_value - hit.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + + // offset now points to one past the end of the first hit. + offset += sizeof(EmbeddingHit::Value); + if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) + delta_len <= offset) { + // Enough space for delta in compressed area. + + // Prepend delta. + offset -= delta_len; + memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf, + delta_len); + + // Prepend new hit. We know that there is room for 'hit' because of the if + // statement above, so calling ValueOrDie is safe. + offset = + PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie(); + // offset is guaranteed to be valid here. So it's safe to ignore the return + // value. The if above will guarantee that offset >= kSpecialHitSize and < + // posting_list_used->size_in_bytes() because the if ensures that there is + // enough room between offset and kSpecialHitSize to fit the delta of the + // previous hit and the uncompressed hit. + SetStartByteOffset(posting_list_used, offset); + } else if (kSpecialHitsSize + delta_len <= offset) { + // Only have space for delta. The new hit must be put in special + // position 1. + + // Prepend delta. + offset -= delta_len; + memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf, + delta_len); + + // Prepend pad. Safe to ignore the return value of PadToEnd because offset + // must be less than posting_list_used->size_in_bytes(). Otherwise, this + // function already would have returned FAILED_PRECONDITION. + PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize, + /*end=*/offset); + + // Put new hit in special position 1. Safe to ignore return value because 1 + // < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, hit); + + // State almost_full. Safe to ignore the return value because + // sizeof(EmbeddingHit) is a valid argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit)); + } else { + // Very rare case where delta is larger than sizeof(EmbeddingHit::Value) + // (i.e. varint delta encoding expanded required storage). We + // move first hit to special position 1 and put new hit in + // special position 0. + EmbeddingHit cur(cur_value); + // Safe to ignore the return value of PadToEnd because offset must be less + // than posting_list_used->size_in_bytes(). Otherwise, this function + // already would have returned FAILED_PRECONDITION. + PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize, + /*end=*/offset); + // Safe to ignore the return value here because 0 and 1 < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/1, cur); + SetSpecialHit(posting_list_used, /*index=*/0, hit); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::PrependHit( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const { + static_assert( + sizeof(EmbeddingHit::Value) <= sizeof(uint64_t), + "EmbeddingHit::Value cannot be larger than 8 bytes because the delta " + "must be able to fit in 8 bytes."); + if (!hit.is_valid()) { + return absl_ports::InvalidArgumentError("Cannot prepend an invalid hit!"); + } + if (!IsPostingListValid(posting_list_used)) { + return absl_ports::FailedPreconditionError( + "This PostingListUsed is in an invalid state and can't add any hits!"); + } + + if (IsFull(posting_list_used)) { + // State full: no space left. + return absl_ports::ResourceExhaustedError("No more room for hits"); + } else if (IsAlmostFull(posting_list_used)) { + return PrependHitToAlmostFull(posting_list_used, hit); + } else if (IsEmpty(posting_list_used)) { + PrependHitToEmpty(posting_list_used, hit); + return libtextclassifier3::Status::OK; + } else { + uint32_t offset = GetStartByteOffset(posting_list_used); + return PrependHitToNotFull(posting_list_used, hit, offset); + } +} + +libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> +PostingListEmbeddingHitSerializer::GetHits( + const PostingListUsed* posting_list_used) const { + std::vector<EmbeddingHit> hits_out; + ICING_RETURN_IF_ERROR(GetHits(posting_list_used, &hits_out)); + return hits_out; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHits( + const PostingListUsed* posting_list_used, + std::vector<EmbeddingHit>* hits_out) const { + return GetHitsInternal(posting_list_used, + /*limit=*/std::numeric_limits<uint32_t>::max(), + /*pop=*/false, hits_out); +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::PopFrontHits( + PostingListUsed* posting_list_used, uint32_t num_hits) const { + if (num_hits == 1 && IsFull(posting_list_used)) { + // The PL is in full status which means that we save 2 uncompressed hits in + // the 2 special postions. But full status may be reached by 2 different + // statuses. + // (1) In "almost full" status + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // When we prepend another hit, we can only put it at the special + // position 0. And we get a full PL + // +-----------------+----------------+-------+-----------------+ + // |new 1st hit |original 1st hit|(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // (2) In "not full" status + // +-----------------+----------------+------+-------+------------------+ + // |hits-start-offset|Hit::kInvalidVal|(pad) |1st hit|(compressed) hits | + // +-----------------+----------------+------+-------+------------------+ + // When we prepend another hit, we can reach any of the 3 following + // scenarios: + // (2.1) not full + // if the space of pad and original 1st hit can accommodate the new 1st hit + // and the encoded delta value. + // +-----------------+----------------+------+-----------+-----------------+ + // |hits-start-offset|Hit::kInvalidVal|(pad) |new 1st hit|(compressed) hits| + // +-----------------+----------------+------+-----------+-----------------+ + // (2.2) almost full + // If the space of pad and original 1st hit cannot accommodate the new 1st + // hit and the encoded delta value but can accommodate the encoded delta + // value only. We can put the new 1st hit at special position 1. + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |new 1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // (2.3) full + // In very rare case, it cannot even accommodate only the encoded delta + // value. we can move the original 1st hit into special position 1 and the + // new 1st hit into special position 0. This may happen because we use + // VarInt encoding method which may make the encoded value longer (about + // 4/3 times of original) + // +-----------------+----------------+-------+-----------------+ + // |new 1st hit |original 1st hit|(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // Suppose now the PL is full. But we don't know whether it arrived to + // this status from "not full" like (2.3) or from "almost full" like (1). + // We'll return to "almost full" status like (1) if we simply pop the new + // 1st hit but we want to make the prepending operation "reversible". So + // there should be some way to return to "not full" if possible. A simple + // way to do it is to pop 2 hits out of the PL to status "almost full" or + // "not full". And add the original 1st hit back. We can return to the + // correct original statuses of (2.1) or (1). This makes our prepending + // operation reversible. + std::vector<EmbeddingHit> out; + + // Popping 2 hits should never fail because we've just ensured that the + // posting list is in the FULL state. + ICING_RETURN_IF_ERROR( + GetHitsInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out)); + + // PrependHit should never fail because out[1] is a valid hit less than + // previous hits in the posting list and because there's no way that the + // posting list could run out of room because it previously stored this hit + // AND another hit. + ICING_RETURN_IF_ERROR(PrependHit(posting_list_used, out[1])); + } else if (num_hits > 0) { + return GetHitsInternal(posting_list_used, /*limit=*/num_hits, /*pop=*/true, + nullptr); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHitsInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<EmbeddingHit>* out) const { + // Put current uncompressed val here. + EmbeddingHit::Value val = EmbeddingHit::kInvalidValue; + uint32_t offset = GetStartByteOffset(posting_list_used); + uint32_t count = 0; + + // First traverse the first two special positions. + while (count < limit && offset < kSpecialHitsSize) { + // Calling ValueOrDie is safe here because offset / sizeof(EmbeddingHit) < + // kNumSpecialData because of the check above. + EmbeddingHit hit = GetSpecialHit(posting_list_used, + /*index=*/offset / sizeof(EmbeddingHit)); + val = hit.value(); + if (out != nullptr) { + out->push_back(hit); + } + offset += sizeof(EmbeddingHit); + count++; + } + + // If special position 1 was set then we need to skip padding. + if (val != EmbeddingHit::kInvalidValue && offset == kSpecialHitsSize) { + offset = GetPadEnd(posting_list_used, offset); + } + + while (count < limit && offset < posting_list_used->size_in_bytes()) { + if (val == EmbeddingHit::kInvalidValue) { + // First hit is in compressed area. Put that in val. + memcpy(&val, posting_list_used->posting_list_buffer() + offset, + sizeof(EmbeddingHit::Value)); + offset += sizeof(EmbeddingHit::Value); + } else { + // Now we have delta encoded subsequent hits. Decode and push. + uint64_t delta; + offset += VarInt::Decode( + posting_list_used->posting_list_buffer() + offset, &delta); + val += delta; + } + EmbeddingHit hit(val); + if (out != nullptr) { + out->push_back(hit); + } + count++; + } + + if (pop) { + PostingListUsed* mutable_posting_list_used = + const_cast<PostingListUsed*>(posting_list_used); + // Modify the posting list so that we pop all hits actually + // traversed. + if (offset >= kSpecialHitsSize && + offset < posting_list_used->size_in_bytes()) { + // In the compressed area. Pop and reconstruct. offset/val is + // the last traversed hit, which we must discard. So move one + // more forward. + uint64_t delta; + offset += VarInt::Decode( + posting_list_used->posting_list_buffer() + offset, &delta); + val += delta; + + // Now val is the first hit of the new posting list. + if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) <= offset) { + // val fits in compressed area. Simply copy. + offset -= sizeof(EmbeddingHit::Value); + memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val, + sizeof(EmbeddingHit::Value)); + } else { + // val won't fit in compressed area. + EmbeddingHit hit(val); + // Okay to ignore the return value here because 1 < kNumSpecialData. + SetSpecialHit(mutable_posting_list_used, /*index=*/1, hit); + + // Prepend pad. Safe to ignore the return value of PadToEnd because + // offset must be less than posting_list_used->size_in_bytes() thanks to + // the if above. + PadToEnd(mutable_posting_list_used, + /*start=*/kSpecialHitsSize, + /*end=*/offset); + offset = sizeof(EmbeddingHit); + } + } + // offset is guaranteed to be valid so ignoring the return value of + // set_start_byte_offset is safe. It falls into one of four scenarios: + // Scenario 1: the above if was false because offset is not < + // posting_list_used->size_in_bytes() + // In this case, offset must be == posting_list_used->size_in_bytes() + // because we reached offset by unwinding hits on the posting list. + // Scenario 2: offset is < kSpecialHitSize + // In this case, offset is guaranteed to be either 0 or + // sizeof(EmbeddingHit) because offset is incremented by + // sizeof(EmbeddingHit) within the first while loop. + // Scenario 3: offset is within the compressed region and the new first hit + // in the posting list (the value that 'val' holds) will fit as an + // uncompressed hit in the compressed region. The resulting offset from + // decompressing val must be >= kSpecialHitSize because otherwise we'd be + // in Scenario 4 + // Scenario 4: offset is within the compressed region, but the new first hit + // in the posting list is too large to fit as an uncompressed hit in the + // in the compressed region. Therefore, it must be stored in a special hit + // and offset will be sizeof(EmbeddingHit). + SetStartByteOffset(mutable_posting_list_used, offset); + } + + return libtextclassifier3::Status::OK; +} + +EmbeddingHit PostingListEmbeddingHitSerializer::GetSpecialHit( + const PostingListUsed* posting_list_used, uint32_t index) const { + static_assert(sizeof(EmbeddingHit::Value) >= sizeof(uint32_t), "HitTooSmall"); + EmbeddingHit val(EmbeddingHit::kInvalidValue); + memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val), + sizeof(val)); + return val; +} + +void PostingListEmbeddingHitSerializer::SetSpecialHit( + PostingListUsed* posting_list_used, uint32_t index, + const EmbeddingHit& val) const { + memcpy(posting_list_used->posting_list_buffer() + index * sizeof(val), &val, + sizeof(val)); +} + +bool PostingListEmbeddingHitSerializer::IsPostingListValid( + const PostingListUsed* posting_list_used) const { + if (IsAlmostFull(posting_list_used)) { + // Special Hit 1 should hold a Hit. Calling ValueOrDie is safe because we + // know that 1 < kNumSpecialData. + if (!GetSpecialHit(posting_list_used, /*index=*/1).is_valid()) { + ICING_LOG(ERROR) + << "Both special hits cannot be invalid at the same time."; + return false; + } + } else if (!IsFull(posting_list_used)) { + // NOT_FULL. Special Hit 0 should hold a valid offset. Calling ValueOrDie is + // safe because we know that 0 < kNumSpecialData. + if (GetSpecialHit(posting_list_used, /*index=*/0).value() > + posting_list_used->size_in_bytes() || + GetSpecialHit(posting_list_used, /*index=*/0).value() < + kSpecialHitsSize) { + ICING_LOG(ERROR) << "EmbeddingHit: " + << GetSpecialHit(posting_list_used, /*index=*/0).value() + << " size: " << posting_list_used->size_in_bytes() + << " sp size: " << kSpecialHitsSize; + return false; + } + } + return true; +} + +uint32_t PostingListEmbeddingHitSerializer::GetStartByteOffset( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used)) { + return 0; + } else if (IsAlmostFull(posting_list_used)) { + return sizeof(EmbeddingHit); + } else { + // NOT_FULL, calling ValueOrDie is safe because we know that 0 < + // kNumSpecialData. + return GetSpecialHit(posting_list_used, /*index=*/0).value(); + } +} + +bool PostingListEmbeddingHitSerializer::SetStartByteOffset( + PostingListUsed* posting_list_used, uint32_t offset) const { + if (offset > posting_list_used->size_in_bytes()) { + ICING_LOG(ERROR) << "offset cannot be a value greater than size " + << posting_list_used->size_in_bytes() << ". offset is " + << offset << "."; + return false; + } + if (offset < kSpecialHitsSize && offset > sizeof(EmbeddingHit)) { + ICING_LOG(ERROR) << "offset cannot be a value between (" + << sizeof(EmbeddingHit) << ", " << kSpecialHitsSize + << "). offset is " << offset << "."; + return false; + } + if (offset < sizeof(EmbeddingHit) && offset != 0) { + ICING_LOG(ERROR) << "offset cannot be a value between (0, " + << sizeof(EmbeddingHit) << "). offset is " << offset + << "."; + return false; + } + if (offset >= kSpecialHitsSize) { + // not_full state. Safe to ignore the return value because 0 and 1 are both + // < kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, EmbeddingHit(offset)); + SetSpecialHit(posting_list_used, /*index=*/1, + EmbeddingHit(EmbeddingHit::kInvalidValue)); + } else if (offset == sizeof(EmbeddingHit)) { + // almost_full state. Safe to ignore the return value because 1 is both < + // kNumSpecialData. + SetSpecialHit(posting_list_used, /*index=*/0, + EmbeddingHit(EmbeddingHit::kInvalidValue)); + } + // Nothing to do for the FULL state - the offset isn't actually stored + // anywhere and both special hits hold valid hits. + return true; +} + +libtextclassifier3::StatusOr<uint32_t> +PostingListEmbeddingHitSerializer::PrependHitUncompressed( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const { + if (offset < kSpecialHitsSize + sizeof(EmbeddingHit::Value)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Not enough room to prepend EmbeddingHit::Value at offset %d.", + offset)); + } + offset -= sizeof(EmbeddingHit::Value); + EmbeddingHit::Value val = hit.value(); + memcpy(posting_list_used->posting_list_buffer() + offset, &val, + sizeof(EmbeddingHit::Value)); + return offset; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embed/posting-list-embedding-hit-serializer.h b/icing/index/embed/posting-list-embedding-hit-serializer.h new file mode 100644 index 0000000..76198c2 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer.h @@ -0,0 +1,284 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ +#define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ + +#include <cstdint> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +// A serializer class to serialize hits to PostingListUsed. Layout described in +// comments in posting-list-embedding-hit-serializer.cc. +class PostingListEmbeddingHitSerializer : public PostingListSerializer { + public: + static constexpr uint32_t kSpecialHitsSize = + kNumSpecialData * sizeof(EmbeddingHit); + + uint32_t GetDataTypeBytes() const override { return sizeof(EmbeddingHit); } + + uint32_t GetMinPostingListSize() const override { + static constexpr uint32_t kMinPostingListSize = kSpecialHitsSize; + static_assert(sizeof(PostingListIndex) <= kMinPostingListSize, + "PostingListIndex must be small enough to fit in a " + "minimum-sized Posting List."); + + return kMinPostingListSize; + } + + uint32_t GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const override; + + uint32_t GetBytesUsed( + const PostingListUsed* posting_list_used) const override; + + void Clear(PostingListUsed* posting_list_used) const override; + + libtextclassifier3::Status MoveFrom(PostingListUsed* dst, + PostingListUsed* src) const override; + + // Prepend a hit to the posting list. + // + // RETURNS: + // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the + // previously added hit. + // - RESOURCE_EXHAUSTED if there is no more room to add hit to the posting + // list. + libtextclassifier3::Status PrependHit(PostingListUsed* posting_list_used, + const EmbeddingHit& hit) const; + + // Prepend hits to the posting list. Hits should be sorted in descending order + // (as defined by the less than operator for Hit) + // + // Returns the number of hits that could be prepended to the posting list. If + // keep_prepended is true, whatever could be prepended is kept, otherwise the + // posting list is left in its original state. + template <class T, EmbeddingHit (*GetHit)(const T&)> + libtextclassifier3::StatusOr<uint32_t> PrependHitArray( + PostingListUsed* posting_list_used, const T* array, uint32_t num_hits, + bool keep_prepended) const; + + // Retrieves the hits stored in the posting list. + // + // RETURNS: + // - On success, a vector of hits sorted by the reverse order of prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + const PostingListUsed* posting_list_used) const; + + // Same as GetHits but appends hits to hits_out. + // + // RETURNS: + // - On success, a vector of hits sorted by the reverse order of prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetHits(const PostingListUsed* posting_list_used, + std::vector<EmbeddingHit>* hits_out) const; + + // Undo the last num_hits hits prepended. If num_hits > number of + // hits we clear all hits. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status PopFrontHits(PostingListUsed* posting_list_used, + uint32_t num_hits) const; + + private: + // Posting list layout formats: + // + // not_full + // + // +-----------------+----------------+-------+-----------------+ + // |hits-start-offset|Hit::kInvalidVal|xxxxxxx|(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // almost_full + // + // +-----------------+----------------+-------+-----------------+ + // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // full() + // + // +-----------------+----------------+-------+-----------------+ + // |1st hit |2nd hit |(pad) |(compressed) hits| + // +-----------------+----------------+-------+-----------------+ + // + // The first two uncompressed hits also implicitly encode information about + // the size of the compressed hits region. + // + // 1. If the posting list is NOT_FULL, then + // posting_list_buffer_[0] contains the byte offset of the start of the + // compressed hits - and, thus, the size of the compressed hits region is + // size_in_bytes - posting_list_buffer_[0]. + // + // 2. If posting list is ALMOST_FULL or FULL, then the compressed hits region + // starts somewhere between [kSpecialHitsSize, kSpecialHitsSize + + // sizeof(EmbeddingHit) - 1] and ends at size_in_bytes - 1. + + // Helpers to determine what state the posting list is in. + bool IsFull(const PostingListUsed* posting_list_used) const { + return GetSpecialHit(posting_list_used, /*index=*/0).is_valid() && + GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + bool IsAlmostFull(const PostingListUsed* posting_list_used) const { + return !GetSpecialHit(posting_list_used, /*index=*/0).is_valid() && + GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + bool IsEmpty(const PostingListUsed* posting_list_used) const { + return GetSpecialHit(posting_list_used, /*index=*/0).value() == + posting_list_used->size_in_bytes() && + !GetSpecialHit(posting_list_used, /*index=*/1).is_valid(); + } + + // Returns false if both special hits are invalid or if the offset value + // stored in the special hit is less than kSpecialHitsSize or greater than + // posting_list_used->size_in_bytes(). Returns true, otherwise. + bool IsPostingListValid(const PostingListUsed* posting_list_used) const; + + // Prepend hit to a posting list that is in the ALMOST_FULL state. + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if hit is not less than the previously added hit. + libtextclassifier3::Status PrependHitToAlmostFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit) const; + + // Prepend hit to a posting list that is in the EMPTY state. This will always + // succeed because there are no pre-existing hits and no validly constructed + // posting list could fail to fit one hit. + void PrependHitToEmpty(PostingListUsed* posting_list_used, + const EmbeddingHit& hit) const; + + // Prepend hit to a posting list that is in the NOT_FULL state. + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if hit is not less than the previously added hit. + libtextclassifier3::Status PrependHitToNotFull( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const; + + // Returns either 0 (full state), sizeof(EmbeddingHit) (almost_full state) or + // a byte offset between kSpecialHitsSize and + // posting_list_used->size_in_bytes() (inclusive) (not_full state). + uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const; + + // Sets the special hits to properly reflect what offset is (see layout + // comment for further details). + // + // Returns false if offset > posting_list_used->size_in_bytes() or offset is + // (kSpecialHitsSize, sizeof(EmbeddingHit)) or offset is + // (sizeof(EmbeddingHit), 0). True, otherwise. + bool SetStartByteOffset(PostingListUsed* posting_list_used, + uint32_t offset) const; + + // Manipulate padded areas. We never store the same hit value twice + // so a delta of 0 is a pad byte. + + // Returns offset of first non-pad byte. + uint32_t GetPadEnd(const PostingListUsed* posting_list_used, + uint32_t offset) const; + + // Fill padding between offset start and offset end with 0s. + // Returns false if end > posting_list_used->size_in_bytes(). True, + // otherwise. + bool PadToEnd(PostingListUsed* posting_list_used, uint32_t start, + uint32_t end) const; + + // Helper for AppendHits/PopFrontHits. Adds limit number of hits to out or all + // hits in the posting list if the posting list contains less than limit + // number of hits. out can be NULL. + // + // NOTE: If called with limit=1, pop=true on a posting list that transitioned + // from NOT_FULL directly to FULL, GetHitsInternal will not return the posting + // list to NOT_FULL. Instead it will leave it in a valid state, but it will be + // ALMOST_FULL. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetHitsInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<EmbeddingHit>* out) const; + + // Retrieves the value stored in the index-th special hit. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + // + // RETURNS: + // - A valid SpecialData<EmbeddingHit>. + EmbeddingHit GetSpecialHit(const PostingListUsed* posting_list_used, + uint32_t index) const; + + // Sets the value stored in the index-th special hit to val. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + void SetSpecialHit(PostingListUsed* posting_list_used, uint32_t index, + const EmbeddingHit& val) const; + + // Prepends hit to the memory region [offset - sizeof(EmbeddingHit), offset] + // and returns the new beginning of the padded region. + // + // RETURNS: + // - The new beginning of the padded region, if successful. + // - INVALID_ARGUMENT if hit will not fit (uncompressed) between offset and + // kSpecialHitsSize + libtextclassifier3::StatusOr<uint32_t> PrependHitUncompressed( + PostingListUsed* posting_list_used, const EmbeddingHit& hit, + uint32_t offset) const; +}; + +// Inlined functions. Implementation details below. Avert eyes! +template <class T, EmbeddingHit (*GetHit)(const T&)> +libtextclassifier3::StatusOr<uint32_t> +PostingListEmbeddingHitSerializer::PrependHitArray( + PostingListUsed* posting_list_used, const T* array, uint32_t num_hits, + bool keep_prepended) const { + if (!IsPostingListValid(posting_list_used)) { + return 0; + } + + // Prepend hits working backwards from array[num_hits - 1]. + uint32_t i; + for (i = 0; i < num_hits; ++i) { + if (!PrependHit(posting_list_used, GetHit(array[num_hits - i - 1])).ok()) { + break; + } + } + if (i != num_hits && !keep_prepended) { + // Didn't fit. Undo everything and check that we have the same offset as + // before. PopFrontHits guarantees that it will remove all 'i' hits so long + // as there are at least 'i' hits in the posting list, which we know there + // are. + ICING_RETURN_IF_ERROR(PopFrontHits(posting_list_used, /*num_hits=*/i)); + } + return i; +} + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_ diff --git a/icing/index/embed/posting-list-embedding-hit-serializer_test.cc b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc new file mode 100644 index 0000000..f829634 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc @@ -0,0 +1,864 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <deque> +#include <iterator> +#include <limits> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/hit-test-utils.h" + +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::Eq; +using testing::IsEmpty; +using testing::Le; +using testing::Lt; + +namespace icing { +namespace lib { + +namespace { + +struct HitElt { + HitElt() = default; + explicit HitElt(const EmbeddingHit &hit_in) : hit(hit_in) {} + + static EmbeddingHit get_hit(const HitElt &hit_elt) { return hit_elt.hit; } + + EmbeddingHit hit; +}; + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedPrependHitNotFull) { + PostingListEmbeddingHitSerializer serializer; + + static const int kNumHits = 2551; + static const size_t kHitsSize = kNumHits * sizeof(EmbeddingHit); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + + // Make used. + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0)); + int expected_size = sizeof(EmbeddingHit::Value); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + EmbeddingHit hit1(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/1); + uint64_t delta = hit0.value() - hit1.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/2); + delta = hit1.value() - hit2.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + EmbeddingHit hit3(BasicHit(/*section_id=*/0, /*document_id=*/3), + /*location=*/3); + delta = hit2.value() - hit3.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListUsedPrependHitAlmostFull) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 32 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Fill up the compressed region. + // Transitions: + // Adding hit0: EMPTY -> NOT_FULL + // Adding hit1: NOT_FULL -> NOT_FULL + // Adding hit2: NOT_FULL -> NOT_FULL + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/1); + EmbeddingHit hit1 = CreateEmbeddingHit(hit0, /*desired_byte_length=*/3); + EmbeddingHit hit2 = CreateEmbeddingHit(hit1, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); + // Size used will be 8 (hit2) + 3 (hit1-hit2) + 3 (hit0-hit1) = 14 bytes + int expected_size = sizeof(EmbeddingHit) + 3 + 3; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + // Add one more hit to transition NOT_FULL -> ALMOST_FULL + EmbeddingHit hit3 = CreateEmbeddingHit(hit2, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); + // Storing them in the compressed region requires 8 (hit) + 3 (hit2-hit3) + + // 3 (hit1-hit2) + 3 (hit0-hit1) = 17 bytes, but there are only 16 bytes in + // the compressed region. So instead, the posting list will transition to + // ALMOST_FULL. The in-use compressed region will actually shrink from 14 + // bytes to 9 bytes because the uncompressed version of hit2 will be + // overwritten with the compressed delta of hit2. hit3 will be written to one + // of the special hits. Because we're in ALMOST_FULL, the expected size is the + // size of the pl minus the one hit used to mark the posting list as + // ALMOST_FULL. + expected_size = pl_size - sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL + EmbeddingHit hit4 = CreateEmbeddingHit(hit3, /*desired_byte_length=*/6); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4)); + // There are currently 9 bytes in use in the compressed region. Hit3 will + // have a 6-byte delta, which fits in the compressed region. Hit3 will be + // moved from the special hit to the compressed region (which will have 15 + // bytes in use after adding hit3). Hit4 will be placed in one of the special + // hits and the posting list will remain in ALMOST_FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> FULL + EmbeddingHit hit5 = CreateEmbeddingHit(hit4, /*desired_byte_length=*/2); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5)); + // There are currently 15 bytes in use in the compressed region. Hit4 will + // have a 2-byte delta which will not fit in the compressed region. So hit4 + // will remain in one of the special hits and hit5 will occupy the other, + // making the posting list FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0))); + + // The posting list is FULL. Adding another hit should fail. + EmbeddingHit hit6 = CreateEmbeddingHit(hit5, /*desired_byte_length=*/1); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit6), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedMinSize) { + PostingListEmbeddingHitSerializer serializer; + + // Min size = 16 + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PL State: EMPTY + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + // Add a hit, PL should shift to ALMOST_FULL state + EmbeddingHit hit0(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/1); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + // Size = sizeof(uncompressed hit0) + int expected_size = sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + // Add the smallest hit possible with a delta of 0b1. PL should shift to FULL + // state. + EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0) + expected_size += sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + // Try to add the smallest hit possible. Should fail + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit2), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayMinSizePostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Min Size = 16 + int pl_size = serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + + // Add five hits. The PL is in the empty state and an empty min size PL can + // only fit two hits. So PrependHitArray should fail. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(2)); + + int can_fit_hits = num_can_prepend; + // The PL has room for 2 hits. We should be able to add them without any + // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL + const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2); + ICING_ASSERT_OK_AND_ASSIGN( + num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(can_fit_hits)); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), + hits_in.rend() - hits_in.size() + can_fit_hits, + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayPostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 48 + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + // The last hit is uncompressed and the four before it should only take one + // byte. Total use = 8 bytes. + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43-36 EmbeddingHit #4 + // 35-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=36 + // ---------------------- + int byte_size = sizeof(EmbeddingHit::Value) + hits_in.size() - 1; + + // Add five hits. The PL is in the empty state and should be able to fit all + // five hits without issue, transitioning the PL from EMPTY -> NOT_FULL. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + EmbeddingHit first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/1); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 EmbeddingHit #11 + // 21-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=22 + // ---------------------- + byte_size += 14; + + // Add these 7 hits. The PL is currently in the NOT_FULL state and should + // remain in the NOT_FULL state. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/8); + hits_in.clear(); + hits_in.emplace_back(first_hit); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-16 <unused> + // 15-8 EmbeddingHit #12 + // 7-0 kSpecialHit + // ---------------------- + byte_size = 40; // 48 - 8 + + // Add this 1 hit. The PL is currently in the NOT_FULL state and should + // transition to the ALMOST_FULL state - even though there is still some + // unused space. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/5); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-17 delta(EmbeddingHit #12) + // 16 <unused> + // 15-8 EmbeddingHit #13 + // 7-0 EmbeddingHit #14 + // ---------------------- + + // Add these 2 hits. + // - The PL is currently in the ALMOST_FULL state. Adding the first hit should + // keep the PL in ALMOST_FULL because the delta between + // EmbeddingHit #12 and EmbeddingHit #13 (5 byte) can fit in the unused area + // (6 bytes). + // - Adding the second hit should transition to the FULL state because the + // delta between EmbeddingHit #13 and EmbeddingHit #14 (3 bytes) is larger + // than the remaining unused area (1 byte). + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayTooManyHits) { + PostingListEmbeddingHitSerializer serializer; + + static constexpr int kNumHits = 130; + static constexpr int kDeltaSize = 1; + static constexpr size_t kHitsSize = + ((kNumHits - 2) * kDeltaSize + (2 * sizeof(EmbeddingHit))); + + // Create an array with one too many hits + std::vector<HitElt> hit_elts_in_too_many; + hit_elts_in_too_many.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0)); + for (int i = 0; i < kNumHits; ++i) { + hit_elts_in_too_many.emplace_back(CreateEmbeddingHit( + hit_elts_in_too_many.back().hit, /*desired_byte_length=*/1)); + } + // Reverse so that hits are inserted in descending order + std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end()); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PrependHitArray should fail because hit_elts_in_too_many is far too large + // for the minimum size pl. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(2)); + ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size())); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + ICING_ASSERT_OK_AND_ASSIGN( + pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + // PrependHitArray should fail because hit_elts_in_too_many is one hit too + // large for this pl. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1)); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListStatusJumpFromNotFullToFullAndBack) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 24 + const uint32_t pl_size = 3 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + EmbeddingHit max_valued_hit( + BasicHit(/*section_id=*/kMaxSectionId, /*document_id=*/kMinDocumentId), + /*location=*/std::numeric_limits<uint32_t>::max()); + ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit)); + uint32_t bytes_used = serializer.GetBytesUsed(&pl); + ASSERT_THAT(bytes_used, sizeof(EmbeddingHit)); + // Status not full. + ASSERT_THAT( + bytes_used, + Le(pl_size - PostingListEmbeddingHitSerializer::kSpecialHitsSize)); + + EmbeddingHit min_valued_hit( + BasicHit(/*section_id=*/kMinSectionId, /*document_id=*/kMaxDocumentId), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl, min_valued_hit)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(min_valued_hit, max_valued_hit))); + // Status should jump to full directly. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size)); + ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(max_valued_hit))); + // Status should return to not full as before. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(bytes_used)); +} + +TEST(PostingListEmbeddingHitSerializerTest, DeltaOverflow) { + PostingListEmbeddingHitSerializer serializer; + + const uint32_t pl_size = 4 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + static const EmbeddingHit::Value kMaxHitValue = + std::numeric_limits<EmbeddingHit::Value>::max(); + static const EmbeddingHit::Value kOverflow[4] = { + kMaxHitValue >> 2, + (kMaxHitValue >> 2) * 2, + (kMaxHitValue >> 2) * 3, + kMaxHitValue - 1, + }; + + // Fit at least 4 ordinary values. + std::deque<EmbeddingHit> hits_pushed; + for (EmbeddingHit::Value v = 0; v < 4; v++) { + hits_pushed.push_front( + EmbeddingHit(BasicHit(kMaxSectionId, kMaxDocumentId), 4 - v)); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + + // Cannot fit 4 overflow values. + hits_pushed.clear(); + ICING_ASSERT_OK_AND_ASSIGN( + pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + for (int i = 3; i >= 1; i--) { + hits_pushed.push_front(EmbeddingHit(/*value=*/kOverflow[i])); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + EXPECT_THAT(serializer.PrependHit(&pl, EmbeddingHit(/*value=*/kOverflow[0])), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForNotFullPL) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 64 + int pl_size = 4 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used NOT_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 63-62 delta(EmbeddingHit #0) + // 61-60 delta(EmbeddingHit #1) + // 59-58 delta(EmbeddingHit #2) + // 57-56 delta(EmbeddingHit #3) + // 55-48 EmbeddingHit #5 + // 47-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=48 + // ---------------------- + int bytes_used = 16; + + // Check that all hits have been inserted + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used + // into a posting list with this min size should make it ALMOST_FULL, which we + // can see should have size = 24. + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int expected_min_size = 24; + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(expected_min_size)); + + // Also check that this min size to fit posting list actually does fit all the + // hits and can only hit one more hit in the ALMOST_FULL state. + ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, min_size_to_fit)); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + } + + // Adding another hit to the min-size-to-fit posting list should succeed + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + // Adding any other hits should fail with RESOURCE_EXHAUSTED error. + EXPECT_THAT( + serializer.PrependHit(&min_size_to_fit_pl, + CreateEmbeddingHit(hit, /*desired_byte_length=*/1)), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Check that all hits have been inserted and the min-fit posting list is now + // FULL. + EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl), + Eq(min_size_to_fit)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 24; + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used ALMOST_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int bytes_used = 16; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should return the same size as pl_used. + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); + + // Add another hit to make the posting list FULL + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should still be the same size as pl_used. + min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveFrom) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + ICING_ASSERT_OK(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromNullArgumentReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromInvalidPostingListReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used1 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used1.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveToInvalidPostingListReturnsFailedPrecondition) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used2 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used2.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveToPostingListTooSmall) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/1, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/embedding-indexing-handler.cc b/icing/index/embedding-indexing-handler.cc new file mode 100644 index 0000000..049e307 --- /dev/null +++ b/icing/index/embedding-indexing-handler.cc @@ -0,0 +1,85 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embedding-indexing-handler.h" + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>> +EmbeddingIndexingHandler::Create(const Clock* clock, + EmbeddingIndex* embedding_index) { + ICING_RETURN_ERROR_IF_NULL(clock); + ICING_RETURN_ERROR_IF_NULL(embedding_index); + + return std::unique_ptr<EmbeddingIndexingHandler>( + new EmbeddingIndexingHandler(clock, embedding_index)); +} + +libtextclassifier3::Status EmbeddingIndexingHandler::Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + bool recovery_mode, PutDocumentStatsProto* put_document_stats) { + std::unique_ptr<Timer> index_timer = clock_.GetNewTimer(); + + if (!IsDocumentIdValid(document_id)) { + return absl_ports::InvalidArgumentError( + IcingStringUtil::StringPrintf("Invalid DocumentId %d", document_id)); + } + + if (embedding_index_.last_added_document_id() != kInvalidDocumentId && + document_id <= embedding_index_.last_added_document_id()) { + if (recovery_mode) { + // Skip the document if document_id <= last_added_document_id in + // recovery mode without returning an error. + return libtextclassifier3::Status::OK; + } + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "DocumentId %d must be greater than last added document_id %d", + document_id, embedding_index_.last_added_document_id())); + } + embedding_index_.set_last_added_document_id(document_id); + + for (const Section<PropertyProto::VectorProto>& vector_section : + tokenized_document.vector_sections()) { + BasicHit hit(/*section_id=*/vector_section.metadata.id, document_id); + for (const PropertyProto::VectorProto& vector : vector_section.content) { + ICING_RETURN_IF_ERROR(embedding_index_.BufferEmbedding(hit, vector)); + } + } + ICING_RETURN_IF_ERROR(embedding_index_.CommitBufferToIndex()); + + if (put_document_stats != nullptr) { + put_document_stats->set_embedding_index_latency_ms( + index_timer->GetElapsedMilliseconds()); + } + + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/embedding-indexing-handler.h b/icing/index/embedding-indexing-handler.h new file mode 100644 index 0000000..f3adf6a --- /dev/null +++ b/icing/index/embedding-indexing-handler.h @@ -0,0 +1,70 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ +#define ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ + +#include <memory> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/data-indexing-handler.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +class EmbeddingIndexingHandler : public DataIndexingHandler { + public: + ~EmbeddingIndexingHandler() override = default; + + // Creates an EmbeddingIndexingHandler instance which does not take + // ownership of any input components. All pointers must refer to valid objects + // that outlive the created EmbeddingIndexingHandler instance. + // + // Returns: + // - An EmbeddingIndexingHandler instance on success + // - FAILED_PRECONDITION_ERROR if any of the input pointer is null + static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndexingHandler>> + Create(const Clock* clock, EmbeddingIndex* embedding_index); + + // Handles the embedding indexing process: add hits into the embedding index + // for all contents in tokenized_document.vector_sections. + // + // Returns: + // - OK on success. + // - INVALID_ARGUMENT_ERROR if document_id is invalid OR document_id is less + // than or equal to the document_id of a previously indexed document in + // non recovery mode. + // - INTERNAL_ERROR if any other errors occur. + // - Any embedding index errors. + libtextclassifier3::Status Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + bool recovery_mode, PutDocumentStatsProto* put_document_stats) override; + + private: + explicit EmbeddingIndexingHandler(const Clock* clock, + EmbeddingIndex* embedding_index) + : DataIndexingHandler(clock), embedding_index_(*embedding_index) {} + + EmbeddingIndex& embedding_index_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_EMBEDDING_INDEXING_HANDLER_H_ diff --git a/icing/index/embedding-indexing-handler_test.cc b/icing/index/embedding-indexing-handler_test.cc new file mode 100644 index 0000000..ed711c1 --- /dev/null +++ b/icing/index/embedding-indexing-handler_test.cc @@ -0,0 +1,620 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embedding-indexing-handler.h" + +#include <cstdint> +#include <initializer_list> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/posting-list-embedding-hit-accessor.h" +#include "icing/index/hit/hit.h" +#include "icing/portable/platform.h" +#include "icing/proto/document_wrapper.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsTrue; + +// Indexable properties (section) and section id. Section id is determined by +// the lexicographical order of indexable property paths. +// Schema type with indexable properties: FakeType +// Section id = 0: "body" +// Section id = 1: "bodyEmbedding" +// Section id = 2: "title" +// Section id = 3: "titleEmbedding" +static constexpr std::string_view kFakeType = "FakeType"; +static constexpr std::string_view kPropertyBody = "body"; +static constexpr std::string_view kPropertyBodyEmbedding = "bodyEmbedding"; +static constexpr std::string_view kPropertyTitle = "title"; +static constexpr std::string_view kPropertyTitleEmbedding = "titleEmbedding"; +static constexpr std::string_view kPropertyNonIndexableEmbedding = + "nonIndexableEmbedding"; + +static constexpr SectionId kSectionIdBodyEmbedding = 1; +static constexpr SectionId kSectionIdTitleEmbedding = 3; + +// Schema type with nested indexable properties: FakeCollectionType +// Section id = 0: "collection.body" +// Section id = 1: "collection.bodyEmbedding" +// Section id = 2: "collection.title" +// Section id = 3: "collection.titleEmbedding" +// Section id = 4: "fullDocEmbedding" +static constexpr std::string_view kFakeCollectionType = "FakeCollectionType"; +static constexpr std::string_view kPropertyCollection = "collection"; +static constexpr std::string_view kPropertyFullDocEmbedding = + "fullDocEmbedding"; + +static constexpr SectionId kSectionIdNestedBodyEmbedding = 1; +static constexpr SectionId kSectionIdNestedTitleEmbedding = 3; +static constexpr SectionId kSectionIdFullDocEmbedding = 4; + +class EmbeddingIndexingHandlerTest : public ::testing::Test { + protected: + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + base_dir_ = GetTestTempDir() + "/icing_test"; + ASSERT_THAT(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()), + IsTrue()); + + embedding_index_working_path_ = base_dir_ + "/embedding_index"; + schema_store_dir_ = base_dir_ + "/schema_store"; + document_store_dir_ = base_dir_ + "/document_store"; + + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_working_path_)); + + language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + lang_segmenter_, + language_segmenter_factory::Create(std::move(segmenter_options))); + + ASSERT_THAT( + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()), + IsTrue()); + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + SchemaProto schema = + SchemaBuilder() + .AddType( + SchemaTypeConfigBuilder() + .SetType(kFakeType) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyTitle) + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyBody) + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyTitleEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType:: + LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyBodyEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType:: + LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyNonIndexableEmbedding) + .SetDataType(TYPE_VECTOR) + .SetCardinality(CARDINALITY_REPEATED))) + .AddType(SchemaTypeConfigBuilder() + .SetType(kFakeCollectionType) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyCollection) + .SetDataTypeDocument( + kFakeType, + /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kPropertyFullDocEmbedding) + .SetDataTypeVector( + EmbeddingIndexingConfig:: + EmbeddingIndexingType::LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ICING_ASSERT_OK(schema_store_->SetSchema( + schema, /*ignore_errors_and_delete_documents=*/false, + /*allow_circular_schema_definitions=*/false)); + + ASSERT_TRUE( + filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str())); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult doc_store_create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get(), + /*force_recovery_and_revalidate_documents=*/false, + /*namespace_id_fingerprint=*/false, + /*pre_mapping_fbv=*/false, + /*use_persistent_hash_map=*/false, + PortableFileBackedProtoLog< + DocumentWrapper>::kDeflateCompressionLevel, + /*initialize_stats=*/nullptr)); + document_store_ = std::move(doc_store_create_result.document_store); + } + + void TearDown() override { + document_store_.reset(); + schema_store_.reset(); + lang_segmenter_.reset(); + embedding_index_.reset(); + + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits( + uint32_t dimension, std::string_view model_signature) { + std::vector<EmbeddingHit> hits; + + libtextclassifier3::StatusOr< + std::unique_ptr<PostingListEmbeddingHitAccessor>> + pl_accessor_or = + embedding_index_->GetAccessor(dimension, model_signature); + std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor; + if (pl_accessor_or.ok()) { + pl_accessor = std::move(pl_accessor_or).ValueOrDie(); + } else if (absl_ports::IsNotFound(pl_accessor_or.status())) { + return hits; + } else { + return std::move(pl_accessor_or).status(); + } + + while (true) { + ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch, + pl_accessor->GetNextHitsBatch()); + if (batch.empty()) { + return hits; + } + hits.insert(hits.end(), batch.begin(), batch.end()); + } + } + + std::vector<float> GetRawEmbeddingData() { + return std::vector<float>(embedding_index_->GetRawEmbeddingData(), + embedding_index_->GetRawEmbeddingData() + + embedding_index_->GetTotalVectorSize()); + } + + Filesystem filesystem_; + FakeClock fake_clock_; + std::string base_dir_; + std::string embedding_index_working_path_; + std::string schema_store_dir_; + std::string document_store_dir_; + + std::unique_ptr<EmbeddingIndex> embedding_index_; + std::unique_ptr<LanguageSegmenter> lang_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> document_store_; +}; + +} // namespace + +TEST_F(EmbeddingIndexingHandlerTest, CreationWithNullPointerShouldFail) { + EXPECT_THAT(EmbeddingIndexingHandler::Create(/*clock=*/nullptr, + embedding_index_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + + EXPECT_THAT(EmbeddingIndexingHandler::Create(&fake_clock_, + /*embedding_index=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(EmbeddingIndexingHandlerTest, HandleEmbeddingSection) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kInvalidDocumentId)); + // Handle document. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + IsOk()); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); +} + +TEST_F(EmbeddingIndexingHandlerTest, HandleNestedEmbeddingSection) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_collection_type/1") + .SetSchema(std::string(kFakeCollectionType)) + .AddDocumentProperty( + std::string(kPropertyCollection), + DocumentBuilder() + .SetKey("icing", "nested_fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty( + std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build()) + .AddVectorProperty(std::string(kPropertyFullDocEmbedding), + CreateVector("model", {2.1, 2.2, 2.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kInvalidDocumentId)); + // Handle document. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + IsOk()); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit( + BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit( + BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit( + BasicHit(kSectionIdNestedTitleEmbedding, /*document_id=*/0), + /*location=*/6), + EmbeddingHit(BasicHit(kSectionIdFullDocEmbedding, /*document_id=*/0), + /*location=*/9)))); + EXPECT_THAT(GetRawEmbeddingData(), ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, + 0.1, 0.2, 0.3, 2.1, 2.2, 2.3)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleInvalidDocumentIdShouldReturnInvalidArgumentError) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK(document_store_->Put(tokenized_document.document())); + + static constexpr DocumentId kCurrentDocumentId = 3; + embedding_index_->set_last_added_document_id(kCurrentDocumentId); + ASSERT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handling document with kInvalidDocumentId should cause a failure, and both + // index data and last_added_document_id should remain unchanged. + EXPECT_THAT( + handler->Handle(tokenized_document, kInvalidDocumentId, + /*recovery_mode=*/false, /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + + // Recovery mode should get the same result. + EXPECT_THAT( + handler->Handle(tokenized_document, kInvalidDocumentId, + /*recovery_mode=*/true, /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), + Eq(kCurrentDocumentId)); + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(tokenized_document.document())); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handling document with document_id == last_added_document_id should cause a + // failure, and both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id)); + + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); + + // Handling document with document_id < last_added_document_id should cause a + // failure, and both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id + 1); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1)); + EXPECT_THAT( + handler->Handle(tokenized_document, document_id, /*recovery_mode=*/false, + /*put_document_stats=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1)); + + // Check that the embedding index should be empty + EXPECT_THAT(GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetRawEmbeddingData(), IsEmpty()); +} + +TEST_F(EmbeddingIndexingHandlerTest, + HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId) { + DocumentProto document1 = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title one") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {0.1, 0.2, 0.3})) + .AddStringProperty(std::string(kPropertyBody), "body one") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {0.4, 0.5, 0.6}), + CreateVector("model", {0.7, 0.8, 0.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {1.1, 1.2, 1.3})) + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("icing", "fake_type/2") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kPropertyTitle), "title two") + .AddVectorProperty(std::string(kPropertyTitleEmbedding), + CreateVector("model", {10.1, 10.2, 10.3})) + .AddStringProperty(std::string(kPropertyBody), "body two") + .AddVectorProperty(std::string(kPropertyBodyEmbedding), + CreateVector("model", {10.4, 10.5, 10.6}), + CreateVector("model", {10.7, 10.8, 10.9})) + .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding), + CreateVector("model", {11.1, 11.2, 11.3})) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document1, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document1))); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document2, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + std::move(document2))); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store_->Put(tokenized_document1.document())); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store_->Put(tokenized_document2.document())); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<EmbeddingIndexingHandler> handler, + EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get())); + + // Handle document with document_id > last_added_document_id in recovery mode. + // The handler should index this document and update last_added_document_id. + EXPECT_THAT( + handler->Handle(tokenized_document1, document_id1, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id1)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + + // Handle document with document_id == last_added_document_id in recovery + // mode. We should not get any error, but the handler should ignore the + // document, so both index data and last_added_document_id should remain + // unchanged. + embedding_index_->set_last_added_document_id(document_id2); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2)); + EXPECT_THAT( + handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); + + // Handle document with document_id < last_added_document_id in recovery mode. + // We should not get any error, but the handler should ignore the document, so + // both index data and last_added_document_id should remain unchanged. + embedding_index_->set_last_added_document_id(document_id2 + 1); + ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1)); + EXPECT_THAT( + handler->Handle(tokenized_document2, document_id2, /*recovery_mode=*/true, + /*put_document_stats=*/nullptr), + IsOk()); + EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1)); + + // Check index + EXPECT_THAT( + GetHits(/*dimension=*/3, /*model_signature=*/"model"), + IsOkAndHolds(ElementsAre( + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/0), + EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0), + /*location=*/3), + EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0), + /*location=*/6)))); + EXPECT_THAT(GetRawEmbeddingData(), + ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/hit/hit.cc b/icing/index/hit/hit.cc index 493e62b..7b7f2f6 100644 --- a/icing/index/hit/hit.cc +++ b/icing/index/hit/hit.cc @@ -14,42 +14,19 @@ #include "icing/index/hit/hit.h" +#include <cstring> +#include <limits> + +#include "icing/schema/section.h" #include "icing/store/document-id.h" #include "icing/util/bit-util.h" +#include "icing/util/logging.h" namespace icing { namespace lib { namespace { -enum FlagOffset { - // This hit, whether exact or not, came from a prefixed section and will - // need to be backfilled into branching posting lists if/when those are - // created. - kInPrefixSection = 0, - // This hit represents a prefix of a longer term. If exact matches are - // required, then this hit should be ignored. - kPrefixHit = 1, - // Whether or not the hit has a term_frequency other than - // kDefaultTermFrequency. - kHasTermFrequency = 2, - kNumFlags = 3, -}; - -static_assert(kDocumentIdBits + kSectionIdBits + kNumFlags < - sizeof(Hit::Value) * 8, - "Hit::kInvalidValue contains risky value and we should have at " - "least one unused bit to avoid potential bugs. Please follow the " - "process mentioned in hit.h to correct the value of " - "Hit::kInvalidValue and remove this static_assert afterwards."); - -static_assert(kDocumentIdBits + kSectionIdBits + kNumFlags <= - sizeof(Hit::Value) * 8, - "HitOverflow"); -static_assert(kDocumentIdBits == 22, ""); -static_assert(kSectionIdBits == 6, ""); -static_assert(kNumFlags == 3, ""); - inline DocumentId InvertDocumentId(DocumentId document_id) { static_assert(kMaxDocumentId <= (std::numeric_limits<DocumentId>::max() - 1), "(kMaxDocumentId + 1) must not overflow."); @@ -88,10 +65,28 @@ SectionId BasicHit::section_id() const { /*len=*/kSectionIdBits); } +Hit::Hit(Value value, Flags flags, TermFrequency term_frequency) + : flags_(flags), term_frequency_(term_frequency) { + memcpy(value_.data(), &value, sizeof(value)); + if (!CheckFlagsAreConsistent()) { + ICING_VLOG(1) + << "Creating Hit that has inconsistent flag values across its fields: " + << "Hit(value=" << value << ", flags=" << flags + << "term_frequency=" << term_frequency << ")"; + } +} + Hit::Hit(SectionId section_id, DocumentId document_id, Hit::TermFrequency term_frequency, bool is_in_prefix_section, bool is_prefix_hit) : term_frequency_(term_frequency) { + // We compute flags first as the value's has_flags bit depends on the flags_ + // field. + Flags temp_flags = 0; + bit_util::BitfieldSet(term_frequency != kDefaultTermFrequency, + kHasTermFrequency, /*len=*/1, &temp_flags); + flags_ = temp_flags; + // Values are stored so that when sorted, they appear in document_id // descending, section_id ascending, order. Also, all else being // equal, non-prefix hits sort before prefix hits. So inverted @@ -99,30 +94,29 @@ Hit::Hit(SectionId section_id, DocumentId document_id, // (uninverted) section_id. Value temp_value = 0; bit_util::BitfieldSet(InvertDocumentId(document_id), - kSectionIdBits + kNumFlags, kDocumentIdBits, + kSectionIdBits + kNumFlagsInValueField, kDocumentIdBits, + &temp_value); + bit_util::BitfieldSet(section_id, kNumFlagsInValueField, kSectionIdBits, &temp_value); - bit_util::BitfieldSet(section_id, kNumFlags, kSectionIdBits, &temp_value); - bit_util::BitfieldSet(term_frequency != kDefaultTermFrequency, - kHasTermFrequency, /*len=*/1, &temp_value); bit_util::BitfieldSet(is_prefix_hit, kPrefixHit, /*len=*/1, &temp_value); bit_util::BitfieldSet(is_in_prefix_section, kInPrefixSection, /*len=*/1, &temp_value); - value_ = temp_value; + + bool has_flags = flags_ != kNoEnabledFlags; + bit_util::BitfieldSet(has_flags, kHasFlags, /*len=*/1, &temp_value); + + memcpy(value_.data(), &temp_value, sizeof(temp_value)); } DocumentId Hit::document_id() const { DocumentId inverted_document_id = bit_util::BitfieldGet( - value(), kSectionIdBits + kNumFlags, kDocumentIdBits); + value(), kSectionIdBits + kNumFlagsInValueField, kDocumentIdBits); // Undo the document_id inversion. return InvertDocumentId(inverted_document_id); } SectionId Hit::section_id() const { - return bit_util::BitfieldGet(value(), kNumFlags, kSectionIdBits); -} - -bool Hit::has_term_frequency() const { - return bit_util::BitfieldGet(value(), kHasTermFrequency, 1); + return bit_util::BitfieldGet(value(), kNumFlagsInValueField, kSectionIdBits); } bool Hit::is_prefix_hit() const { @@ -133,6 +127,27 @@ bool Hit::is_in_prefix_section() const { return bit_util::BitfieldGet(value(), kInPrefixSection, 1); } +bool Hit::has_flags() const { + return bit_util::BitfieldGet(value(), kHasFlags, 1); +} + +bool Hit::has_term_frequency() const { + return bit_util::BitfieldGet(flags(), kHasTermFrequency, 1); +} + +bool Hit::CheckFlagsAreConsistent() const { + bool has_flags = flags_ != kNoEnabledFlags; + bool has_flags_enabled_in_value = + bit_util::BitfieldGet(value(), kHasFlags, /*len=*/1); + + bool has_term_frequency = term_frequency_ != kDefaultTermFrequency; + bool has_term_frequency_enabled_in_flags = + bit_util::BitfieldGet(flags(), kHasTermFrequency, /*len=*/1); + + return has_flags == has_flags_enabled_in_value && + has_term_frequency == has_term_frequency_enabled_in_flags; +} + Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) { return Hit(old_hit.section_id(), new_document_id, old_hit.term_frequency(), old_hit.is_in_prefix_section(), old_hit.is_prefix_hit()); @@ -140,7 +155,8 @@ Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) { bool Hit::EqualsDocumentIdAndSectionId::operator()(const Hit& hit1, const Hit& hit2) const { - return (hit1.value() >> kNumFlags) == (hit2.value() >> kNumFlags); + return (hit1.value() >> kNumFlagsInValueField) == + (hit2.value() >> kNumFlagsInValueField); } } // namespace lib diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index 111b320..e971016 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -17,6 +17,7 @@ #include <array> #include <cstdint> +#include <cstring> #include <limits> #include "icing/legacy/core/icing-packed-pod.h" @@ -59,6 +60,7 @@ class BasicHit { explicit BasicHit(SectionId section_id, DocumentId document_id); explicit BasicHit() : value_(kInvalidValue) {} + explicit BasicHit(Value value) : value_(value) {} bool is_valid() const { return value_ != kInvalidValue; } Value value() const { return value_; } @@ -70,6 +72,9 @@ class BasicHit { private: // Value bits layout: 4 unused + 22 document_id + 6 section id. + // + // The Value is guaranteed to be an unsigned integer, but its size and the + // information it stores may change if the Hit's encoding format is changed. Value value_; } __attribute__((packed)); static_assert(sizeof(BasicHit) == 4, ""); @@ -78,51 +83,83 @@ static_assert(sizeof(BasicHit) == 4, ""); // consists of: // - a DocumentId // - a SectionId -// referring to the document and section that the hit corresponds to, as well as -// metadata about the hit: -// - whether the Hit has a TermFrequency other than the default value -// - whether the Hit does not appear exactly in the document, but instead -// represents a term that is a prefix of a term in the document -// - whether the Hit came from a section that has prefix expansion enabled -// and a term frequency for the hit. +// - referring to the document and section that the hit corresponds to +// - Metadata about the hit: +// - whether the Hit does not appear exactly in the document, but instead +// represents a term that is a prefix of a term in the document +// (is_prefix_hit) +// - whether the Hit came from a section that has prefix expansion enabled +// (is_in_prefix_section) +// - whether the Hit has set any bitmask flags (has_flags) +// - bitmasks in flags fields: +// - whether the Hit has a TermFrequency other than the default value +// (has_term_frequency) +// - a term frequency for the hit // // The hit is the most basic unit of the index and, when grouped together by // term, can be used to encode what terms appear in what documents. class Hit { public: - // The datatype used to encode Hit information: the document_id, section_id - // and the has_term_frequency, prefix hit and in prefix section flags. + // The datatype used to encode Hit information: the document_id, section_id, + // and 3 flags: is_prefix_hit, is_hit_in_prefix_section and has_flags flag. + // + // The Value is guaranteed to be an unsigned integer, but its size and the + // information it stores may change if the Hit's encoding format is changed. using Value = uint32_t; // WARNING: Changing this value will invalidate any pre-existing posting lists // on user devices. // - // WARNING: - // - Hit::kInvalidValue should contain inverted kInvalidDocumentId, which is - // b'00...0. However, currently we set it as UINT32_MAX and actually it - // contains b'11...1, which is the inverted document_id 0. - // - It means Hit::kInvalidValue contains valid (document_id, section_id, - // flags), so we potentially cannot distinguish if a Hit is invalid or not. - // The invalidity is an essential feature for posting list since we use it - // to determine the state of the posting list. - // - The reason why it won't break the current posting list is because the - // unused bit(s) are set as 1 for Hit::kInvalidValue and 0 for all valid - // Hits. In other words, the unused bit(s) are actually serving as "invalid - // flag". - // - If we want to exhaust all unused bits in the future, then we have to - // change Hit::kInvalidValue to set the inverted document_id section - // correctly (b'00...0, refer to BasicHit::kInvalidValue as an example). - // - Also this problem is guarded by static_assert in hit.cc. If exhausting - // all unused bits, then the static_assert will detect and fail. We can - // safely remove the static_assert check after following the above process - // to resolve the incorrect Hit::kInvalidValue issue. - static constexpr Value kInvalidValue = std::numeric_limits<Value>::max(); + // kInvalidValue contains: + // - 0 for unused bits. Note that unused bits are always 0 for both valid and + // invalid Hit values. + // - Inverted kInvalidDocumentId + // - SectionId 0 (valid), which is ok because inverted kInvalidDocumentId has + // already invalidated the value. In fact, we currently use all 2^6 section + // ids and there is no "invalid section id", so it doesn't matter what + // SectionId we set for kInvalidValue. + static constexpr Value kInvalidValue = 0; // Docs are sorted in reverse, and 0 is never used as the inverted // DocumentId (because it is the inverse of kInvalidValue), so it is always // the max in a descending sort. static constexpr Value kMaxDocumentIdSortValue = 0; + enum FlagOffsetsInFlagsField { + // Whether or not the hit has a term_frequency other than + // kDefaultTermFrequency. + kHasTermFrequency = 0, + kNumFlagsInFlagsField = 1, + }; + + enum FlagOffsetsInValueField { + // Whether or not the hit has a flags value other than kNoEnabledFlags (i.e. + // it has flags enabled in the flags field) + kHasFlags = 0, + // This hit, whether exact or not, came from a prefixed section and will + // need to be backfilled into branching posting lists if/when those are + // created. + kInPrefixSection = 1, + // This hit represents a prefix of a longer term. If exact matches are + // required, then this hit should be ignored. + kPrefixHit = 2, + kNumFlagsInValueField = 3, + }; + static_assert(kDocumentIdBits + kSectionIdBits + kNumFlagsInValueField <= + sizeof(Hit::Value) * 8, + "HitOverflow"); + static_assert(kDocumentIdBits == 22, ""); + static_assert(kSectionIdBits == 6, ""); + static_assert(kNumFlagsInValueField == 3, ""); + + // The datatype used to encode additional bit-flags in the Hit. + // This is guaranteed to be an unsigned integer, but its size may change if + // more flags are introduced in the future and require more bits to encode. + using Flags = uint8_t; + static constexpr Flags kNoEnabledFlags = 0; + // The Term Frequency of a Hit. + // This is guaranteed to be an unsigned integer, but its size may change if we + // need to expand the max term-frequency. using TermFrequency = uint8_t; using TermFrequencyArray = std::array<Hit::TermFrequency, kTotalNumSections>; // Max TermFrequency is 255. @@ -131,40 +168,67 @@ class Hit { static constexpr TermFrequency kDefaultTermFrequency = 1; static constexpr TermFrequency kNoTermFrequency = 0; - explicit Hit(Value value = kInvalidValue, - TermFrequency term_frequency = kDefaultTermFrequency) - : value_(value), term_frequency_(term_frequency) {} - Hit(SectionId section_id, DocumentId document_id, - TermFrequency term_frequency, bool is_in_prefix_section = false, - bool is_prefix_hit = false); + explicit Hit(Value value) + : Hit(value, kNoEnabledFlags, kDefaultTermFrequency) {} + explicit Hit(Value value, Flags flags, TermFrequency term_frequency); + explicit Hit(SectionId section_id, DocumentId document_id, + TermFrequency term_frequency, bool is_in_prefix_section, + bool is_prefix_hit); bool is_valid() const { return value() != kInvalidValue; } - Value value() const { return value_; } + + Value value() const { + Value value; + memcpy(&value, value_.data(), sizeof(value)); + return value; + } + DocumentId document_id() const; SectionId section_id() const; + bool is_prefix_hit() const; + bool is_in_prefix_section() const; + // Whether or not the hit has any flags set to true. + bool has_flags() const; + + Flags flags() const { return flags_; } // Whether or not the hit contains a valid term frequency. bool has_term_frequency() const; + TermFrequency term_frequency() const { return term_frequency_; } - bool is_prefix_hit() const; - bool is_in_prefix_section() const; + + // Returns true if the flags values across the Hit's value_, term_frequency_ + // and flags_ fields are consistent. + bool CheckFlagsAreConsistent() const; // Creates a new hit based on old_hit but with new_document_id set. static Hit TranslateHit(Hit old_hit, DocumentId new_document_id); - bool operator<(const Hit& h2) const { return value() < h2.value(); } - bool operator==(const Hit& h2) const { return value() == h2.value(); } + bool operator<(const Hit& h2) const { + if (value() != h2.value()) { + return value() < h2.value(); + } + return flags() < h2.flags(); + } + bool operator==(const Hit& h2) const { + return value() == h2.value() && flags() == h2.flags(); + } struct EqualsDocumentIdAndSectionId { bool operator()(const Hit& hit1, const Hit& hit2) const; }; private: - // Value and TermFrequency must be in this order. - // Value bits layout: 1 unused + 22 document_id + 6 section id + 3 flags. - Value value_; + // Value, Flags and TermFrequency must be in this order. + // Value bits layout: 1 unused + 22 document_id + 6 section_id + 1 + // is_prefix_hit + 1 is_in_prefix_section + 1 has_flags. + std::array<char, sizeof(Value)> value_; + // Flags bits layout: 1 reserved + 6 unused + 1 has_term_frequency. + // The left-most bit is reserved for chaining additional fields in case of + // future hit expansions. + Flags flags_; TermFrequency term_frequency_; -} __attribute__((packed)); -static_assert(sizeof(Hit) == 5, ""); +}; +static_assert(sizeof(Hit) == 6, ""); // TODO(b/138991332) decide how to remove/replace all is_packed_pod assertions. static_assert(icing_is_packed_pod<Hit>::value, "go/icing-ubsan"); diff --git a/icing/index/hit/hit_test.cc b/icing/index/hit/hit_test.cc index 0086d91..1233e00 100644 --- a/icing/index/hit/hit_test.cc +++ b/icing/index/hit/hit_test.cc @@ -14,6 +14,9 @@ #include "icing/index/hit/hit.h" +#include <algorithm> +#include <vector> + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/schema/section.h" @@ -94,77 +97,127 @@ TEST(BasicHitTest, Comparison) { } TEST(HitTest, HasTermFrequencyFlag) { - Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency); + Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(h1.has_term_frequency(), IsFalse()); EXPECT_THAT(h1.term_frequency(), Eq(Hit::kDefaultTermFrequency)); - Hit h2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency); + Hit h2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(h2.has_term_frequency(), IsTrue()); EXPECT_THAT(h2.term_frequency(), Eq(kSomeTermFrequency)); } TEST(HitTest, IsPrefixHitFlag) { - Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency); + Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(h1.is_prefix_hit(), IsFalse()); Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); EXPECT_THAT(h2.is_prefix_hit(), IsFalse()); Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true); EXPECT_THAT(h3.is_prefix_hit(), IsTrue()); + + Hit h4(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); + EXPECT_THAT(h4.is_prefix_hit(), IsFalse()); } TEST(HitTest, IsInPrefixSectionFlag) { - Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency); + Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(h1.is_in_prefix_section(), IsFalse()); Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(h2.is_in_prefix_section(), IsFalse()); Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); EXPECT_THAT(h3.is_in_prefix_section(), IsTrue()); } +TEST(HitTest, HasFlags) { + Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + EXPECT_THAT(h1.has_flags(), IsFalse()); + + Hit h2(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true); + EXPECT_THAT(h2.has_flags(), IsFalse()); + + Hit h3(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); + EXPECT_THAT(h3.has_flags(), IsFalse()); + + Hit h4(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + EXPECT_THAT(h4.has_flags(), IsTrue()); + + Hit h5(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_prefix_hit=*/true, /*is_in_prefix_section=*/true); + EXPECT_THAT(h5.has_flags(), IsTrue()); + + Hit h6(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_prefix_hit=*/false, /*is_in_prefix_section=*/true); + EXPECT_THAT(h6.has_flags(), IsTrue()); + + Hit h7(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_prefix_hit=*/true, /*is_in_prefix_section=*/false); + EXPECT_THAT(h7.has_flags(), IsTrue()); +} + TEST(HitTest, Accessors) { - Hit h1(kSomeSectionid, kSomeDocumentId, Hit::kDefaultTermFrequency); + Hit h1(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true); EXPECT_THAT(h1.document_id(), Eq(kSomeDocumentId)); EXPECT_THAT(h1.section_id(), Eq(kSomeSectionid)); + EXPECT_THAT(h1.term_frequency(), Eq(kSomeTermFrequency)); + EXPECT_THAT(h1.is_in_prefix_section(), IsFalse()); + EXPECT_THAT(h1.is_prefix_hit(), IsTrue()); } TEST(HitTest, Valid) { - Hit def; + Hit def(Hit::kInvalidValue); EXPECT_THAT(def.is_valid(), IsFalse()); - - Hit explicit_invalid(Hit::kInvalidValue); + Hit explicit_invalid(Hit::kInvalidValue, Hit::kNoEnabledFlags, + Hit::kDefaultTermFrequency); EXPECT_THAT(explicit_invalid.is_valid(), IsFalse()); static constexpr Hit::Value kSomeValue = 65372; - Hit explicit_valid(kSomeValue); + Hit explicit_valid(kSomeValue, Hit::kNoEnabledFlags, + Hit::kDefaultTermFrequency); EXPECT_THAT(explicit_valid.is_valid(), IsTrue()); - Hit maximum_document_id_hit(kSomeSectionid, kMaxDocumentId, - kSomeTermFrequency); + Hit maximum_document_id_hit( + kSomeSectionid, kMaxDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(maximum_document_id_hit.is_valid(), IsTrue()); - Hit maximum_section_id_hit(kMaxSectionId, kSomeDocumentId, - kSomeTermFrequency); + Hit maximum_section_id_hit(kMaxSectionId, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); EXPECT_THAT(maximum_section_id_hit.is_valid(), IsTrue()); - Hit minimum_document_id_hit(kSomeSectionid, 0, kSomeTermFrequency); + Hit minimum_document_id_hit(kSomeSectionid, 0, kSomeTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); EXPECT_THAT(minimum_document_id_hit.is_valid(), IsTrue()); - Hit minimum_section_id_hit(0, kSomeDocumentId, kSomeTermFrequency); + Hit minimum_section_id_hit(0, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); EXPECT_THAT(minimum_section_id_hit.is_valid(), IsTrue()); - // We use Hit with value Hit::kMaxDocumentIdSortValue for std::lower_bound in - // the lite index. Verify that the value of the smallest valid Hit (which + // We use Hit with value Hit::kMaxDocumentIdSortValue for std::lower_bound + // in the lite index. Verify that the value of the smallest valid Hit (which // contains kMinSectionId, kMaxDocumentId and 3 flags = false) is >= // Hit::kMaxDocumentIdSortValue. - Hit smallest_hit(kMinSectionId, kMaxDocumentId, Hit::kDefaultTermFrequency); + Hit smallest_hit(kMinSectionId, kMaxDocumentId, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ASSERT_THAT(smallest_hit.is_valid(), IsTrue()); ASSERT_THAT(smallest_hit.has_term_frequency(), IsFalse()); ASSERT_THAT(smallest_hit.is_prefix_hit(), IsFalse()); @@ -173,39 +226,78 @@ TEST(HitTest, Valid) { } TEST(HitTest, Comparison) { - Hit hit(1, 243, Hit::kDefaultTermFrequency); + Hit hit(/*section_id=*/1, /*document_id=*/243, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // DocumentIds are sorted in ascending order. So a hit with a lower // document_id should be considered greater than one with a higher // document_id. - Hit higher_document_id_hit(1, 2409, Hit::kDefaultTermFrequency); - Hit higher_section_id_hit(15, 243, Hit::kDefaultTermFrequency); + Hit higher_document_id_hit( + /*section_id=*/1, /*document_id=*/2409, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit higher_section_id_hit(/*section_id=*/15, /*document_id=*/243, + Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); // Whether or not a term frequency was set is considered, but the term // frequency itself is not. - Hit term_frequency_hit(1, 243, 12); - Hit prefix_hit(1, 243, Hit::kDefaultTermFrequency, + Hit term_frequency_hit(/*section_id=*/1, 243, /*term_frequency=*/12, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); + Hit prefix_hit(/*section_id=*/1, 243, Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true); - Hit hit_in_prefix_section(1, 243, Hit::kDefaultTermFrequency, + Hit hit_in_prefix_section(/*section_id=*/1, 243, Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); + Hit hit_with_all_flags_enabled(/*section_id=*/1, 243, 56, + /*is_in_prefix_section=*/true, + /*is_prefix_hit=*/true); std::vector<Hit> hits{hit, higher_document_id_hit, higher_section_id_hit, term_frequency_hit, prefix_hit, - hit_in_prefix_section}; + hit_in_prefix_section, + hit_with_all_flags_enabled}; std::sort(hits.begin(), hits.end()); - EXPECT_THAT( - hits, ElementsAre(higher_document_id_hit, hit, hit_in_prefix_section, - prefix_hit, term_frequency_hit, higher_section_id_hit)); + EXPECT_THAT(hits, + ElementsAre(higher_document_id_hit, hit, term_frequency_hit, + hit_in_prefix_section, prefix_hit, + hit_with_all_flags_enabled, higher_section_id_hit)); - Hit higher_term_frequency_hit(1, 243, 108); + Hit higher_term_frequency_hit(/*section_id=*/1, 243, /*term_frequency=*/108, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); // The term frequency value is not considered when comparing hits. EXPECT_THAT(term_frequency_hit, Not(Lt(higher_term_frequency_hit))); EXPECT_THAT(higher_term_frequency_hit, Not(Lt(term_frequency_hit))); } +TEST(HitTest, CheckFlagsAreConsistent) { + Hit::Value value_without_flags = 1 << 30; + Hit::Value value_with_flags = value_without_flags + 1; + Hit::Flags flags_with_term_freq = 1; + + Hit consistent_hit_no_flags(value_without_flags, Hit::kNoEnabledFlags, + Hit::kDefaultTermFrequency); + Hit consistent_hit_with_term_frequency(value_with_flags, flags_with_term_freq, + kSomeTermFrequency); + EXPECT_THAT(consistent_hit_no_flags.CheckFlagsAreConsistent(), IsTrue()); + EXPECT_THAT(consistent_hit_with_term_frequency.CheckFlagsAreConsistent(), + IsTrue()); + + Hit inconsistent_hit_1(value_with_flags, Hit::kNoEnabledFlags, + Hit::kDefaultTermFrequency); + Hit inconsistent_hit_2(value_with_flags, Hit::kNoEnabledFlags, + kSomeTermFrequency); + Hit inconsistent_hit_3(value_with_flags, flags_with_term_freq, + Hit::kDefaultTermFrequency); + EXPECT_THAT(inconsistent_hit_1.CheckFlagsAreConsistent(), IsFalse()); + EXPECT_THAT(inconsistent_hit_2.CheckFlagsAreConsistent(), IsFalse()); + EXPECT_THAT(inconsistent_hit_3.CheckFlagsAreConsistent(), IsFalse()); +} + } // namespace } // namespace lib diff --git a/icing/index/index.cc b/icing/index/index.cc index 98058be..f917aa6 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -67,8 +67,7 @@ libtextclassifier3::StatusOr<LiteIndex::Options> CreateLiteIndexOptions( } return LiteIndex::Options( options.base_dir + "/idx/lite.", options.index_merge_size, - options.lite_index_sort_at_indexing, options.lite_index_sort_size, - options.include_property_existence_metadata_hits); + options.lite_index_sort_at_indexing, options.lite_index_sort_size); } std::string MakeMainIndexFilepath(const std::string& base_dir) { @@ -345,7 +344,8 @@ libtextclassifier3::Status Index::Editor::BufferTerm(const char* term) { libtextclassifier3::Status Index::Editor::IndexAllBufferedTerms() { for (auto itr = seen_tokens_.begin(); itr != seen_tokens_.end(); itr++) { Hit hit(section_id_, document_id_, /*term_frequency=*/itr->second, - term_match_type_ == TermMatchType::PREFIX); + /*is_in_prefix_section=*/term_match_type_ == TermMatchType::PREFIX, + /*is_prefix_hit=*/false); ICING_ASSIGN_OR_RETURN( uint32_t term_id, term_id_codec_->EncodeTvi(itr->first, TviType::LITE)); ICING_RETURN_IF_ERROR(lite_index_->AddHit(term_id, hit)); diff --git a/icing/index/index.h b/icing/index/index.h index a5d75c4..a09e28f 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -35,6 +35,7 @@ #include "icing/index/term-metadata.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/proto/debug.pb.h" +#include "icing/proto/logging.pb.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/storage.pb.h" #include "icing/proto/term.pb.h" @@ -72,20 +73,16 @@ class Index { struct Options { explicit Options(const std::string& base_dir, uint32_t index_merge_size, bool lite_index_sort_at_indexing, - uint32_t lite_index_sort_size, - bool include_property_existence_metadata_hits = false) + uint32_t lite_index_sort_size) : base_dir(base_dir), index_merge_size(index_merge_size), lite_index_sort_at_indexing(lite_index_sort_at_indexing), - lite_index_sort_size(lite_index_sort_size), - include_property_existence_metadata_hits( - include_property_existence_metadata_hits) {} + lite_index_sort_size(lite_index_sort_size) {} std::string base_dir; int32_t index_merge_size; bool lite_index_sort_at_indexing; int32_t lite_index_sort_size; - bool include_property_existence_metadata_hits; }; // Creates an instance of Index in the directory pointed by file_dir. @@ -178,6 +175,13 @@ class Index { return debug_info; } + void PublishQueryStats(QueryStatsProto* query_stats) const { + query_stats->set_lite_index_hit_buffer_byte_size( + lite_index_->GetHitBufferByteSize()); + query_stats->set_lite_index_hit_buffer_unsorted_byte_size( + lite_index_->GetHitBufferUnsortedByteSize()); + } + // Returns the byte size of the all the elements held in the index. This // excludes the size of any internal metadata of the index, e.g. the index's // header. @@ -301,9 +305,7 @@ class Index { } // Sorts the LiteIndex HitBuffer. - void SortLiteIndex() { - lite_index_->SortHits(); - } + void SortLiteIndex() { lite_index_->SortHits(); } // Reduces internal file sizes by reclaiming space of deleted documents. // new_last_added_document_id will be used to update the last added document diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 04a6bb7..50b65ad 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -33,9 +33,11 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/lite/term-id-hit-pair.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/proto/debug.pb.h" +#include "icing/proto/logging.pb.h" #include "icing/proto/storage.pb.h" #include "icing/proto/term.pb.h" #include "icing/schema/section.h" @@ -2734,11 +2736,47 @@ TEST_F(IndexTest, IndexStorageInfoProto) { EXPECT_THAT(storage_info.main_index_lexicon_size(), Ge(0)); EXPECT_THAT(storage_info.main_index_storage_size(), Ge(0)); EXPECT_THAT(storage_info.main_index_block_size(), Ge(0)); - // There should be 1 block for the header and 1 block for two posting lists. + // There should be 1 block for the header and 1 block for three posting lists + // ("fo", "foo", "foul"). EXPECT_THAT(storage_info.num_blocks(), Eq(2)); EXPECT_THAT(storage_info.min_free_fraction(), Ge(0)); } +TEST_F(IndexTest, PublishQueryStats) { + // Add two documents to the lite index without merging. + Index::Editor edit = index_->Edit(kDocumentId0, kSectionId2, + TermMatchType::PREFIX, /*namespace_id=*/0); + ASSERT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::PREFIX, + /*namespace_id=*/0); + ASSERT_THAT(edit.BufferTerm("foul"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + // Verify query stats. + QueryStatsProto query_stats1; + index_->PublishQueryStats(&query_stats1); + EXPECT_THAT(query_stats1.lite_index_hit_buffer_byte_size(), + Eq(2 * sizeof(TermIdHitPair::Value))); + EXPECT_THAT(query_stats1.lite_index_hit_buffer_unsorted_byte_size(), + Ge(2 * sizeof(TermIdHitPair::Value))); + + // Sort lite index. + index_->SortLiteIndex(); + QueryStatsProto query_stats2; + index_->PublishQueryStats(&query_stats2); + EXPECT_THAT(query_stats2.lite_index_hit_buffer_byte_size(), + Eq(2 * sizeof(TermIdHitPair::Value))); + EXPECT_THAT(query_stats2.lite_index_hit_buffer_unsorted_byte_size(), Eq(0)); + + // Merge lite index to main index. + ICING_ASSERT_OK(index_->Merge()); + QueryStatsProto query_stats3; + index_->PublishQueryStats(&query_stats3); + EXPECT_THAT(query_stats3.lite_index_hit_buffer_byte_size(), Eq(0)); + EXPECT_THAT(query_stats3.lite_index_hit_buffer_unsorted_byte_size(), Eq(0)); +} + } // namespace } // namespace lib diff --git a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h index c16a1c4..d766712 100644 --- a/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h +++ b/icing/index/iterator/doc-hit-info-iterator-property-in-schema.h @@ -17,13 +17,17 @@ #include <cstdint> #include <memory> +#include <set> #include <string> -#include <string_view> #include <utility> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" namespace icing { diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc index 35dc0b9..735adaa 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc @@ -113,12 +113,12 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions( const DocumentStore* document_store, const SchemaStore* schema_store, const SearchSpecProto& search_spec, int64_t current_time_ms) { std::unordered_map<std::string, std::set<std::string>> type_property_filters; - // TODO(b/294274922): Add support for polymorphism in type property filters. - for (const TypePropertyMask& type_property_mask : - search_spec.type_property_filters()) { - type_property_filters[type_property_mask.schema_type()] = - std::set<std::string>(type_property_mask.paths().begin(), - type_property_mask.paths().end()); + for (const SchemaStore::ExpandedTypePropertyMask& type_property_mask : + schema_store->ExpandTypePropertyMasks( + search_spec.type_property_filters())) { + type_property_filters[type_property_mask.schema_type] = + std::set<std::string>(type_property_mask.paths.begin(), + type_property_mask.paths.end()); } auto data = std::make_unique<SectionRestrictData>( document_store, schema_store, current_time_ms, type_property_filters); @@ -134,7 +134,9 @@ DocHitInfoIteratorSectionRestrict::ApplyRestrictions( ChildrenMapper mapper; mapper = [&data, &mapper](std::unique_ptr<DocHitInfoIterator> iterator) -> std::unique_ptr<DocHitInfoIterator> { - if (iterator->is_leaf()) { + if (iterator->full_section_restriction_applied()) { + return iterator; + } else if (iterator->is_leaf()) { return std::make_unique<DocHitInfoIteratorSectionRestrict>( std::move(iterator), data); } else { @@ -149,24 +151,8 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { doc_hit_info_ = DocHitInfo(kInvalidDocumentId); while (delegate_->Advance().ok()) { DocumentId document_id = delegate_->doc_hit_info().document_id(); - - auto data_optional = data_->document_store().GetAliveDocumentFilterData( - document_id, data_->current_time_ms()); - if (!data_optional) { - // Ran into some error retrieving information on this hit, skip - continue; - } - - // Guaranteed that the DocumentFilterData exists at this point - SchemaTypeId schema_type_id = data_optional.value().schema_type_id(); - auto schema_type_or = data_->schema_store().GetSchemaType(schema_type_id); - if (!schema_type_or.ok()) { - // Ran into error retrieving schema type, skip - continue; - } - const std::string* schema_type = std::move(schema_type_or).ValueOrDie(); SectionIdMask allowed_sections_mask = - data_->ComputeAllowedSectionsMask(*schema_type); + data_->ComputeAllowedSectionsMask(document_id); // A hit can be in multiple sections at once, need to check which of the // section ids match the sections allowed by type_property_masks_. This can diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h index 728f957..921e4d4 100644 --- a/icing/index/iterator/doc-hit-info-iterator.h +++ b/icing/index/iterator/doc-hit-info-iterator.h @@ -214,6 +214,12 @@ class DocHitInfoIterator { virtual bool is_leaf() { return false; } + // Whether the iterator has already been applied all the required section + // restrictions. + // If true, calling DocHitInfoIteratorSectionRestrict::ApplyRestrictions on + // this iterator will have no effect. + virtual bool full_section_restriction_applied() const { return false; } + virtual ~DocHitInfoIterator() = default; // Returns: diff --git a/icing/index/iterator/section-restrict-data.cc b/icing/index/iterator/section-restrict-data.cc index 085437d..581ac8e 100644 --- a/icing/index/iterator/section-restrict-data.cc +++ b/icing/index/iterator/section-restrict-data.cc @@ -22,6 +22,8 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" namespace icing { namespace lib { @@ -78,5 +80,24 @@ SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask( return new_section_id_mask; } +SectionIdMask SectionRestrictData::ComputeAllowedSectionsMask( + DocumentId document_id) { + auto data_optional = + document_store_.GetAliveDocumentFilterData(document_id, current_time_ms_); + if (!data_optional) { + // Ran into some error retrieving information on this document, skip + return kSectionIdMaskNone; + } + // Guaranteed that the DocumentFilterData exists at this point + SchemaTypeId schema_type_id = data_optional.value().schema_type_id(); + auto schema_type_or = schema_store_.GetSchemaType(schema_type_id); + if (!schema_type_or.ok()) { + // Ran into error retrieving schema type, skip + return kSectionIdMaskNone; + } + const std::string* schema_type = std::move(schema_type_or).ValueOrDie(); + return ComputeAllowedSectionsMask(*schema_type); +} + } // namespace lib } // namespace icing diff --git a/icing/index/iterator/section-restrict-data.h b/icing/index/iterator/section-restrict-data.h index 26ca597..64d5087 100644 --- a/icing/index/iterator/section-restrict-data.h +++ b/icing/index/iterator/section-restrict-data.h @@ -23,6 +23,7 @@ #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" namespace icing { @@ -48,10 +49,21 @@ class SectionRestrictData { // Returns: // - If type_property_filters_ has an entry for the given schema type or // wildcard(*), return a bitwise or of section IDs in the schema type - // that that are also present in the relevant filter list. + // that are also present in the relevant filter list. // - Otherwise, return kSectionIdMaskAll. SectionIdMask ComputeAllowedSectionsMask(const std::string& schema_type); + // Calculates the section mask of allowed sections(determined by the + // property filters map) for the given document id, by retrieving its schema + // type name and calling the above method. + // + // Returns: + // - If type_property_filters_ has an entry for the given document's schema + // type or wildcard(*), return a bitwise or of section IDs in the schema + // type that are also present in the relevant filter list. + // - Otherwise, return kSectionIdMaskAll. + SectionIdMask ComputeAllowedSectionsMask(DocumentId document_id); + const DocumentStore& document_store() const { return document_store_; } const SchemaStore& schema_store() const { return schema_store_; } diff --git a/icing/index/lite/lite-index-header.h b/icing/index/lite/lite-index-header.h index 75de8fa..741d173 100644 --- a/icing/index/lite/lite-index-header.h +++ b/icing/index/lite/lite-index-header.h @@ -53,14 +53,7 @@ class LiteIndex_Header { class LiteIndex_HeaderImpl : public LiteIndex_Header { public: struct HeaderData { - static uint32_t GetCurrentMagic( - bool include_property_existence_metadata_hits) { - if (!include_property_existence_metadata_hits) { - return 0x01c61418; - } else { - return 0x56e07d5b; - } - } + static const uint32_t kMagic = 0xC2EAD682; uint32_t lite_index_crc; uint32_t magic; @@ -76,15 +69,10 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header { uint32_t searchable_end; }; - explicit LiteIndex_HeaderImpl(HeaderData *hdr, - bool include_property_existence_metadata_hits) - : hdr_(hdr), - include_property_existence_metadata_hits_( - include_property_existence_metadata_hits) {} + explicit LiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {} bool check_magic() const override { - return hdr_->magic == HeaderData::GetCurrentMagic( - include_property_existence_metadata_hits_); + return hdr_->magic == HeaderData::kMagic; } uint32_t lite_index_crc() const override { return hdr_->lite_index_crc; } @@ -111,8 +99,7 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header { void Reset() override { hdr_->lite_index_crc = 0; - hdr_->magic = - HeaderData::GetCurrentMagic(include_property_existence_metadata_hits_); + hdr_->magic = HeaderData::kMagic; hdr_->last_added_docid = kInvalidDocumentId; hdr_->cur_size = 0; hdr_->searchable_end = 0; @@ -120,7 +107,6 @@ class LiteIndex_HeaderImpl : public LiteIndex_Header { private: HeaderData *hdr_; - bool include_property_existence_metadata_hits_; }; static_assert(24 == sizeof(LiteIndex_HeaderImpl::HeaderData), "sizeof(HeaderData) != 24"); diff --git a/icing/index/lite/lite-index-options.cc b/icing/index/lite/lite-index-options.cc index 7e6c076..b4810ea 100644 --- a/icing/index/lite/lite-index-options.cc +++ b/icing/index/lite/lite-index-options.cc @@ -69,16 +69,14 @@ IcingDynamicTrie::Options CalculateTrieOptions(uint32_t hit_buffer_size) { } // namespace -LiteIndexOptions::LiteIndexOptions( - const std::string& filename_base, uint32_t hit_buffer_want_merge_bytes, - bool hit_buffer_sort_at_indexing, uint32_t hit_buffer_sort_threshold_bytes, - bool include_property_existence_metadata_hits) +LiteIndexOptions::LiteIndexOptions(const std::string& filename_base, + uint32_t hit_buffer_want_merge_bytes, + bool hit_buffer_sort_at_indexing, + uint32_t hit_buffer_sort_threshold_bytes) : filename_base(filename_base), hit_buffer_want_merge_bytes(hit_buffer_want_merge_bytes), hit_buffer_sort_at_indexing(hit_buffer_sort_at_indexing), - hit_buffer_sort_threshold_bytes(hit_buffer_sort_threshold_bytes), - include_property_existence_metadata_hits( - include_property_existence_metadata_hits) { + hit_buffer_sort_threshold_bytes(hit_buffer_sort_threshold_bytes) { hit_buffer_size = CalculateHitBufferSize(hit_buffer_want_merge_bytes); lexicon_options = CalculateTrieOptions(hit_buffer_size); display_mappings_options = CalculateTrieOptions(hit_buffer_size); diff --git a/icing/index/lite/lite-index-options.h b/icing/index/lite/lite-index-options.h index 8b03449..5eae5c7 100644 --- a/icing/index/lite/lite-index-options.h +++ b/icing/index/lite/lite-index-options.h @@ -32,8 +32,7 @@ struct LiteIndexOptions { LiteIndexOptions(const std::string& filename_base, uint32_t hit_buffer_want_merge_bytes, bool hit_buffer_sort_at_indexing, - uint32_t hit_buffer_sort_threshold_bytes, - bool include_property_existence_metadata_hits = false); + uint32_t hit_buffer_sort_threshold_bytes); IcingDynamicTrie::Options lexicon_options; IcingDynamicTrie::Options display_mappings_options; @@ -43,7 +42,6 @@ struct LiteIndexOptions { uint32_t hit_buffer_size = 0; bool hit_buffer_sort_at_indexing = false; uint32_t hit_buffer_sort_threshold_bytes = 0; - bool include_property_existence_metadata_hits = false; }; } // namespace lib diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index 3f9cc93..3aed7e4 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -74,7 +74,7 @@ size_t header_size() { return sizeof(LiteIndex_HeaderImpl::HeaderData); } } // namespace const TermIdHitPair::Value TermIdHitPair::kInvalidValue = - TermIdHitPair(0, Hit()).value(); + TermIdHitPair(0, Hit(Hit::kInvalidValue)).value(); libtextclassifier3::StatusOr<std::unique_ptr<LiteIndex>> LiteIndex::Create( const LiteIndex::Options& options, const IcingFilesystem* filesystem) { @@ -168,8 +168,7 @@ libtextclassifier3::Status LiteIndex::Initialize() { header_mmap_.Remap(hit_buffer_fd_.get(), kHeaderFileOffset, header_size()); header_ = std::make_unique<LiteIndex_HeaderImpl>( reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( - header_mmap_.address()), - options_.include_property_existence_metadata_hits); + header_mmap_.address())); header_->Reset(); if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true, @@ -184,8 +183,7 @@ libtextclassifier3::Status LiteIndex::Initialize() { header_mmap_.Remap(hit_buffer_fd_.get(), kHeaderFileOffset, header_size()); header_ = std::make_unique<LiteIndex_HeaderImpl>( reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( - header_mmap_.address()), - options_.include_property_existence_metadata_hits); + header_mmap_.address())); if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true, sizeof(TermIdHitPair::Value), header_->cur_size(), @@ -499,7 +497,7 @@ int LiteIndex::FetchHits( // When disabled, the entire HitBuffer should be sorted already and only // binary search is needed. if (options_.hit_buffer_sort_at_indexing) { - uint32_t unsorted_length = header_->cur_size() - header_->searchable_end(); + uint32_t unsorted_length = GetHitBufferUnsortedSizeImpl(); for (uint32_t i = 1; i <= unsorted_length; ++i) { TermIdHitPair term_id_hit_pair = array[header_->cur_size() - i]; if (term_id_hit_pair.term_id() == term_id) { @@ -519,7 +517,8 @@ int LiteIndex::FetchHits( // Do binary search over the sorted section and repeat the above steps. TermIdHitPair target_term_id_hit_pair( - term_id, Hit(Hit::kMaxDocumentIdSortValue, Hit::kDefaultTermFrequency)); + term_id, Hit(Hit::kMaxDocumentIdSortValue, Hit::kNoEnabledFlags, + Hit::kDefaultTermFrequency)); for (const TermIdHitPair* ptr = std::lower_bound( array, array + header_->searchable_end(), target_term_id_hit_pair); ptr < array + header_->searchable_end(); ++ptr) { @@ -607,13 +606,13 @@ IndexStorageInfoProto LiteIndex::GetStorageInfo( void LiteIndex::SortHitsImpl() { // Make searchable by sorting by hit buffer. - uint32_t sort_len = header_->cur_size() - header_->searchable_end(); - if (sort_len <= 0) { + uint32_t need_sort_len = GetHitBufferUnsortedSizeImpl(); + if (need_sort_len <= 0) { return; } IcingTimer timer; - auto* array_start = + TermIdHitPair::Value* array_start = hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size()); TermIdHitPair::Value* sort_start = array_start + header_->searchable_end(); std::sort(sort_start, array_start + header_->cur_size()); @@ -625,7 +624,7 @@ void LiteIndex::SortHitsImpl() { std::inplace_merge(array_start, array_start + header_->searchable_end(), array_start + header_->cur_size()); } - ICING_VLOG(2) << "Lite index sort and merge " << sort_len << " into " + ICING_VLOG(2) << "Lite index sort and merge " << need_sort_len << " into " << header_->searchable_end() << " in " << timer.Elapsed() * 1000 << "ms"; diff --git a/icing/index/lite/lite-index.h b/icing/index/lite/lite-index.h index 288602a..45dc280 100644 --- a/icing/index/lite/lite-index.h +++ b/icing/index/lite/lite-index.h @@ -308,6 +308,24 @@ class LiteIndex { IndexStorageInfoProto GetStorageInfo(IndexStorageInfoProto storage_info) const ICING_LOCKS_EXCLUDED(mutex_); + // Returns the size of unsorted part of hit_buffer_. + uint32_t GetHitBufferByteSize() const ICING_LOCKS_EXCLUDED(mutex_) { + absl_ports::shared_lock l(&mutex_); + return size_impl() * sizeof(TermIdHitPair::Value); + } + + // Returns the size of unsorted part of hit_buffer_. + uint32_t GetHitBufferUnsortedSize() const ICING_LOCKS_EXCLUDED(mutex_) { + absl_ports::shared_lock l(&mutex_); + return GetHitBufferUnsortedSizeImpl(); + } + + // Returns the byte size of unsorted part of hit_buffer_. + uint64_t GetHitBufferUnsortedByteSize() const ICING_LOCKS_EXCLUDED(mutex_) { + absl_ports::shared_lock l(&mutex_); + return GetHitBufferUnsortedSizeImpl() * sizeof(TermIdHitPair::Value); + } + // Reduces internal file sizes by reclaiming space of deleted documents. // // This method also sets the last_added_docid of the index to @@ -377,13 +395,13 @@ class LiteIndex { bool NeedSortAtQuerying() const ICING_SHARED_LOCKS_REQUIRED(mutex_) { return HasUnsortedHitsExceedingSortThresholdImpl() || (!options_.hit_buffer_sort_at_indexing && - header_->cur_size() - header_->searchable_end() > 0); + GetHitBufferUnsortedSizeImpl() > 0); } // Non-locking implementation for HasUnsortedHitsExceedingSortThresholdImpl(). bool HasUnsortedHitsExceedingSortThresholdImpl() const ICING_SHARED_LOCKS_REQUIRED(mutex_) { - return header_->cur_size() - header_->searchable_end() >= + return GetHitBufferUnsortedSizeImpl() >= (options_.hit_buffer_sort_threshold_bytes / sizeof(TermIdHitPair::Value)); } @@ -408,6 +426,12 @@ class LiteIndex { std::vector<Hit::TermFrequencyArray>* term_frequency_out) const ICING_SHARED_LOCKS_REQUIRED(mutex_); + // Returns the size of unsorted part of hit_buffer_. + uint32_t GetHitBufferUnsortedSizeImpl() const + ICING_SHARED_LOCKS_REQUIRED(mutex_) { + return header_->cur_size() - header_->searchable_end(); + } + // File descriptor that points to where the header and hit buffer are written // to. ScopedFd hit_buffer_fd_ ICING_GUARDED_BY(mutex_); diff --git a/icing/index/lite/lite-index_test.cc b/icing/index/lite/lite-index_test.cc index 9811fa2..f8ea94a 100644 --- a/icing/index/lite/lite-index_test.cc +++ b/icing/index/lite/lite-index_test.cc @@ -27,7 +27,7 @@ #include "icing/index/hit/hit.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/lite/doc-hit-info-iterator-term-lite.h" -#include "icing/index/lite/lite-index-header.h" +#include "icing/index/lite/term-id-hit-pair.h" #include "icing/index/term-id-codec.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" @@ -49,6 +49,7 @@ using ::testing::Eq; using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; +using ::testing::Pointee; using ::testing::SizeIs; class LiteIndexTest : public testing::Test { @@ -71,15 +72,25 @@ class LiteIndexTest : public testing::Test { constexpr NamespaceId kNamespace0 = 0; +TEST_F(LiteIndexTest, TermIdHitPairInvalidValue) { + TermIdHitPair invalidTermHitPair(TermIdHitPair::kInvalidValue); + + EXPECT_THAT(invalidTermHitPair.term_id(), Eq(0)); + EXPECT_THAT(invalidTermHitPair.hit().value(), Eq(Hit::kInvalidValue)); + EXPECT_THAT(invalidTermHitPair.hit().term_frequency(), + Eq(Hit::kDefaultTermFrequency)); +} + TEST_F(LiteIndexTest, LiteIndexFetchHits_sortAtQuerying_unsortedHitsBelowSortThreshold) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/false, - /*hit_buffer_sort_threshold_bytes=*/64); + // Unsorted tail can contain a max of 8 TermIdHitPairs. + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/false, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -95,9 +106,9 @@ TEST_F(LiteIndexTest, ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit foo_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit foo_hit1(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit0)); ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit1)); @@ -107,24 +118,18 @@ TEST_F(LiteIndexTest, ICING_ASSERT_OK_AND_ASSIGN(uint32_t bar_term_id, term_id_codec_->EncodeTvi(bar_tvi, TviType::LITE)); Hit bar_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit bar_hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit0)); ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit1)); + // Check the total size and unsorted size of the hit buffer. + EXPECT_THAT(lite_index, Pointee(SizeIs(4))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4)); // Check that unsorted hits does not exceed the sort threshold. EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse()); - // Check that hits are unsorted. Persist the data and pread from - // LiteIndexHeader. - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - LiteIndex_HeaderImpl::HeaderData header_data; - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4)); - // Query the LiteIndex std::vector<DocHitInfo> hits1; lite_index->FetchHits( @@ -148,28 +153,26 @@ TEST_F(LiteIndexTest, // checker. EXPECT_THAT(hits2, IsEmpty()); - // Check that hits are sorted after querying LiteIndex. Persist the data and - // pread from LiteIndexHeader. - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0)); + // Check the total size and unsorted size of the hit buffer. Hits should be + // sorted after querying LiteIndex. + EXPECT_THAT(lite_index, Pointee(SizeIs(4))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0)); } TEST_F(LiteIndexTest, LiteIndexFetchHits_sortAtIndexing_unsortedHitsBelowSortThreshold) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. + // The unsorted tail can contain a max of 8 TermIdHitPairs. // However note that in these tests we're unable to sort hits after // indexing, as sorting performed by the string-section-indexing-handler // after indexing all hits in an entire document, rather than after each // AddHits() operation. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/true, - /*hit_buffer_sort_threshold_bytes=*/64); + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/true, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -185,9 +188,9 @@ TEST_F(LiteIndexTest, ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit foo_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit foo_hit1(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit0)); ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, foo_hit1)); @@ -197,24 +200,18 @@ TEST_F(LiteIndexTest, ICING_ASSERT_OK_AND_ASSIGN(uint32_t bar_term_id, term_id_codec_->EncodeTvi(bar_tvi, TviType::LITE)); Hit bar_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit bar_hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit0)); ICING_ASSERT_OK(lite_index->AddHit(bar_term_id, bar_hit1)); + // Check the total size and unsorted size of the hit buffer. + EXPECT_THAT(lite_index, Pointee(SizeIs(4))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4)); // Check that unsorted hits does not exceed the sort threshold. EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse()); - // Check that hits are unsorted. Persist the data and pread from - // LiteIndexHeader. - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - LiteIndex_HeaderImpl::HeaderData header_data; - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4)); - // Query the LiteIndex std::vector<DocHitInfo> hits1; lite_index->FetchHits( @@ -238,15 +235,11 @@ TEST_F(LiteIndexTest, // checker. EXPECT_THAT(hits2, IsEmpty()); - // Check that hits are still unsorted after querying LiteIndex because the - // HitBuffer unsorted size is still below the sort threshold, and we've - // enabled sort_at_indexing. - // Persist the data and performing a pread on LiteIndexHeader. - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(4)); + // Check the total size and unsorted size of the hit buffer. Hits should be + // still unsorted after querying LiteIndex because the HitBuffer unsorted size + // is still below the sort threshold, and we've enabled sort_at_indexing. + EXPECT_THAT(lite_index, Pointee(SizeIs(4))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(4)); } TEST_F( @@ -254,15 +247,16 @@ TEST_F( LiteIndexFetchHits_sortAtQuerying_unsortedHitsExceedingSortAtIndexThreshold) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. + // The unsorted tail can contain a max of 8 TermIdHitPairs. // However note that in these tests we're unable to sort hits after // indexing, as sorting performed by the string-section-indexing-handler // after indexing all hits in an entire document, rather than after each // AddHits() operation. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/false, - /*hit_buffer_sort_threshold_bytes=*/64); + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/false, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -274,36 +268,36 @@ TEST_F( // Create 4 hits for docs 0-2, and 2 hits for doc 3 -- 14 in total // Doc 0 Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit1(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit2(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 1 Hit doc1_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit2(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit3(/*section_id=*/2, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 2 Hit doc2_hit0(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit1(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit2(/*section_id=*/1, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit3(/*section_id=*/2, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 3 Hit doc3_hit0(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc3_hit1(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Create terms // Foo @@ -348,7 +342,10 @@ TEST_F( ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc2_hit3)); ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc3_hit0)); ICING_ASSERT_OK(lite_index->AddHit(baz_term_id, doc3_hit1)); - // Verify that the HitBuffer has not been sorted. + // Check the total size and unsorted size of the hit buffer. The HitBuffer has + // not been sorted. + EXPECT_THAT(lite_index, Pointee(SizeIs(14))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(14)); EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue()); // We now have the following in the hit buffer: @@ -400,7 +397,10 @@ TEST_F( EXPECT_THAT(hits3[1].document_id(), Eq(0)); EXPECT_THAT(hits3[1].hit_section_ids_mask(), Eq(0b1)); - // Check that the HitBuffer is sorted after the query call. + // Check the total size and unsorted size of the hit buffer. The HitBuffer + // should be sorted after the query call. + EXPECT_THAT(lite_index, Pointee(SizeIs(14))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0)); EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse()); } @@ -409,11 +409,12 @@ TEST_F( LiteIndexFetchHits_sortAtIndexing_unsortedHitsExceedingSortAtIndexThreshold) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/true, - /*hit_buffer_sort_threshold_bytes=*/64); + // The unsorted tail can contain a max of 8 TermIdHitPairs. + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/true, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -425,49 +426,49 @@ TEST_F( // Create 4 hits for docs 0-2, and 2 hits for doc 3 -- 14 in total // Doc 0 Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit1(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit2(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 1 Hit doc1_hit0(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit2(/*section_id=*/1, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit3(/*section_id=*/2, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 2 Hit doc2_hit0(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit1(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit2(/*section_id=*/1, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc2_hit3(/*section_id=*/2, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 3 Hit doc3_hit0(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc3_hit1(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc3_hit2(/*section_id=*/1, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc3_hit3(/*section_id=*/2, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Doc 4 Hit doc4_hit0(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc4_hit1(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc4_hit2(/*section_id=*/1, /*document_id=*/4, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc4_hit3(/*section_id=*/2, /*document_id=*/4, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); // Create terms // Foo @@ -511,13 +512,11 @@ TEST_F( // AddHit() itself, we need to invoke SortHits() manually. EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue()); lite_index->SortHits(); - // Check that the HitBuffer is sorted. - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - LiteIndex_HeaderImpl::HeaderData header_data; - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0)); + // Check the total size and unsorted size of the hit buffer. The HitBuffer + // should be sorted after calling SortHits(). + EXPECT_THAT(lite_index, Pointee(SizeIs(8))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0)); + EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse()); // Add 12 more hits so that sort threshold is exceeded again. ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc2_hit0)); @@ -536,6 +535,8 @@ TEST_F( // Adding these hits exceeds the sort threshold. However when sort_at_indexing // is enabled, sorting is done in the string-section-indexing-handler rather // than AddHit() itself. + EXPECT_THAT(lite_index, Pointee(SizeIs(20))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(12)); EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsTrue()); // We now have the following in the hit buffer: @@ -589,25 +590,24 @@ TEST_F( EXPECT_THAT(hits3[1].document_id(), Eq(0)); EXPECT_THAT(hits3[1].hit_section_ids_mask(), Eq(0b1)); - // Check that the HitBuffer is sorted after the query call. FetchHits should - // sort before performing binary search if the HitBuffer unsorted size exceeds - // the sort threshold. Regardless of the sort_at_indexing config. + // Check the total size and unsorted size of the hit buffer. FetchHits should + // sort before performing search if the HitBuffer unsorted size exceeds the + // sort threshold, regardless of the sort_at_indexing config (to avoid + // sequential search on an extremely long unsorted tails). + EXPECT_THAT(lite_index, Pointee(SizeIs(20))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0)); EXPECT_THAT(lite_index->HasUnsortedHitsExceedingSortThreshold(), IsFalse()); - ASSERT_THAT(lite_index->PersistToDisk(), IsOk()); - ASSERT_TRUE(filesystem_.PRead((lite_index_file_name + "hb").c_str(), - &header_data, sizeof(header_data), - LiteIndex::kHeaderFileOffset)); - EXPECT_THAT(header_data.cur_size - header_data.searchable_end, Eq(0)); } TEST_F(LiteIndexTest, LiteIndexIterator) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/true, - /*hit_buffer_sort_threshold_bytes=*/64); + // The unsorted tail can contain a max of 8 TermIdHitPairs. + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/true, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -623,9 +623,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); SectionIdMask doc0_section_id_mask = 0b11; std::unordered_map<SectionId, Hit::TermFrequency> expected_section_ids_tf_map0 = {{0, 3}, {1, 5}}; @@ -633,9 +633,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator) { ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc0_hit1)); Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); SectionIdMask doc1_section_id_mask = 0b110; std::unordered_map<SectionId, Hit::TermFrequency> expected_section_ids_tf_map1 = {{1, 7}, {2, 11}}; @@ -671,11 +671,12 @@ TEST_F(LiteIndexTest, LiteIndexIterator) { TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) { // Set up LiteIndex and TermIdCodec std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; - // At 64 bytes the unsorted tail can contain a max of 8 TermHitPairs. - LiteIndex::Options options(lite_index_file_name, - /*hit_buffer_want_merge_bytes=*/1024 * 1024, - /*hit_buffer_sort_at_indexing=*/false, - /*hit_buffer_sort_threshold_bytes=*/64); + // The unsorted tail can contain a max of 8 TermIdHitPairs. + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/false, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, LiteIndex::Create(options, &icing_filesystem_)); ICING_ASSERT_OK_AND_ASSIGN( @@ -691,9 +692,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); SectionIdMask doc0_section_id_mask = 0b11; std::unordered_map<SectionId, Hit::TermFrequency> expected_section_ids_tf_map0 = {{0, 3}, {1, 5}}; @@ -701,9 +702,9 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) { ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, doc0_hit1)); Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); SectionIdMask doc1_section_id_mask = 0b110; std::unordered_map<SectionId, Hit::TermFrequency> expected_section_ids_tf_map1 = {{1, 7}, {2, 11}}; @@ -736,6 +737,54 @@ TEST_F(LiteIndexTest, LiteIndexIterator_sortAtIndexingDisabled) { term, expected_section_ids_tf_map0))); } +TEST_F(LiteIndexTest, LiteIndexHitBufferSize) { + // Set up LiteIndex and TermIdCodec + std::string lite_index_file_name = index_dir_ + "/test_file.lite-idx.index"; + // The unsorted tail can contain a max of 8 TermIdHitPairs. + LiteIndex::Options options( + lite_index_file_name, + /*hit_buffer_want_merge_bytes=*/1024 * 1024, + /*hit_buffer_sort_at_indexing=*/true, + /*hit_buffer_sort_threshold_bytes=*/sizeof(TermIdHitPair) * 8); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LiteIndex> lite_index, + LiteIndex::Create(options, &icing_filesystem_)); + ICING_ASSERT_OK_AND_ASSIGN( + term_id_codec_, + TermIdCodec::Create( + IcingDynamicTrie::max_value_index(IcingDynamicTrie::Options()), + IcingDynamicTrie::max_value_index(options.lexicon_options))); + + const std::string term = "foo"; + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t tvi, + lite_index->InsertTerm(term, TermMatchType::PREFIX, kNamespace0)); + ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, + term_id_codec_->EncodeTvi(tvi, TviType::LITE)); + Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, hit0)); + ICING_ASSERT_OK(lite_index->AddHit(foo_term_id, hit1)); + + // Check the total size and byte size of the hit buffer. + EXPECT_THAT(lite_index, Pointee(SizeIs(2))); + EXPECT_THAT(lite_index->GetHitBufferByteSize(), + Eq(2 * sizeof(TermIdHitPair::Value))); + // Check the unsorted size and byte size of the hit buffer. + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(2)); + EXPECT_THAT(lite_index->GetHitBufferUnsortedByteSize(), + Eq(2 * sizeof(TermIdHitPair::Value))); + + // Sort the hit buffer and check again. + lite_index->SortHits(); + EXPECT_THAT(lite_index, Pointee(SizeIs(2))); + EXPECT_THAT(lite_index->GetHitBufferByteSize(), + Eq(2 * sizeof(TermIdHitPair::Value))); + EXPECT_THAT(lite_index->GetHitBufferUnsortedSize(), Eq(0)); + EXPECT_THAT(lite_index->GetHitBufferUnsortedByteSize(), Eq(0)); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/index/lite/lite-index_thread-safety_test.cc b/icing/index/lite/lite-index_thread-safety_test.cc index 53aa6cd..a73ca28 100644 --- a/icing/index/lite/lite-index_thread-safety_test.cc +++ b/icing/index/lite/lite-index_thread-safety_test.cc @@ -13,15 +13,25 @@ // limitations under the License. #include <array> +#include <cstdint> +#include <memory> #include <string> +#include <string_view> #include <thread> #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" #include "icing/index/lite/lite-index.h" #include "icing/index/term-id-codec.h" +#include "icing/legacy/index/icing-dynamic-trie.h" +#include "icing/legacy/index/icing-filesystem.h" #include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -109,9 +119,11 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousFetchHits_singleTerm) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); Hit doc_hit1(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId1, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0)); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit1)); @@ -155,9 +167,11 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousFetchHits_multipleTerms) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); Hit doc_hit1(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId1, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit0)); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit1)); } @@ -208,7 +222,8 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousAddHitAndFetchHits_singleTerm) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0)); // Create kNumThreads threads. Every even-numbered thread calls FetchHits and @@ -228,7 +243,8 @@ TEST_F(LiteIndexThreadSafetyTest, SimultaneousAddHitAndFetchHits_singleTerm) { } else { // Odd-numbered thread calls AddHit. Hit doc_hit(/*section_id=*/thread_id / 2, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit)); } }; @@ -273,7 +289,8 @@ TEST_F(LiteIndexThreadSafetyTest, ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0)); // Create kNumThreads threads. Every even-numbered thread calls FetchHits and @@ -302,7 +319,8 @@ TEST_F(LiteIndexThreadSafetyTest, // Odd-numbered thread calls AddHit. // AddHit to section 0 of a new doc. Hit doc_hit(/*section_id=*/kSectionId0, /*document_id=*/thread_id / 2, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit)); } }; @@ -335,9 +353,11 @@ TEST_F(LiteIndexThreadSafetyTest, ManyAddHitAndOneFetchHits_multipleTerms) { ICING_ASSERT_OK_AND_ASSIGN(uint32_t term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc_hit0(/*section_id=*/kSectionId0, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); Hit doc_hit1(/*section_id=*/kSectionId1, /*document_id=*/kDocumentId0, - Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit0)); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit1)); } @@ -370,7 +390,7 @@ TEST_F(LiteIndexThreadSafetyTest, ManyAddHitAndOneFetchHits_multipleTerms) { // AddHit to section (thread_id % 5 + 1) of doc 0. Hit doc_hit(/*section_id=*/thread_id % 5 + 1, /*document_id=*/kDocumentId0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(term_id, doc_hit)); } }; diff --git a/icing/index/lite/term-id-hit-pair.h b/icing/index/lite/term-id-hit-pair.h index 82bd010..760aa76 100644 --- a/icing/index/lite/term-id-hit-pair.h +++ b/icing/index/lite/term-id-hit-pair.h @@ -15,25 +15,22 @@ #ifndef ICING_INDEX_TERM_ID_HIT_PAIR_H_ #define ICING_INDEX_TERM_ID_HIT_PAIR_H_ +#include <array> #include <cstdint> -#include <limits> -#include <memory> -#include <string> -#include <vector> #include "icing/index/hit/hit.h" -#include "icing/util/bit-util.h" namespace icing { namespace lib { class TermIdHitPair { public: - // Layout bits: 24 termid + 32 hit value + 8 hit term frequency. - using Value = uint64_t; + // Layout bits: 24 termid + 32 hit value + 8 hit flags + 8 hit term frequency. + using Value = std::array<uint8_t, 9>; static constexpr int kTermIdBits = 24; static constexpr int kHitValueBits = sizeof(Hit::Value) * 8; + static constexpr int kHitFlagsBits = sizeof(Hit::Flags) * 8; static constexpr int kHitTermFrequencyBits = sizeof(Hit::TermFrequency) * 8; static const Value kInvalidValue; @@ -41,33 +38,48 @@ class TermIdHitPair { explicit TermIdHitPair(Value v = kInvalidValue) : value_(v) {} TermIdHitPair(uint32_t term_id, const Hit& hit) { - static_assert(kTermIdBits + kHitValueBits + kHitTermFrequencyBits <= - sizeof(Value) * 8, - "TermIdHitPairTooBig"); - - value_ = 0; - // Term id goes into the most significant bits because it takes - // precedent in sorts. - bit_util::BitfieldSet(term_id, kHitValueBits + kHitTermFrequencyBits, - kTermIdBits, &value_); - bit_util::BitfieldSet(hit.value(), kHitTermFrequencyBits, kHitValueBits, - &value_); - bit_util::BitfieldSet(hit.term_frequency(), 0, kHitTermFrequencyBits, - &value_); + static_assert( + kTermIdBits + kHitValueBits + kHitFlagsBits + kHitTermFrequencyBits <= + sizeof(Value) * 8, + "TermIdHitPairTooBig"); + + // Set termId. Term id takes 3 bytes and goes into value_[0:2] (most + // significant bits) because it takes precedent in sorts. + value_[0] = static_cast<uint8_t>((term_id >> 16) & 0xff); + value_[1] = static_cast<uint8_t>((term_id >> 8) & 0xff); + value_[2] = static_cast<uint8_t>((term_id >> 0) & 0xff); + + // Set hit value. Hit value takes 4 bytes and goes into value_[3:6] + value_[3] = static_cast<uint8_t>((hit.value() >> 24) & 0xff); + value_[4] = static_cast<uint8_t>((hit.value() >> 16) & 0xff); + value_[5] = static_cast<uint8_t>((hit.value() >> 8) & 0xff); + value_[6] = static_cast<uint8_t>((hit.value() >> 0) & 0xff); + + // Set flags in value_[7]. + value_[7] = hit.flags(); + + // Set term-frequency in value_[8] + value_[8] = hit.term_frequency(); } uint32_t term_id() const { - return bit_util::BitfieldGet(value_, kHitValueBits + kHitTermFrequencyBits, - kTermIdBits); + return (static_cast<uint32_t>(value_[0]) << 16) | + (static_cast<uint32_t>(value_[1]) << 8) | + (static_cast<uint32_t>(value_[2]) << 0); } Hit hit() const { - return Hit( - bit_util::BitfieldGet(value_, kHitTermFrequencyBits, kHitValueBits), - bit_util::BitfieldGet(value_, 0, kHitTermFrequencyBits)); + Hit::Value hit_value = (static_cast<uint32_t>(value_[3]) << 24) | + (static_cast<uint32_t>(value_[4]) << 16) | + (static_cast<uint32_t>(value_[5]) << 8) | + (static_cast<uint32_t>(value_[6]) << 0); + Hit::Flags hit_flags = value_[7]; + Hit::TermFrequency term_frequency = value_[8]; + + return Hit(hit_value, hit_flags, term_frequency); } - Value value() const { return value_; } + const Value& value() const { return value_; } bool operator==(const TermIdHitPair& rhs) const { return value_ == rhs.value_; diff --git a/icing/index/lite/term-id-hit-pair_test.cc b/icing/index/lite/term-id-hit-pair_test.cc new file mode 100644 index 0000000..28855b4 --- /dev/null +++ b/icing/index/lite/term-id-hit-pair_test.cc @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/lite/term-id-hit-pair.h" + +#include <algorithm> +#include <cstdint> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/index/hit/hit.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; + +static constexpr DocumentId kSomeDocumentId = 24; +static constexpr SectionId kSomeSectionid = 5; +static constexpr Hit::TermFrequency kSomeTermFrequency = 57; +static constexpr uint32_t kSomeTermId = 129; +static constexpr uint32_t kSomeSmallerTermId = 1; +static constexpr uint32_t kSomeLargerTermId = 0b101010101111111100000001; + +TEST(TermIdHitPairTest, Accessors) { + Hit hit1(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit hit2(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true); + Hit hit3(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit invalid_hit(Hit::kInvalidValue); + + TermIdHitPair term_id_hit_pair_1(kSomeTermId, hit1); + EXPECT_THAT(term_id_hit_pair_1.term_id(), Eq(kSomeTermId)); + EXPECT_THAT(term_id_hit_pair_1.hit(), Eq(hit1)); + + TermIdHitPair term_id_hit_pair_2(kSomeLargerTermId, hit2); + EXPECT_THAT(term_id_hit_pair_2.term_id(), Eq(kSomeLargerTermId)); + EXPECT_THAT(term_id_hit_pair_2.hit(), Eq(hit2)); + + TermIdHitPair term_id_hit_pair_3(kSomeTermId, invalid_hit); + EXPECT_THAT(term_id_hit_pair_3.term_id(), Eq(kSomeTermId)); + EXPECT_THAT(term_id_hit_pair_3.hit(), Eq(invalid_hit)); +} + +TEST(TermIdHitPairTest, Comparison) { + Hit hit(kSomeSectionid, kSomeDocumentId, kSomeTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit smaller_hit(/*section_id=*/1, /*document_id=*/100, /*term_frequency=*/1, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + + TermIdHitPair term_id_hit_pair(kSomeTermId, hit); + TermIdHitPair term_id_hit_pair_equal(kSomeTermId, hit); + TermIdHitPair term_id_hit_pair_smaller_hit(kSomeTermId, smaller_hit); + TermIdHitPair term_id_hit_pair_smaller_term_id(kSomeSmallerTermId, hit); + TermIdHitPair term_id_hit_pair_larger_term_id(kSomeLargerTermId, hit); + TermIdHitPair term_id_hit_pair_smaller_term_id_and_hit(kSomeSmallerTermId, + smaller_hit); + + std::vector<TermIdHitPair> term_id_hit_pairs{ + term_id_hit_pair, + term_id_hit_pair_equal, + term_id_hit_pair_smaller_hit, + term_id_hit_pair_smaller_term_id, + term_id_hit_pair_larger_term_id, + term_id_hit_pair_smaller_term_id_and_hit}; + std::sort(term_id_hit_pairs.begin(), term_id_hit_pairs.end()); + EXPECT_THAT(term_id_hit_pairs, + ElementsAre(term_id_hit_pair_smaller_term_id_and_hit, + term_id_hit_pair_smaller_term_id, + term_id_hit_pair_smaller_hit, term_id_hit_pair_equal, + term_id_hit_pair, term_id_hit_pair_larger_term_id)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/main/main-index-merger.cc b/icing/index/main/main-index-merger.cc index c26a6d7..cc130c2 100644 --- a/icing/index/main/main-index-merger.cc +++ b/icing/index/main/main-index-merger.cc @@ -14,14 +14,20 @@ #include "icing/index/main/main-index-merger.h" +#include <algorithm> #include <cstdint> #include <cstring> -#include <memory> #include <unordered_map> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" -#include "icing/file/posting_list/index-block.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index.h" #include "icing/index/lite/term-id-hit-pair.h" +#include "icing/index/main/main-index.h" #include "icing/index/term-id-codec.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/util/logging.h" diff --git a/icing/index/main/main-index-merger_test.cc b/icing/index/main/main-index-merger_test.cc index 37e14fc..333e338 100644 --- a/icing/index/main/main-index-merger_test.cc +++ b/icing/index/main/main-index-merger_test.cc @@ -13,19 +13,23 @@ // limitations under the License. #include "icing/index/main/main-index-merger.h" +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "icing/absl_ports/canonical_errors.h" #include "icing/file/filesystem.h" -#include "icing/index/iterator/doc-hit-info-iterator.h" -#include "icing/index/main/doc-hit-info-iterator-term-main.h" -#include "icing/index/main/main-index-merger.h" +#include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index.h" +#include "icing/index/lite/term-id-hit-pair.h" #include "icing/index/main/main-index.h" #include "icing/index/term-id-codec.h" -#include "icing/index/term-property-id.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" -#include "icing/schema/section.h" #include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -89,10 +93,10 @@ TEST_F(MainIndexMergerTest, TranslateTermNotAdded) { term_id_codec_->EncodeTvi(fool_tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit)); // 2. Build up a fake LexiconMergeOutputs @@ -128,10 +132,10 @@ TEST_F(MainIndexMergerTest, PrefixExpansion) { term_id_codec_->EncodeTvi(fool_tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit)); // 2. Build up a fake LexiconMergeOutputs @@ -191,11 +195,11 @@ TEST_F(MainIndexMergerTest, DedupePrefixAndExactWithDifferentTermFrequencies) { term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit)); Hit foo_doc0_hit(/*section_id=*/0, /*document_id=*/0, - Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/true, + /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, foo_doc0_hit)); // 2. Build up a fake LexiconMergeOutputs @@ -255,10 +259,10 @@ TEST_F(MainIndexMergerTest, DedupeWithExactSameTermFrequencies) { term_id_codec_->EncodeTvi(foo_tvi, TviType::LITE)); Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit)); Hit foo_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/57, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, foo_doc0_hit)); // The prefix hit should take the sum as term_frequency - 114. Hit prefix_foo_doc0_hit(/*section_id=*/0, /*document_id=*/0, @@ -320,11 +324,11 @@ TEST_F(MainIndexMergerTest, DedupePrefixExpansion) { Hit foot_doc0_hit(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/Hit::kMaxTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, foot_doc0_hit)); Hit fool_doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, fool_doc0_hit)); // 2. Build up a fake LexiconMergeOutputs diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index aae60c6..85ee4dc 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -651,7 +651,7 @@ libtextclassifier3::Status MainIndex::AddPrefixBackfillHits( ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor->GetNextHitsBatch()); } - Hit last_added_hit; + Hit last_added_hit(Hit::kInvalidValue); // The hits in backfill_hits are in the reverse order of how they were added. // Iterate in reverse to add them to this new posting list in the correct // order. diff --git a/icing/index/main/main-index_test.cc b/icing/index/main/main-index_test.cc index fa96e6c..db9dbe2 100644 --- a/icing/index/main/main-index_test.cc +++ b/icing/index/main/main-index_test.cc @@ -14,23 +14,33 @@ #include "icing/index/main/main-index.h" +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "icing/absl_ports/canonical_errors.h" #include "icing/file/filesystem.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/lite/lite-index.h" #include "icing/index/lite/term-id-hit-pair.h" #include "icing/index/main/doc-hit-info-iterator-term-main.h" #include "icing/index/main/main-index-merger.h" #include "icing/index/term-id-codec.h" -#include "icing/index/term-property-id.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -152,7 +162,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForPrefixReturnsValidAccessor) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -178,7 +188,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForPrefixReturnsNotFound) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -217,7 +227,7 @@ TEST_F(MainIndexTest, MainIndexGetAccessorForExactReturnsValidAccessor) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -254,18 +264,18 @@ TEST_F(MainIndexTest, MergeIndexToEmpty) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit)); ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit)); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit)); Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc2_hit)); ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc2_hit)); @@ -332,18 +342,18 @@ TEST_F(MainIndexTest, MergeIndexToPreexisting) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit)); ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit)); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc1_hit)); Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc2_hit)); ICING_ASSERT_OK(lite_index_->AddHit(far_term_id, doc2_hit)); @@ -387,14 +397,14 @@ TEST_F(MainIndexTest, MergeIndexToPreexisting) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc3_hit(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc3_hit)); ICING_ASSERT_OK(lite_index_->AddHit(four_term_id, doc3_hit)); ICING_ASSERT_OK(lite_index_->AddHit(foul_term_id, doc3_hit)); ICING_ASSERT_OK(lite_index_->AddHit(fall_term_id, doc3_hit)); Hit doc4_hit(/*section_id=*/0, /*document_id=*/4, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(four_term_id, doc4_hit)); ICING_ASSERT_OK(lite_index_->AddHit(foul_term_id, doc4_hit)); @@ -449,15 +459,15 @@ TEST_F(MainIndexTest, ExactRetrievedInPrefixSearch) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit)); Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc2_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -500,15 +510,15 @@ TEST_F(MainIndexTest, PrefixNotRetrievedInExactSearch) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc0_hit)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit)); Hit doc2_hit(/*section_id=*/0, /*document_id=*/2, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc2_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -554,17 +564,17 @@ TEST_F(MainIndexTest, document_id % Hit::kMaxTermFrequency + 1); Hit doc_hit0( /*section_id=*/0, /*document_id=*/document_id, term_frequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit0)); Hit doc_hit1( /*section_id=*/1, /*document_id=*/document_id, term_frequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit1)); Hit doc_hit2( /*section_id=*/2, /*document_id=*/document_id, term_frequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc_hit2)); } @@ -619,7 +629,7 @@ TEST_F(MainIndexTest, MergeIndexBackfilling) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc0_hit(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/true); + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(fool_term_id, doc0_hit)); // 2. Create the main index. It should have no entries in its lexicon. @@ -648,7 +658,7 @@ TEST_F(MainIndexTest, MergeIndexBackfilling) { term_id_codec_->EncodeTvi(tvi, TviType::LITE)); Hit doc1_hit(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foot_term_id, doc1_hit)); // 5. Merge the index. The main index should now contain "fool", "foot" @@ -682,7 +692,7 @@ TEST_F(MainIndexTest, OneHitInTheFirstPageForTwoPagesMainIndex) { uint32_t num_docs = 2038; for (DocumentId document_id = 0; document_id < num_docs; ++document_id) { Hit doc_hit(section_id, document_id, Hit::kDefaultTermFrequency, - /*is_in_prefix_section=*/false); + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit)); } diff --git a/icing/index/main/posting-list-hit-accessor_test.cc b/icing/index/main/posting-list-hit-accessor_test.cc index 1127814..c2460ff 100644 --- a/icing/index/main/posting-list-hit-accessor_test.cc +++ b/icing/index/main/posting-list-hit-accessor_test.cc @@ -15,14 +15,19 @@ #include "icing/index/main/posting-list-hit-accessor.h" #include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/file/filesystem.h" #include "icing/file/posting_list/flash-index-storage.h" -#include "icing/file/posting_list/index-block.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-common.h" #include "icing/file/posting_list/posting-list-identifier.h" -#include "icing/file/posting_list/posting-list-used.h" #include "icing/index/hit/hit.h" #include "icing/index/main/posting-list-hit-serializer.h" #include "icing/testing/common-matchers.h" @@ -102,7 +107,8 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLKeepOnSameBlock) { PostingListHitAccessor::Create(flash_index_storage_.get(), serializer_.get())); // Add a single hit. This will fit in a min-sized posting list. - Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency); + Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); PostingListAccessor::FinalizeResult result1 = std::move(*pl_accessor).Finalize(); @@ -139,12 +145,12 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) { std::unique_ptr<PostingListHitAccessor> pl_accessor, PostingListHitAccessor::Create(flash_index_storage_.get(), serializer_.get())); - // The smallest posting list size is 15 bytes. The first four hits will be - // compressed to one byte each and will be able to fit in the 5 byte padded - // region. The last hit will fit in one of the special hits. The posting list - // will be ALMOST_FULL and can fit at most 2 more hits. + // Use a small posting list of 30 bytes. The first 17 hits will be compressed + // to one byte each and will be able to fit in the 18 byte padded region. The + // last hit will fit in one of the special hits. The posting list will be + // ALMOST_FULL and can fit at most 2 more hits. std::vector<Hit> hits1 = - CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); + CreateHits(/*num_hits=*/18, /*desired_byte_length=*/1); for (const Hit& hit : hits1) { ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } @@ -160,10 +166,9 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) { pl_accessor, PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), serializer_.get(), result1.id)); - // The current posting list can fit at most 2 more hits. Adding 12 more hits - // should result in these hits being moved to a larger posting list. + // The current posting list can fit at most 2 more hits. std::vector<Hit> hits2 = CreateHits( - /*start_docid=*/hits1.back().document_id() + 1, /*num_hits=*/12, + /*last_hit=*/hits1.back(), /*num_hits=*/2, /*desired_byte_length=*/1); for (const Hit& hit : hits2) { @@ -172,18 +177,36 @@ TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) { PostingListAccessor::FinalizeResult result2 = std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result2.status); + // The 2 hits should still fit on the first block + EXPECT_THAT(result1.id.block_index(), Eq(1)); + EXPECT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more hit + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result2.id)); + // The current posting list should be FULL. Adding more hits should result in + // these hits being moved to a larger posting list. + Hit single_hit = + CreateHit(/*last_hit=*/hits2.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(pl_accessor->PrependHit(single_hit)); + PostingListAccessor::FinalizeResult result3 = + std::move(*pl_accessor).Finalize(); + ICING_EXPECT_OK(result3.status); // Should have been allocated to the second (new) block because the posting // list should have grown beyond the size that the first block maintains. - EXPECT_THAT(result2.id.block_index(), Eq(2)); - EXPECT_THAT(result2.id.posting_list_index(), Eq(0)); + EXPECT_THAT(result3.id.block_index(), Eq(2)); + EXPECT_THAT(result3.id.posting_list_index(), Eq(0)); - // The posting list at result2.id should hold all of the hits that have been + // The posting list at result3.id should hold all of the hits that have been // added. for (const Hit& hit : hits2) { hits1.push_back(hit); } + hits1.push_back(single_hit); ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, - flash_index_storage_->GetPostingList(result2.id)); + flash_index_storage_->GetPostingList(result3.id)); EXPECT_THAT(serializer_->GetHits(&pl_holder.posting_list), IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); } @@ -307,7 +330,7 @@ TEST_F(PostingListHitAccessorTest, InvalidHitReturnsInvalidArgument) { std::unique_ptr<PostingListHitAccessor> pl_accessor, PostingListHitAccessor::Create(flash_index_storage_.get(), serializer_.get())); - Hit invalid_hit; + Hit invalid_hit(Hit::kInvalidValue); EXPECT_THAT(pl_accessor->PrependHit(invalid_hit), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -317,14 +340,17 @@ TEST_F(PostingListHitAccessorTest, HitsNotDecreasingReturnsInvalidArgument) { std::unique_ptr<PostingListHitAccessor> pl_accessor, PostingListHitAccessor::Create(flash_index_storage_.get(), serializer_.get())); - Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency); + Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); - Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency); + Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(pl_accessor->PrependHit(hit2), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency); + Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(pl_accessor->PrependHit(hit3), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -345,7 +371,8 @@ TEST_F(PostingListHitAccessorTest, PreexistingPostingListNoHitsAdded) { std::unique_ptr<PostingListHitAccessor> pl_accessor, PostingListHitAccessor::Create(flash_index_storage_.get(), serializer_.get())); - Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency); + Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); PostingListAccessor::FinalizeResult result1 = std::move(*pl_accessor).Finalize(); diff --git a/icing/index/main/posting-list-hit-serializer.cc b/icing/index/main/posting-list-hit-serializer.cc index e14a0c0..88c0754 100644 --- a/icing/index/main/posting-list-hit-serializer.cc +++ b/icing/index/main/posting-list-hit-serializer.cc @@ -19,8 +19,11 @@ #include <limits> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/hit/hit.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/index/icing-bit-util.h" #include "icing/util/logging.h" @@ -35,6 +38,10 @@ uint32_t GetTermFrequencyByteSize(const Hit& hit) { return hit.has_term_frequency() ? sizeof(Hit::TermFrequency) : 0; } +uint32_t GetFlagsByteSize(const Hit& hit) { + return hit.has_flags() ? sizeof(Hit::Flags) : 0; +} + } // namespace uint32_t PostingListHitSerializer::GetBytesUsed( @@ -55,14 +62,45 @@ uint32_t PostingListHitSerializer::GetMinPostingListSizeToFit( return posting_list_used->size_in_bytes(); } - // In NOT_FULL status BytesUsed contains no special hits. The minimum sized - // posting list that would be guaranteed to fit these hits would be - // ALMOST_FULL, with kInvalidHit in special_hit(0), the uncompressed Hit in - // special_hit(1) and the n compressed hits in the compressed region. - // BytesUsed contains one uncompressed Hit and n compressed hits. Therefore, - // fitting these hits into a posting list would require BytesUsed plus one - // extra hit. - return GetBytesUsed(posting_list_used) + sizeof(Hit); + // Edge case: if number of hits <= 2, and posting_list_used is NOT_FULL, we + // should return kMinPostingListSize as we know that two hits is able to fit + // in the 2 special hits position of a min-sized posting list. + // - Checking this scenario requires deserializing posting_list_used, which is + // an expensive operation. However this is only possible when + // GetBytesUsed(posting_list_used) <= + // sizeof(uncompressed hit) + max(delta encoded value size) + + // (sizeof(Hit) + (sizeof(Hit) - sizeof(Hit::Value)), so we don't need to do + // this when the size exceeds this number. + if (GetBytesUsed(posting_list_used) <= + 2 * sizeof(Hit) + 5 - sizeof(Hit::Value)) { + // Check if we're able to get more than 2 hits from posting_list_used + std::vector<Hit> hits; + libtextclassifier3::Status status = + GetHitsInternal(posting_list_used, /*limit=*/3, /*pop=*/false, &hits); + if (status.ok() && hits.size() <= 2) { + return GetMinPostingListSize(); + } + } + + // - In NOT_FULL status, BytesUsed contains no special hits. For a posting + // list in the NOT_FULL state with n hits, we would have n-1 compressed hits + // and 1 uncompressed hit. + // - The minimum sized posting list that would be guaranteed to fit these hits + // would be FULL, but calculating the size required for the FULL posting + // list would require deserializing the last two added hits, so instead we + // will calculate the size of an ALMOST_FULL posting list to fit. + // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the + // full uncompressed Hit in special_hit(1), and the n-1 compressed hits in + // the compressed region. + // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed + // hits. However, it's possible that the uncompressed Hit is not a full hit, + // but rather only the Hit::Value (this is the case if + // !hit.has_term_frequency()). + // - Therefore, fitting these hits into a posting list would require + // BytesUsed + one extra full hit + byte difference between a full hit and + // Hit::Value. i.e: + // ByteUsed + sizeof(Hit) + (sizeof(Hit) - sizeof(Hit::Value)). + return GetBytesUsed(posting_list_used) + 2 * sizeof(Hit) - sizeof(Hit::Value); } void PostingListHitSerializer::Clear(PostingListUsed* posting_list_used) const { @@ -160,27 +198,37 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToAlmostFull( // in the padded area and put new hit at the special position 1. // Calling ValueOrDie is safe here because 1 < kNumSpecialData. Hit cur = GetSpecialHit(posting_list_used, /*index=*/1).ValueOrDie(); - if (cur.value() <= hit.value()) { + if (cur < hit || cur == hit) { return absl_ports::InvalidArgumentError( "Hit being prepended must be strictly less than the most recent Hit"); } - uint64_t delta = cur.value() - hit.value(); uint8_t delta_buf[VarInt::kMaxEncodedLen64]; - size_t delta_len = VarInt::Encode(delta, delta_buf); + size_t delta_len = + EncodeNextHitValue(/*next_hit_value=*/hit.value(), + /*curr_hit_value=*/cur.value(), delta_buf); + uint32_t cur_flags_bytes = GetFlagsByteSize(cur); uint32_t cur_term_frequency_bytes = GetTermFrequencyByteSize(cur); uint32_t pad_end = GetPadEnd(posting_list_used, /*offset=*/kSpecialHitsSize); - if (pad_end >= kSpecialHitsSize + delta_len + cur_term_frequency_bytes) { - // Pad area has enough space for delta and term_frequency of existing hit - // (cur). Write delta at pad_end - delta_len - cur_term_frequency_bytes. + if (pad_end >= kSpecialHitsSize + delta_len + cur_flags_bytes + + cur_term_frequency_bytes) { + // Pad area has enough space for delta, flags and term_frequency of existing + // hit (cur). Write delta at pad_end - delta_len - cur_flags_bytes - + // cur_term_frequency_bytes. uint8_t* delta_offset = posting_list_used->posting_list_buffer() + pad_end - - delta_len - cur_term_frequency_bytes; + delta_len - cur_flags_bytes - + cur_term_frequency_bytes; memcpy(delta_offset, delta_buf, delta_len); - // Now copy term_frequency. + + // Now copy flags. + Hit::Flags flags = cur.flags(); + uint8_t* flags_offset = delta_offset + delta_len; + memcpy(flags_offset, &flags, cur_flags_bytes); + // Copy term_frequency. Hit::TermFrequency term_frequency = cur.term_frequency(); - uint8_t* term_frequency_offset = delta_offset + delta_len; + uint8_t* term_frequency_offset = flags_offset + cur_flags_bytes; memcpy(term_frequency_offset, &term_frequency, cur_term_frequency_bytes); // Now first hit is the new hit, at special position 1. Safe to ignore the @@ -231,23 +279,36 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull( return absl_ports::FailedPreconditionError( "Posting list is in an invalid state."); } + + // Retrieve the last added (cur) hit's value and flags and compare to the hit + // we're adding. Hit::Value cur_value; - memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset, - sizeof(Hit::Value)); - if (cur_value <= hit.value()) { + uint8_t* cur_value_offset = posting_list_used->posting_list_buffer() + offset; + memcpy(&cur_value, cur_value_offset, sizeof(Hit::Value)); + Hit::Flags cur_flags = Hit::kNoEnabledFlags; + if (GetFlagsByteSize(Hit(cur_value)) > 0) { + uint8_t* cur_flags_offset = cur_value_offset + sizeof(Hit::Value); + memcpy(&cur_flags, cur_flags_offset, sizeof(Hit::Flags)); + } + // Term-frequency is not used for hit comparison so it's ok to pass in the + // default term-frequency here. + Hit cur(cur_value, cur_flags, Hit::kDefaultTermFrequency); + if (cur < hit || cur == hit) { return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "Hit %d being prepended must be strictly less than the most recent " - "Hit %d", - hit.value(), cur_value)); + "Hit (value=%d, flags=%d) being prepended must be strictly less than " + "the most recent Hit (value=%d, flags=%d)", + hit.value(), hit.flags(), cur_value, cur_flags)); } - uint64_t delta = cur_value - hit.value(); uint8_t delta_buf[VarInt::kMaxEncodedLen64]; - size_t delta_len = VarInt::Encode(delta, delta_buf); + size_t delta_len = + EncodeNextHitValue(/*next_hit_value=*/hit.value(), + /*curr_hit_value=*/cur.value(), delta_buf); + uint32_t hit_flags_bytes = GetFlagsByteSize(hit); uint32_t hit_term_frequency_bytes = GetTermFrequencyByteSize(hit); // offset now points to one past the end of the first hit. offset += sizeof(Hit::Value); - if (kSpecialHitsSize + sizeof(Hit::Value) + delta_len + + if (kSpecialHitsSize + sizeof(Hit::Value) + delta_len + hit_flags_bytes + hit_term_frequency_bytes <= offset) { // Enough space for delta in compressed area. @@ -257,9 +318,9 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull( memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf, delta_len); - // Prepend new hit with (possibly) its term_frequency. We know that there is - // room for 'hit' because of the if statement above, so calling ValueOrDie - // is safe. + // Prepend new hit with (possibly) its flags and term_frequency. We know + // that there is room for 'hit' because of the if statement above, so + // calling ValueOrDie is safe. offset = PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie(); // offset is guaranteed to be valid here. So it's safe to ignore the return @@ -295,13 +356,13 @@ libtextclassifier3::Status PostingListHitSerializer::PrependHitToNotFull( // (i.e. varint delta encoding expanded required storage). We // move first hit to special position 1 and put new hit in // special position 0. - Hit cur(cur_value); - // offset is < kSpecialHitsSize + delta_len. delta_len is at most 5 bytes. + + // Offset is < kSpecialHitsSize + delta_len. delta_len is at most 5 bytes. // Therefore, offset must be less than kSpecialHitSize + 5. Since posting - // list size must be divisible by sizeof(Hit) (5), it is guaranteed that + // list size must be divisible by sizeof(Hit) (6), it is guaranteed that // offset < size_in_bytes, so it is safe to ignore the return value here. - ICING_RETURN_IF_ERROR( - ConsumeTermFrequencyIfPresent(posting_list_used, &cur, &offset)); + ICING_RETURN_IF_ERROR(ConsumeFlagsAndTermFrequencyIfPresent( + posting_list_used, &cur, &offset)); // Safe to ignore the return value of PadToEnd because offset must be less // than posting_list_used->size_in_bytes(). Otherwise, this function // already would have returned FAILED_PRECONDITION. @@ -463,19 +524,19 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal( offset += sizeof(Hit::Value); } else { // Now we have delta encoded subsequent hits. Decode and push. - uint64_t delta; - offset += VarInt::Decode( - posting_list_used->posting_list_buffer() + offset, &delta); - val += delta; + DecodedHitInfo decoded_hit_info = DecodeNextHitValue( + posting_list_used->posting_list_buffer() + offset, val); + offset += decoded_hit_info.encoded_size; + val = decoded_hit_info.hit_value; } Hit hit(val); libtextclassifier3::Status status = - ConsumeTermFrequencyIfPresent(posting_list_used, &hit, &offset); + ConsumeFlagsAndTermFrequencyIfPresent(posting_list_used, &hit, &offset); if (!status.ok()) { // This posting list has been corrupted somehow. The first hit of the - // posting list claims to have a term frequency, but there's no more room - // in the posting list for that term frequency to exist. Return an empty - // vector and zero to indicate no hits retrieved. + // posting list claims to have a term frequency or flag, but there's no + // more room in the posting list for that term frequency or flag to exist. + // Return an empty vector and zero to indicate no hits retrieved. if (out != nullptr) { out->clear(); } @@ -497,10 +558,10 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal( // In the compressed area. Pop and reconstruct. offset/val is // the last traversed hit, which we must discard. So move one // more forward. - uint64_t delta; - offset += VarInt::Decode( - posting_list_used->posting_list_buffer() + offset, &delta); - val += delta; + DecodedHitInfo decoded_hit_info = DecodeNextHitValue( + posting_list_used->posting_list_buffer() + offset, val); + offset += decoded_hit_info.encoded_size; + val = decoded_hit_info.hit_value; // Now val is the first hit of the new posting list. if (kSpecialHitsSize + sizeof(Hit::Value) <= offset) { @@ -509,17 +570,18 @@ libtextclassifier3::Status PostingListHitSerializer::GetHitsInternal( memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val, sizeof(Hit::Value)); } else { - // val won't fit in compressed area. Also see if there is a + // val won't fit in compressed area. Also see if there is a flag or // term_frequency. Hit hit(val); libtextclassifier3::Status status = - ConsumeTermFrequencyIfPresent(posting_list_used, &hit, &offset); + ConsumeFlagsAndTermFrequencyIfPresent(posting_list_used, &hit, + &offset); if (!status.ok()) { // This posting list has been corrupted somehow. The first hit of - // the posting list claims to have a term frequency, but there's no - // more room in the posting list for that term frequency to exist. - // Return an empty vector and zero to indicate no hits retrieved. Do - // not pop anything. + // the posting list claims to have a term frequency or flag, but + // there's no more room in the posting list for that term frequency or + // flag to exist. Return an empty vector and zero to indicate no hits + // retrieved. Do not pop anything. if (out != nullptr) { out->clear(); } @@ -569,7 +631,7 @@ libtextclassifier3::StatusOr<Hit> PostingListHitSerializer::GetSpecialHit( return absl_ports::InvalidArgumentError( "Special hits only exist at indices 0 and 1"); } - Hit val; + Hit val(Hit::kInvalidValue); memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val), sizeof(val)); return val; @@ -653,11 +715,11 @@ bool PostingListHitSerializer::SetStartByteOffset( // not_full state. Safe to ignore the return value because 0 and 1 are both // < kNumSpecialData. SetSpecialHit(posting_list_used, /*index=*/0, Hit(offset)); - SetSpecialHit(posting_list_used, /*index=*/1, Hit()); + SetSpecialHit(posting_list_used, /*index=*/1, Hit(Hit::kInvalidValue)); } else if (offset == sizeof(Hit)) { // almost_full state. Safe to ignore the return value because 1 is both < // kNumSpecialData. - SetSpecialHit(posting_list_used, /*index=*/0, Hit()); + SetSpecialHit(posting_list_used, /*index=*/0, Hit(Hit::kInvalidValue)); } // Nothing to do for the FULL state - the offset isn't actually stored // anywhere and both special hits hold valid hits. @@ -667,46 +729,72 @@ bool PostingListHitSerializer::SetStartByteOffset( libtextclassifier3::StatusOr<uint32_t> PostingListHitSerializer::PrependHitUncompressed( PostingListUsed* posting_list_used, const Hit& hit, uint32_t offset) const { - if (hit.has_term_frequency()) { - if (offset < kSpecialHitsSize + sizeof(Hit)) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "Not enough room to prepend Hit at offset %d.", offset)); - } - offset -= sizeof(Hit); - memcpy(posting_list_used->posting_list_buffer() + offset, &hit, - sizeof(Hit)); - } else { - if (offset < kSpecialHitsSize + sizeof(Hit::Value)) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "Not enough room to prepend Hit::Value at offset %d.", offset)); - } - offset -= sizeof(Hit::Value); - Hit::Value val = hit.value(); - memcpy(posting_list_used->posting_list_buffer() + offset, &val, - sizeof(Hit::Value)); + uint32_t hit_bytes_to_prepend = sizeof(Hit::Value) + GetFlagsByteSize(hit) + + GetTermFrequencyByteSize(hit); + + if (offset < kSpecialHitsSize + hit_bytes_to_prepend) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Not enough room to prepend Hit at offset %d.", offset)); } + + if (hit.has_term_frequency()) { + offset -= sizeof(Hit::TermFrequency); + Hit::TermFrequency term_frequency = hit.term_frequency(); + memcpy(posting_list_used->posting_list_buffer() + offset, &term_frequency, + sizeof(Hit::TermFrequency)); + } + if (hit.has_flags()) { + offset -= sizeof(Hit::Flags); + Hit::Flags flags = hit.flags(); + memcpy(posting_list_used->posting_list_buffer() + offset, &flags, + sizeof(Hit::Flags)); + } + offset -= sizeof(Hit::Value); + Hit::Value val = hit.value(); + memcpy(posting_list_used->posting_list_buffer() + offset, &val, + sizeof(Hit::Value)); return offset; } libtextclassifier3::Status -PostingListHitSerializer::ConsumeTermFrequencyIfPresent( +PostingListHitSerializer::ConsumeFlagsAndTermFrequencyIfPresent( const PostingListUsed* posting_list_used, Hit* hit, uint32_t* offset) const { - if (!hit->has_term_frequency()) { - // No term frequency to consume. Everything is fine. + if (!hit->has_flags()) { + // No flags (and by extension, no term-frequency) to consume. Everything is + // fine. return libtextclassifier3::Status::OK; } - if (*offset + sizeof(Hit::TermFrequency) > - posting_list_used->size_in_bytes()) { + + // Consume flags + Hit::Flags flags; + if (*offset + sizeof(Hit::Flags) > posting_list_used->size_in_bytes()) { return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "offset %d must not point past the end of the posting list of size %d.", + "offset %d must not point past the end of the posting list of size " + "%d.", *offset, posting_list_used->size_in_bytes())); } - Hit::TermFrequency term_frequency; - memcpy(&term_frequency, posting_list_used->posting_list_buffer() + *offset, - sizeof(Hit::TermFrequency)); - *hit = Hit(hit->value(), term_frequency); - *offset += sizeof(Hit::TermFrequency); + memcpy(&flags, posting_list_used->posting_list_buffer() + *offset, + sizeof(Hit::Flags)); + *hit = Hit(hit->value(), flags, Hit::kDefaultTermFrequency); + *offset += sizeof(Hit::Flags); + + if (hit->has_term_frequency()) { + // Consume term frequency + if (*offset + sizeof(Hit::TermFrequency) > + posting_list_used->size_in_bytes()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "offset %d must not point past the end of the posting list of size " + "%d.", + *offset, posting_list_used->size_in_bytes())); + } + Hit::TermFrequency term_frequency; + memcpy(&term_frequency, posting_list_used->posting_list_buffer() + *offset, + sizeof(Hit::TermFrequency)); + *hit = Hit(hit->value(), flags, term_frequency); + *offset += sizeof(Hit::TermFrequency); + } + return libtextclassifier3::Status::OK; } diff --git a/icing/index/main/posting-list-hit-serializer.h b/icing/index/main/posting-list-hit-serializer.h index 2986d9c..08c792c 100644 --- a/icing/index/main/posting-list-hit-serializer.h +++ b/icing/index/main/posting-list-hit-serializer.h @@ -15,6 +15,7 @@ #ifndef ICING_INDEX_MAIN_POSTING_LIST_HIT_SERIALIZER_H_ #define ICING_INDEX_MAIN_POSTING_LIST_HIT_SERIALIZER_H_ +#include <cstddef> #include <cstdint> #include <vector> @@ -23,6 +24,7 @@ #include "icing/file/posting_list/posting-list-common.h" #include "icing/file/posting_list/posting-list-used.h" #include "icing/index/hit/hit.h" +#include "icing/legacy/index/icing-bit-util.h" #include "icing/util/status-macros.h" namespace icing { @@ -34,6 +36,52 @@ class PostingListHitSerializer : public PostingListSerializer { public: static constexpr uint32_t kSpecialHitsSize = kNumSpecialData * sizeof(Hit); + struct DecodedHitInfo { + // The decoded hit value. + Hit::Value hit_value; + + // Size of the encoded hit in bytes. + size_t encoded_size; + }; + + // Given the current hit value, encodes the next hit value for serialization + // in the posting list. + // + // The encoded value is the varint-encoded delta between next_hit_value and + // curr_hit_value. + // - We add 1 to this delta so as to avoid getting a delta value of 0. + // - This allows us to add duplicate hits with the same value, which is a + // valid case if we need to store hits with different flags that belong in + // the same section-id/doc-id combo. + // - We cannot have an encoded hit delta with a value of 0 as 0 is currently + // used for padding the unused region in the posting list. + // + // REQUIRES: next_hit_value <= curr_hit_value AND + // curr_hit_value - next_hit_value < + // std::numeric_limits<Hit::Value>::max() + // + // RETURNS: next_hit_value's encoded length in bytes and writes the encoded + // value directly into encoded_buf_out. + static size_t EncodeNextHitValue(Hit::Value next_hit_value, + Hit::Value curr_hit_value, + uint8_t* encoded_buf_out) { + uint64_t delta = curr_hit_value - next_hit_value + 1; + return VarInt::Encode(delta, encoded_buf_out); + } + + // Given the current hit value, decodes the next hit value from an encoded + // byte array buffer. + // + // RETURNS: DecodedHitInfo containing the decoded hit value and the value's + // encoded size in bytes. + static DecodedHitInfo DecodeNextHitValue(const uint8_t* encoded_buf_in, + Hit::Value curr_hit_value) { + uint64_t delta; + size_t delta_size = VarInt::Decode(encoded_buf_in, &delta); + Hit::Value hit_value = curr_hit_value + delta - 1; + return {hit_value, delta_size}; + } + uint32_t GetDataTypeBytes() const override { return sizeof(Hit); } uint32_t GetMinPostingListSize() const override { @@ -299,15 +347,19 @@ class PostingListHitSerializer : public PostingListSerializer { PostingListUsed* posting_list_used, const Hit& hit, uint32_t offset) const; - // If hit has a term frequency, consumes the term frequency at offset, updates - // hit to include the term frequency and updates offset to reflect that the - // term frequency has been consumed. + // If hit has the flags and/or term frequency field, consumes the flags and/or + // term frequency at offset, updates hit to include the flag and/or term + // frequency and updates offset to reflect that the flag and/or term frequency + // fields have been consumed. // // RETURNS: // - OK, if successful - // - INVALID_ARGUMENT if hit has a term frequency and offset + - // sizeof(Hit::TermFrequency) >= posting_list_used->size_in_bytes() - libtextclassifier3::Status ConsumeTermFrequencyIfPresent( + // - INVALID_ARGUMENT if hit has a flags and/or term frequency field and + // offset + sizeof(Hit's flag) + sizeof(Hit's tf) >= + // posting_list_used->size_in_bytes() + // i.e. the posting list is not large enough to consume the hit's flags + // and term frequency fields + libtextclassifier3::Status ConsumeFlagsAndTermFrequencyIfPresent( const PostingListUsed* posting_list_used, Hit* hit, uint32_t* offset) const; }; diff --git a/icing/index/main/posting-list-hit-serializer_test.cc b/icing/index/main/posting-list-hit-serializer_test.cc index 7f0b945..ea135ef 100644 --- a/icing/index/main/posting-list-hit-serializer_test.cc +++ b/icing/index/main/posting-list-hit-serializer_test.cc @@ -14,22 +14,32 @@ #include "icing/index/main/posting-list-hit-serializer.h" +#include <algorithm> +#include <cstddef> #include <cstdint> #include <deque> -#include <memory> +#include <iterator> +#include <limits> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/hit-test-utils.h" using testing::ElementsAre; using testing::ElementsAreArray; using testing::Eq; +using testing::Gt; using testing::IsEmpty; +using testing::IsFalse; +using testing::IsTrue; using testing::Le; using testing::Lt; @@ -58,103 +68,217 @@ TEST(PostingListHitSerializerTest, PostingListUsedPrependHitNotFull) { PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); // Make used. - Hit hit0(/*section_id=*/0, 0, /*term_frequency=*/56); + Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/56, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0)); - // Size = sizeof(uncompressed hit0) - int expected_size = sizeof(Hit); - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // Size = sizeof(uncompressed hit0::Value) + // + sizeof(hit0::Flags) + // + sizeof(hit0::TermFrequency) + int expected_size = + sizeof(Hit::Value) + sizeof(Hit::Flags) + sizeof(Hit::TermFrequency); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); - Hit hit1(/*section_id=*/0, 1, Hit::kDefaultTermFrequency); + Hit hit1(/*section_id=*/0, /*document_id=*/1, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = PostingListHitSerializer::EncodeNextHitValue( + /*next_hit_value=*/hit1.value(), + /*curr_hit_value=*/hit0.value(), delta_buf); ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1)); - // Size = sizeof(uncompressed hit1) - // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency) - expected_size += 2 + sizeof(Hit::TermFrequency); - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // Size = sizeof(uncompressed hit1::Value) + // + sizeof(hit0-hit1) + // + sizeof(hit0::Flags) + // + sizeof(hit0::TermFrequency) + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit1, hit0))); - Hit hit2(/*section_id=*/0, 2, /*term_frequency=*/56); + Hit hit2(/*section_id=*/0, /*document_id=*/2, /*term_frequency=*/56, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + delta_len = PostingListHitSerializer::EncodeNextHitValue( + /*next_hit_value=*/hit2.value(), + /*curr_hit_value=*/hit1.value(), delta_buf); ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2)); - // Size = sizeof(uncompressed hit2) + // Size = sizeof(uncompressed hit2::Value) + sizeof(hit2::Flags) + // + sizeof(hit2::TermFrequency) // + sizeof(hit1-hit2) - // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency) - expected_size += 2; - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // + sizeof(hit0-hit1) + // + sizeof(hit0::flags) + // + sizeof(hit0::term_frequency) + expected_size += delta_len + sizeof(Hit::Flags) + sizeof(Hit::TermFrequency); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); - Hit hit3(/*section_id=*/0, 3, Hit::kDefaultTermFrequency); + Hit hit3(/*section_id=*/0, /*document_id=*/3, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + delta_len = PostingListHitSerializer::EncodeNextHitValue( + /*next_hit_value=*/hit3.value(), + /*curr_hit_value=*/hit2.value(), delta_buf); ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3)); - // Size = sizeof(uncompressed hit3) - // + sizeof(hit2-hit3) + sizeof(hit2::term_frequency) + // Size = sizeof(uncompressed hit3::Value) + // + sizeof(hit2-hit3) + sizeof(hit2::Flags) + // + sizeof(hit2::TermFrequency) // + sizeof(hit1-hit2) - // + sizeof(hit0-hit1) + sizeof(hit0::term_frequency) - expected_size += 2 + sizeof(Hit::TermFrequency); - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // + sizeof(hit0-hit1) + // + sizeof(hit0::flags) + // + sizeof(hit0::term_frequency) + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); +} + +TEST(PostingListHitSerializerTest, + PostingListUsedPrependHitAlmostFull_withFlags) { + PostingListHitSerializer serializer; + + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Fill up the compressed region. + // Transitions: + // Adding hit0: EMPTY -> NOT_FULL + // Adding hit1: NOT_FULL -> NOT_FULL + // Adding hit2: NOT_FULL -> NOT_FULL + Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3); + Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2, /*term_frequency=*/57, + /*is_in_prefix_section=*/true, + /*is_prefix_hit=*/true); + EXPECT_THAT(hit2.has_flags(), IsTrue()); + EXPECT_THAT(hit2.has_term_frequency(), IsTrue()); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); + // Size used will be 4 (hit2::Value) + 1 (hit2::Flags) + 1 + // (hit2::TermFrequency) + 2 (hit1-hit2) + 3 (hit0-hit1) = 11 bytes + int expected_size = sizeof(Hit::Value) + sizeof(Hit::Flags) + + sizeof(Hit::TermFrequency) + 2 + 3; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + // Add one more hit to transition NOT_FULL -> ALMOST_FULL + Hit hit3 = + CreateHit(hit2, /*desired_byte_length=*/3, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + EXPECT_THAT(hit3.has_flags(), IsFalse()); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); + // Storing them in the compressed region requires 4 (hit3::Value) + 3 + // (hit2-hit3) + 1 (hit2::Flags) + 1 (hit2::TermFrequency) + 2 (hit1-hit2) + 3 + // (hit0-hit1) = 14 bytes, but there are only 12 bytes in the compressed + // region. So instead, the posting list will transition to ALMOST_FULL. + // The in-use compressed region will actually shrink from 11 bytes to 10 bytes + // because the uncompressed version of hit2 will be overwritten with the + // compressed delta of hit2. hit3 will be written to one of the special hits. + // Because we're in ALMOST_FULL, the expected size is the size of the pl minus + // the one hit used to mark the posting list as ALMOST_FULL. + expected_size = pl_size - sizeof(Hit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL + Hit hit4 = CreateHit(hit3, /*desired_byte_length=*/2); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4)); + // There are currently 10 bytes in use in the compressed region. hit3 will + // have a 2-byte delta, which fits in the compressed region. Hit3 will be + // moved from the special hit to the compressed region (which will have 12 + // bytes in use after adding hit3). Hit4 will be placed in one of the special + // hits and the posting list will remain in ALMOST_FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> FULL + Hit hit5 = CreateHit(hit4, /*desired_byte_length=*/2); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5)); + // There are currently 12 bytes in use in the compressed region. hit4 will + // have a 2-byte delta which will not fit in the compressed region. So hit4 + // will remain in one of the special hits and hit5 will occupy the other, + // making the posting list FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0))); + + // The posting list is FULL. Adding another hit should fail. + Hit hit6 = CreateHit(hit5, /*desired_byte_length=*/1); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit6), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); } TEST(PostingListHitSerializerTest, PostingListUsedPrependHitAlmostFull) { PostingListHitSerializer serializer; - int size = 2 * serializer.GetMinPostingListSize(); + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); // Fill up the compressed region. // Transitions: // Adding hit0: EMPTY -> NOT_FULL // Adding hit1: NOT_FULL -> NOT_FULL // Adding hit2: NOT_FULL -> NOT_FULL - Hit hit0(/*section_id=*/0, 0, Hit::kDefaultTermFrequency); - Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/2); - Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2); + Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3); + Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/3); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); - // Size used will be 2+2+4=8 bytes - int expected_size = sizeof(Hit::Value) + 2 + 2; - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // Size used will be 4 (hit2::Value) + 3 (hit1-hit2) + 3 (hit0-hit1) + // = 10 bytes + int expected_size = sizeof(Hit::Value) + 3 + 3; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); // Add one more hit to transition NOT_FULL -> ALMOST_FULL Hit hit3 = CreateHit(hit2, /*desired_byte_length=*/3); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); - // Compressed region would be 2+2+3+4=11 bytes, but the compressed region is - // only 10 bytes. So instead, the posting list will transition to ALMOST_FULL. - // The in-use compressed region will actually shrink from 8 bytes to 7 bytes + // Storing them in the compressed region requires 4 (hit3::Value) + 3 + // (hit2-hit3) + 3 (hit1-hit2) + 3 (hit0-hit1) = 13 bytes, but there are only + // 12 bytes in the compressed region. So instead, the posting list will + // transition to ALMOST_FULL. + // The in-use compressed region will actually shrink from 10 bytes to 9 bytes // because the uncompressed version of hit2 will be overwritten with the // compressed delta of hit2. hit3 will be written to one of the special hits. // Because we're in ALMOST_FULL, the expected size is the size of the pl minus // the one hit used to mark the posting list as ALMOST_FULL. - expected_size = size - sizeof(Hit); - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + expected_size = pl_size - sizeof(Hit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL Hit hit4 = CreateHit(hit3, /*desired_byte_length=*/2); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4)); - // There are currently 7 bytes in use in the compressed region. hit3 will have - // a 2-byte delta. That delta will fit in the compressed region (which will - // now have 9 bytes in use), hit4 will be placed in one of the special hits - // and the posting list will remain in ALMOST_FULL. - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); + // There are currently 9 bytes in use in the compressed region. Hit3 will + // have a 2-byte delta, which fits in the compressed region. Hit3 will be + // moved from the special hit to the compressed region (which will have 11 + // bytes in use after adding hit3). Hit4 will be placed in one of the special + // hits and the posting list will remain in ALMOST_FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0))); // Add one more hit to transition ALMOST_FULL -> FULL Hit hit5 = CreateHit(hit4, /*desired_byte_length=*/2); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5)); - // There are currently 9 bytes in use in the compressed region. hit4 will have - // a 2-byte delta which will not fit in the compressed region. So hit4 will - // remain in one of the special hits and hit5 will occupy the other, making - // the posting list FULL. - EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(size)); + // There are currently 11 bytes in use in the compressed region. Hit4 will + // have a 2-byte delta which will not fit in the compressed region. So hit4 + // will remain in one of the special hits and hit5 will occupy the other, + // making the posting list FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0))); @@ -164,9 +288,59 @@ TEST(PostingListHitSerializerTest, PostingListUsedPrependHitAlmostFull) { StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); } +TEST(PostingListHitSerializerTest, PrependHitsWithSameValue) { + PostingListHitSerializer serializer; + + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Fill up the compressed region. + Hit hit0(/*section_id=*/0, /*document_id=*/0, Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/3); + Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2, /*term_frequency=*/57, + /*is_in_prefix_section=*/true, + /*is_prefix_hit=*/true); + // Create hit3 with the same value but different flags as hit2 (hit3_flags + // is set to have all currently-defined flags enabled) + Hit::Flags hit3_flags = 0; + for (int i = 0; i < Hit::kNumFlagsInFlagsField; ++i) { + hit3_flags |= (1 << i); + } + Hit hit3(hit2.value(), /*term_frequency=*/hit2.term_frequency(), + /*flags=*/hit3_flags); + + // hit3 is larger than hit2 (its flag value is larger), and so needs to be + // prepended first + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); + // Posting list is now ALMOST_FULL + // ---------------------- + // 23-21 delta(Hit #0) + // 20-19 delta(Hit #1) + // 18 term-frequency(Hit #2) + // 17 flags(Hit #2) + // 16 delta(Hit #2) = 0 + // 15-12 <unused padding = 0> + // 11-6 Hit #3 + // 5-0 kSpecialHit + // ---------------------- + int bytes_used = pl_size - sizeof(Hit); + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit3, hit1, hit0))); +} + TEST(PostingListHitSerializerTest, PostingListUsedMinSize) { PostingListHitSerializer serializer; + // Min size = 12 ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, PostingListUsed::CreateFromUnitializedRegion( @@ -176,7 +350,7 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) { EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); // Add a hit, PL should shift to ALMOST_FULL state - Hit hit0(/*section_id=*/0, 0, /*term_frequency=*/0, + Hit hit0(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/0, /*is_in_prefix_section=*/false, /*is_prefix_hit=*/true); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); @@ -185,10 +359,10 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) { EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Le(expected_size)); EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); - // Add the smallest hit possible - no term_frequency and a delta of 1. PL - // should shift to FULL state. - Hit hit1(/*section_id=*/0, 0, /*term_frequency=*/0, - /*is_in_prefix_section=*/true, + // Add the smallest hit possible - no term_frequency, non-prefix hit and a + // delta of 0b10. PL should shift to FULL state. + Hit hit1(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/0, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0) @@ -198,7 +372,7 @@ TEST(PostingListHitSerializerTest, PostingListUsedMinSize) { IsOkAndHolds(ElementsAre(hit1, hit0))); // Try to add the smallest hit possible. Should fail - Hit hit2(/*section_id=*/0, 0, /*term_frequency=*/0, + Hit hit2(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/0, /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); EXPECT_THAT(serializer.PrependHit(&pl_used, hit2), @@ -212,14 +386,17 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayMinSizePostingList) { PostingListHitSerializer serializer; - // Min Size = 10 - int size = serializer.GetMinPostingListSize(); + // Min Size = 12 + int pl_size = serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<HitElt> hits_in; - hits_in.emplace_back(Hit(1, 0, Hit::kDefaultTermFrequency)); + hits_in.emplace_back(Hit(/*section_id=*/1, /*document_id=*/0, + Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false)); hits_in.emplace_back( CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); hits_in.emplace_back( @@ -235,7 +412,7 @@ TEST(PostingListHitSerializerTest, ICING_ASSERT_OK_AND_ASSIGN( uint32_t num_can_prepend, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hits_in[0], hits_in.size(), false))); + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); EXPECT_THAT(num_can_prepend, Eq(2)); int can_fit_hits = num_can_prepend; @@ -243,10 +420,11 @@ TEST(PostingListHitSerializerTest, // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2); ICING_ASSERT_OK_AND_ASSIGN( - num_can_prepend, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, hits_in_ptr, can_fit_hits, false))); + num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false))); EXPECT_THAT(num_can_prepend, Eq(can_fit_hits)); - EXPECT_THAT(size, Eq(serializer.GetBytesUsed(&pl_used))); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); std::deque<Hit> hits_pushed; std::transform(hits_in.rbegin(), hits_in.rend() - hits_in.size() + can_fit_hits, @@ -258,14 +436,17 @@ TEST(PostingListHitSerializerTest, TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { PostingListHitSerializer serializer; - // Size = 30 - int size = 3 * serializer.GetMinPostingListSize(); + // Size = 36 + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<HitElt> hits_in; - hits_in.emplace_back(Hit(1, 0, Hit::kDefaultTermFrequency)); + hits_in.emplace_back(Hit(/*section_id=*/1, /*document_id=*/0, + Hit::kDefaultTermFrequency, + /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false)); hits_in.emplace_back( CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); hits_in.emplace_back( @@ -278,14 +459,14 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { // The last hit is uncompressed and the four before it should only take one // byte. Total use = 8 bytes. // ---------------------- - // 29 delta(Hit #1) - // 28 delta(Hit #2) - // 27 delta(Hit #3) - // 26 delta(Hit #4) - // 25-22 Hit #5 - // 21-10 <unused> - // 9-5 kSpecialHit - // 4-0 Offset=22 + // 35 delta(Hit #0) + // 34 delta(Hit #1) + // 33 delta(Hit #2) + // 32 delta(Hit #3) + // 31-28 Hit #4 + // 27-12 <unused> + // 11-6 kSpecialHit + // 5-0 Offset=28 // ---------------------- int byte_size = sizeof(Hit::Value) + hits_in.size() - 1; @@ -294,7 +475,7 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { ICING_ASSERT_OK_AND_ASSIGN( uint32_t num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hits_in[0], hits_in.size(), false))); + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); EXPECT_THAT(num_could_fit, Eq(hits_in.size())); EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); std::deque<Hit> hits_pushed; @@ -316,31 +497,35 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); hits_in.emplace_back( CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); std::reverse(hits_in.begin(), hits_in.end()); - // Size increased by the deltas of these hits (1+2+1+2+3+2) = 11 bytes + // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes // ---------------------- - // 29 delta(Hit #1) - // 28 delta(Hit #2) - // 27 delta(Hit #3) - // 26 delta(Hit #4) - // 25 delta(Hit #5) - // 24-23 delta(Hit #6) - // 22 delta(Hit #7) - // 21-20 delta(Hit #8) - // 19-17 delta(Hit #9) - // 16-15 delta(Hit #10) - // 14-11 Hit #11 - // 10 <unused> - // 9-5 kSpecialHit - // 4-0 Offset=11 + // 35 delta(Hit #0) + // 34 delta(Hit #1) + // 33 delta(Hit #2) + // 32 delta(Hit #3) + // 31 delta(Hit #4) + // 30-29 delta(Hit #5) + // 28 delta(Hit #6) + // 27-26 delta(Hit #7) + // 25-23 delta(Hit #8) + // 22-21 delta(Hit #9) + // 20-18 delta(Hit #10) + // 17-14 Hit #11 + // 13-12 <unused> + // 11-6 kSpecialHit + // 5-0 Offset=14 // ---------------------- - byte_size += 11; + byte_size += 14; - // Add these 6 hits. The PL is currently in the NOT_FULL state and should + // Add these 7 hits. The PL is currently in the NOT_FULL state and should // remain in the NOT_FULL state. ICING_ASSERT_OK_AND_ASSIGN( - num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hits_in[0], hits_in.size(), false))); + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); EXPECT_THAT(num_could_fit, Eq(hits_in.size())); EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); // All hits from hits_in were added. @@ -353,29 +538,32 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { hits_in.clear(); hits_in.emplace_back(first_hit); // ---------------------- - // 29 delta(Hit #1) - // 28 delta(Hit #2) - // 27 delta(Hit #3) - // 26 delta(Hit #4) - // 25 delta(Hit #5) - // 24-23 delta(Hit #6) - // 22 delta(Hit #7) - // 21-20 delta(Hit #8) - // 19-17 delta(Hit #9) - // 16-15 delta(Hit #10) - // 14-12 delta(Hit #11) - // 11-10 <unused> - // 9-5 Hit #12 - // 4-0 kSpecialHit + // 35 delta(Hit #0) + // 34 delta(Hit #1) + // 33 delta(Hit #2) + // 32 delta(Hit #3) + // 31 delta(Hit #4) + // 30-29 delta(Hit #5) + // 28 delta(Hit #6) + // 27-26 delta(Hit #7) + // 25-23 delta(Hit #8) + // 22-21 delta(Hit #9) + // 20-18 delta(Hit #10) + // 17-15 delta(Hit #11) + // 14-12 <unused> + // 11-6 Hit #12 + // 5-0 kSpecialHit // ---------------------- - byte_size = 25; + byte_size = 30; // 36 - 6 // Add this 1 hit. The PL is currently in the NOT_FULL state and should // transition to the ALMOST_FULL state - even though there is still some - // unused space. + // unused space. This is because the unused space (3 bytes) is less than + // the size of a uncompressed Hit. ICING_ASSERT_OK_AND_ASSIGN( - num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hits_in[0], hits_in.size(), false))); + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); EXPECT_THAT(num_could_fit, Eq(hits_in.size())); EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); // All hits from hits_in were added. @@ -388,37 +576,40 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayPostingList) { hits_in.clear(); hits_in.emplace_back(first_hit); hits_in.emplace_back( - CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + CreateHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); std::reverse(hits_in.begin(), hits_in.end()); // ---------------------- - // 29 delta(Hit #1) - // 28 delta(Hit #2) - // 27 delta(Hit #3) - // 26 delta(Hit #4) - // 25 delta(Hit #5) - // 24-23 delta(Hit #6) - // 22 delta(Hit #7) - // 21-20 delta(Hit #8) - // 19-17 delta(Hit #9) - // 16-15 delta(Hit #10) - // 14-12 delta(Hit #11) - // 11 delta(Hit #12) - // 10 <unused> - // 9-5 Hit #13 - // 4-0 Hit #14 + // 35 delta(Hit #0) + // 34 delta(Hit #1) + // 33 delta(Hit #2) + // 32 delta(Hit #3) + // 31 delta(Hit #4) + // 30-29 delta(Hit #5) + // 28 delta(Hit #6) + // 27-26 delta(Hit #7) + // 25-23 delta(Hit #8) + // 22-21 delta(Hit #9) + // 20-18 delta(Hit #10) + // 17-15 delta(Hit #11) + // 14 delta(Hit #12) + // 13-12 <unused> + // 11-6 Hit #13 + // 5-0 Hit #14 // ---------------------- - // Add these 2 hits. The PL is currently in the ALMOST_FULL state. Adding the - // first hit should keep the PL in ALMOST_FULL because the delta between Hit - // #12 and Hit #13 (1 byte) can fit in the unused area (2 bytes). Adding the - // second hit should tranisition to the FULL state because the delta between - // Hit #13 and Hit #14 (2 bytes) is larger than the remaining unused area - // (1 byte). + // Add these 2 hits. + // - The PL is currently in the ALMOST_FULL state. Adding the first hit should + // keep the PL in ALMOST_FULL because the delta between + // Hit #13 and Hit #14 (1 byte) can fit in the unused area (3 bytes). + // - Adding the second hit should transition to the FULL state because the + // delta between Hit #14 and Hit #15 (3 bytes) is larger than the remaining + // unused area (2 byte). ICING_ASSERT_OK_AND_ASSIGN( - num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hits_in[0], hits_in.size(), false))); + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); EXPECT_THAT(num_could_fit, Eq(hits_in.size())); - EXPECT_THAT(size, Eq(serializer.GetBytesUsed(&pl_used))); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); // All hits from hits_in were added. std::transform(hits_in.rbegin(), hits_in.rend(), std::front_inserter(hits_pushed), HitElt::get_hit); @@ -431,9 +622,9 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) { static constexpr int kNumHits = 128; static constexpr int kDeltaSize = 1; - static constexpr int kTermFrequencySize = 1; static constexpr size_t kHitsSize = - ((kNumHits * (kDeltaSize + kTermFrequencySize)) / 5) * 5; + ((kNumHits - 2) * kDeltaSize + (2 * sizeof(Hit))) / sizeof(Hit) * + sizeof(Hit); // Create an array with one too many hits std::vector<Hit> hits_in_too_many = @@ -442,18 +633,21 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) { for (const Hit &hit : hits_in_too_many) { hit_elts_in_too_many.emplace_back(hit); } + // Reverse so that hits are inserted in descending order + std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end()); + ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, PostingListUsed::CreateFromUnitializedRegion( &serializer, serializer.GetMinPostingListSize())); - // PrependHitArray should fail because hit_elts_in_too_many is far too large // for the minimum size pl. ICING_ASSERT_OK_AND_ASSIGN( uint32_t num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), - false))); + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(2)); ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size())); ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); @@ -464,10 +658,11 @@ TEST(PostingListHitSerializerTest, PostingListPrependHitArrayTooManyHits) { // PrependHitArray should fail because hit_elts_in_too_many is one hit too // large for this pl. ICING_ASSERT_OK_AND_ASSIGN( - num_could_fit, (serializer.PrependHitArray<HitElt, HitElt::get_hit>( - &pl_used, &hit_elts_in_too_many[0], - hit_elts_in_too_many.size(), false))); - ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size())); + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1)); ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); } @@ -476,16 +671,37 @@ TEST(PostingListHitSerializerTest, PostingListStatusJumpFromNotFullToFullAndBack) { PostingListHitSerializer serializer; + // Size = 18 const uint32_t pl_size = 3 * sizeof(Hit); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); - ICING_ASSERT_OK(serializer.PrependHit(&pl, Hit(Hit::kInvalidValue - 1, 0))); + + Hit max_valued_hit(kMaxSectionId, kMinDocumentId, Hit::kMaxTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true); + ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit)); uint32_t bytes_used = serializer.GetBytesUsed(&pl); + ASSERT_THAT(bytes_used, sizeof(Hit::Value) + sizeof(Hit::Flags) + + sizeof(Hit::TermFrequency)); // Status not full. ASSERT_THAT(bytes_used, Le(pl_size - PostingListHitSerializer::kSpecialHitsSize)); - ICING_ASSERT_OK(serializer.PrependHit(&pl, Hit(Hit::kInvalidValue >> 2, 0))); + + Hit min_valued_hit(kMinSectionId, kMaxDocumentId, Hit::kMaxTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = PostingListHitSerializer::EncodeNextHitValue( + /*next_hit_value=*/min_valued_hit.value(), + /*curr_hit_value=*/max_valued_hit.value(), delta_buf); + // The compressed region available is pl_size - 2 * sizeof(specialHits) = 6 + // We need to also fit max_valued_hit's flags and term-frequency fields, which + // each take 1 byte + // So we'll jump directly to FULL if the varint-encoded delta of the 2 hits > + // 6 - sizeof(Hit::Flags) - sizeof(Hit::TermFrequency) = 4 + ASSERT_THAT(delta_len, Gt(4)); + ICING_ASSERT_OK(serializer.PrependHit( + &pl, Hit(kMinSectionId, kMaxDocumentId, Hit::kMaxTermFrequency, + /*is_in_prefix_section=*/true, /*is_prefix_hit=*/true))); // Status should jump to full directly. ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size)); ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1)); @@ -501,11 +717,12 @@ TEST(PostingListHitSerializerTest, DeltaOverflow) { PostingListUsed pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + static const Hit::Value kMaxHitValue = std::numeric_limits<Hit::Value>::max(); static const Hit::Value kOverflow[4] = { - Hit::kInvalidValue >> 2, - (Hit::kInvalidValue >> 2) * 2, - (Hit::kInvalidValue >> 2) * 3, - Hit::kInvalidValue - 1, + kMaxHitValue >> 2, + (kMaxHitValue >> 2) * 2, + (kMaxHitValue >> 2) * 3, + kMaxHitValue - 1, }; // Fit at least 4 ordinary values. @@ -516,22 +733,245 @@ TEST(PostingListHitSerializerTest, DeltaOverflow) { // Cannot fit 4 overflow values. ICING_ASSERT_OK_AND_ASSIGN( pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); - ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[3]))); - ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[2]))); + Hit::Flags has_term_frequency_flags = 0b1; + ICING_EXPECT_OK(serializer.PrependHit( + &pl, Hit(/*value=*/kOverflow[3], has_term_frequency_flags, + /*term_frequency=*/8))); + ICING_EXPECT_OK(serializer.PrependHit( + &pl, Hit(/*value=*/kOverflow[2], has_term_frequency_flags, + /*term_frequency=*/8))); // Can fit only one more. - ICING_EXPECT_OK(serializer.PrependHit(&pl, Hit(kOverflow[1]))); - EXPECT_THAT(serializer.PrependHit(&pl, Hit(kOverflow[0])), + ICING_EXPECT_OK(serializer.PrependHit( + &pl, Hit(/*value=*/kOverflow[1], has_term_frequency_flags, + /*term_frequency=*/8))); + EXPECT_THAT(serializer.PrependHit( + &pl, Hit(/*value=*/kOverflow[0], has_term_frequency_flags, + /*term_frequency=*/8)), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListHitSerializerTest, GetMinPostingListToFitForNotFullPL) { + PostingListHitSerializer serializer; + + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used NOT_FULL + std::vector<Hit> hits_in = + CreateHits(/*num_hits=*/7, /*desired_byte_length=*/1); + for (const Hit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23 delta(Hit #0) + // 22 delta(Hit #1) + // 21 delta(Hit #2) + // 20 delta(Hit #3) + // 19 delta(Hit #4) + // 18 delta(Hit #5) + // 17-14 Hit #6 + // 13-12 <unused> + // 11-6 kSpecialHit + // 5-0 Offset=14 + // ---------------------- + int bytes_used = 10; + + // Check that all hits have been inserted + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used + // into a posting list with this min size should make it ALMOST_FULL, which we + // can see should have size = 18. + // ---------------------- + // 17 delta(Hit #0) + // 16 delta(Hit #1) + // 15 delta(Hit #2) + // 14 delta(Hit #3) + // 13 delta(Hit #4) + // 12 delta(Hit #5) + // 11-6 Hit #6 + // 5-0 kSpecialHit + // ---------------------- + int expected_min_size = 18; + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(expected_min_size)); + + // Also check that this min size to fit posting list actually does fit all the + // hits and can only hit one more hit in the ALMOST_FULL state. + ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, min_size_to_fit)); + for (const Hit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + } + + // Adding another hit to the min-size-to-fit posting list should succeed + Hit hit = CreateHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + // Adding any other hits should fail with RESOURCE_EXHAUSTED error. + EXPECT_THAT(serializer.PrependHit(&min_size_to_fit_pl, + CreateHit(hit, /*desired_byte_length=*/1)), StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Check that all hits have been inserted and the min-fit posting list is now + // FULL. + EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl), + Eq(min_size_to_fit)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListHitSerializerTest, GetMinPostingListToFitForTwoHits) { + PostingListHitSerializer serializer; + + // Size = 36 + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Create and add 2 hits + Hit first_hit(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); + std::vector<Hit> hits_in = + CreateHits(first_hit, /*num_hits=*/2, /*desired_byte_length=*/4); + for (const Hit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 35 term-frequency(Hit #0) + // 34 flags(Hit #0) + // 33-30 delta(Hit #0) + // 29 term-frequency(Hit #1) + // 28 flags(Hit #1) + // 27-24 Hit #1 + // 23-12 <unused> + // 11-6 kSpecialHit + // 5-0 Offset=24 + // ---------------------- + int bytes_used = 12; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should return min posting list size. + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), + Eq(serializer.GetMinPostingListSize())); +} + +TEST(PostingListHitSerializerTest, GetMinPostingListToFitForThreeSmallHits) { + PostingListHitSerializer serializer; + + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add 3 small hits that fit in the size range where we should be + // checking for whether the PL has only 2 hits + std::vector<Hit> hits_in = + CreateHits(/*num_hits=*/3, /*desired_byte_length=*/1); + for (const Hit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23 delta(Hit #0) + // 22 delta(Hit #1) + // 21-18 Hit #2 + // 17-12 <unused> + // 11-6 kSpecialHit + // 5-0 Offset=18 + // ---------------------- + int bytes_used = 6; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used + // into a posting list with this min size should make it ALMOST_FULL, which we + // can see should have size = 14. This should not return the min posting list + // size. + // ---------------------- + // 13 delta(Hit #0) + // 12 delta(Hit #1) + // 11-6 Hit #2 + // 5-0 kSpecialHit + // ---------------------- + int expected_min_size = 14; + + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), + Gt(serializer.GetMinPostingListSize())); + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), + Eq(expected_min_size)); +} + +TEST(PostingListHitSerializerTest, + GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) { + PostingListHitSerializer serializer; + + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used ALMOST_FULL + std::vector<Hit> hits_in = + CreateHits(/*num_hits=*/7, /*desired_byte_length=*/2); + for (const Hit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23-22 delta(Hit #0) + // 21-20 delta(Hit #1) + // 19-18 delta(Hit #2) + // 17-16 delta(Hit #3) + // 15-14 delta(Hit #4) + // 13-12 delta(Hit #5) + // 11-6 Hit #6 + // 5-0 kSpecialHit + // ---------------------- + int bytes_used = 18; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<Hit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should return the same size as pl_used. + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); + + // Add another hit to make the posting list FULL + Hit hit = CreateHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should still be the same size as pl_used. + min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); } TEST(PostingListHitSerializerTest, MoveFrom) { PostingListHitSerializer serializer; - int size = 3 * serializer.GetMinPostingListSize(); + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used1, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit &hit : hits1) { @@ -540,7 +980,7 @@ TEST(PostingListHitSerializerTest, MoveFrom) { ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used2, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits2 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2); for (const Hit &hit : hits2) { @@ -556,10 +996,10 @@ TEST(PostingListHitSerializerTest, MoveFrom) { TEST(PostingListHitSerializerTest, MoveFromNullArgumentReturnsInvalidArgument) { PostingListHitSerializer serializer; - int size = 3 * serializer.GetMinPostingListSize(); + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used1, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit &hit : hits) { ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); @@ -575,10 +1015,10 @@ TEST(PostingListHitSerializerTest, MoveFromInvalidPostingListReturnsInvalidArgument) { PostingListHitSerializer serializer; - int size = 3 * serializer.GetMinPostingListSize(); + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used1, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit &hit : hits1) { @@ -587,7 +1027,7 @@ TEST(PostingListHitSerializerTest, ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used2, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits2 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2); for (const Hit &hit : hits2) { @@ -595,7 +1035,7 @@ TEST(PostingListHitSerializerTest, } // Write invalid hits to the beginning of pl_used1 to make it invalid. - Hit invalid_hit; + Hit invalid_hit(Hit::kInvalidValue); Hit *first_hit = reinterpret_cast<Hit *>(pl_used1.posting_list_buffer()); *first_hit = invalid_hit; ++first_hit; @@ -610,10 +1050,10 @@ TEST(PostingListHitSerializerTest, MoveToInvalidPostingListReturnsFailedPrecondition) { PostingListHitSerializer serializer; - int size = 3 * serializer.GetMinPostingListSize(); + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used1, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit &hit : hits1) { @@ -622,7 +1062,7 @@ TEST(PostingListHitSerializerTest, ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used2, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits2 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/2); for (const Hit &hit : hits2) { @@ -630,7 +1070,7 @@ TEST(PostingListHitSerializerTest, } // Write invalid hits to the beginning of pl_used2 to make it invalid. - Hit invalid_hit; + Hit invalid_hit(Hit::kInvalidValue); Hit *first_hit = reinterpret_cast<Hit *>(pl_used2.posting_list_buffer()); *first_hit = invalid_hit; ++first_hit; @@ -644,10 +1084,10 @@ TEST(PostingListHitSerializerTest, TEST(PostingListHitSerializerTest, MoveToPostingListTooSmall) { PostingListHitSerializer serializer; - int size = 3 * serializer.GetMinPostingListSize(); + int pl_size = 3 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used1, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit &hit : hits1) { @@ -672,30 +1112,36 @@ TEST(PostingListHitSerializerTest, MoveToPostingListTooSmall) { IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); } -TEST(PostingListHitSerializerTest, PopHitsWithScores) { +TEST(PostingListHitSerializerTest, PopHitsWithTermFrequenciesAndFlags) { PostingListHitSerializer serializer; - int size = 2 * serializer.GetMinPostingListSize(); + // Size = 24 + int pl_size = 2 * serializer.GetMinPostingListSize(); ICING_ASSERT_OK_AND_ASSIGN( PostingListUsed pl_used, - PostingListUsed::CreateFromUnitializedRegion(&serializer, size)); + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); - // This posting list is 20-bytes. Create four hits that will have deltas of - // two bytes each and all of whom will have a non-default score. This posting - // list will be almost_full. + // This posting list is 24-bytes. Create four hits that will have deltas of + // two bytes each and all of whom will have a non-default term-frequency. This + // posting list will be almost_full. // // ---------------------- - // 19 score(Hit #0) - // 18-17 delta(Hit #0) - // 16 score(Hit #1) - // 15-14 delta(Hit #1) - // 13 score(Hit #2) - // 12-11 delta(Hit #2) - // 10 <unused> - // 9-5 Hit #3 - // 4-0 kInvalidHitVal + // 23 term-frequency(Hit #0) + // 22 flags(Hit #0) + // 21-20 delta(Hit #0) + // 19 term-frequency(Hit #1) + // 18 flags(Hit #1) + // 17-16 delta(Hit #1) + // 15 term-frequency(Hit #2) + // 14 flags(Hit #2) + // 13-12 delta(Hit #2) + // 11-6 Hit #3 + // 5-0 kInvalidHit // ---------------------- - Hit hit0(/*section_id=*/0, /*document_id=*/0, /*score=*/5); + int bytes_used = 18; + + Hit hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/5, + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); Hit hit1 = CreateHit(hit0, /*desired_byte_length=*/2); Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/2); Hit hit3 = CreateHit(hit2, /*desired_byte_length=*/2); @@ -707,22 +1153,26 @@ TEST(PostingListHitSerializerTest, PopHitsWithScores) { ICING_ASSERT_OK_AND_ASSIGN(std::vector<Hit> hits_out, serializer.GetHits(&pl_used)); EXPECT_THAT(hits_out, ElementsAre(hit3, hit2, hit1, hit0)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); // Now, pop the last hit. The posting list should contain the first three // hits. // // ---------------------- - // 19 score(Hit #0) - // 18-17 delta(Hit #0) - // 16 score(Hit #1) - // 15-14 delta(Hit #1) - // 13-10 <unused> - // 9-5 Hit #2 - // 4-0 kInvalidHitVal + // 23 term-frequency(Hit #0) + // 22 flags(Hit #0) + // 21-20 delta(Hit #0) + // 19 term-frequency(Hit #1) + // 18 flags(Hit #1) + // 17-16 delta(Hit #1) + // 15-12 <unused> + // 11-6 Hit #2 + // 5-0 kInvalidHit // ---------------------- ICING_ASSERT_OK(serializer.PopFrontHits(&pl_used, 1)); ICING_ASSERT_OK_AND_ASSIGN(hits_out, serializer.GetHits(&pl_used)); EXPECT_THAT(hits_out, ElementsAre(hit2, hit1, hit0)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); } } // namespace diff --git a/icing/query/advanced_query_parser/function.cc b/icing/query/advanced_query_parser/function.cc index e7938db..0ff160b 100644 --- a/icing/query/advanced_query_parser/function.cc +++ b/icing/query/advanced_query_parser/function.cc @@ -13,8 +13,15 @@ // limitations under the License. #include "icing/query/advanced_query_parser/function.h" +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/query/advanced_query_parser/param.h" +#include "icing/query/advanced_query_parser/pending-value.h" #include "icing/util/status-macros.h" namespace icing { @@ -73,5 +80,20 @@ libtextclassifier3::StatusOr<PendingValue> Function::Eval( return eval_(std::move(args)); } +libtextclassifier3::StatusOr<DataType> Function::get_param_type(int i) const { + if (i < 0 || params_.empty()) { + return absl_ports::OutOfRangeError("Invalid argument index."); + } + const Param* parm; + if (i < params_.size()) { + parm = ¶ms_.at(i); + } else if (params_.back().cardinality == Cardinality::kVariable) { + parm = ¶ms_.back(); + } else { + return absl_ports::OutOfRangeError("Invalid argument index."); + } + return parm->data_type; +} + } // namespace lib -} // namespace icing
\ No newline at end of file +} // namespace icing diff --git a/icing/query/advanced_query_parser/function.h b/icing/query/advanced_query_parser/function.h index 3514878..08cc7e8 100644 --- a/icing/query/advanced_query_parser/function.h +++ b/icing/query/advanced_query_parser/function.h @@ -46,6 +46,8 @@ class Function { libtextclassifier3::StatusOr<PendingValue> Eval( std::vector<PendingValue>&& args) const; + libtextclassifier3::StatusOr<DataType> get_param_type(int i) const; + private: Function(DataType return_type, std::string name, std::vector<Param> params, EvalFunction eval) diff --git a/icing/query/advanced_query_parser/param.h b/icing/query/advanced_query_parser/param.h index 69c46be..9ea1915 100644 --- a/icing/query/advanced_query_parser/param.h +++ b/icing/query/advanced_query_parser/param.h @@ -35,13 +35,19 @@ struct Param { libtextclassifier3::Status Matches(PendingValue& arg) const { bool matches = arg.data_type() == data_type; - // Values of type kText could also potentially be valid kLong values. If - // we're expecting a kLong and we have a kText, try to parse it as a kLong. + // Values of type kText could also potentially be valid kLong or kDouble + // values. If we're expecting a kLong or kDouble and we have a kText, try to + // parse it as what we expect. if (!matches && data_type == DataType::kLong && arg.data_type() == DataType::kText) { ICING_RETURN_IF_ERROR(arg.ParseInt()); matches = true; } + if (!matches && data_type == DataType::kDouble && + arg.data_type() == DataType::kText) { + ICING_RETURN_IF_ERROR(arg.ParseDouble()); + matches = true; + } return matches ? libtextclassifier3::Status::OK : absl_ports::InvalidArgumentError( "Provided arg doesn't match required param type."); diff --git a/icing/query/advanced_query_parser/pending-value.cc b/icing/query/advanced_query_parser/pending-value.cc index 67bdc3a..a3f95d9 100644 --- a/icing/query/advanced_query_parser/pending-value.cc +++ b/icing/query/advanced_query_parser/pending-value.cc @@ -13,7 +13,11 @@ // limitations under the License. #include "icing/query/advanced_query_parser/pending-value.h" +#include <cstdlib> + +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" namespace icing { namespace lib { @@ -40,5 +44,27 @@ libtextclassifier3::Status PendingValue::ParseInt() { return libtextclassifier3::Status::OK; } +libtextclassifier3::Status PendingValue::ParseDouble() { + if (data_type_ == DataType::kDouble) { + return libtextclassifier3::Status::OK; + } else if (data_type_ != DataType::kText) { + return absl_ports::InvalidArgumentError("Cannot parse value as double"); + } + if (query_term_.is_prefix_val) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Cannot use prefix operator '*' with numeric value: ", + query_term_.term)); + } + char* value_end; + double_val_ = std::strtod(query_term_.term.c_str(), &value_end); + if (value_end != query_term_.term.c_str() + query_term_.term.length()) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Unable to parse \"", query_term_.term, "\" as double.")); + } + data_type_ = DataType::kDouble; + query_term_ = {/*term=*/"", /*raw_term=*/"", /*is_prefix_val=*/false}; + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/query/advanced_query_parser/pending-value.h b/icing/query/advanced_query_parser/pending-value.h index 1a6717e..34912f3 100644 --- a/icing/query/advanced_query_parser/pending-value.h +++ b/icing/query/advanced_query_parser/pending-value.h @@ -14,12 +14,16 @@ #ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_ #define ICING_QUERY_ADVANCED_QUERY_PARSER_PENDING_VALUE_H_ +#include <cstdint> #include <memory> #include <string> +#include <string_view> #include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/util/status-macros.h" @@ -30,10 +34,14 @@ namespace lib { enum class DataType { kNone, kLong, + kDouble, kText, kString, kStringList, kDocumentIterator, + // TODO(b/326656531): Instead of creating a vector index type, consider + // changing it to vector type so that the data is the vector directly. + kVectorIndex, }; struct QueryTerm { @@ -52,6 +60,10 @@ struct PendingValue { return PendingValue(std::move(text), DataType::kText); } + static PendingValue CreateVectorIndexPendingValue(int64_t vector_index) { + return PendingValue(vector_index, DataType::kVectorIndex); + } + PendingValue() : data_type_(DataType::kNone) {} explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator) @@ -111,6 +123,16 @@ struct PendingValue { return long_val_; } + libtextclassifier3::StatusOr<double> double_val() { + ICING_RETURN_IF_ERROR(ParseDouble()); + return double_val_; + } + + libtextclassifier3::StatusOr<int64_t> vector_index_val() const { + ICING_RETURN_IF_ERROR(CheckDataType(DataType::kVectorIndex)); + return long_val_; + } + // Attempts to interpret the value as an int. A pending value can be parsed as // an int under two circumstances: // 1. It holds a kText value which can be parsed to an int @@ -122,12 +144,26 @@ struct PendingValue { // - INVALID_ARGUMENT if the value could not be parsed as a long libtextclassifier3::Status ParseInt(); + // Attempts to interpret the value as a double. A pending value can be parsed + // as a double under two circumstances: + // 1. It holds a kText value which can be parsed to a double + // 2. It holds a kDouble value + // If #1 is true, then the parsed value will be stored in double_val_ and + // data_type will be updated to kDouble. + // RETURNS: + // - OK, if able to successfully parse the value into a double + // - INVALID_ARGUMENT if the value could not be parsed as a double + libtextclassifier3::Status ParseDouble(); + DataType data_type() const { return data_type_; } private: explicit PendingValue(QueryTerm query_term, DataType data_type) : query_term_(std::move(query_term)), data_type_(data_type) {} + explicit PendingValue(int64_t long_val, DataType data_type) + : long_val_(long_val), data_type_(data_type) {} + libtextclassifier3::Status CheckDataType(DataType required_data_type) const { if (data_type_ == required_data_type) { return libtextclassifier3::Status::OK; @@ -151,6 +187,9 @@ struct PendingValue { // long_val_ will be populated when data_type_ is kLong - after a successful // call to ParseInt. int64_t long_val_; + // double_val_ will be populated when data_type_ is kDouble - after a + // successful call to ParseDouble. + double double_val_; DataType data_type_; }; diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc index 31da959..1ac52c5 100644 --- a/icing/query/advanced_query_parser/query-visitor.cc +++ b/icing/query/advanced_query_parser/query-visitor.cc @@ -16,20 +16,26 @@ #include <algorithm> #include <cstdint> -#include <cstdlib> #include <iterator> #include <limits> #include <memory> #include <set> #include <string> +#include <string_view> +#include <unordered_map> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" +#include "icing/index/embed/doc-hit-info-iterator-embedding.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h" #include "icing/index/iterator/doc-hit-info-iterator-and.h" +#include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator-none.h" #include "icing/index/iterator/doc-hit-info-iterator-not.h" #include "icing/index/iterator/doc-hit-info-iterator-or.h" @@ -37,17 +43,23 @@ #include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h" #include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/iterator/section-restrict-data.h" #include "icing/index/property-existence-indexing-handler.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/function.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/param.h" #include "icing/query/advanced_query_parser/parser.h" #include "icing/query/advanced_query_parser/pending-value.h" #include "icing/query/advanced_query_parser/util/string-util.h" #include "icing/query/query-features.h" +#include "icing/query/query-results.h" #include "icing/schema/property-util.h" +#include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer.h" +#include "icing/util/embedding-util.h" #include "icing/util/status-macros.h" namespace icing { @@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() { .ValueOrDie(); registered_functions_.insert( {has_property_function.name(), std::move(has_property_function)}); + + // vector_index getSearchSpecEmbedding(long); + auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) { + return PendingValue::CreateVectorIndexPendingValue( + args.at(0).long_val().ValueOrDie()); + }; + Function get_search_spec_embedding_function = + Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding", + {Param(DataType::kLong)}, + std::move(get_search_spec_embedding)) + .ValueOrDie(); + registered_functions_.insert({get_search_spec_embedding_function.name(), + std::move(get_search_spec_embedding_function)}); + + // DocHitInfoIterator semanticSearch(vector_index, double, double, string); + auto semantic_search = [this](std::vector<PendingValue>&& args) { + return this->SemanticSearchFunction(std::move(args)); + }; + Function semantic_search_function = + Function::Create(DataType::kDocumentIterator, "semanticSearch", + {Param(DataType::kVectorIndex), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kString, Cardinality::kOptional)}, + std::move(semantic_search)) + .ValueOrDie(); + registered_functions_.insert( + {semantic_search_function.name(), std::move(semantic_search_function)}); } libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( @@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( document_store_.last_added_document_id()); } else { QueryVisitor query_visitor( - &index_, &numeric_index_, &document_store_, &schema_store_, - &normalizer_, &tokenizer_, query->raw_term, filter_options_, - match_type_, needs_term_frequency_info_, pending_property_restricts_, - processing_not_, current_time_ms_); + &index_, &numeric_index_, &embedding_index_, &document_store_, + &schema_store_, &normalizer_, &tokenizer_, query->raw_term, + embedding_query_vectors_, filter_options_, match_type_, + embedding_query_metric_type_, needs_term_frequency_info_, + pending_property_restricts_, processing_not_, current_time_ms_); tree_root->Accept(&query_visitor); ICING_ASSIGN_OR_RETURN(query_result, std::move(query_visitor).ConsumeResults()); @@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction( return PendingValue(std::move(property_in_document_iterator)); } +libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction( + std::vector<PendingValue>&& args) { + features_.insert(kEmbeddingSearchFeature); + + int64_t vector_index = args.at(0).vector_index_val().ValueOrDie(); + if (embedding_query_vectors_ == nullptr || vector_index < 0 || + vector_index >= embedding_query_vectors_->size()) { + return absl_ports::InvalidArgumentError("Got invalid vector search index!"); + } + + // Handle default values for the optional arguments. + double low = -std::numeric_limits<double>::infinity(); + double high = std::numeric_limits<double>::infinity(); + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + embedding_query_metric_type_; + if (args.size() >= 2) { + low = args.at(1).double_val().ValueOrDie(); + } + if (args.size() >= 3) { + high = args.at(2).double_val().ValueOrDie(); + } + if (args.size() >= 4) { + const std::string& metric = args.at(3).string_val().ValueOrDie()->term; + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + + // Create SectionRestrictData for section restriction. + std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr; + if (pending_property_restricts_.has_active_property_restricts()) { + std::unordered_map<std::string, std::set<std::string>> + type_property_filters; + type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] = + pending_property_restricts_.active_property_restricts(); + section_restrict_data = std::make_unique<SectionRestrictData>( + &document_store_, &schema_store_, current_time_ms_, + type_property_filters); + } + + // Create and return iterator. + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map = + &embedding_query_results_.result_scores[vector_index][metric_type]; + ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator, + DocHitInfoIteratorEmbedding::Create( + &embedding_query_vectors_->at(vector_index), + std::move(section_restrict_data), metric_type, low, + high, score_map, &embedding_index_)); + return PendingValue(std::move(iterator)); +} + libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() { if (pending_values_.empty()) { return absl_ports::InvalidArgumentError("Unable to retrieve int value."); @@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() { // raw_text, then all of raw_text must correspond to this token. raw_token = raw_text; } else { - ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken( - raw_text, token.text)); + ICING_ASSIGN_OR_RETURN( + raw_token, string_util::FindEscapedToken(raw_text, token.text)); } normalized_term = normalizer_.NormalizeTerm(token.text); QueryTerm term_value{std::move(normalized_term), raw_token, @@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator( "Visit unary operator child didn't correctly add pending values."); } - // 3. We want to preserve the original text of the integer value, append our - // minus and *then* parse as an int. - ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue()); - int_text_val.term = absl_ports::StrCat("-", int_text_val.term); + // 3. We want to preserve the original text of the numeric value, append our + // minus to the text. It will be parsed as either an int or a double later. + ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue()); + numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term); PendingValue pending_value = - PendingValue::CreateTextPendingValue(std::move(int_text_val)); - ICING_RETURN_IF_ERROR(pending_value.long_val()); + PendingValue::CreateTextPendingValue(std::move(numeric_text_val)); - // We've parsed our integer value successfully. Pop our placeholder, push it + // We've parsed our numeric value successfully. Pop our placeholder, push it // on to the stack and return successfully. if (!pending_values_.top().is_placeholder()) { return absl_ports::InvalidArgumentError( @@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) { end = text_val.raw_term.data() + text_val.raw_term.length(); } else { start = std::min(start, text_val.raw_term.data()); - end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length()); + end = std::max(end, + text_val.raw_term.data() + text_val.raw_term.length()); } members.push_back(std::move(text_val.term)); } @@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { "Function ", node->function_name()->value(), " is not supported.")); return; } + const Function& function = itr->second; // 2. Put in a placeholder PendingValue pending_values_.push(PendingValue()); // 3. Visit the children. - for (const std::unique_ptr<Node>& arg : node->args()) { + expecting_numeric_arg_ = true; + for (int i = 0; i < node->args().size(); ++i) { + const std::unique_ptr<Node>& arg = node->args()[i]; + libtextclassifier3::StatusOr<DataType> arg_type_or = + function.get_param_type(i); + bool current_level_expecting_numeric_arg = expecting_numeric_arg_; + // If arg_type_or has an error, we should ignore it for now, since + // function.Eval should do the type check and return better error messages. + if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong || + arg_type_or.ValueOrDie() == DataType::kDouble)) { + expecting_numeric_arg_ = true; + } arg->Accept(this); + expecting_numeric_arg_ = current_level_expecting_numeric_arg; if (has_pending_error()) { return; } @@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { pending_values_.pop(); } std::reverse(args.begin(), args.end()); - const Function& function = itr->second; auto eval_result = function.Eval(std::move(args)); if (!eval_result.ok()) { pending_error_ = std::move(eval_result).status(); @@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && { results.root_iterator = std::move(iterator_or).ValueOrDie(); results.query_term_iterators = std::move(query_term_iterators_); results.query_terms = std::move(property_query_terms_map_); + results.embedding_query_results = std::move(embedding_query_results_); results.features_in_use = std::move(features_); return results; } diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h index d090b3c..17149f5 100644 --- a/icing/query/advanced_query_parser/query-visitor.h +++ b/icing/query/advanced_query_parser/query-visitor.h @@ -17,13 +17,19 @@ #include <cstdint> #include <memory> +#include <set> #include <stack> #include <string> +#include <string_view> +#include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -33,10 +39,12 @@ #include "icing/query/advanced_query_parser/pending-value.h" #include "icing/query/query-features.h" #include "icing/query/query-results.h" +#include "icing/query/query-terms.h" #include "icing/schema/schema-store.h" #include "icing/store/document-store.h" #include "icing/tokenization/tokenizer.h" #include "icing/transform/normalizer.h" +#include <google/protobuf/repeated_field.h> namespace icing { namespace lib { @@ -45,19 +53,23 @@ namespace lib { // the parser. class QueryVisitor : public AbstractSyntaxTreeVisitor { public: - explicit QueryVisitor(Index* index, - const NumericIndex<int64_t>* numeric_index, - const DocumentStore* document_store, - const SchemaStore* schema_store, - const Normalizer* normalizer, - const Tokenizer* tokenizer, - std::string_view raw_query_text, - DocHitInfoIteratorFilter::Options filter_options, - TermMatchType::Code match_type, - bool needs_term_frequency_info, int64_t current_time_ms) - : QueryVisitor(index, numeric_index, document_store, schema_store, - normalizer, tokenizer, raw_query_text, filter_options, - match_type, needs_term_frequency_info, + explicit QueryVisitor( + Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, + const DocumentStore* document_store, const SchemaStore* schema_store, + const Normalizer* normalizer, const Tokenizer* tokenizer, + std::string_view raw_query_text, + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors, + DocHitInfoIteratorFilter::Options filter_options, + TermMatchType::Code match_type, + SearchSpecProto::EmbeddingQueryMetricType::Code + embedding_query_metric_type, + bool needs_term_frequency_info, int64_t current_time_ms) + : QueryVisitor(index, numeric_index, embedding_index, document_store, + schema_store, normalizer, tokenizer, raw_query_text, + embedding_query_vectors, filter_options, match_type, + embedding_query_metric_type, needs_term_frequency_info, PendingPropertyRestricts(), /*processing_not=*/false, current_time_ms) {} @@ -106,22 +118,31 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { explicit QueryVisitor( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const DocumentStore* document_store, const SchemaStore* schema_store, const Normalizer* normalizer, const Tokenizer* tokenizer, std::string_view raw_query_text, + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors, DocHitInfoIteratorFilter::Options filter_options, - TermMatchType::Code match_type, bool needs_term_frequency_info, + TermMatchType::Code match_type, + SearchSpecProto::EmbeddingQueryMetricType::Code + embedding_query_metric_type, + bool needs_term_frequency_info, PendingPropertyRestricts pending_property_restricts, bool processing_not, int64_t current_time_ms) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), document_store_(*document_store), schema_store_(*schema_store), normalizer_(*normalizer), tokenizer_(*tokenizer), raw_query_text_(raw_query_text), + embedding_query_vectors_(embedding_query_vectors), filter_options_(std::move(filter_options)), match_type_(match_type), + embedding_query_metric_type_(embedding_query_metric_type), needs_term_frequency_info_(needs_term_frequency_info), pending_property_restricts_(std::move(pending_property_restricts)), processing_not_(processing_not), @@ -264,6 +285,22 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { libtextclassifier3::StatusOr<PendingValue> HasPropertyFunction( std::vector<PendingValue>&& args); + // Implementation of the semanticSearch(vector, low, high, metric) custom + // function. This function is used for supporting vector search with a + // syntax like `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`. + // + // low, high, metric are optional parameters: + // - low is default to negative infinity + // - high is default to positive infinity + // - metric is default to the metric specified in SearchSpec + // + // Returns: + // - a Pending Value of type DocHitIterator that returns all documents with + // an embedding vector that has a score within [low, high]. + // - any errors returned by Lexer::ExtractTokens + libtextclassifier3::StatusOr<PendingValue> SemanticSearchFunction( + std::vector<PendingValue>&& args); + // Handles a NaryOperatorNode where the operator is HAS (':') and pushes an // iterator with the proper section filter applied. If the current property // restriction represented by pending_property_restricts and the first child @@ -292,19 +329,26 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { SectionRestrictQueryTermsMap property_query_terms_map_; QueryTermIteratorsMap query_term_iterators_; + + EmbeddingQueryResults embedding_query_results_; + // Set of features invoked in the query. std::unordered_set<Feature> features_; Index& index_; // Does not own! const NumericIndex<int64_t>& numeric_index_; // Does not own! + const EmbeddingIndex& embedding_index_; // Does not own! const DocumentStore& document_store_; // Does not own! const SchemaStore& schema_store_; // Does not own! const Normalizer& normalizer_; // Does not own! const Tokenizer& tokenizer_; // Does not own! std::string_view raw_query_text_; + const google::protobuf::RepeatedPtrField<PropertyProto::VectorProto>* + embedding_query_vectors_; // Nullable, does not own! DocHitInfoIteratorFilter::Options filter_options_; TermMatchType::Code match_type_; + SearchSpecProto::EmbeddingQueryMetricType::Code embedding_query_metric_type_; // Whether or not term_frequency information is needed. This affects: // - how DocHitInfoIteratorTerms are constructed // - whether the QueryTermIteratorsMap is populated in the QueryResults. diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc index 9455baa..c5ba866 100644 --- a/icing/query/advanced_query_parser/query-visitor_test.cc +++ b/icing/query/advanced_query_parser/query-visitor_test.cc @@ -15,6 +15,7 @@ #include "icing/query/advanced_query_parser/query-visitor.h" #include <cstdint> +#include <initializer_list> #include <limits> #include <memory> #include <string> @@ -31,6 +32,7 @@ #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/hit/hit.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-filter.h" @@ -42,6 +44,7 @@ #include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" +#include "icing/proto/search.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" @@ -54,6 +57,7 @@ #include "icing/store/document-store.h" #include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" +#include "icing/testing/embedding-test-utils.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" #include "icing/testing/test-data.h" @@ -67,16 +71,22 @@ #include "icing/util/clock.h" #include "icing/util/status-macros.h" #include "unicode/uloc.h" +#include <google/protobuf/repeated_field.h> namespace icing { namespace lib { namespace { +using ::testing::DoubleNear; using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Pointee; using ::testing::UnorderedElementsAre; +constexpr float kEps = 0.000001; + constexpr DocumentId kDocumentId0 = 0; constexpr DocumentId kDocumentId1 = 1; constexpr DocumentId kDocumentId2 = 2; @@ -85,6 +95,18 @@ constexpr SectionId kSectionId0 = 0; constexpr SectionId kSectionId1 = 1; constexpr SectionId kSectionId2 = 2; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_UNKNOWN = + SearchSpecProto::EmbeddingQueryMetricType::UNKNOWN; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_COSINE = SearchSpecProto::EmbeddingQueryMetricType::COSINE; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_DOT_PRODUCT = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + EMBEDDING_METRIC_EUCLIDEAN = + SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN; + template <typename T, typename U> std::vector<T> ExtractKeys(const std::unordered_map<T, U>& map) { std::vector<T> keys; @@ -106,6 +128,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { test_dir_ = GetTestTempDir() + "/icing"; index_dir_ = test_dir_ + "/index"; numeric_index_dir_ = test_dir_ + "/numeric_index"; + embedding_index_dir_ = test_dir_ + "/embedding_index"; store_dir_ = test_dir_ + "/store"; schema_store_dir_ = test_dir_ + "/schema_store"; filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -154,6 +177,10 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( /*max_term_byte_size=*/1000)); @@ -219,6 +246,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { std::string test_dir_; std::string index_dir_; std::string numeric_index_dir_; + std::string embedding_index_dir_; std::string schema_store_dir_; std::string store_dir_; Clock clock_; @@ -226,6 +254,7 @@ class QueryVisitorTest : public ::testing::TestWithParam<QueryType> { std::unique_ptr<DocumentStore> document_store_; std::unique_ptr<Index> index_; std::unique_ptr<DummyNumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<Normalizer> normalizer_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Tokenizer> tokenizer_; @@ -252,9 +281,11 @@ TEST_P(QueryVisitorTest, SimpleLessThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -295,9 +326,11 @@ TEST_P(QueryVisitorTest, SimpleLessThanEq) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -338,9 +371,11 @@ TEST_P(QueryVisitorTest, SimpleEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -381,9 +416,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThanEq) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -424,9 +461,11 @@ TEST_P(QueryVisitorTest, SimpleGreaterThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -468,9 +507,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -512,9 +553,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanEqual) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -557,9 +600,11 @@ TEST_P(QueryVisitorTest, NestedPropertyLessThan) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -585,9 +630,11 @@ TEST_P(QueryVisitorTest, IntParsingError) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -599,9 +646,11 @@ TEST_P(QueryVisitorTest, NotEqualsUnsupported) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -647,9 +696,11 @@ TEST_P(QueryVisitorTest, LessThanTooManyOperandsInvalid) { args.push_back(std::move(extra_value_node)); auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -674,9 +725,11 @@ TEST_P(QueryVisitorTest, LessThanTooFewOperandsInvalid) { args.push_back(std::move(member_node)); auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -705,9 +758,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -727,9 +782,11 @@ TEST_P(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { TEST_P(QueryVisitorTest, NeverVisitedReturnsInvalid) { QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), "", + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), "", /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); @@ -756,9 +813,11 @@ TEST_P(QueryVisitorTest, IntMinLessThanInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -786,9 +845,11 @@ TEST_P(QueryVisitorTest, IntMaxGreaterThanInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -801,9 +862,11 @@ TEST_P(QueryVisitorTest, NumericComparisonPropertyStringIsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -865,9 +928,11 @@ TEST_P(QueryVisitorTest, NumericComparatorDoesntAffectLaterTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -908,9 +973,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyEnabled) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -960,9 +1027,11 @@ TEST_P(QueryVisitorTest, SingleTermTermFrequencyDisabled) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/false, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1012,9 +1081,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1028,9 +1099,11 @@ TEST_P(QueryVisitorTest, SingleTermPrefix) { query = CreateQuery("fo*"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1048,9 +1121,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1062,9 +1137,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterNumericValueReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1076,9 +1153,11 @@ TEST_P(QueryVisitorTest, PrefixOperatorAfterPropertyRestrictReturnsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -1114,9 +1193,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1137,9 +1218,11 @@ TEST_P(QueryVisitorTest, SegmentationWithPrefix) { query = CreateQuery("ba?fo*"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1174,9 +1257,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1221,9 +1306,11 @@ TEST_P(QueryVisitorTest, SingleVerbatimTermPrefix) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_EXACT, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1274,9 +1361,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingQuote) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1326,9 +1415,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingEscape) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1380,9 +1471,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1407,9 +1500,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { query = CreateQuery(R"(("foobar\\y"))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1462,9 +1557,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1488,9 +1585,11 @@ TEST_P(QueryVisitorTest, VerbatimTermNewLine) { query = CreateQuery(R"(("foobar\\n"))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1537,9 +1636,11 @@ TEST_P(QueryVisitorTest, VerbatimTermEscapingComplex) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1596,9 +1697,11 @@ TEST_P(QueryVisitorTest, SingleMinusTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1650,9 +1753,11 @@ TEST_P(QueryVisitorTest, SingleNotTerm) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1705,9 +1810,11 @@ TEST_P(QueryVisitorTest, NestedNotTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1774,9 +1881,11 @@ TEST_P(QueryVisitorTest, DeeplyNestedNotTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1813,9 +1922,11 @@ TEST_P(QueryVisitorTest, ImplicitAndTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1856,9 +1967,11 @@ TEST_P(QueryVisitorTest, ExplicitAndTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1899,9 +2012,11 @@ TEST_P(QueryVisitorTest, OrTerms) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1944,9 +2059,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -1969,9 +2086,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { query = CreateQuery("bar OR baz foo"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -1993,9 +2112,11 @@ TEST_P(QueryVisitorTest, AndOrTermPrecedence) { query = CreateQuery("(bar OR baz) foo"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2055,9 +2176,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2075,9 +2198,11 @@ TEST_P(QueryVisitorTest, AndOrNotPrecedence) { query = CreateQuery("foo NOT (bar OR baz)"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2140,9 +2265,11 @@ TEST_P(QueryVisitorTest, PropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2216,9 +2343,11 @@ TEST_F(QueryVisitorTest, MultiPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2259,9 +2388,11 @@ TEST_P(QueryVisitorTest, PropertyFilterStringIsInvalid) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -2315,9 +2446,11 @@ TEST_P(QueryVisitorTest, PropertyFilterNonNormalized) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2386,9 +2519,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithGrouping) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2453,9 +2588,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2474,9 +2611,11 @@ TEST_P(QueryVisitorTest, ValidNestedPropertyFilter) { /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2540,9 +2679,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2561,9 +2702,11 @@ TEST_P(QueryVisitorTest, InvalidNestedPropertyFilter) { /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2627,9 +2770,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2648,9 +2793,11 @@ TEST_P(QueryVisitorTest, NotWithPropertyFilter) { "NOT ", CreateQuery("(foo OR bar)", /*property_restrict=*/"prop1")); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2728,9 +2875,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2755,9 +2904,11 @@ TEST_P(QueryVisitorTest, PropertyFilterWithNot) { query = CreateQuery("(NOT foo OR bar)", /*property_restrict=*/"prop1"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -2837,9 +2988,11 @@ TEST_P(QueryVisitorTest, SegmentationTest) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -2957,9 +3110,11 @@ TEST_P(QueryVisitorTest, PropertyRestrictsPopCorrectly) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3074,9 +3229,11 @@ TEST_P(QueryVisitorTest, UnsatisfiablePropertyRestrictsPopCorrectly) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3098,9 +3255,11 @@ TEST_F(QueryVisitorTest, UnsupportedFunctionReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3112,9 +3271,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooFewArgumentsReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3126,9 +3287,11 @@ TEST_F(QueryVisitorTest, SearchFunctionTooManyArgumentsReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3142,9 +3305,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3154,9 +3319,11 @@ TEST_F(QueryVisitorTest, query = R"(search(createList("subject")))"; ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(), @@ -3170,9 +3337,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3182,9 +3351,11 @@ TEST_F(QueryVisitorTest, query = R"(search("foo", 7))"; ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); EXPECT_THAT(std::move(query_visitor_two).ConsumeResults(), @@ -3197,9 +3368,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3260,9 +3433,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3284,9 +3459,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { R"(", createList("prop1")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3308,9 +3485,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedFunctionCalls) { R"(", createList("prop1")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_four_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_four_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_four_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3430,9 +3609,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_one_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3462,9 +3643,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { R"(", createList("prop6", "prop0", "prop4", "prop2")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3488,9 +3671,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsNarrowing) { R"(", createList("prop0", "prop6")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3610,9 +3795,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(level_one_query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_one_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_one_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3634,9 +3821,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { R"(", createList("prop6", "prop0", "prop4", "prop2")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_two_query)); QueryVisitor query_visitor_two( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), level_two_query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_two_query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_two); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3659,9 +3848,11 @@ TEST_F(QueryVisitorTest, SearchFunctionNestedPropertyRestrictsExpanding) { R"( "prop0", "prop6", "prop4", "prop7")))"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(level_three_query)); QueryVisitor query_visitor_three( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), - level_three_query, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), level_three_query, /*embedding_query_vectors=*/nullptr, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor_three); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -3685,9 +3876,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3701,9 +3894,11 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3717,9 +3912,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3732,9 +3929,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3786,9 +3985,11 @@ TEST_P(QueryVisitorTest, PropertyDefinedFunctionReturnsMatchingDocuments) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3839,9 +4040,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3890,9 +4093,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info_=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -3910,9 +4115,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3925,9 +4132,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3941,9 +4150,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -3956,9 +4167,11 @@ TEST_F(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); EXPECT_THAT(std::move(query_visitor).ConsumeResults(), @@ -4015,9 +4228,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor1( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor1); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -4033,9 +4248,11 @@ TEST_P(QueryVisitorTest, HasPropertyFunctionReturnsMatchingDocuments) { query = CreateQuery("bar OR NOT hasProperty(\"price\")"); ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); QueryVisitor query_visitor2( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor2); ICING_ASSERT_OK_AND_ASSIGN(query_results, @@ -4088,9 +4305,11 @@ TEST_P(QueryVisitorTest, ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, ParseQueryHelper(query)); QueryVisitor query_visitor( - index_.get(), numeric_index_.get(), document_store_.get(), - schema_store_.get(), normalizer_.get(), tokenizer_.get(), query, + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, /*embedding_query_vectors=*/nullptr, DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); root_node->Accept(&query_visitor); ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, @@ -4102,6 +4321,890 @@ TEST_P(QueryVisitorTest, EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); } +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithNoArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = "semanticSearch()"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithIncorrectArgumentTypeReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = "semanticSearch(0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithExtraArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, \"COSINE\", 0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + GetSearchSpecEmbeddingFunctionWithExtraArgumentReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query index is invalid, since there are only 2 queries. + std::string query = "semanticSearch(getSearchSpecEmbedding(0, 1))"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithInvalidIndexReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query index is invalid, since there are only 2 queries. + std::string query = "semanticSearch(getSearchSpecEmbedding(10))"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionWithInvalidMetricReturnsInvalidArgument) { + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model1", {0.1, 0.2, 0.3}); + *embedding_query_vectors.Add() = CreateVector("my_model2", {-1, 2, -3, 4}); + + // The embedding query metric is invalid. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 10, \"UNKNOWN\")"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // Passing an unknown default metric type without overriding it in the query + // expression is also considered invalid. + query = "semanticSearch(getSearchSpecEmbedding(0), -10, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_UNKNOWN, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + EXPECT_THAT(std::move(query_visitor2).ConsumeResults(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleLowerBound) { + // Index two embedding vectors. + PropertyProto::VectorProto vector0 = + CreateVector("my_model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector1 = + CreateVector("my_model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector0)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has a semantic score of 1 with vector0 and + // -1 with vector1. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // The query should match vector0 only. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0.5)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + + // The query should match both vector0 and vector1. + query = "semanticSearch(getSearchSpecEmbedding(0), -1.5)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match nothing, since there is no vector with a + // score >= 1.01. + query = "semanticSearch(getSearchSpecEmbedding(0), 1.01)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSimpleUpperBound) { + // Index two embedding vectors. + PropertyProto::VectorProto vector0 = + CreateVector("my_model", {0.1, 0.2, 0.3}); + PropertyProto::VectorProto vector1 = + CreateVector("my_model", {-0.1, -0.2, -0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector0)); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), vector1)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has a semantic score of 1 with vector0 and + // -1 with vector1. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // The query should match vector1 only. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), -100, 0.5)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match both vector0 and vector1. + query = "semanticSearch(getSearchSpecEmbedding(0), -100, 1.5)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId1), + Pointee(UnorderedElementsAre(DoubleNear(-1, kEps)))); + + // The query should match nothing, since there is no vector with a + // score <= -1.01. + query = "semanticSearch(getSearchSpecEmbedding(0), -100, -1.01)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), IsEmpty()); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionMetricOverride) { + // Index a embedding vector. + PropertyProto::VectorProto vector = CreateVector("my_model", {0.1, 0.2, 0.3}); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), vector)); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query that has: + // - a cosine semantic score of 1 + // - a dot product semantic score of 0.14 + // - a euclidean semantic score of 0 + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + *embedding_query_vectors.Add() = CreateVector("my_model", {0.1, 0.2, 0.3}); + + // Create a query that overrides the metric to COSINE. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), 0.95, 1.05, \"COSINE\")"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_COSINE, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(1, kEps)))); + + // Create a query that overrides the metric to DOT_PRODUCT. + query = + "semanticSearch(getSearchSpecEmbedding(0), 0.1, 0.2, \"DOT_PRODUCT\")"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_COSINE, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(0.14, kEps)))); + + // Create a query that overrides the metric to EUCLIDEAN. + query = + "semanticSearch(getSearchSpecEmbedding(0), -0.05, 0.05, \"EUCLIDEAN\")"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor3( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + // The default metric to be overridden + EMBEDDING_METRIC_UNKNOWN, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor3); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor3).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_EUCLIDEAN, kDocumentId0), + Pointee(UnorderedElementsAre(DoubleNear(0, kEps)))); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionMultipleQueries) { + // Index 3 embedding vectors for document 0. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId2, kDocumentId0), + CreateVector("my_model2", {-1, 2, 3, -4}))); + // Index 2 embedding vectors for document 1. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId1), + CreateVector("my_model2", {1, -2, 3, -4}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + // Semantic scores for this query: + // - document 0: -2 (section 0), 0 (section 1) + // - document 1: 6 (section 0) + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + // Semantic scores for this query: + // - document 0: 4 (section 2) + // - document 1: -2 (section 1) + embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model2", {-1, 1, -1, -1}); + + // The query can only match document 0: + // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part should match + // semantic scores {-2, 0}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0)" part should match + // semantic scores {4}. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -5) AND " + "semanticSearch(getSearchSpecEmbedding(1), 0)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId0, kSectionId1, + kSectionId2})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(4))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // The query can match both document 0 and document 1: + // For document 0: + // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return + // semantic scores {}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return + // semantic scores {4}. + // For document 1: + // - The "semanticSearch(getSearchSpecEmbedding(0), 1)" part should return + // semantic scores {6}. + // - The "semanticSearch(getSearchSpecEmbedding(1), 0.1)" part should return + // semantic scores {}. + query = + "semanticSearch(getSearchSpecEmbedding(0), 1) OR " + "semanticSearch(getSearchSpecEmbedding(1), 0.1)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + IsNull()); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId2})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + IsNull()); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/1, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(4))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, + SemanticSearchFunctionMultipleQueriesScoresMergedRepeat) { + // Index 3 embedding vectors for document 0. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, -3}))); + // Index 2 embedding vectors for document 1. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create two embedding queries. + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + // Semantic scores for this query: + // - document 0: -2 (section 0), 0 (section 1) + // - document 1: 6 (section 0) + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // The query should match both document 0 and document 1, since the overall + // range is [-10, 10]. The scores in the results should be merged. + std::string query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 0) OR " + "semanticSearch(getSearchSpecEmbedding(0), 0.0001, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // The same query appears twice, in which case all the scores in the results + // should repeat twice. + query = + "semanticSearch(getSearchSpecEmbedding(0), -10, 10) OR " + "semanticSearch(getSearchSpecEmbedding(0), -10, 10)"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6, 6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2, 0, -2, 0))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionHybridQueries) { + // Index terms + Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1, + TERM_MATCH_PREFIX, /*namespace_id=*/0); + ICING_ASSERT_OK(editor.BufferTerm("foo")); + ICING_ASSERT_OK(editor.IndexAllBufferedTerms()); + editor = index_->Edit(kDocumentId1, kSectionId1, TERM_MATCH_PREFIX, + /*namespace_id=*/0); + ICING_ASSERT_OK(editor.BufferTerm("bar")); + ICING_ASSERT_OK(editor.IndexAllBufferedTerms()); + + // Index embedding vectors + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query with semantic scores: + // - document 0: -2 + // - document 1: 6 + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // Perform a hybrid search: + // - The "semanticSearch(getSearchSpecEmbedding(0), 0)" part only matches + // document 1. + // - The "foo" part only matches document 0. + std::string query = "semanticSearch(getSearchSpecEmbedding(0), 0) OR foo"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), + UnorderedElementsAre("foo")); + EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre("")); + EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo")); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results for document 1. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(6))); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + IsNull()); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Perform another hybrid search: + // - The "semanticSearch(getSearchSpecEmbedding(0), -5)" part matches both + // document 0 and 1. + // - The "foo" part only matches document 0. + // As a result, only document 0 will be returned. + query = "semanticSearch(getSearchSpecEmbedding(0), -5) AND foo"; + ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); + QueryVisitor query_visitor2( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor2); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor2).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), + UnorderedElementsAre("foo")); + EXPECT_THAT(ExtractKeys(query_results.query_terms), UnorderedElementsAre("")); + EXPECT_THAT(query_results.query_terms[""], UnorderedElementsAre("foo")); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + itr = query_results.root_iterator.get(); + // Check results for document 0. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT(itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{ + kSectionId0, kSectionId1})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST_F(QueryVisitorTest, SemanticSearchFunctionSectionRestriction) { + ICING_ASSERT_OK(schema_store_->SetSchema( + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("type") + .AddProperty(PropertyConfigBuilder() + .SetName("prop1") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("prop2") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(), + /*ignore_errors_and_delete_documents=*/false, + /*allow_circular_schema_definitions=*/false)); + + // Create two documents. + ICING_ASSERT_OK(document_store_->Put( + DocumentBuilder().SetKey("ns", "uri0").SetSchema("type").Build())); + ICING_ASSERT_OK(document_store_->Put( + DocumentBuilder().SetKey("ns", "uri1").SetSchema("type").Build())); + // Add embedding vectors into different sections for the two documents. + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId0, kDocumentId0), + CreateVector("my_model1", {1, -2, -3}))); + ICING_ASSERT_OK(embedding_index_->BufferEmbedding( + BasicHit(kSectionId1, kDocumentId0), + CreateVector("my_model1", {-1, -2, 3}))); + ICING_ASSERT_OK( + embedding_index_->BufferEmbedding(BasicHit(kSectionId0, kDocumentId1), + CreateVector("my_model1", {-1, 2, 3}))); + ICING_ASSERT_OK( + embedding_index_->BufferEmbedding(BasicHit(kSectionId1, kDocumentId1), + CreateVector("my_model1", {1, 2, -3}))); + ICING_ASSERT_OK(embedding_index_->CommitBufferToIndex()); + + // Create an embedding query with semantic scores: + // - document 0: -2 (section 0), 6 (section 1) + // - document 1: 2 (section 0), -6 (section 1) + google::protobuf::RepeatedPtrField<PropertyProto::VectorProto> embedding_query_vectors; + PropertyProto::VectorProto* embedding_query = embedding_query_vectors.Add(); + *embedding_query = CreateVector("my_model1", {-1, -1, 1}); + + // An embedding query with section restriction. The scores returned should + // only be limited to the section restricted. + std::string query = "prop1:semanticSearch(getSearchSpecEmbedding(0), -100)"; + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + ParseQueryHelper(query)); + QueryVisitor query_visitor( + index_.get(), numeric_index_.get(), embedding_index_.get(), + document_store_.get(), schema_store_.get(), normalizer_.get(), + tokenizer_.get(), query, &embedding_query_vectors, + DocHitInfoIteratorFilter::Options(), TERM_MATCH_PREFIX, + EMBEDDING_METRIC_DOT_PRODUCT, + /*needs_term_frequency_info=*/true, clock_.GetSystemTimeMilliseconds()); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(ExtractKeys(query_results.query_term_iterators), IsEmpty()); + EXPECT_THAT(query_results.query_terms, IsEmpty()); + EXPECT_THAT(query_results.features_in_use, + UnorderedElementsAre(kListFilterQueryLanguageFeature, + kEmbeddingSearchFeature)); + DocHitInfoIterator* itr = query_results.root_iterator.get(); + // Check results. + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId1), + Pointee(UnorderedElementsAre(2))); + ICING_ASSERT_OK(itr->Advance()); + EXPECT_THAT( + itr->doc_hit_info(), + EqualsDocHitInfo(kDocumentId0, std::vector<SectionId>{kSectionId0})); + EXPECT_THAT( + query_results.embedding_query_results.GetMatchedScoresForDocument( + /*query_vector_index=*/0, EMBEDDING_METRIC_DOT_PRODUCT, kDocumentId0), + Pointee(UnorderedElementsAre(-2))); + EXPECT_THAT(itr->Advance(), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + INSTANTIATE_TEST_SUITE_P(QueryVisitorTest, QueryVisitorTest, testing::Values(QueryType::kPlain, QueryType::kSearch)); diff --git a/icing/query/query-features.h b/icing/query/query-features.h index d829cd7..bc3602f 100644 --- a/icing/query/query-features.h +++ b/icing/query/query-features.h @@ -52,9 +52,15 @@ constexpr Feature kListFilterQueryLanguageFeature = constexpr Feature kHasPropertyFunctionFeature = "HAS_PROPERTY_FUNCTION"; // Features#HAS_PROPERTY_FUNCTION +// This feature relates to the use of embedding searches in the advanced query +// language. Ex. `semanticSearch(getSearchSpecEmbedding(0), 0.5, 1, "COSINE")`. +constexpr Feature kEmbeddingSearchFeature = + "EMBEDDING_SEARCH"; // Features#EMBEDDING_SEARCH + inline std::unordered_set<Feature> GetQueryFeaturesSet() { return {kNumericSearchFeature, kVerbatimSearchFeature, - kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature}; + kListFilterQueryLanguageFeature, kHasPropertyFunctionFeature, + kEmbeddingSearchFeature}; } } // namespace lib diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index bbfbf3c..6e13001 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -14,6 +14,7 @@ #include "icing/query/query-processor.h" +#include <cstdint> #include <deque> #include <memory> #include <stack> @@ -26,6 +27,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h" #include "icing/index/iterator/doc-hit-info-iterator-and.h" @@ -34,13 +36,14 @@ #include "icing/index/iterator/doc-hit-info-iterator-or.h" #include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/proto/logging.pb.h" #include "icing/proto/search.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" #include "icing/query/advanced_query_parser/query-visitor.h" #include "icing/query/query-features.h" -#include "icing/query/query-processor.h" #include "icing/query/query-results.h" #include "icing/query/query-terms.h" #include "icing/query/query-utils.h" @@ -49,11 +52,11 @@ #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter.h" -#include "icing/tokenization/raw-query-tokenizer.h" #include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer-factory.h" #include "icing/tokenization/tokenizer.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" #include "icing/util/status-macros.h" namespace icing { @@ -109,39 +112,46 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame( libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store) { + const SchemaStore* schema_store, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(numeric_index); + ICING_RETURN_ERROR_IF_NULL(embedding_index); ICING_RETURN_ERROR_IF_NULL(language_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(clock); - return std::unique_ptr<QueryProcessor>( - new QueryProcessor(index, numeric_index, language_segmenter, normalizer, - document_store, schema_store)); + return std::unique_ptr<QueryProcessor>(new QueryProcessor( + index, numeric_index, embedding_index, language_segmenter, normalizer, + document_store, schema_store, clock)); } QueryProcessor::QueryProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store) + const SchemaStore* schema_store, + const Clock* clock) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), language_segmenter_(*language_segmenter), normalizer_(*normalizer), document_store_(*document_store), - schema_store_(*schema_store) {} + schema_store_(*schema_store), + clock_(*clock) {} libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( const SearchSpecProto& search_spec, ScoringSpecProto::RankingStrategy::Code ranking_strategy, - int64_t current_time_ms) { + int64_t current_time_ms, QueryStatsProto::SearchStats* search_stats) { if (search_spec.search_type() == SearchSpecProto::SearchType::UNDEFINED) { return absl_ports::InvalidArgumentError(absl_ports::StrCat( "Search type ", @@ -152,9 +162,9 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( if (search_spec.search_type() == SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { ICING_VLOG(1) << "Using EXPERIMENTAL_ICING_ADVANCED_QUERY parser!"; - ICING_ASSIGN_OR_RETURN( - results, - ParseAdvancedQuery(search_spec, ranking_strategy, current_time_ms)); + ICING_ASSIGN_OR_RETURN(results, + ParseAdvancedQuery(search_spec, ranking_strategy, + current_time_ms, search_stats)); } else { ICING_ASSIGN_OR_RETURN( results, ParseRawQuery(search_spec, ranking_strategy, current_time_ms)); @@ -167,8 +177,10 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( search_spec.enabled_features().end()); for (const Feature feature : results.features_in_use) { if (enabled_features.find(feature) == enabled_features.end()) { - return absl_ports::InvalidArgumentError( - absl_ports::StrCat("Attempted use of unenabled feature ", feature)); + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "Attempted use of unenabled feature ", feature, + ". Please make sure that you have explicitly set all advanced query " + "features used in this query as enabled in the SearchSpec.")); } } @@ -188,17 +200,27 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery( const SearchSpecProto& search_spec, ScoringSpecProto::RankingStrategy::Code ranking_strategy, - int64_t current_time_ms) const { - QueryResults results; + int64_t current_time_ms, QueryStatsProto::SearchStats* search_stats) const { + std::unique_ptr<Timer> lexer_timer = clock_.GetNewTimer(); Lexer lexer(search_spec.query(), Lexer::Language::QUERY); ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens, lexer.ExtractTokens()); + if (search_stats != nullptr) { + search_stats->set_query_processor_lexer_extract_token_latency_ms( + lexer_timer->GetElapsedMilliseconds()); + } + std::unique_ptr<Timer> parser_timer = clock_.GetNewTimer(); Parser parser = Parser::Create(std::move(lexer_tokens)); ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root, parser.ConsumeQuery()); + if (search_stats != nullptr) { + search_stats->set_query_processor_parser_consume_query_latency_ms( + parser_timer->GetElapsedMilliseconds()); + } if (tree_root == nullptr) { + QueryResults results; results.root_iterator = std::make_unique<DocHitInfoIteratorAllDocumentId>( document_store_.last_added_document_id()); return results; @@ -210,13 +232,23 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery( DocHitInfoIteratorFilter::Options options = GetFilterOptions(search_spec); bool needs_term_frequency_info = ranking_strategy == ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE; - QueryVisitor query_visitor(&index_, &numeric_index_, &document_store_, - &schema_store_, &normalizer_, - plain_tokenizer.get(), search_spec.query(), - std::move(options), search_spec.term_match_type(), - needs_term_frequency_info, current_time_ms); + + std::unique_ptr<Timer> query_visitor_timer = clock_.GetNewTimer(); + QueryVisitor query_visitor( + &index_, &numeric_index_, &embedding_index_, &document_store_, + &schema_store_, &normalizer_, plain_tokenizer.get(), search_spec.query(), + &search_spec.embedding_query_vectors(), std::move(options), + search_spec.term_match_type(), search_spec.embedding_query_metric_type(), + needs_term_frequency_info, current_time_ms); tree_root->Accept(&query_visitor); - return std::move(query_visitor).ConsumeResults(); + ICING_ASSIGN_OR_RETURN(QueryResults results, + std::move(query_visitor).ConsumeResults()); + if (search_stats != nullptr) { + search_stats->set_query_processor_query_visitor_latency_ms( + query_visitor_timer->GetElapsedMilliseconds()); + } + + return results; } // TODO(cassiewang): Collect query stats to populate the SearchResultsProto diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h index d4c22dd..d90b5f6 100644 --- a/icing/query/query-processor.h +++ b/icing/query/query-processor.h @@ -19,17 +19,17 @@ #include <memory> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" -#include "icing/index/iterator/doc-hit-info-iterator-filter.h" -#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/numeric/numeric-index.h" +#include "icing/proto/logging.pb.h" #include "icing/proto/search.pb.h" #include "icing/query/query-results.h" -#include "icing/query/query-terms.h" #include "icing/schema/schema-store.h" #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" namespace icing { namespace lib { @@ -48,8 +48,10 @@ class QueryProcessor { // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, - const DocumentStore* document_store, const SchemaStore* schema_store); + const DocumentStore* document_store, const SchemaStore* schema_store, + const Clock* clock); // Parse the search configurations (including the query, any additional // filters, etc.) in the SearchSpecProto into one DocHitInfoIterator. @@ -68,15 +70,17 @@ class QueryProcessor { libtextclassifier3::StatusOr<QueryResults> ParseSearch( const SearchSpecProto& search_spec, ScoringSpecProto::RankingStrategy::Code ranking_strategy, - int64_t current_time_ms); + int64_t current_time_ms, + QueryStatsProto::SearchStats* search_stats = nullptr); private: explicit QueryProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store); + const SchemaStore* schema_store, const Clock* clock); // Parse the query into a one DocHitInfoIterator that represents the root of a // query tree in our new Advanced Query Language. @@ -88,7 +92,8 @@ class QueryProcessor { libtextclassifier3::StatusOr<QueryResults> ParseAdvancedQuery( const SearchSpecProto& search_spec, ScoringSpecProto::RankingStrategy::Code ranking_strategy, - int64_t current_time_ms) const; + int64_t current_time_ms, + QueryStatsProto::SearchStats* search_stats) const; // Parse the query into a one DocHitInfoIterator that represents the root of a // query tree. @@ -106,12 +111,14 @@ class QueryProcessor { // Not const because we could modify/sort the hit buffer in the lite index at // query time. - Index& index_; - const NumericIndex<int64_t>& numeric_index_; - const LanguageSegmenter& language_segmenter_; - const Normalizer& normalizer_; - const DocumentStore& document_store_; - const SchemaStore& schema_store_; + Index& index_; // Does not own. + const NumericIndex<int64_t>& numeric_index_; // Does not own. + const EmbeddingIndex& embedding_index_; // Does not own. + const LanguageSegmenter& language_segmenter_; // Does not own. + const Normalizer& normalizer_; // Does not own. + const DocumentStore& document_store_; // Does not own. + const SchemaStore& schema_store_; // Does not own. + const Clock& clock_; // Does not own. }; } // namespace lib diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc index 025e8e6..5b3bf99 100644 --- a/icing/query/query-processor_benchmark.cc +++ b/icing/query/query-processor_benchmark.cc @@ -12,26 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "testing/base/public/benchmark.h" -#include "gmock/gmock.h" #include "third_party/absl/flags/flag.h" +#include "icing/absl_ports/str_cat.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/dummy-numeric-index.h" -#include "icing/index/numeric/numeric-index.h" +#include "icing/legacy/index/icing-filesystem.h" #include "icing/proto/schema.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/term.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/query-results.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/store/document-id.h" +#include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "icing/util/clock.h" #include "icing/util/logging.h" #include "unicode/uloc.h" @@ -118,6 +132,7 @@ void BM_QueryOneTerm(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -134,6 +149,9 @@ void BM_QueryOneTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -172,8 +190,9 @@ void BM_QueryOneTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get())); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -247,6 +266,7 @@ void BM_QueryFiveTerms(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -263,6 +283,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -315,8 +338,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get())); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); const std::string query_string = absl_ports::StrCat( input_string_a, " ", input_string_b, " ", input_string_c, " ", @@ -394,6 +418,7 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -410,6 +435,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -451,8 +479,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get())); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -526,6 +555,7 @@ void BM_QueryHiragana(benchmark::State& state) { const std::string base_dir = GetTestTempDir() + "/query_processor_benchmark"; const std::string index_dir = base_dir + "/index"; const std::string numeric_index_dir = base_dir + "/numeric_index"; + const std::string embedding_index_dir = base_dir + "/embedding_index"; const std::string schema_dir = base_dir + "/schema"; const std::string doc_store_dir = base_dir + "/store"; @@ -542,6 +572,9 @@ void BM_QueryHiragana(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( auto numeric_index, DummyNumericIndex<int64_t>::Create(filesystem, numeric_index_dir)); + ICING_ASSERT_OK_AND_ASSIGN( + auto embedding_index, + EmbeddingIndex::Create(&filesystem, embedding_index_dir)); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = @@ -583,8 +616,9 @@ void BM_QueryHiragana(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, QueryProcessor::Create(index.get(), numeric_index.get(), - language_segmenter.get(), normalizer.get(), - document_store.get(), schema_store.get())); + embedding_index.get(), language_segmenter.get(), + normalizer.get(), document_store.get(), + schema_store.get(), &clock)); SearchSpecProto search_spec; search_spec.set_query(input_string); diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index 53e3035..288c7fb 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -14,17 +14,24 @@ #include "icing/query/query-processor.h" +#include <array> #include <cstdint> #include <memory> #include <string> +#include <unordered_map> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -33,6 +40,7 @@ #include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" +#include "icing/proto/logging.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/term.pb.h" @@ -53,6 +61,8 @@ #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "icing/util/status-macros.h" #include "unicode/uloc.h" namespace icing { @@ -61,6 +71,7 @@ namespace lib { namespace { using ::testing::ElementsAre; +using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::UnorderedElementsAre; @@ -85,7 +96,8 @@ class QueryProcessorTest store_dir_(test_dir_ + "/store"), schema_store_dir_(test_dir_ + "/schema_store"), index_dir_(test_dir_ + "/index"), - numeric_index_dir_(test_dir_ + "/numeric_index") {} + numeric_index_dir_(test_dir_ + "/numeric_index"), + embedding_index_dir_(test_dir_ + "/embedding_index") {} void SetUp() override { filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -124,6 +136,9 @@ class QueryProcessorTest ICING_ASSERT_OK_AND_ASSIGN( numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); language_segmenter_factory::SegmenterOptions segmenter_options( ULOC_US, jni_cache_.get()); @@ -136,9 +151,10 @@ class QueryProcessorTest ICING_ASSERT_OK_AND_ASSIGN( query_processor_, - QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get())); + QueryProcessor::Create( + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); } libtextclassifier3::Status AddTokenToIndex( @@ -174,10 +190,12 @@ class QueryProcessorTest IcingFilesystem icing_filesystem_; const std::string index_dir_; const std::string numeric_index_dir_; + const std::string embedding_index_dir_; protected: std::unique_ptr<Index> index_; std::unique_ptr<NumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; FakeClock fake_clock_; @@ -190,34 +208,51 @@ class QueryProcessorTest TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) { EXPECT_THAT( QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get()), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr, - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get()), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), + /*embedding_index=*/nullptr, + language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(QueryProcessor::Create( + index_.get(), numeric_index_.get(), embedding_index_.get(), + /*language_segmenter=*/nullptr, normalizer_.get(), + document_store_.get(), schema_store_.get(), &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( QueryProcessor::Create(index_.get(), numeric_index_.get(), - /*language_segmenter=*/nullptr, normalizer_.get(), - document_store_.get(), schema_store_.get()), + embedding_index_.get(), language_segmenter_.get(), + /*normalizer=*/nullptr, document_store_.get(), + schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( QueryProcessor::Create( - index_.get(), numeric_index_.get(), language_segmenter_.get(), - /*normalizer=*/nullptr, document_store_.get(), schema_store_.get()), + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), + /*document_store=*/nullptr, schema_store_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - /*document_store=*/nullptr, schema_store_.get()), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + /*schema_store=*/nullptr, &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(index_.get(), numeric_index_.get(), + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), /*clock=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - /*schema_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) { @@ -2947,8 +2982,9 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> local_query_processor, QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get())); + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); SearchSpecProto search_spec; search_spec.set_query("hello"); @@ -3009,8 +3045,9 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> local_query_processor, QueryProcessor::Create(index_.get(), numeric_index_.get(), - language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get())); + embedding_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); SearchSpecProto search_spec; search_spec.set_query("hello"); @@ -3320,6 +3357,65 @@ TEST_P(QueryProcessorTest, GroupingInSectionRestriction) { std::vector<SectionId>{prop1_section_id}))); } +TEST_P(QueryProcessorTest, ParseAdvancedQueryShouldSetSearchStats) { + if (GetParam() != + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + GTEST_SKIP(); + } + + // Create the schema and document store + SchemaProto schema = SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("email")) + .Build(); + ASSERT_THAT(schema_store_->SetSchema( + schema, /*ignore_errors_and_delete_documents=*/false, + /*allow_circular_schema_definitions=*/false), + IsOk()); + + // These documents don't actually match to the tokens in the index. We're + // inserting the documents to get the appropriate number of documents and + // namespaces populated. + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace1", "1") + .SetSchema("email") + .Build())); + + // Populate the index + SectionId section_id = 0; + TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + + EXPECT_THAT( + AddTokenToIndex(document_id, section_id, term_match_type, "hello"), + IsOk()); + EXPECT_THAT( + AddTokenToIndex(document_id, section_id, term_match_type, "world"), + IsOk()); + + SearchSpecProto search_spec; + search_spec.set_query("hello world"); + search_spec.set_term_match_type(term_match_type); + search_spec.set_search_type(GetParam()); + + static constexpr int64_t kSearchStatsLatencyMs = 10; + fake_clock_.SetTimerElapsedMilliseconds(kSearchStatsLatencyMs); + + QueryStatsProto::SearchStats search_stats; + ICING_ASSERT_OK_AND_ASSIGN( + QueryResults results, + query_processor_->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, + fake_clock_.GetSystemTimeMilliseconds(), &search_stats)); + + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_THAT(search_stats.query_processor_lexer_extract_token_latency_ms(), + Eq(kSearchStatsLatencyMs)); + EXPECT_THAT(search_stats.query_processor_parser_consume_query_latency_ms(), + Eq(kSearchStatsLatencyMs)); + EXPECT_THAT(search_stats.query_processor_query_visitor_latency_ms(), + Eq(kSearchStatsLatencyMs)); +} + INSTANTIATE_TEST_SUITE_P( QueryProcessorTest, QueryProcessorTest, testing::Values( diff --git a/icing/query/query-results.h b/icing/query/query-results.h index 52cdd71..983ab35 100644 --- a/icing/query/query-results.h +++ b/icing/query/query-results.h @@ -18,9 +18,10 @@ #include <memory> #include <unordered_set> +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator.h" -#include "icing/query/query-terms.h" #include "icing/query/query-features.h" +#include "icing/query/query-terms.h" namespace icing { namespace lib { @@ -35,6 +36,10 @@ struct QueryResults { // beginning with root_iterator. // This will only be populated when ranking_strategy == RELEVANCE_SCORE. QueryTermIteratorsMap query_term_iterators; + // Contains similarity scores from embedding based queries, which will be used + // in the advanced scoring language to determine the results for the + // "this.matchedSemanticScores(...)" function. + EmbeddingQueryResults embedding_query_results; // Features that are invoked during query execution. // The list of possible features is defined in query_features.h. std::unordered_set<Feature> features_in_use; diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc index eb86e3b..dfebb98 100644 --- a/icing/query/suggestion-processor.cc +++ b/icing/query/suggestion-processor.cc @@ -14,14 +14,37 @@ #include "icing/query/suggestion-processor.h" -#include "icing/proto/schema.pb.h" +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/index.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/index/term-metadata.h" #include "icing/proto/search.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/query-results.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/store/namespace-id.h" #include "icing/store/suggestion-result-checker-impl.h" -#include "icing/tokenization/tokenizer-factory.h" -#include "icing/tokenization/tokenizer.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -29,20 +52,24 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> SuggestionProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store) { + const SchemaStore* schema_store, + const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(numeric_index); + ICING_RETURN_ERROR_IF_NULL(embedding_index); ICING_RETURN_ERROR_IF_NULL(language_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(clock); - return std::unique_ptr<SuggestionProcessor>( - new SuggestionProcessor(index, numeric_index, language_segmenter, - normalizer, document_store, schema_store)); + return std::unique_ptr<SuggestionProcessor>(new SuggestionProcessor( + index, numeric_index, embedding_index, language_segmenter, normalizer, + document_store, schema_store, clock)); } libtextclassifier3::StatusOr< @@ -224,8 +251,9 @@ SuggestionProcessor::QuerySuggestions( ICING_ASSIGN_OR_RETURN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(&index_, &numeric_index_, &language_segmenter_, - &normalizer_, &document_store_, &schema_store_)); + QueryProcessor::Create(&index_, &numeric_index_, &embedding_index_, + &language_segmenter_, &normalizer_, + &document_store_, &schema_store_, &clock_)); SearchSpecProto search_spec; search_spec.set_query(suggestion_spec.prefix()); @@ -298,14 +326,18 @@ SuggestionProcessor::QuerySuggestions( SuggestionProcessor::SuggestionProcessor( Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, - const DocumentStore* document_store, const SchemaStore* schema_store) + const DocumentStore* document_store, const SchemaStore* schema_store, + const Clock* clock) : index_(*index), numeric_index_(*numeric_index), + embedding_index_(*embedding_index), language_segmenter_(*language_segmenter), normalizer_(*normalizer), document_store_(*document_store), - schema_store_(*schema_store) {} + schema_store_(*schema_store), + clock_(*clock) {} } // namespace lib } // namespace icing diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h index e100031..cf393b4 100644 --- a/icing/query/suggestion-processor.h +++ b/icing/query/suggestion-processor.h @@ -15,7 +15,12 @@ #ifndef ICING_QUERY_SUGGESTION_PROCESSOR_H_ #define ICING_QUERY_SUGGESTION_PROCESSOR_H_ +#include <cstdint> +#include <memory> +#include <vector> + #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-index.h" #include "icing/index/index.h" #include "icing/index/numeric/numeric-index.h" #include "icing/proto/search.pb.h" @@ -23,6 +28,7 @@ #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer.h" +#include "icing/util/clock.h" namespace icing { namespace lib { @@ -41,9 +47,10 @@ class SuggestionProcessor { // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> Create(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store); + const SchemaStore* schema_store, const Clock* clock); // Query suggestions based on the given SuggestionSpecProto. // @@ -57,19 +64,23 @@ class SuggestionProcessor { private: explicit SuggestionProcessor(Index* index, const NumericIndex<int64_t>* numeric_index, + const EmbeddingIndex* embedding_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store); + const SchemaStore* schema_store, + const Clock* clock); // Not const because we could modify/sort the TermMetaData buffer in the lite // index. - Index& index_; - const NumericIndex<int64_t>& numeric_index_; - const LanguageSegmenter& language_segmenter_; - const Normalizer& normalizer_; - const DocumentStore& document_store_; - const SchemaStore& schema_store_; + Index& index_; // Does not own. + const NumericIndex<int64_t>& numeric_index_; // Does not own. + const EmbeddingIndex& embedding_index_; // Does not own. + const LanguageSegmenter& language_segmenter_; // Does not own. + const Normalizer& normalizer_; // Does not own. + const DocumentStore& document_store_; // Does not own. + const SchemaStore& schema_store_; // Does not own. + const Clock& clock_; // Does not own. }; } // namespace lib diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc index 9f9094d..c9bdd8a 100644 --- a/icing/query/suggestion-processor_test.cc +++ b/icing/query/suggestion-processor_test.cc @@ -14,14 +14,30 @@ #include "icing/query/suggestion-processor.h" +#include <cstdint> +#include <memory> #include <string> +#include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-index.h" +#include "icing/index/index.h" #include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/index/term-metadata.h" +#include "icing/jni/jni-cache.h" +#include "icing/legacy/index/icing-filesystem.h" +#include "icing/portable/platform.h" #include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" @@ -30,7 +46,9 @@ #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "unicode/uloc.h" namespace icing { @@ -59,7 +77,8 @@ class SuggestionProcessorTest : public Test { store_dir_(test_dir_ + "/store"), schema_store_dir_(test_dir_ + "/schema_store"), index_dir_(test_dir_ + "/index"), - numeric_index_dir_(test_dir_ + "/numeric_index") {} + numeric_index_dir_(test_dir_ + "/numeric_index"), + embedding_index_dir_(test_dir_ + "/embedding_index") {} void SetUp() override { filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); @@ -105,6 +124,9 @@ class SuggestionProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( numeric_index_, DummyNumericIndex<int64_t>::Create(filesystem_, numeric_index_dir_)); + ICING_ASSERT_OK_AND_ASSIGN( + embedding_index_, + EmbeddingIndex::Create(&filesystem_, embedding_index_dir_)); language_segmenter_factory::SegmenterOptions segmenter_options( ULOC_US, jni_cache_.get()); @@ -118,8 +140,9 @@ class SuggestionProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( suggestion_processor_, SuggestionProcessor::Create( - index_.get(), numeric_index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), schema_store_.get())); + index_.get(), numeric_index_.get(), embedding_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); } libtextclassifier3::Status AddTokenToIndex( @@ -146,10 +169,12 @@ class SuggestionProcessorTest : public Test { IcingFilesystem icing_filesystem_; const std::string index_dir_; const std::string numeric_index_dir_; + const std::string embedding_index_dir_; protected: std::unique_ptr<Index> index_; std::unique_ptr<NumericIndex<int64_t>> numeric_index_; + std::unique_ptr<EmbeddingIndex> embedding_index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; FakeClock fake_clock_; @@ -673,6 +698,16 @@ TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) { EXPECT_THAT(terms_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } + + if (SearchSpecProto::default_instance().search_type() == + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + suggestion_spec.set_prefix( + "bar OR semanticSearch(getSearchSpecEmbedding(0), 0.5, 1)"); + terms_or = suggestion_processor_->QuerySuggestions( + suggestion_spec, fake_clock_.GetSystemTimeMilliseconds()); + EXPECT_THAT(terms_or, + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + } } TEST_F(SuggestionProcessorTest, InvalidPrefixTest) { diff --git a/icing/schema-builder.h b/icing/schema-builder.h index c74505e..c702afe 100644 --- a/icing/schema-builder.h +++ b/icing/schema-builder.h @@ -56,6 +56,13 @@ constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_UNKNOWN = constexpr IntegerIndexingConfig::NumericMatchType::Code NUMERIC_MATCH_RANGE = IntegerIndexingConfig::NumericMatchType::RANGE; +constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code + EMBEDDING_INDEXING_UNKNOWN = + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN; +constexpr EmbeddingIndexingConfig::EmbeddingIndexingType::Code + EMBEDDING_INDEXING_LINEAR_SEARCH = + EmbeddingIndexingConfig::EmbeddingIndexingType::LINEAR_SEARCH; + constexpr PropertyConfigProto::DataType::Code TYPE_UNKNOWN = PropertyConfigProto::DataType::UNKNOWN; constexpr PropertyConfigProto::DataType::Code TYPE_STRING = @@ -70,6 +77,8 @@ constexpr PropertyConfigProto::DataType::Code TYPE_BYTES = PropertyConfigProto::DataType::BYTES; constexpr PropertyConfigProto::DataType::Code TYPE_DOCUMENT = PropertyConfigProto::DataType::DOCUMENT; +constexpr PropertyConfigProto::DataType::Code TYPE_VECTOR = + PropertyConfigProto::DataType::VECTOR; constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_NONE = JoinableConfig::ValueType::NONE; @@ -146,6 +155,15 @@ class PropertyConfigBuilder { return *this; } + PropertyConfigBuilder& SetDataTypeVector( + EmbeddingIndexingConfig::EmbeddingIndexingType::Code + embedding_indexing_type) { + property_.set_data_type(PropertyConfigProto::DataType::VECTOR); + property_.mutable_embedding_indexing_config()->set_embedding_indexing_type( + embedding_indexing_type); + return *this; + } + PropertyConfigBuilder& SetJoinable( JoinableConfig::ValueType::Code join_value_type, bool propagate_delete) { property_.mutable_joinable_config()->set_value_type(join_value_type); @@ -159,6 +177,11 @@ class PropertyConfigBuilder { return *this; } + PropertyConfigBuilder& SetDescription(std::string description) { + property_.set_description(std::move(description)); + return *this; + } + PropertyConfigProto Build() const { return std::move(property_); } private: @@ -186,6 +209,11 @@ class SchemaTypeConfigBuilder { return *this; } + SchemaTypeConfigBuilder& SetDescription(std::string description) { + type_config_.set_description(std::move(description)); + return *this; + } + SchemaTypeConfigBuilder& AddProperty(PropertyConfigProto property) { *type_config_.add_properties() = std::move(property); return *this; diff --git a/icing/schema/property-util.cc b/icing/schema/property-util.cc index 67ff748..7c86122 100644 --- a/icing/schema/property-util.cc +++ b/icing/schema/property-util.cc @@ -14,6 +14,8 @@ #include "icing/schema/property-util.h" +#include <cstddef> +#include <cstdint> #include <string> #include <string_view> #include <vector> @@ -131,6 +133,14 @@ ExtractPropertyValues<int64_t>(const PropertyProto& property) { property.int64_values().end()); } +template <> +libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>> +ExtractPropertyValues<PropertyProto::VectorProto>( + const PropertyProto& property) { + return std::vector<PropertyProto::VectorProto>( + property.vector_values().begin(), property.vector_values().end()); +} + } // namespace property_util } // namespace lib diff --git a/icing/schema/property-util.h b/icing/schema/property-util.h index 7557879..c409a9d 100644 --- a/icing/schema/property-util.h +++ b/icing/schema/property-util.h @@ -15,8 +15,11 @@ #ifndef ICING_SCHEMA_PROPERTY_UTIL_H_ #define ICING_SCHEMA_PROPERTY_UTIL_H_ +#include <cstddef> +#include <cstdint> #include <string> #include <string_view> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -163,6 +166,11 @@ template <> libtextclassifier3::StatusOr<std::vector<int64_t>> ExtractPropertyValues<int64_t>(const PropertyProto& property); +template <> +libtextclassifier3::StatusOr<std::vector<PropertyProto::VectorProto>> +ExtractPropertyValues<PropertyProto::VectorProto>( + const PropertyProto& property); + template <typename T> libtextclassifier3::StatusOr<std::vector<T>> ExtractPropertyValuesFromDocument( const DocumentProto& document, std::string_view property_path) { diff --git a/icing/schema/property-util_test.cc b/icing/schema/property-util_test.cc index eddcc84..5e8a430 100644 --- a/icing/schema/property-util_test.cc +++ b/icing/schema/property-util_test.cc @@ -22,6 +22,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/portable/equals-proto.h" #include "icing/proto/document.pb.h" #include "icing/testing/common-matchers.h" @@ -30,6 +31,7 @@ namespace lib { namespace { +using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::ElementsAre; using ::testing::IsEmpty; @@ -38,6 +40,8 @@ static constexpr std::string_view kPropertySingleString = "singleString"; static constexpr std::string_view kPropertyRepeatedString = "repeatedString"; static constexpr std::string_view kPropertySingleInteger = "singleInteger"; static constexpr std::string_view kPropertyRepeatedInteger = "repeatedInteger"; +static constexpr std::string_view kPropertySingleVector = "singleVector"; +static constexpr std::string_view kPropertyRepeatedVector = "repeatedVector"; static constexpr std::string_view kTypeNestedTest = "NestedTest"; static constexpr std::string_view kPropertyStr = "str"; @@ -83,6 +87,28 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeInteger) { IsOkAndHolds(ElementsAre(123, -456, 0))); } +TEST(PropertyUtilTest, ExtractPropertyValuesTypeVector) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + + PropertyProto property; + *property.mutable_vector_values()->Add() = vector1; + *property.mutable_vector_values()->Add() = vector2; + + EXPECT_THAT( + property_util::ExtractPropertyValues<PropertyProto::VectorProto>( + property), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); +} + TEST(PropertyUtilTest, ExtractPropertyValuesMismatchedType) { PropertyProto property; property.mutable_int64_values()->Add(123); @@ -110,6 +136,16 @@ TEST(PropertyUtilTest, ExtractPropertyValuesTypeUnimplemented) { } TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + DocumentProto document = DocumentBuilder() .SetKey("icing", "test/1") @@ -119,6 +155,9 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { "repeated2", "repeated3") .AddInt64Property(std::string(kPropertySingleInteger), 123) .AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3) + .AddVectorProperty(std::string(kPropertySingleVector), vector1) + .AddVectorProperty(std::string(kPropertyRepeatedVector), vector1, + vector2) .Build(); // Single string @@ -139,9 +178,33 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocument) { EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument<int64_t>( document, /*property_path=*/kPropertyRepeatedInteger), IsOkAndHolds(ElementsAre(1, 2, 3))); + // Single vector + EXPECT_THAT(property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + document, /*property_path=*/kPropertySingleVector), + IsOkAndHolds(ElementsAre(EqualsProto(vector1)))); + // Repeated vector + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + document, /*property_path=*/kPropertyRepeatedVector), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); } TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + PropertyProto::VectorProto vector3; + vector3.set_model_signature("my_model3"); + vector3.add_values(1.0f); + DocumentProto nested_document = DocumentBuilder() .SetKey("icing", "nested/1") @@ -158,6 +221,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { .AddInt64Property(std::string(kPropertySingleInteger), 123) .AddInt64Property(std::string(kPropertyRepeatedInteger), 1, 2, 3) + .AddVectorProperty(std::string(kPropertySingleVector), + vector1) + .AddVectorProperty(std::string(kPropertyRepeatedVector), + vector1, vector2) .Build(), DocumentBuilder() .SetSchema(std::string(kTypeTest)) @@ -168,6 +235,10 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { .AddInt64Property(std::string(kPropertySingleInteger), 456) .AddInt64Property(std::string(kPropertyRepeatedInteger), 4, 5, 6) + .AddVectorProperty(std::string(kPropertySingleVector), + vector2) + .AddVectorProperty(std::string(kPropertyRepeatedVector), + vector2, vector3) .Build()) .Build(); @@ -189,6 +260,17 @@ TEST(PropertyUtilTest, ExtractPropertyValuesFromDocumentNested) { property_util::ExtractPropertyValuesFromDocument<int64_t>( nested_document, /*property_path=*/"nestedDocument.repeatedInteger"), IsOkAndHolds(ElementsAre(1, 2, 3, 4, 5, 6))); + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + nested_document, /*property_path=*/"nestedDocument.singleVector"), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2)))); + EXPECT_THAT( + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>( + nested_document, /*property_path=*/"nestedDocument.repeatedVector"), + IsOkAndHolds(ElementsAre(EqualsProto(vector1), EqualsProto(vector2), + EqualsProto(vector2), EqualsProto(vector3)))); // Test the property at first level EXPECT_THAT( diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index e17e388..6830787 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -486,7 +486,7 @@ libtextclassifier3::Status SchemaStore::RegenerateDerivedFiles( ICING_RETURN_IF_ERROR(schema_file_->Write(std::move(base_schema_ptr))); // LINT.IfChange(min_overlay_version_compatibility) - // Although the current version is 3, the schema is compatible with + // Although the current version is 4, the schema is compatible with // version 1, so min_overlay_version_compatibility should be 1. int32_t min_overlay_version_compatibility = version_util::kVersionOne; // LINT.ThenChange(//depot/google3/icing/file/version-util.h:kVersion) diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc index 8cc7008..ca5cdd3 100644 --- a/icing/schema/schema-store_test.cc +++ b/icing/schema/schema-store_test.cc @@ -14,8 +14,10 @@ #include "icing/schema/schema-store.h" +#include <cstdint> #include <memory> #include <string> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" @@ -23,6 +25,7 @@ #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" #include "icing/document-builder.h" +#include "icing/file/file-backed-proto.h" #include "icing/file/filesystem.h" #include "icing/file/mock-filesystem.h" #include "icing/file/version-util.h" @@ -32,10 +35,7 @@ #include "icing/proto/logging.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/storage.pb.h" -#include "icing/proto/term.pb.h" #include "icing/schema-builder.h" -#include "icing/schema/schema-util.h" -#include "icing/schema/section-manager.h" #include "icing/schema/section.h" #include "icing/store/document-filter-data.h" #include "icing/testing/common-matchers.h" @@ -142,7 +142,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveConstructible) { IsOkAndHolds(Eq(expected_checksum))); SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN, - "prop1"); + EMBEDDING_INDEXING_UNKNOWN, "prop1"); EXPECT_THAT(move_constructed_schema_store.GetSectionMetadata("type_a"), IsOkAndHolds(Pointee(ElementsAre(expected_metadata)))); } @@ -193,7 +193,7 @@ TEST_F(SchemaStoreTest, SchemaStoreMoveAssignment) { IsOkAndHolds(Eq(expected_checksum))); SectionMetadata expected_metadata(/*id_in=*/0, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, NUMERIC_MATCH_UNKNOWN, - "prop1"); + EMBEDDING_INDEXING_UNKNOWN, "prop1"); EXPECT_THAT(move_assigned_schema_store->GetSectionMetadata("type_a"), IsOkAndHolds(Pointee(ElementsAre(expected_metadata)))); } @@ -1527,10 +1527,10 @@ TEST_F(SchemaStoreTest, SetSchemaRegenerateDerivedFilesFailure) { .Build(); SectionMetadata expected_int_prop1_metadata( /*id_in=*/0, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, "intProp1"); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, "intProp1"); SectionMetadata expected_string_prop1_metadata( /*id_in=*/1, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, - NUMERIC_MATCH_UNKNOWN, "stringProp1"); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, "stringProp1"); ICING_ASSERT_OK_AND_ASSIGN(SectionGroup section_group, schema_store->ExtractSections(document)); ASSERT_THAT(section_group.string_sections, SizeIs(1)); diff --git a/icing/schema/schema-util.cc b/icing/schema/schema-util.cc index 72287a8..976d1b7 100644 --- a/icing/schema/schema-util.cc +++ b/icing/schema/schema-util.cc @@ -115,6 +115,12 @@ bool IsIntegerNumericMatchTypeCompatible( return old_indexed.numeric_match_type() == new_indexed.numeric_match_type(); } +bool IsEmbeddingIndexingCompatible(const EmbeddingIndexingConfig& old_indexed, + const EmbeddingIndexingConfig& new_indexed) { + return old_indexed.embedding_indexing_type() == + new_indexed.embedding_indexing_type(); +} + bool IsDocumentIndexingCompatible(const DocumentIndexingConfig& old_indexed, const DocumentIndexingConfig& new_indexed) { // TODO(b/265304217): This could mark the new schema as incompatible and @@ -824,6 +830,10 @@ libtextclassifier3::Status SchemaUtil::ValidateDocumentIndexingConfig( !property_config.document_indexing_config() .indexable_nested_properties_list() .empty(); + case PropertyConfigProto::DataType::VECTOR: + return property_config.embedding_indexing_config() + .embedding_indexing_type() != + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN; case PropertyConfigProto::DataType::UNKNOWN: case PropertyConfigProto::DataType::DOUBLE: case PropertyConfigProto::DataType::BOOLEAN: @@ -1087,7 +1097,10 @@ const SchemaUtil::SchemaDelta SchemaUtil::ComputeCompatibilityDelta( new_property_config->integer_indexing_config()) || !IsDocumentIndexingCompatible( old_property_config.document_indexing_config(), - new_property_config->document_indexing_config())) { + new_property_config->document_indexing_config()) || + !IsEmbeddingIndexingCompatible( + old_property_config.embedding_indexing_config(), + new_property_config->embedding_indexing_config())) { is_index_incompatible = true; } diff --git a/icing/schema/schema-util_test.cc b/icing/schema/schema-util_test.cc index 82683ba..e77ffab 100644 --- a/icing/schema/schema-util_test.cc +++ b/icing/schema/schema-util_test.cc @@ -2900,6 +2900,118 @@ TEST_P(SchemaUtilTest, IsEmpty()); } +TEST_P(SchemaUtilTest, ChangingIndexedVectorPropertiesMakesIndexIncompatible) { + SchemaProto schema_with_indexed_property = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaProto schema_with_unindexed_property = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty( + PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::SchemaDelta schema_delta; + schema_delta.schema_types_index_incompatible.insert(kPersonType); + + // New schema gained a new indexed vector property. + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta( + schema_with_unindexed_property, schema_with_indexed_property, + no_dependents_map), + Eq(schema_delta)); + + // New schema lost an indexed vector property. + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta( + schema_with_indexed_property, schema_with_unindexed_property, + no_dependents_map), + Eq(schema_delta)); +} + +TEST_P(SchemaUtilTest, AddingNewIndexedVectorPropertyMakesIndexIncompatible) { + // Configure old schema + SchemaProto old_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Configure new schema + SchemaProto new_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("NewEmbeddingProperty") + .SetDataTypeVector( + EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::SchemaDelta schema_delta; + schema_delta.schema_types_index_incompatible.insert(kPersonType); + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema, + no_dependents_map), + Eq(schema_delta)); +} + +TEST_P(SchemaUtilTest, + AddingNewNonIndexedVectorPropertyShouldRemainIndexCompatible) { + // Configure old schema + SchemaProto old_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Configure new schema + SchemaProto new_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType(kPersonType) + .AddProperty(PropertyConfigBuilder() + .SetName("Property") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("NewProperty") + .SetDataTypeVector(EMBEDDING_INDEXING_UNKNOWN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + SchemaUtil::DependentMap no_dependents_map; + EXPECT_THAT(SchemaUtil::ComputeCompatibilityDelta(old_schema, new_schema, + no_dependents_map) + .schema_types_index_incompatible, + IsEmpty()); +} + TEST_P(SchemaUtilTest, AddingNewIndexedDocumentPropertyMakesIndexAndJoinIncompatible) { SchemaTypeConfigProto nested_schema = diff --git a/icing/schema/section-manager.cc b/icing/schema/section-manager.cc index 3d540d6..8689bf2 100644 --- a/icing/schema/section-manager.cc +++ b/icing/schema/section-manager.cc @@ -61,6 +61,7 @@ libtextclassifier3::Status AppendNewSectionMetadata( property_config.string_indexing_config().tokenizer_type(), property_config.string_indexing_config().term_match_type(), property_config.integer_indexing_config().numeric_match_type(), + property_config.embedding_indexing_config().embedding_indexing_type(), std::move(concatenated_path))); return libtextclassifier3::Status::OK; } @@ -162,6 +163,19 @@ libtextclassifier3::StatusOr<SectionGroup> SectionManager::ExtractSections( section_group.integer_sections); break; } + case PropertyConfigProto::DataType::VECTOR: { + if (section_metadata.embedding_indexing_type == + EmbeddingIndexingConfig::EmbeddingIndexingType::UNKNOWN) { + // Skip if embedding indexing type is UNKNOWN. + break; + } + AppendSection( + section_metadata, + property_util::ExtractPropertyValuesFromDocument< + PropertyProto::VectorProto>(document, section_metadata.path), + section_group.vector_sections); + break; + } default: { // Skip other data types. break; diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc index eee78e9..b735fb1 100644 --- a/icing/schema/section-manager_test.cc +++ b/icing/schema/section-manager_test.cc @@ -14,10 +14,12 @@ #include "icing/schema/section-manager.h" +#include <cstdint> #include <memory> #include <string> #include <string_view> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -27,6 +29,8 @@ #include "icing/schema-builder.h" #include "icing/schema/schema-type-manager.h" #include "icing/schema/schema-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-filter-data.h" #include "icing/store/dynamic-trie-key-mapper.h" #include "icing/store/key-mapper.h" #include "icing/testing/common-matchers.h" @@ -37,6 +41,7 @@ namespace lib { namespace { +using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Pointee; @@ -48,6 +53,8 @@ static constexpr std::string_view kTypeEmail = "Email"; static constexpr std::string_view kPropertyRecipientIds = "recipientIds"; static constexpr std::string_view kPropertyRecipients = "recipients"; static constexpr std::string_view kPropertySubject = "subject"; +static constexpr std::string_view kPropertySubjectEmbedding = + "subjectEmbedding"; static constexpr std::string_view kPropertyTimestamp = "timestamp"; // non-indexable static constexpr std::string_view kPropertyAttachment = "attachment"; @@ -109,6 +116,14 @@ PropertyConfigProto CreateSubjectPropertyConfig() { .Build(); } +PropertyConfigProto CreateSubjectEmbeddingPropertyConfig() { + return PropertyConfigBuilder() + .SetName(kPropertySubjectEmbedding) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL) + .Build(); +} + PropertyConfigProto CreateTimestampPropertyConfig() { return PropertyConfigBuilder() .SetName(kPropertyTimestamp) @@ -145,6 +160,7 @@ SchemaTypeConfigProto CreateEmailTypeConfig() { return SchemaTypeConfigBuilder() .SetType(kTypeEmail) .AddProperty(CreateSubjectPropertyConfig()) + .AddProperty(CreateSubjectEmbeddingPropertyConfig()) .AddProperty(PropertyConfigBuilder() .SetName(kPropertyText) .SetDataTypeString(TERM_MATCH_UNKNOWN, TOKENIZER_NONE) @@ -220,11 +236,19 @@ class SectionManagerTest : public ::testing::Test { ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeConversation, 1)); ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeGroup, 2)); + email_subject_embedding_ = PropertyProto::VectorProto(); + email_subject_embedding_.add_values(1.0); + email_subject_embedding_.add_values(2.0); + email_subject_embedding_.add_values(3.0); + email_subject_embedding_.set_model_signature("my_model"); + email_document_ = DocumentBuilder() .SetKey("icing", "email/1") .SetSchema(std::string(kTypeEmail)) .AddStringProperty(std::string(kPropertySubject), "the subject") + .AddVectorProperty(std::string(kPropertySubjectEmbedding), + email_subject_embedding_) .AddStringProperty(std::string(kPropertyText), "the text") .AddBytesProperty(std::string(kPropertyAttachment), "attachment bytes") @@ -265,6 +289,7 @@ class SectionManagerTest : public ::testing::Test { SchemaUtil::TypeConfigMap type_config_map_; std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper_; + PropertyProto::VectorProto email_subject_embedding_; DocumentProto email_document_; DocumentProto conversation_document_; DocumentProto group_document_; @@ -308,11 +333,21 @@ TEST_F(SectionManagerTest, ExtractSections) { EXPECT_THAT(section_group.integer_sections[0].content, ElementsAre(1, 2, 3)); EXPECT_THAT(section_group.integer_sections[1].metadata, - EqualsSectionMetadata(/*expected_id=*/3, + EqualsSectionMetadata(/*expected_id=*/4, /*expected_property_path=*/"timestamp", CreateTimestampPropertyConfig())); EXPECT_THAT(section_group.integer_sections[1].content, ElementsAre(kDefaultTimestamp)); + + // Vector sections + EXPECT_THAT(section_group.vector_sections, SizeIs(1)); + EXPECT_THAT( + section_group.vector_sections[0].metadata, + EqualsSectionMetadata(/*expected_id=*/3, + /*expected_property_path=*/"subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())); + EXPECT_THAT(section_group.vector_sections[0].content, + ElementsAre(EqualsProto(email_subject_embedding_))); } TEST_F(SectionManagerTest, ExtractSectionsNested) { @@ -359,11 +394,22 @@ TEST_F(SectionManagerTest, ExtractSectionsNested) { EXPECT_THAT( section_group.integer_sections[1].metadata, - EqualsSectionMetadata(/*expected_id=*/3, + EqualsSectionMetadata(/*expected_id=*/4, /*expected_property_path=*/"emails.timestamp", CreateTimestampPropertyConfig())); EXPECT_THAT(section_group.integer_sections[1].content, ElementsAre(kDefaultTimestamp, kDefaultTimestamp)); + + // Vector sections + EXPECT_THAT(section_group.vector_sections, SizeIs(1)); + EXPECT_THAT(section_group.vector_sections[0].metadata, + EqualsSectionMetadata( + /*expected_id=*/3, + /*expected_property_path=*/"emails.subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())); + EXPECT_THAT(section_group.vector_sections[0].content, + ElementsAre(EqualsProto(email_subject_embedding_), + EqualsProto(email_subject_embedding_))); } TEST_F(SectionManagerTest, ExtractSectionsIndexableNestedPropertiesList) { @@ -461,7 +507,8 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { // 0 -> recipientIds // 1 -> recipients // 2 -> subject - // 3 -> timestamp + // 3 -> subjectEmbedding + // 4 -> timestamp EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/0, /*section_id=*/0), IsOkAndHolds(Pointee(EqualsSectionMetadata( @@ -472,13 +519,30 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { IsOkAndHolds(Pointee(EqualsSectionMetadata( /*expected_id=*/1, /*expected_property_path=*/"recipients", CreateRecipientsPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/2), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/2, /*expected_property_path=*/"subject", + CreateSubjectPropertyConfig())))); + EXPECT_THAT( + schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/3), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/3, /*expected_property_path=*/"subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/0, /*section_id=*/4), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/4, /*expected_property_path=*/"timestamp", + CreateTimestampPropertyConfig())))); // Conversation (section id -> section property path): // 0 -> emails.recipientIds // 1 -> emails.recipients // 2 -> emails.subject - // 3 -> emails.timestamp - // 4 -> name + // 3 -> emails.subjectEmbedding + // 4 -> emails.timestamp + // 5 -> name EXPECT_THAT( schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/1, /*section_id=*/0), @@ -497,16 +561,22 @@ TEST_F(SectionManagerTest, GetSectionMetadata) { IsOkAndHolds(Pointee(EqualsSectionMetadata( /*expected_id=*/2, /*expected_property_path=*/"emails.subject", CreateSubjectPropertyConfig())))); + EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( + /*schema_type_id=*/1, /*section_id=*/3), + IsOkAndHolds(Pointee(EqualsSectionMetadata( + /*expected_id=*/3, + /*expected_property_path=*/"emails.subjectEmbedding", + CreateSubjectEmbeddingPropertyConfig())))); EXPECT_THAT( schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/3), + /*schema_type_id=*/1, /*section_id=*/4), IsOkAndHolds(Pointee(EqualsSectionMetadata( - /*expected_id=*/3, /*expected_property_path=*/"emails.timestamp", + /*expected_id=*/4, /*expected_property_path=*/"emails.timestamp", CreateTimestampPropertyConfig())))); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/4), + /*schema_type_id=*/1, /*section_id=*/5), IsOkAndHolds(Pointee(EqualsSectionMetadata( - /*expected_id=*/4, /*expected_property_path=*/"name", + /*expected_id=*/5, /*expected_property_path=*/"name", CreateNamePropertyConfig())))); // Group (section id -> section property path): @@ -615,25 +685,27 @@ TEST_F(SectionManagerTest, GetSectionMetadataInvalidSectionId) { // 0 -> recipientIds // 1 -> recipients // 2 -> subject - // 3 -> timestamp + // 3 -> subjectEmbedding + // 4 -> timestamp EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/0, /*section_id=*/-1), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/0, /*section_id=*/4), + /*schema_type_id=*/0, /*section_id=*/5), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // Conversation (section id -> section property path): // 0 -> emails.recipientIds // 1 -> emails.recipients // 2 -> emails.subject - // 3 -> emails.timestamp - // 4 -> name + // 3 -> emails.subjectEmbedding + // 4 -> emails.timestamp + // 5 -> name EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( /*schema_type_id=*/1, /*section_id=*/-1), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(schema_type_manager->section_manager().GetSectionMetadata( - /*schema_type_id=*/1, /*section_id=*/5), + /*schema_type_id=*/1, /*section_id=*/6), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } diff --git a/icing/schema/section.h b/icing/schema/section.h index 3685a29..76e9e57 100644 --- a/icing/schema/section.h +++ b/icing/schema/section.h @@ -21,6 +21,7 @@ #include <utility> #include <vector> +#include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/term.pb.h" @@ -89,18 +90,31 @@ struct SectionMetadata { // Contents will be matched by a range query. IntegerIndexingConfig::NumericMatchType::Code numeric_match_type; + // How vectors in a vector section should be indexed. + // + // EmbeddingIndexingType::UNKNOWN: + // Contents will not be indexed. It is invalid for a vector section + // (data_type == 'VECTOR') to have embedding_indexing_type == 'UNKNOWN'. + // + // EmbeddingIndexingType::LINEAR_SEARCH: + // Contents will be indexed for linear search. + EmbeddingIndexingConfig::EmbeddingIndexingType::Code embedding_indexing_type; + explicit SectionMetadata( SectionId id_in, PropertyConfigProto::DataType::Code data_type_in, StringIndexingConfig::TokenizerType::Code tokenizer, TermMatchType::Code term_match_type_in, IntegerIndexingConfig::NumericMatchType::Code numeric_match_type_in, + EmbeddingIndexingConfig::EmbeddingIndexingType::Code + embedding_indexing_type_in, std::string&& path_in) : path(std::move(path_in)), id(id_in), data_type(data_type_in), tokenizer(tokenizer), term_match_type(term_match_type_in), - numeric_match_type(numeric_match_type_in) {} + numeric_match_type(numeric_match_type_in), + embedding_indexing_type(embedding_indexing_type_in) {} SectionMetadata(const SectionMetadata& other) = default; SectionMetadata& operator=(const SectionMetadata& other) = default; @@ -144,6 +158,7 @@ struct Section { struct SectionGroup { std::vector<Section<std::string_view>> string_sections; std::vector<Section<int64_t>> integer_sections; + std::vector<Section<PropertyProto::VectorProto>> vector_sections; }; } // namespace lib diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc index 83c1519..e375a8e 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer.cc @@ -14,14 +14,25 @@ #include "icing/scoring/advanced_scoring/advanced-scorer.h" +#include <cstdint> #include <memory> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/advanced_scoring/scoring-visitor.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/section-weights.h" +#include "icing/store/document-store.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -29,11 +40,15 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher) { + const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); Lexer lexer(scoring_spec.advanced_scoring_expression(), Lexer::Language::SCORING); @@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, std::unique_ptr<Bm25fCalculator> bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store, section_weights.get(), current_time_ms); - ScoringVisitor visitor(default_score, document_store, schema_store, - section_weights.get(), bm25f_calculator.get(), - join_children_fetcher, current_time_ms); + ScoringVisitor visitor(default_score, default_semantic_metric_type, + document_store, schema_store, section_weights.get(), + bm25f_calculator.get(), join_children_fetcher, + embedding_query_results, current_time_ms); tree_root->Accept(&visitor); ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression, diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h index d69abad..00477a3 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.h +++ b/icing/scoring/advanced_scoring/advanced-scorer.h @@ -15,17 +15,26 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ #define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ +#include <cstdint> #include <memory> +#include <string> +#include <unordered_map> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-store.h" +#include "icing/util/logging.h" namespace icing { namespace lib { @@ -38,9 +47,11 @@ class AdvancedScorer : public Scorer { // INVALID_ARGUMENT if fails to create an instance static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); double GetScore(const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override { @@ -63,7 +74,7 @@ class AdvancedScorer : public Scorer { bm25f_calculator_->PrepareToScore(query_term_iterators); } - bool is_constant() const { return score_expression_->is_constant_double(); } + bool is_constant() const { return score_expression_->is_constant(); } private: explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, diff --git a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc index 3612359..ca081cd 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_fuzz_test.cc @@ -12,11 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <cstddef> #include <cstdint> #include <memory> +#include <string> #include <string_view> +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/advanced-scorer.h" +#include "icing/store/document-store.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" @@ -29,6 +36,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { const std::string test_dir = GetTestTempDir() + "/icing"; const std::string doc_store_dir = test_dir + "/doc_store"; const std::string schema_store_dir = test_dir + "/schema_store"; + EmbeddingQueryResults empty_embedding_query_results_; filesystem.DeleteDirectoryRecursively(test_dir.c_str()); filesystem.CreateDirectoryRecursively(doc_store_dir.c_str()); filesystem.CreateDirectoryRecursively(schema_store_dir.c_str()); @@ -54,9 +62,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { scoring_spec.set_advanced_scoring_expression(text); AdvancedScorer::Create(scoring_spec, - /*default_score=*/10, document_store.get(), - schema_store.get(), - fake_clock.GetSystemTimeMilliseconds()); + /*default_score=*/10, + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + fake_clock.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_); // Not able to test the GetScore method of AdvancedScorer, since it will only // be available after AdvancedScorer is successfully created. However, the diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc index cc1d413..4b9a46c 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc @@ -15,14 +15,22 @@ #include "icing/scoring/advanced_scoring/advanced-scorer.h" #include <cmath> +#include <cstdint> #include <memory> #include <string> #include <string_view> +#include <unordered_map> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/join/join-children-fetcher.h" #include "icing/proto/document.pb.h" @@ -31,6 +39,8 @@ #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer.h" #include "icing/store/document-id.h" @@ -45,6 +55,7 @@ namespace lib { namespace { using ::testing::DoubleNear; using ::testing::Eq; +using ::testing::HasSubstr; class AdvancedScorerTest : public testing::Test { protected: @@ -124,15 +135,19 @@ class AdvancedScorerTest : public testing::Test { const std::string test_dir_; const std::string doc_store_dir_; const std::string schema_store_dir_; + EmbeddingQueryResults empty_embedding_query_results_; Filesystem filesystem_; std::unique_ptr<SchemaStore> schema_store_; std::unique_ptr<DocumentStore> document_store_; FakeClock fake_clock_; }; -constexpr double kEps = 0.0000000001; +constexpr double kEps = 0.0000001; constexpr int kDefaultScore = 0; constexpr int64_t kDefaultCreationTimestampMs = 1571100001111; +constexpr SearchSpecProto::EmbeddingQueryMetricType::Code + kDefaultSemanticMetricType = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; DocumentProto CreateDocument( const std::string& name_space, const std::string& uri, @@ -193,8 +208,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) { scoring_spec.set_rank_by( ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10, + kDefaultSemanticMetricType, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // Non-empty scoring expression for normal scoring @@ -202,8 +220,11 @@ TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); scoring_spec.set_advanced_scoring_expression("1"); EXPECT_THAT(scorer_factory::Create(scoring_spec, /*default_score=*/10, + kDefaultSemanticMetricType, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -215,9 +236,11 @@ TEST_F(AdvancedScorerTest, SimpleExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("123"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -233,44 +256,61 @@ TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5)); } @@ -283,103 +323,131 @@ TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("log(2.718281828459045)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("len(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sum(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10 + 11 + 12 + 13 + 14)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("avg(10, 11, 12, 13, 14)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq((10 + 11 + 12 + 13 + 14) / 5.)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sin(3.141592653589793)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("cos(3.141592653589793)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); } @@ -394,17 +462,21 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("this.documentScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.creationTimestamp()"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(kDefaultCreationTimestampMs)); ICING_ASSERT_OK_AND_ASSIGN( @@ -412,8 +484,10 @@ TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) { AdvancedScorer::Create( CreateAdvancedScoringSpec( "this.documentScore() + this.creationTimestamp()"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123 + kDefaultCreationTimestampMs)); } @@ -429,8 +503,10 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) { AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageCount(1) + this.usageCount(2) " "+ this.usageLastUsedTimestamp(3)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); ICING_ASSERT_OK(document_store_->ReportUsage( CreateUsageReport("namespace", "uri", 100000, UsageReport::USAGE_TYPE1))); @@ -446,22 +522,28 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(1)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(100000)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(2)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(200000)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(3)"), - /*default_score=*/10, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300000)); } @@ -478,24 +560,29 @@ TEST_F(AdvancedScorerTest, DocumentUsageFunctionOutOfRange) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("this.usageCount(4)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(4)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount(0)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(0)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount(1.5)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(1.5)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); } @@ -516,9 +603,11 @@ TEST_F(AdvancedScorerTest, RelevanceScoreFunctionScoreExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("this.relevanceScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); scorer->PrepareToScore(/*query_term_iterators=*/{}); // Should get the default score. @@ -566,8 +655,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(2)); // document_id_2 has one child. @@ -579,8 +669,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("sum(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3)); // document_id_2 has one child with score 4. @@ -592,8 +683,9 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { scorer, AdvancedScorer::Create( CreateAdvancedScoringSpec("avg(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.)); // document_id_2 has one child with score 4. @@ -604,13 +696,15 @@ TEST_F(AdvancedScorerTest, ChildrenScoresFunctionScoreExpression) { Eq(default_score)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create( - CreateAdvancedScoringSpec( - // Equivalent to "avg(this.childrenRankingSignals())" - "sum(this.childrenRankingSignals()) / " - "len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fetcher)); + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + // Equivalent to "avg(this.childrenRankingSignals())" + "sum(this.childrenRankingSignals()) / " + "len(this.childrenRankingSignals())"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fetcher, &empty_embedding_query_results_)); // document_id_1 has two children with scores 1 and 2. EXPECT_THAT(scorer->GetScore(docHitInfo1, /*query_it=*/nullptr), Eq(3 / 2.)); // document_id_2 has one child with score 4. @@ -670,9 +764,11 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // min([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // min([0.5, 0.8]) = 0.5 @@ -682,10 +778,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // max([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // max([0.5, 0.8]) = 0.8 @@ -695,10 +794,13 @@ TEST_F(AdvancedScorerTest, PropertyWeightsFunctionScoreExpression) { spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // sum([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // sum([0.5, 0.8]) = 1.3 @@ -747,9 +849,11 @@ TEST_F(AdvancedScorerTest, ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // min([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // min([0.5, 1, 0.5]) = 0.5 @@ -757,10 +861,13 @@ TEST_F(AdvancedScorerTest, spec_proto.set_advanced_scoring_expression("max(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // max([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // max([0.5, 1, 0.5]) = 1 @@ -768,10 +875,13 @@ TEST_F(AdvancedScorerTest, spec_proto.set_advanced_scoring_expression("sum(this.propertyWeights())"); ICING_ASSERT_OK_AND_ASSIGN( - scorer, AdvancedScorer::Create(spec_proto, - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, + AdvancedScorer::Create(spec_proto, + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); // sum([1]) = 1 EXPECT_THAT(scorer->GetScore(doc_hit_info_1, /*query_it=*/nullptr), Eq(1)); // sum([0.5, 1, 0.5]) = 2 @@ -786,20 +896,22 @@ TEST_F(AdvancedScorerTest, InvalidChildrenScoresFunctionScoreExpression) { EXPECT_THAT( AdvancedScorer::Create( CreateAdvancedScoringSpec("len(this.childrenRankingSignals())"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), - /*join_children_fetcher=*/nullptr), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // The root expression can only be of double type, but here it is of list // type. JoinChildrenFetcher fake_fetcher(JoinSpecProto::default_instance(), /*map_joinable_qualified_id=*/{}); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.childrenRankingSignals()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds(), &fake_fetcher), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.childrenRankingSignals()"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + &fake_fetcher, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, ComplexExpression) { @@ -821,9 +933,11 @@ TEST_F(AdvancedScorerTest, ComplexExpression) { "+ 10 * (2 + 10 + this.creationTimestamp()))" // This should evaluate to default score. "+ this.relevanceScore()"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_FALSE(scorer->is_constant()); scorer->PrepareToScore(/*query_term_iterators=*/{}); @@ -847,19 +961,24 @@ TEST_F(AdvancedScorerTest, ConstantExpression) { "pow(sin(2), 2)" "+ log(2, 122) / 12.34" "* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"), - /*default_score=*/10, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_TRUE(scorer->is_constant()); } // Should be a parsing Error TEST_F(AdvancedScorerTest, EmptyExpression) { - EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec(""), - /*default_score=*/10, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec(""), + /*default_score=*/10, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { @@ -872,30 +991,38 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + AdvancedScorer::Create( + CreateAdvancedScoringSpec("log(0)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, - AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); ICING_ASSERT_OK_AND_ASSIGN( scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds())); + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_)); EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); } @@ -904,133 +1031,402 @@ TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { TEST_F(AdvancedScorerTest, MathTypeError) { const double default_score = 0; - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT( - AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(AdvancedScorerTest, DocumentFunctionTypeError) { const double default_score = 0; - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("documentScore(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.creationTimestamp(1)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.creationTimestamp(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount()"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("relevanceScore(1)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(this)"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("that.documentScore()"), default_score, + kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.this.creationTimestamp()"), + default_score, kDefaultSemanticMetricType, document_store_.get(), + schema_store_.get(), fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"), + default_score, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results_), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.usageCount()"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), +} + +TEST_F(AdvancedScorerTest, + MatchedSemanticScoresFunctionScoreExpressionTypeError) { + libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> scorer_or = + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(matchedSemanticScores(getSearchSpecEmbedding(0)))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("not called with \"this\"")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("relevanceScore(1)"), default_score, - document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("got invalid argument type for embedding vector")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0), 0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("documentScore(this)"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("Embedding metric can only be given as a string")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSINE\", 0))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("that.documentScore()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("got invalid number of arguments")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSIGN\"))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create( - CreateAdvancedScoringSpec("this.this.creationTimestamp()"), - default_score, document_store_.get(), schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("Unknown metric type: COSIGN")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(\"0\")))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"), - default_score, document_store_.get(), - schema_store_.get(), - fake_clock_.GetSystemTimeMilliseconds()), + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("getSearchSpecEmbedding got invalid argument type")); + + scorer_or = AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding()))"), + kDefaultSemanticMetricType, kDefaultSemanticMetricType, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results_); + EXPECT_THAT(scorer_or, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(scorer_or.status().error_message(), + HasSubstr("getSearchSpecEmbedding must have 1 argument")); +} + +void AddEntryToEmbeddingQueryScoreMap( + EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map, + double semantic_score, DocumentId document_id) { + score_map[document_id].push_back(semantic_score); +} + +TEST_F(AdvancedScorerTest, MatchedSemanticScoresFunctionScoreExpression) { + DocumentId document_id_0 = 0; + DocumentId document_id_1 = 1; + DocHitInfo doc_hit_info_0(document_id_0); + DocHitInfo doc_hit_info_1(document_id_1); + EmbeddingQueryResults embedding_query_results; + + // Let the first query assign the following semantic scores: + // COSINE: + // Document 0: 0.1, 0.2 + // Document 1: 0.3, 0.4 + // DOT_PRODUCT: + // Document 0: 0.5 + // Document 1: 0.6 + // EUCLIDEAN: + // Document 0: 0.7 + // Document 1: 0.8 + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map = + &embedding_query_results + .result_scores[0][SearchSpecProto::EmbeddingQueryMetricType::COSINE]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.1, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.2, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.3, document_id_1); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.4, document_id_1); + score_map = &embedding_query_results.result_scores + [0][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.5, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.6, document_id_1); + score_map = + &embedding_query_results + .result_scores[0] + [SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.7, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.8, document_id_1); + + // Let the second query only assign DOT_PRODUCT scores: + // DOT_PRODUCT: + // Document 0: 0.1 + // Document 1: 0.2 + score_map = &embedding_query_results.result_scores + [1][SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT]; + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.1, document_id_0); + AddEntryToEmbeddingQueryScoreMap(*score_map, + /*semantic_score=*/0.2, document_id_1); + + // Get semantic scores for default metric (DOT_PRODUCT) for the first query. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.5, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.6, kEps)); + + // Get semantic scores for a metric overriding the default one for the first + // query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(0), \"COSINE\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1 + 0.2, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.3 + 0.4, kEps)); + + // Get semantic scores for multiple metrics for the first query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"COSINE\")) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"DOT_PRODUCT\")) + " + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)" + ", \"EUCLIDEAN\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), + DoubleNear(0.1 + 0.2 + 0.5 + 0.7, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), + DoubleNear(0.3 + 0.4 + 0.6 + 0.8, kEps)); + + // Get semantic scores for the second query. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0.1, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0.2, kEps)); + + // The second query does not contain cosine scores. + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("sum(this.matchedSemanticScores(" + "getSearchSpecEmbedding(1), \"COSINE\"))"), + kDefaultScore, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store_.get(), schema_store_.get(), + fake_clock_.GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &embedding_query_results)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_0), DoubleNear(0, kEps)); + EXPECT_THAT(scorer->GetScore(doc_hit_info_1), DoubleNear(0, kEps)); } } // namespace diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc index e8a2a89..687180a 100644 --- a/icing/scoring/advanced_scoring/score-expression.cc +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -14,10 +14,39 @@ #include "icing/scoring/advanced_scoring/score-expression.h" +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <cstdlib> +#include <memory> #include <numeric> +#include <optional> +#include <string> +#include <string_view> +#include <unordered_map> +#include <unordered_set> +#include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/schema/section.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/util/embedding-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -49,7 +78,7 @@ OperatorScoreExpression::Create( return absl_ports::InvalidArgumentError( "Operators are only supported for double type."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { children_all_constant_double = false; } } @@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create( "Got an invalid type for the math function. Should expect a double " "type argument."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { args_all_constant_double = false; } } @@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId( return filter_data_optional.value().schema_type_id(); } +libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> +GetSearchSpecEmbeddingFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args) { + if (args.size() != 1) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " must have 1 argument.")); + } + if (args[0]->type() != ScoreExpressionType::kDouble) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid argument type.")); + } + bool is_constant = args[0]->is_constant(); + std::unique_ptr<ScoreExpression> expression = + std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>( + new GetSearchSpecEmbeddingFunctionScoreExpression( + std::move(args[0]))); + if (is_constant) { + return ConstantScoreExpression::Create( + expression->eval(DocHitInfo(), /*query_it=*/nullptr), + expression->type()); + } + return expression; +} + +libtextclassifier3::StatusOr<double> +GetSearchSpecEmbeddingFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + arg_->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + if (query_index != raw_query_index) { + return absl_ports::InvalidArgumentError( + "The index of an embedding query must be an integer."); + } + return query_index; +} + +libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> +MatchedSemanticScoresFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results) { + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); + ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args)); + + if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " is not called with \"this\"")); + } + if (args.size() != 2 && args.size() != 3) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid number of arguments.")); + } + if (args[1]->type() != ScoreExpressionType::kVectorIndex) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + kFunctionName, " got invalid argument type for embedding vector.")); + } + if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a string."); + } + + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + default_metric_type; + if (args.size() == 3) { + if (!args[2]->is_constant()) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a constant string."); + } + ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string()); + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>( + new MatchedSemanticScoresFunctionScoreExpression( + std::move(args), metric_type, *embedding_query_results)); +} + +libtextclassifier3::StatusOr<std::vector<double>> +MatchedSemanticScoresFunctionScoreExpression::eval_list( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + args_[1]->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + const std::vector<double>* scores = + embedding_query_results_.GetMatchedScoresForDocument( + query_index, metric_type_, hit_info.document_id()); + if (scores == nullptr) { + return std::vector<double>(); + } + return *scores; +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h index 08d7997..e28fcd7 100644 --- a/icing/scoring/advanced_scoring/score-expression.h +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -15,20 +15,26 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ -#include <algorithm> -#include <cmath> +#include <cstdint> #include <memory> +#include <string> +#include <string_view> #include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" -#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -36,7 +42,11 @@ namespace lib { enum class ScoreExpressionType { kDouble, kDoubleList, - kDocument // Only "this" is considered as document type. + kDocument, // Only "this" is considered as document type. + // TODO(b/326656531): Instead of creating a vector index type, consider + // changing it to vector type so that the data is the vector directly. + kVectorIndex, + kString, }; class ScoreExpression { @@ -75,12 +85,24 @@ class ScoreExpression { "checking."); } + virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const { + if (type() == ScoreExpressionType::kString) { + return absl_ports::UnimplementedError( + "All ScoreExpressions of type string must provide their own " + "implementation of eval_string!"); + } + return absl_ports::InternalError( + "Runtime type error: the expression should never be evaluated to a " + "string. There must be inconsistencies in the static type checking."); + } + // Indicate the type to which the current expression will be evaluated. virtual ScoreExpressionType type() const = 0; - // Indicate whether the current expression is a constant double. - // Returns true if and only if the object is of ConstantScoreExpression type. - virtual bool is_constant_double() const { return false; } + // Indicate whether the current expression is a constant. + // Returns true if and only if the object is of ConstantScoreExpression or + // StringExpression type. + virtual bool is_constant() const { return false; } }; class ThisExpression : public ScoreExpression { @@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression { class ConstantScoreExpression : public ScoreExpression { public: static std::unique_ptr<ConstantScoreExpression> Create( - libtextclassifier3::StatusOr<double> c) { + libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type = ScoreExpressionType::kDouble) { return std::unique_ptr<ConstantScoreExpression>( - new ConstantScoreExpression(c)); + new ConstantScoreExpression(c, type)); } libtextclassifier3::StatusOr<double> eval( @@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression { return c_; } - ScoreExpressionType type() const override { - return ScoreExpressionType::kDouble; - } + ScoreExpressionType type() const override { return type_; } - bool is_constant_double() const override { return true; } + bool is_constant() const override { return true; } private: - explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c) - : c_(c) {} + explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type) + : c_(c), type_(type) {} libtextclassifier3::StatusOr<double> c_; + ScoreExpressionType type_; +}; + +class StringExpression : public ScoreExpression { + public: + static std::unique_ptr<StringExpression> Create(std::string str) { + return std::unique_ptr<StringExpression>( + new StringExpression(std::move(str))); + } + + libtextclassifier3::StatusOr<std::string_view> eval_string() const override { + return str_; + } + + ScoreExpressionType type() const override { + return ScoreExpressionType::kString; + } + + bool is_constant() const override { return true; } + + private: + explicit StringExpression(std::string str) : str_(std::move(str)) {} + std::string str_; }; class OperatorScoreExpression : public ScoreExpression { @@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression { int64_t current_time_ms_; }; +class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding"; + + // RETURNS: + // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if + // not simplifiable. + // - A ConstantScoreExpression instance on success if simplifiable. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( + std::vector<std::unique_ptr<ScoreExpression>> args); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kVectorIndex; + } + + private: + explicit GetSearchSpecEmbeddingFunctionScoreExpression( + std::unique_ptr<ScoreExpression> arg) + : arg_(std::move(arg)) {} + std::unique_ptr<ScoreExpression> arg_; +}; + +class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "matchedSemanticScores"; + + // RETURNS: + // - A MatchedSemanticScoresFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> + Create(std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results); + + libtextclassifier3::StatusOr<std::vector<double>> eval_list( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kDoubleList; + } + + private: + explicit MatchedSemanticScoresFunctionScoreExpression( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + const EmbeddingQueryResults& embedding_query_results) + : args_(std::move(args)), + metric_type_(metric_type), + embedding_query_results_(embedding_query_results) {} + + std::vector<std::unique_ptr<ScoreExpression>> args_; + const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; + const EmbeddingQueryResults& embedding_query_results_; +}; + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc index 588090d..cd58366 100644 --- a/icing/scoring/advanced_scoring/score-expression_test.cc +++ b/icing/scoring/advanced_scoring/score-expression_test.cc @@ -14,15 +14,16 @@ #include "icing/scoring/advanced_scoring/score-expression.h" -#include <cmath> #include <memory> -#include <string> -#include <string_view> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/testing/common-matchers.h" namespace icing { @@ -47,7 +48,7 @@ class NonConstantScoreExpression : public ScoreExpression { return ScoreExpressionType::kDouble; } - bool is_constant_double() const override { return false; } + bool is_constant() const override { return false; } }; class ListScoreExpression : public ScoreExpression { @@ -87,7 +88,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { OperatorScoreExpression::OperatorType::kPlus, MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(1)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); // 1 - 2 - 3 = -4 @@ -97,7 +98,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(3)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4))); // 1 * 2 * 3 * 4 = 24 @@ -108,7 +109,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(3), ConstantScoreExpression::Create(4)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24))); // 1 / 2 / 4 = 0.125 @@ -118,7 +119,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { MakeChildren(ConstantScoreExpression::Create(1), ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(4)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125))); // -(2) = -2 @@ -126,7 +127,7 @@ TEST(ScoreExpressionTest, OperatorSimplification) { expression, OperatorScoreExpression::Create( OperatorScoreExpression::OperatorType::kNegative, MakeChildren(ConstantScoreExpression::Create(2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2))); } @@ -138,7 +139,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { MathFunctionScoreExpression::FunctionType::kPow, MakeChildren(ConstantScoreExpression::Create(2), ConstantScoreExpression::Create(2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4))); // abs(-2) = 2 @@ -146,7 +147,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kAbs, MakeChildren(ConstantScoreExpression::Create(-2)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); // log(e) = 1 @@ -154,7 +155,7 @@ TEST(ScoreExpressionTest, MathFunctionSimplification) { expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kLog, MakeChildren(ConstantScoreExpression::Create(M_E)))); - ASSERT_TRUE(expression->is_constant_double()); + ASSERT_TRUE(expression->is_constant()); EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1))); } @@ -166,7 +167,7 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { OperatorScoreExpression::OperatorType::kPlus, MakeChildren(ConstantScoreExpression::Create(1), NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // non_constant * non_constant = non_constant ICING_ASSERT_OK_AND_ASSIGN( @@ -174,14 +175,14 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { OperatorScoreExpression::OperatorType::kTimes, MakeChildren(NonConstantScoreExpression::Create(), NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // -(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, OperatorScoreExpression::Create( OperatorScoreExpression::OperatorType::kNegative, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // pow(non_constant, 2) = non_constant ICING_ASSERT_OK_AND_ASSIGN( @@ -189,21 +190,21 @@ TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { MathFunctionScoreExpression::FunctionType::kPow, MakeChildren(NonConstantScoreExpression::Create(), ConstantScoreExpression::Create(2)))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // abs(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kAbs, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); // log(non_constant) = non_constant ICING_ASSERT_OK_AND_ASSIGN( expression, MathFunctionScoreExpression::Create( MathFunctionScoreExpression::FunctionType::kLog, MakeChildren(NonConstantScoreExpression::Create()))); - ASSERT_FALSE(expression->is_constant_double()); + ASSERT_FALSE(expression->is_constant()); } TEST(ScoreExpressionTest, MathFunctionsWithListTypeArgument) { diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc index e2b24a2..05240c0 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.cc +++ b/icing/scoring/advanced_scoring/scoring-visitor.cc @@ -14,7 +14,17 @@ #include "icing/scoring/advanced_scoring/scoring-visitor.h" +#include <cstdlib> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/scoring/advanced_scoring/score-expression.h" namespace icing { namespace lib { @@ -25,8 +35,7 @@ void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) { } void ScoringVisitor::VisitString(const StringNode* node) { - pending_error_ = - absl_ports::InvalidArgumentError("Scoring does not support String!"); + stack_.push_back(StringExpression::Create(node->value())); } void ScoringVisitor::VisitText(const TextNode* node) { @@ -120,6 +129,17 @@ void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node, expression = MathFunctionScoreExpression::Create( MathFunctionScoreExpression::kFunctionNames.at(function_name), std::move(args)); + } else if (function_name == + GetSearchSpecEmbeddingFunctionScoreExpression::kFunctionName) { + // getSearchSpecEmbedding function + expression = + GetSearchSpecEmbeddingFunctionScoreExpression::Create(std::move(args)); + } else if (function_name == + MatchedSemanticScoresFunctionScoreExpression::kFunctionName) { + // matchedSemanticScores function + expression = MatchedSemanticScoresFunctionScoreExpression::Create( + std::move(args), default_semantic_metric_type_, + &embedding_query_results_); } if (!expression.ok()) { diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h index cfee25b..bb5e6ba 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.h +++ b/icing/scoring/advanced_scoring/scoring-visitor.h @@ -15,14 +15,22 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ #define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/join/join-children-fetcher.h" #include "icing/legacy/core/icing-string-util.h" -#include "icing/proto/scoring.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-store.h" namespace icing { @@ -31,18 +39,23 @@ namespace lib { class ScoringVisitor : public AbstractSyntaxTreeVisitor { public: explicit ScoringVisitor(double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, SectionWeights* section_weights, Bm25fCalculator* bm25f_calculator, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results, int64_t current_time_ms) : default_score_(default_score), + default_semantic_metric_type_(default_semantic_metric_type), document_store_(*document_store), schema_store_(*schema_store), section_weights_(*section_weights), bm25f_calculator_(*bm25f_calculator), join_children_fetcher_(join_children_fetcher), + embedding_query_results_(*embedding_query_results), current_time_ms_(current_time_ms) {} void VisitFunctionName(const FunctionNameNode* node) override; @@ -90,12 +103,15 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor { } double default_score_; + const SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type_; const DocumentStore& document_store_; const SchemaStore& schema_store_; SectionWeights& section_weights_; Bm25fCalculator& bm25f_calculator_; // A non-null join_children_fetcher_ indicates scoring in a join. const JoinChildrenFetcher* join_children_fetcher_; // Does not own. + const EmbeddingQueryResults& embedding_query_results_; libtextclassifier3::Status pending_error_; std::vector<std::unique_ptr<ScoreExpression>> stack_; diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index 7cb5a95..4da8d1f 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -12,18 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <algorithm> #include <cstdint> #include <limits> #include <memory> #include <random> #include <string> +#include <unordered_map> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "testing/base/public/benchmark.h" +#include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" @@ -31,6 +35,7 @@ #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" #include "icing/schema/schema-store.h" +#include "icing/schema/section.h" #include "icing/scoring/ranker.h" #include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scoring-processor.h" @@ -131,11 +136,15 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -237,11 +246,15 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by( ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -344,11 +357,15 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -446,11 +463,15 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { ScoringSpecProto scoring_spec; scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + EmbeddingQueryResults empty_embedding_query_results; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get(), - schema_store.get(), - clock.GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + scoring_spec, /*default_semantic_metric_type=*/ + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT, + document_store.get(), schema_store.get(), + clock.GetSystemTimeMilliseconds(), /*join_children_fetcher=*/nullptr, + &empty_embedding_query_results)); int num_to_score = state.range(0); int num_of_documents = state.range(1); diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc index e56f10c..1d66d7f 100644 --- a/icing/scoring/scorer-factory.cc +++ b/icing/scoring/scorer-factory.cc @@ -14,19 +14,26 @@ #include "icing/scoring/scorer-factory.h" +#include <cstdint> #include <memory> +#include <optional> +#include <string> #include <unordered_map> +#include <utility> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" #include "icing/proto/scoring.pb.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/advanced-scorer.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" #include "icing/scoring/section-weights.h" -#include "icing/store/document-id.h" +#include "icing/store/document-associated-score-data.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -173,10 +180,14 @@ namespace scorer_factory { libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) { + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); if (!scoring_spec.advanced_scoring_expression().empty() && scoring_spec.rank_by() != @@ -223,9 +234,10 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( return absl_ports::InvalidArgumentError( "Advanced scoring is enabled, but the expression is empty!"); } - return AdvancedScorer::Create(scoring_spec, default_score, document_store, - schema_store, current_time_ms, - join_children_fetcher); + return AdvancedScorer::Create( + scoring_spec, default_score, default_semantic_metric_type, + document_store, schema_store, current_time_ms, join_children_fetcher, + embedding_query_results); case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: // Use join aggregate score to rank. Since the aggregation score is // calculated by child documents after joining (in JoinProcessor), we can diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h index 659bebd..f5766b3 100644 --- a/icing/scoring/scorer-factory.h +++ b/icing/scoring/scorer-factory.h @@ -15,8 +15,13 @@ #ifndef ICING_SCORING_SCORER_FACTORY_H_ #define ICING_SCORING_SCORER_FACTORY_H_ +#include <cstdint> +#include <memory> + #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/join/join-children-fetcher.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/scorer.h" #include "icing/store/document-store.h" @@ -37,9 +42,11 @@ namespace scorer_factory { // INVALID_ARGUMENT if fails to create an instance libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); } // namespace scorer_factory diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 5194c7f..e22d5f4 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -14,13 +14,19 @@ #include "icing/scoring/scorer.h" +#include <cstdint> +#include <limits> #include <memory> #include <string> +#include <utility> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" @@ -30,7 +36,6 @@ #include "icing/schema/schema-store.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer-test-utils.h" -#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" @@ -107,6 +112,10 @@ class ScorerTest : public ::testing::TestWithParam<ScorerTestingMode> { fake_clock1_.SetSystemTimeMilliseconds(new_time); } + SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + EmbeddingQueryResults empty_embedding_query_results; + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -134,8 +143,10 @@ TEST_P(ScorerTest, CreationWithNullDocumentStoreShouldFail) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, /*document_store=*/nullptr, schema_store(), - fake_clock1().GetSystemTimeMilliseconds()), + /*default_score=*/0, default_semantic_metric_type, + /*document_store=*/nullptr, schema_store(), + fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -144,8 +155,9 @@ TEST_P(ScorerTest, CreationWithNullSchemaStoreShouldFail) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, document_store(), - /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + /*schema_store=*/nullptr, fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -155,8 +167,9 @@ TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -181,8 +194,9 @@ TEST_P(ScorerTest, ShouldGetDefaultDocumentScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -206,8 +220,9 @@ TEST_P(ScorerTest, ShouldGetCorrectDocumentScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -233,8 +248,9 @@ TEST_P(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()), - /*default_score=*/10, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/10, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -268,8 +284,9 @@ TEST_P(ScorerTest, ShouldGetCorrectCreationTimestampScore) { CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -297,22 +314,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -347,22 +367,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -397,22 +420,25 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), - /*default_score=*/0, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -444,31 +470,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -516,31 +545,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -588,31 +620,34 @@ TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -651,8 +686,9 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { scorer_factory::Create( CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::NONE, GetParam()), - /*default_score=*/3, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + /*default_score=*/3, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -662,11 +698,13 @@ TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, scorer_factory::Create( - CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE, GetParam()), - /*default_score=*/111, document_store(), schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer, + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE, GetParam()), + /*default_score=*/111, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -690,13 +728,14 @@ TEST_P(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP, - GetParam()), - /*default_score=*/0, document_store(), - schema_store(), - fake_clock1().GetSystemTimeMilliseconds())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()), + /*default_score=*/0, default_semantic_metric_type, document_store(), + schema_store(), fake_clock1().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index b827bd8..bb62db9 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -14,6 +14,7 @@ #include "icing/scoring/scoring-processor.h" +#include <cstdint> #include <limits> #include <memory> #include <string> @@ -23,10 +24,12 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" #include "icing/proto/scoring.pb.h" -#include "icing/scoring/ranker.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer.h" @@ -44,24 +47,28 @@ constexpr double kDefaultScoreInAscendingOrder = libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher) { + const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); bool is_descending_order = scoring_spec.order_by() == ScoringSpecProto::Order::DESC; ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - scorer_factory::Create(scoring_spec, - is_descending_order - ? kDefaultScoreInDescendingOrder - : kDefaultScoreInAscendingOrder, - document_store, schema_store, current_time_ms, - join_children_fetcher)); + scorer_factory::Create( + scoring_spec, + is_descending_order ? kDefaultScoreInDescendingOrder + : kDefaultScoreInAscendingOrder, + default_semantic_metric_type, document_store, schema_store, + current_time_ms, join_children_fetcher, embedding_query_results)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 8634a22..de0b95c 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -23,6 +23,7 @@ #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/proto/logging.pb.h" @@ -46,9 +47,12 @@ class ScoringProcessor { // A ScoringProcessor on success // FAILED_PRECONDITION on any null pointer input static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create( - const ScoringSpecProto& scoring_spec, const DocumentStore* document_store, - const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher = nullptr); + const ScoringSpecProto& scoring_spec, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, + const DocumentStore* document_store, const SchemaStore* schema_store, + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results); // Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns // a vector of ScoredDocumentHits. The size of results is no more than diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index deddff8..e3b70a6 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -15,22 +15,39 @@ #include "icing/scoring/scoring-processor.h" #include <cstdint> +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/file/portable-file-backed-proto-log.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/term.pb.h" #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scorer-test-utils.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -111,6 +128,10 @@ class ScoringProcessorTest const FakeClock& fake_clock() const { return fake_clock_; } + SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type = + SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + EmbeddingQueryResults empty_embedding_query_results; + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -187,27 +208,31 @@ PropertyWeight CreatePropertyWeight(std::string path, double weight) { TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) { ScoringSpecProto spec_proto; - EXPECT_THAT(ScoringProcessor::Create( - spec_proto, /*document_store=*/nullptr, schema_store(), - fake_clock().GetSystemTimeMilliseconds()), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, /*document_store=*/nullptr, + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { ScoringSpecProto spec_proto; EXPECT_THAT( - ScoringProcessor::Create(spec_proto, document_store(), - /*schema_store=*/nullptr, - fake_clock().GetSystemTimeMilliseconds()), + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + /*schema_store=*/nullptr, fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_P(ScoringProcessorTest, ShouldCreateInstance) { ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); - ICING_EXPECT_OK( - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ICING_EXPECT_OK(ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); } TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { @@ -222,8 +247,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/5), @@ -249,8 +276,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/-1), @@ -280,8 +309,10 @@ TEST_P(ScoringProcessorTest, ShouldRespectNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/2), @@ -313,8 +344,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -369,8 +402,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -439,8 +474,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -513,8 +550,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -563,8 +602,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -629,8 +670,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -700,8 +743,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -762,8 +807,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor with no explicit weights set. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); ScoringSpecProto spec_proto_with_weights = CreateScoringSpecForRankingStrategy( @@ -778,9 +825,11 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor with default weight set for "body" property. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor_with_weights, - ScoringProcessor::Create(spec_proto_with_weights, document_store(), - schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto_with_weights, default_semantic_metric_type, + document_store(), schema_store(), + fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -866,8 +915,10 @@ TEST_P(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -929,8 +980,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -990,8 +1043,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageCount) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -1051,8 +1106,10 @@ TEST_P(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -1088,8 +1145,10 @@ TEST_P(ScoringProcessorTest, ShouldHandleNoScores) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/4), ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default), @@ -1138,8 +1197,10 @@ TEST_P(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store(), schema_store(), - fake_clock().GetSystemTimeMilliseconds())); + ScoringProcessor::Create( + spec_proto, default_semantic_metric_type, document_store(), + schema_store(), fake_clock().GetSystemTimeMilliseconds(), + /*join_children_fetcher=*/nullptr, &empty_embedding_query_results)); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), diff --git a/icing/testing/embedding-test-utils.h b/icing/testing/embedding-test-utils.h new file mode 100644 index 0000000..931953e --- /dev/null +++ b/icing/testing/embedding-test-utils.h @@ -0,0 +1,45 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_TESTING_EMBEDDING_TEST_UTILS_H_ +#define ICING_TESTING_EMBEDDING_TEST_UTILS_H_ + +#include <initializer_list> +#include <string> + +#include "icing/proto/document.pb.h" + +namespace icing { +namespace lib { + +inline PropertyProto::VectorProto CreateVector( + const std::string& model_signature, std::initializer_list<float> values) { + PropertyProto::VectorProto vector; + vector.set_model_signature(model_signature); + for (float value : values) { + vector.add_values(value); + } + return vector; +} + +template <typename... V> +inline PropertyProto::VectorProto CreateVector( + const std::string& model_signature, V&&... values) { + return CreateVector(model_signature, values...); +} + +} // namespace lib +} // namespace icing + +#endif // ICING_TESTING_EMBEDDING_TEST_UTILS_H_ diff --git a/icing/testing/hit-test-utils.cc b/icing/testing/hit-test-utils.cc index 7ad8a64..c235e23 100644 --- a/icing/testing/hit-test-utils.cc +++ b/icing/testing/hit-test-utils.cc @@ -14,29 +14,45 @@ #include "icing/testing/hit-test-utils.h" +#include <cstdint> +#include <vector> + +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/hit/hit.h" +#include "icing/index/main/posting-list-hit-serializer.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + namespace icing { namespace lib { // Returns a hit that has a delta of desired_byte_length from last_hit. -Hit CreateHit(Hit last_hit, int desired_byte_length) { - Hit hit = (last_hit.section_id() == kMinSectionId) - ? Hit(kMaxSectionId, last_hit.document_id() + 1, - last_hit.term_frequency()) - : Hit(last_hit.section_id() - 1, last_hit.document_id(), - last_hit.term_frequency()); +Hit CreateHit(const Hit& last_hit, int desired_byte_length) { + return CreateHit(last_hit, desired_byte_length, last_hit.term_frequency(), + /*is_in_prefix_section=*/false, /*is_prefix_hit=*/false); +} + +// Returns a hit that has a delta of desired_byte_length from last_hit, with +// the desired term_frequency and flags +Hit CreateHit(const Hit& last_hit, int desired_byte_length, + Hit::TermFrequency term_frequency, bool is_in_prefix_section, + bool is_prefix_hit) { + Hit hit = last_hit; uint8_t buf[5]; - while (VarInt::Encode(last_hit.value() - hit.value(), buf) < - desired_byte_length) { + do { hit = (hit.section_id() == kMinSectionId) - ? Hit(kMaxSectionId, hit.document_id() + 1, hit.term_frequency()) - : Hit(hit.section_id() - 1, hit.document_id(), - hit.term_frequency()); - } + ? Hit(kMaxSectionId, hit.document_id() + 1, term_frequency, + is_in_prefix_section, is_prefix_hit) + : Hit(hit.section_id() - 1, hit.document_id(), term_frequency, + is_in_prefix_section, is_prefix_hit); + } while (PostingListHitSerializer::EncodeNextHitValue( + /*next_hit_value=*/hit.value(), + /*curr_hit_value=*/last_hit.value(), buf) < desired_byte_length); return hit; } // Returns a vector of num_hits Hits with the first hit starting at start_docid -// and with 1-byte deltas. +// and with deltas of the desired byte length. std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits, int desired_byte_length) { std::vector<Hit> hits; @@ -44,16 +60,56 @@ std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits, return hits; } hits.push_back(Hit(/*section_id=*/1, /*document_id=*/start_docid, - Hit::kDefaultTermFrequency)); + Hit::kDefaultTermFrequency, /*is_in_prefix_section=*/false, + /*is_prefix_hit=*/false)); while (hits.size() < num_hits) { hits.push_back(CreateHit(hits.back(), desired_byte_length)); } return hits; } +// Returns a vector of num_hits Hits with the first hit being the desired byte +// length from last_hit, and with deltas of the same desired byte length. +std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits, + int desired_byte_length) { + std::vector<Hit> hits; + if (num_hits < 1) { + return hits; + } + hits.reserve(num_hits); + for (int i = 0; i < num_hits; ++i) { + hits.push_back( + CreateHit(hits.empty() ? last_hit : hits.back(), desired_byte_length)); + } + return hits; +} + std::vector<Hit> CreateHits(int num_hits, int desired_byte_length) { return CreateHits(/*start_docid=*/0, num_hits, desired_byte_length); } +EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit, + uint32_t desired_byte_length) { + // Create a delta that has (desired_byte_length - 1) * 7 + 1 bits, so that it + // can be encoded in desired_byte_length bytes. + uint64_t delta = UINT64_C(1) << ((desired_byte_length - 1) * 7); + return EmbeddingHit(last_hit.value() - delta); +} + +std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits, + int desired_byte_length) { + std::vector<EmbeddingHit> hits; + if (num_hits == 0) { + return hits; + } + hits.reserve(num_hits); + hits.push_back(EmbeddingHit(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0)); + for (int i = 1; i < num_hits; ++i) { + hits.push_back(CreateEmbeddingHit(hits.back(), desired_byte_length)); + } + return hits; +} + } // namespace lib } // namespace icing diff --git a/icing/testing/hit-test-utils.h b/icing/testing/hit-test-utils.h index e236ec0..e041c22 100644 --- a/icing/testing/hit-test-utils.h +++ b/icing/testing/hit-test-utils.h @@ -15,28 +15,51 @@ #ifndef ICING_TESTING_HIT_TEST_UTILS_H_ #define ICING_TESTING_HIT_TEST_UTILS_H_ +#include <cstdint> #include <vector> +#include "icing/index/embed/embedding-hit.h" #include "icing/index/hit/hit.h" -#include "icing/legacy/index/icing-bit-util.h" -#include "icing/schema/section.h" #include "icing/store/document-id.h" namespace icing { namespace lib { // Returns a hit that has a delta of desired_byte_length from last_hit. -Hit CreateHit(Hit last_hit, int desired_byte_length); +Hit CreateHit(const Hit& last_hit, int desired_byte_length); + +// Returns a hit that has a delta of desired_byte_length from last_hit, with +// the desired term_frequency and flags +Hit CreateHit(const Hit& last_hit, int desired_byte_length, + Hit::TermFrequency term_frequency, bool is_in_prefix_section, + bool is_prefix_hit); // Returns a vector of num_hits Hits with the first hit starting at start_docid // and with desired_byte_length deltas. std::vector<Hit> CreateHits(DocumentId start_docid, int num_hits, int desired_byte_length); +// Returns a vector of num_hits Hits with the first hit being the desired byte +// length from last_hit, and with deltas of the same desired byte length. +std::vector<Hit> CreateHits(const Hit& last_hit, int num_hits, + int desired_byte_length); + // Returns a vector of num_hits Hits with the first hit starting at 0 and each // with desired_byte_length deltas. std::vector<Hit> CreateHits(int num_hits, int desired_byte_length); +// Returns a hit that has a delta of desired_byte_length from last_hit after +// VarInt encoding. +// Requires that 0 < desired_byte_length <= VarInt::kMaxEncodedLen64. +EmbeddingHit CreateEmbeddingHit(const EmbeddingHit& last_hit, + uint32_t desired_byte_length); + +// Returns a vector of num_hits Hits with the first hit starting at document 0 +// and with a delta of desired_byte_length between each subsequent hit after +// VarInt encoding. +std::vector<EmbeddingHit> CreateEmbeddingHits(int num_hits, + int desired_byte_length); + } // namespace lib } // namespace icing diff --git a/icing/text_classifier/lib3/utils/base/logging.h b/icing/text_classifier/lib3/utils/base/logging.h index 92d775e..39e09eb 100644 --- a/icing/text_classifier/lib3/utils/base/logging.h +++ b/icing/text_classifier/lib3/utils/base/logging.h @@ -16,9 +16,9 @@ #define ICING_TEXT_CLASSIFIER_LIB3_UTILS_BASE_LOGGING_H_ #include <cassert> +#include <cstdint> #include <string> -#include "icing/text_classifier/lib3/utils/base/integral_types.h" #include "icing/text_classifier/lib3/utils/base/logging_levels.h" #include "icing/text_classifier/lib3/utils/base/port.h" @@ -45,7 +45,8 @@ inline LoggingStringStream& operator<<(LoggingStringStream& stream, template <typename T> inline LoggingStringStream& operator<<(LoggingStringStream& stream, T* const entry) { - stream.message.append(std::to_string(reinterpret_cast<const uint64>(entry))); + stream.message.append( + std::to_string(reinterpret_cast<const uint64_t>(entry))); return stream; } diff --git a/icing/text_classifier/lib3/utils/java/jni-helper.cc b/icing/text_classifier/lib3/utils/java/jni-helper.cc index 60a9dfb..cb0b899 100644 --- a/icing/text_classifier/lib3/utils/java/jni-helper.cc +++ b/icing/text_classifier/lib3/utils/java/jni-helper.cc @@ -14,6 +14,8 @@ #include "icing/text_classifier/lib3/utils/java/jni-helper.h" +#include <cstdint> + #include "icing/text_classifier/lib3/utils/base/status_macros.h" namespace libtextclassifier3 { @@ -121,8 +123,8 @@ StatusOr<bool> JniHelper::CallBooleanMethod(JNIEnv* env, jobject object, return result; } -StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object, - jmethodID method_id, ...) { +StatusOr<int32_t> JniHelper::CallIntMethod(JNIEnv* env, jobject object, + jmethodID method_id, ...) { va_list args; va_start(args, method_id); jint result = env->CallIntMethodV(object, method_id, args); @@ -132,8 +134,8 @@ StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object, return result; } -StatusOr<int64> JniHelper::CallLongMethod(JNIEnv* env, jobject object, - jmethodID method_id, ...) { +StatusOr<int64_t> JniHelper::CallLongMethod(JNIEnv* env, jobject object, + jmethodID method_id, ...) { va_list args; va_start(args, method_id); jlong result = env->CallLongMethodV(object, method_id, args); diff --git a/icing/text_classifier/lib3/utils/java/jni-helper.h b/icing/text_classifier/lib3/utils/java/jni-helper.h index 4e548ec..8b57c11 100644 --- a/icing/text_classifier/lib3/utils/java/jni-helper.h +++ b/icing/text_classifier/lib3/utils/java/jni-helper.h @@ -20,6 +20,7 @@ #include <jni.h> +#include <cstdint> #include <string> #include "icing/text_classifier/lib3/utils/base/status.h" @@ -140,10 +141,10 @@ class JniHelper { ...); static StatusOr<bool> CallBooleanMethod(JNIEnv* env, jobject object, jmethodID method_id, ...); - static StatusOr<int32> CallIntMethod(JNIEnv* env, jobject object, - jmethodID method_id, ...); - static StatusOr<int64> CallLongMethod(JNIEnv* env, jobject object, - jmethodID method_id, ...); + static StatusOr<int32_t> CallIntMethod(JNIEnv* env, jobject object, + jmethodID method_id, ...); + static StatusOr<int64_t> CallLongMethod(JNIEnv* env, jobject object, + jmethodID method_id, ...); static StatusOr<float> CallFloatMethod(JNIEnv* env, jobject object, jmethodID method_id, ...); static StatusOr<double> CallDoubleMethod(JNIEnv* env, jobject object, diff --git a/icing/util/document-validator.cc b/icing/util/document-validator.cc index e0880ea..bc75334 100644 --- a/icing/util/document-validator.cc +++ b/icing/util/document-validator.cc @@ -15,13 +15,23 @@ #include "icing/util/document-validator.h" #include <cstdint> +#include <string> +#include <string_view> +#include <unordered_map> #include <unordered_set> +#include <utility> #include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/legacy/core/icing-string-util.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" +#include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" +#include "icing/store/document-filter-data.h" +#include "icing/util/logging.h" #include "icing/util/status-macros.h" namespace icing { @@ -123,6 +133,18 @@ libtextclassifier3::Status DocumentValidator::Validate( } else if (property_config.data_type() == PropertyConfigProto::DataType::DOCUMENT) { value_size = property.document_values_size(); + } else if (property_config.data_type() == + PropertyConfigProto::DataType::VECTOR) { + value_size = property.vector_values_size(); + for (const PropertyProto::VectorProto& vector_value : + property.vector_values()) { + if (vector_value.values_size() == 0) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Property '%s' contains empty vectors for key: (%s, %s).", + property.name().c_str(), document.namespace_().c_str(), + document.uri().c_str())); + } + } } if (property_config.cardinality() == diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc index 9d10b36..2c366fd 100644 --- a/icing/util/document-validator_test.cc +++ b/icing/util/document-validator_test.cc @@ -15,7 +15,11 @@ #include "icing/util/document-validator.h" #include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" @@ -42,6 +46,7 @@ constexpr char kPropertySubject[] = "subject"; constexpr char kPropertyText[] = "text"; constexpr char kPropertyRecipients[] = "recipients"; constexpr char kPropertyNote[] = "note"; +constexpr char kPropertyNoteEmbedding[] = "noteEmbedding"; // type and property names of Conversation constexpr char kTypeConversation[] = "Conversation"; constexpr char kTypeConversationWithEmailNote[] = "ConversationWithEmailNote"; @@ -92,6 +97,10 @@ class DocumentValidatorTest : public ::testing::Test { .AddProperty(PropertyConfigBuilder() .SetName(kPropertyNote) .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kPropertyNoteEmbedding) + .SetDataType(TYPE_VECTOR) .SetCardinality(CARDINALITY_OPTIONAL))) .AddType( SchemaTypeConfigBuilder() @@ -146,6 +155,12 @@ class DocumentValidatorTest : public ::testing::Test { } DocumentBuilder SimpleEmailWithNoteBuilder() { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + vector.set_model_signature("my_model"); + return DocumentBuilder() .SetKey(kDefaultNamespace, "email_with_note/1") .SetSchema(kTypeEmailWithNote) @@ -153,7 +168,8 @@ class DocumentValidatorTest : public ::testing::Test { .AddStringProperty(kPropertyText, kDefaultString) .AddStringProperty(kPropertyRecipients, kDefaultString, kDefaultString, kDefaultString) - .AddStringProperty(kPropertyNote, kDefaultString); + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector); } DocumentBuilder SimpleConversationBuilder() { @@ -597,6 +613,64 @@ TEST_F(DocumentValidatorTest, NegativeDocumentTtlMsInvalid) { HasSubstr("is negative"))); } +TEST_F(DocumentValidatorTest, ValidateEmbeddingZeroDimensionInvalid) { + PropertyProto::VectorProto vector; + vector.set_model_signature("my_model"); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + EXPECT_THAT(document_validator_->Validate(email), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT, + HasSubstr("contains empty vectors"))); +} + +TEST_F(DocumentValidatorTest, ValidateEmbeddingEmptySignatureOk) { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + vector.set_model_signature(""); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + ICING_EXPECT_OK(document_validator_->Validate(email)); +} + +TEST_F(DocumentValidatorTest, ValidateEmbeddingNoSignatureOk) { + PropertyProto::VectorProto vector; + vector.add_values(0.1); + vector.add_values(0.2); + vector.add_values(0.3); + DocumentProto email = + DocumentBuilder() + .SetKey(kDefaultNamespace, "email_with_note/1") + .SetSchema(kTypeEmailWithNote) + .AddStringProperty(kPropertySubject, kDefaultString) + .AddStringProperty(kPropertyText, kDefaultString) + .AddStringProperty(kPropertyRecipients, kDefaultString, + kDefaultString, kDefaultString) + .AddStringProperty(kPropertyNote, kDefaultString) + .AddVectorProperty(kPropertyNoteEmbedding, vector) + .Build(); + ICING_EXPECT_OK(document_validator_->Validate(email)); +} + } // namespace } // namespace lib diff --git a/icing/util/embedding-util.h b/icing/util/embedding-util.h new file mode 100644 index 0000000..5026051 --- /dev/null +++ b/icing/util/embedding-util.h @@ -0,0 +1,49 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_UTIL_EMBEDDING_UTIL_H_ +#define ICING_UTIL_EMBEDDING_UTIL_H_ + +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +namespace embedding_util { + +inline libtextclassifier3::StatusOr< + SearchSpecProto::EmbeddingQueryMetricType::Code> +GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name) { + if (metric_name == "COSINE") { + return SearchSpecProto::EmbeddingQueryMetricType::COSINE; + } else if (metric_name == "DOT_PRODUCT") { + return SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; + } else if (metric_name == "EUCLIDEAN") { + return SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN; + } + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unknown metric type: ", metric_name)); +} + +} // namespace embedding_util + +} // namespace lib +} // namespace icing + +#endif // ICING_UTIL_EMBEDDING_UTIL_H_ diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc index 19aaddf..e10fe25 100644 --- a/icing/util/tokenized-document.cc +++ b/icing/util/tokenized-document.cc @@ -14,16 +14,18 @@ #include "icing/util/tokenized-document.h" -#include <string> +#include <memory> #include <string_view> +#include <utility> #include <vector> -#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/proto/document.pb.h" #include "icing/schema/joinable-property.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/tokenization/language-segmenter.h" +#include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer-factory.h" #include "icing/tokenization/tokenizer.h" #include "icing/util/document-validator.h" @@ -85,6 +87,7 @@ TokenizedDocument::Create(const SchemaStore* schema_store, return TokenizedDocument(std::move(document), std::move(tokenized_string_sections), std::move(section_group.integer_sections), + std::move(section_group.vector_sections), std::move(joinable_property_group)); } diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h index 7cc34e3..0337083 100644 --- a/icing/util/tokenized-document.h +++ b/icing/util/tokenized-document.h @@ -16,7 +16,8 @@ #define ICING_STORE_TOKENIZED_DOCUMENT_H_ #include <cstdint> -#include <string> +#include <string_view> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -63,6 +64,11 @@ class TokenizedDocument { return integer_sections_; } + const std::vector<Section<PropertyProto::VectorProto>>& vector_sections() + const { + return vector_sections_; + } + const std::vector<JoinableProperty<std::string_view>>& qualified_id_join_properties() const { return joinable_property_group_.qualified_id_properties; @@ -74,15 +80,18 @@ class TokenizedDocument { DocumentProto&& document, std::vector<TokenizedSection>&& tokenized_string_sections, std::vector<Section<int64_t>>&& integer_sections, + std::vector<Section<PropertyProto::VectorProto>>&& vector_sections, JoinablePropertyGroup&& joinable_property_group) : document_(std::move(document)), tokenized_string_sections_(std::move(tokenized_string_sections)), integer_sections_(std::move(integer_sections)), + vector_sections_(std::move(vector_sections)), joinable_property_group_(std::move(joinable_property_group)) {} DocumentProto document_; std::vector<TokenizedSection> tokenized_string_sections_; std::vector<Section<int64_t>> integer_sections_; + std::vector<Section<PropertyProto::VectorProto>> vector_sections_; JoinablePropertyGroup joinable_property_group_; }; diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc index 7c97776..ab7f4b9 100644 --- a/icing/util/tokenized-document_test.cc +++ b/icing/util/tokenized-document_test.cc @@ -16,6 +16,8 @@ #include <memory> #include <string> +#include <string_view> +#include <utility> #include <vector> #include "gmock/gmock.h" @@ -25,7 +27,6 @@ #include "icing/portable/platform.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" -#include "icing/proto/term.pb.h" #include "icing/schema-builder.h" #include "icing/schema/joinable-property.h" #include "icing/schema/schema-store.h" @@ -59,13 +60,19 @@ static constexpr std::string_view kIndexableIntegerProperty1 = "indexableInteger1"; static constexpr std::string_view kIndexableIntegerProperty2 = "indexableInteger2"; +static constexpr std::string_view kIndexableVectorProperty1 = + "indexableVector1"; +static constexpr std::string_view kIndexableVectorProperty2 = + "indexableVector2"; static constexpr std::string_view kStringExactProperty = "stringExact"; static constexpr std::string_view kStringPrefixProperty = "stringPrefix"; static constexpr SectionId kIndexableInteger1SectionId = 0; static constexpr SectionId kIndexableInteger2SectionId = 1; -static constexpr SectionId kStringExactSectionId = 2; -static constexpr SectionId kStringPrefixSectionId = 3; +static constexpr SectionId kIndexableVector1SectionId = 2; +static constexpr SectionId kIndexableVector2SectionId = 3; +static constexpr SectionId kStringExactSectionId = 4; +static constexpr SectionId kStringPrefixSectionId = 5; // Joinable properties and joinable property id. Joinable property id is // determined by the lexicographical order of joinable property path. @@ -77,19 +84,33 @@ static constexpr JoinablePropertyId kQualifiedId2JoinablePropertyId = 1; const SectionMetadata kIndexableInteger1SectionMetadata( kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1)); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, + std::string(kIndexableIntegerProperty1)); const SectionMetadata kIndexableInteger2SectionMetadata( kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, - NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2)); + NUMERIC_MATCH_RANGE, EMBEDDING_INDEXING_UNKNOWN, + std::string(kIndexableIntegerProperty2)); + +const SectionMetadata kIndexableVector1SectionMetadata( + kIndexableVector1SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH, + std::string(kIndexableVectorProperty1)); + +const SectionMetadata kIndexableVector2SectionMetadata( + kIndexableVector2SectionId, TYPE_VECTOR, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_LINEAR_SEARCH, + std::string(kIndexableVectorProperty2)); const SectionMetadata kStringExactSectionMetadata( kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, - NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty)); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, + std::string(kStringExactProperty)); const SectionMetadata kStringPrefixSectionMetadata( kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX, - NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty)); + NUMERIC_MATCH_UNKNOWN, EMBEDDING_INDEXING_UNKNOWN, + std::string(kStringPrefixProperty)); const JoinablePropertyMetadata kQualifiedId1JoinablePropertyMetadata( kQualifiedId1JoinablePropertyId, TYPE_STRING, @@ -102,6 +123,7 @@ const JoinablePropertyMetadata kQualifiedId2JoinablePropertyMetadata( // Other non-indexable/joinable properties. constexpr std::string_view kUnindexedStringProperty = "unindexedString"; constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger"; +constexpr std::string_view kUnindexedVectorProperty = "unindexedVector"; class TokenizedDocumentTest : public ::testing::Test { protected: @@ -140,6 +162,10 @@ class TokenizedDocumentTest : public ::testing::Test { .SetDataType(TYPE_INT64) .SetCardinality(CARDINALITY_OPTIONAL)) .AddProperty(PropertyConfigBuilder() + .SetName(kUnindexedVectorProperty) + .SetDataType(TYPE_VECTOR) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() .SetName(kIndexableIntegerProperty1) .SetDataTypeInt64(NUMERIC_MATCH_RANGE) .SetCardinality(CARDINALITY_REPEATED)) @@ -147,6 +173,16 @@ class TokenizedDocumentTest : public ::testing::Test { .SetName(kIndexableIntegerProperty2) .SetDataTypeInt64(NUMERIC_MATCH_RANGE) .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kIndexableVectorProperty1) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty( + PropertyConfigBuilder() + .SetName(kIndexableVectorProperty2) + .SetDataTypeVector(EMBEDDING_INDEXING_LINEAR_SEARCH) + .SetCardinality(CARDINALITY_OPTIONAL)) .AddProperty(PropertyConfigBuilder() .SetName(kStringExactProperty) .SetDataTypeString(TERM_MATCH_EXACT, @@ -196,6 +232,16 @@ class TokenizedDocumentTest : public ::testing::Test { }; TEST_F(TokenizedDocumentTest, CreateAll) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -208,6 +254,10 @@ TEST_F(TokenizedDocumentTest, CreateAll) { .AddInt64Property(std::string(kUnindexedIntegerProperty), 789) .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3) .AddInt64Property(std::string(kIndexableIntegerProperty2), 456) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1) + .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1, + vector2) + .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1) .AddStringProperty(std::string(kQualifiedId1), "pkg$db/ns#uri1") .AddStringProperty(std::string(kQualifiedId2), "pkg$db/ns#uri2") .Build(); @@ -244,6 +294,17 @@ TEST_F(TokenizedDocumentTest, CreateAll) { EXPECT_THAT(tokenized_document.integer_sections().at(1).content, ElementsAre(456)); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata, + Eq(kIndexableVector1SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).content, + ElementsAre(EqualsProto(vector1), EqualsProto(vector2))); + EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata, + Eq(kIndexableVector2SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(1).content, + ElementsAre(EqualsProto(vector1))); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2)); EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata, @@ -278,6 +339,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -314,6 +378,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) { EXPECT_THAT(tokenized_document.integer_sections().at(1).content, ElementsAre(456)); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -341,6 +408,9 @@ TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -381,6 +451,92 @@ TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + + // Qualified id join properties + EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateNoIndexableVectorProperties) { + PropertyProto::VectorProto vector; + vector.set_model_signature("my_model"); + vector.add_values(1.0f); + + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + + // Qualified id join properties + EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateMultipleIndexableVectorProperties) { + PropertyProto::VectorProto vector1; + vector1.set_model_signature("my_model1"); + vector1.add_values(1.0f); + vector1.add_values(2.0f); + PropertyProto::VectorProto vector2; + vector2.set_model_signature("my_model2"); + vector2.add_values(-1.0f); + vector2.add_values(-2.0f); + vector2.add_values(-3.0f); + + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddVectorProperty(std::string(kUnindexedVectorProperty), vector1) + .AddVectorProperty(std::string(kIndexableVectorProperty1), vector1, + vector2) + .AddVectorProperty(std::string(kIndexableVectorProperty2), vector1) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).metadata, + Eq(kIndexableVector1SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(0).content, + ElementsAre(EqualsProto(vector1), EqualsProto(vector2))); + EXPECT_THAT(tokenized_document.vector_sections().at(1).metadata, + Eq(kIndexableVector2SectionMetadata)); + EXPECT_THAT(tokenized_document.vector_sections().at(1).content, + ElementsAre(EqualsProto(vector1))); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -408,6 +564,9 @@ TEST_F(TokenizedDocumentTest, CreateNoJoinQualifiedIdProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), IsEmpty()); } @@ -437,6 +596,9 @@ TEST_F(TokenizedDocumentTest, CreateMultipleJoinQualifiedIdProperties) { // integer sections EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); + // vector sections + EXPECT_THAT(tokenized_document.vector_sections(), IsEmpty()); + // Qualified id join properties EXPECT_THAT(tokenized_document.qualified_id_join_properties(), SizeIs(2)); EXPECT_THAT(tokenized_document.qualified_id_join_properties().at(0).metadata, diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 79fcdb8..a8a571e 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -77,13 +77,6 @@ public class IcingSearchEngine implements IcingSearchEngineInterface { icingSearchEngineImpl.close(); } - @SuppressWarnings("deprecation") - @Override - protected void finalize() throws Throwable { - icingSearchEngineImpl.close(); - super.finalize(); - } - @NonNull @Override public InitializeResultProto initialize() { diff --git a/java/src/com/google/android/icing/IcingSearchEngineImpl.java b/java/src/com/google/android/icing/IcingSearchEngineImpl.java index 57744c4..7994162 100644 --- a/java/src/com/google/android/icing/IcingSearchEngineImpl.java +++ b/java/src/com/google/android/icing/IcingSearchEngineImpl.java @@ -71,13 +71,6 @@ public class IcingSearchEngineImpl implements Closeable { closed = true; } - @SuppressWarnings("deprecation") - @Override - protected void finalize() throws Throwable { - close(); - super.finalize(); - } - @Nullable public byte[] initialize() { throwIfClosed(); diff --git a/java/src/com/google/android/icing/IcingSearchEngineInterface.java b/java/src/com/google/android/icing/IcingSearchEngineInterface.java index 0bc58f1..67f60ed 100644 --- a/java/src/com/google/android/icing/IcingSearchEngineInterface.java +++ b/java/src/com/google/android/icing/IcingSearchEngineInterface.java @@ -32,7 +32,7 @@ import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.UsageReport; import java.io.Closeable; -/** A common user-facing interface to expose the funcationalities provided by Icing Library. */ +/** A common user-facing interface to expose the functionalities provided by Icing Library. */ public interface IcingSearchEngineInterface extends Closeable { /** * Initializes the current IcingSearchEngine implementation. diff --git a/proto/icing/proto/document.proto b/proto/icing/proto/document.proto index 1a501e7..919769e 100644 --- a/proto/icing/proto/document.proto +++ b/proto/icing/proto/document.proto @@ -78,7 +78,7 @@ message DocumentProto { } // Holds a property field of the Document. -// Next tag: 8 +// Next tag: 9 message PropertyProto { // Name of the property. // See icing.lib.PropertyConfigProto.property_name for details. @@ -92,6 +92,17 @@ message PropertyProto { repeated bool boolean_values = 5; repeated bytes bytes_values = 6; repeated DocumentProto document_values = 7; + + message VectorProto { + // The values of the vector. + repeated float values = 1 [packed = true]; + // The model signature of the vector, which can be any string used to + // identify the model, so that embedding searches can be restricted only to + // the vectors with the matching target signature. + // Eg: "universal-sentence-encoder_v0" + optional string model_signature = 2; + } + repeated VectorProto vector_values = 8; } // Result of a call to IcingSearchEngine.Put diff --git a/proto/icing/proto/initialize.proto b/proto/icing/proto/initialize.proto index 9dd9e88..7c58d2c 100644 --- a/proto/icing/proto/initialize.proto +++ b/proto/icing/proto/initialize.proto @@ -23,6 +23,56 @@ option java_package = "com.google.android.icing.proto"; option java_multiple_files = true; option objc_class_prefix = "ICNG"; +// Next tag: 7 +message IcingSearchEngineFeatureInfoProto { + // REQUIRED: Enum representing an IcingLib feature flagged using + // IcingSearchEngineOptions + optional FlaggedFeatureType feature_type = 1; + + enum FlaggedFeatureType { + // This value should never purposely be used. This is used for backwards + // compatibility reasons. + UNKNOWN = 0; + + // Feature for flag + // IcingSearchEngineOptions::build_property_existence_metadata_hits. + // + // This feature covers the kHasPropertyFunctionFeature advanced query + // feature, and related metadata hits indexing used for property existence + // check. + FEATURE_HAS_PROPERTY_OPERATOR = 1; + } + + // Whether the feature requires the document store to be rebuilt. + // The default value is false. + optional bool needs_document_store_rebuild = 2; + + // Whether the feature requires the schema store to be rebuilt. + // The default value is false. + optional bool needs_schema_store_rebuild = 3; + + // Whether the feature requires the term index to be rebuilt. + // The default value is false. + optional bool needs_term_index_rebuild = 4; + + // Whether the feature requires the integer index to be rebuilt. + // The default value is false. + optional bool needs_integer_index_rebuild = 5; + + // Whether the feature requires the qualified id join index to be rebuilt. + // The default value is false. + optional bool needs_qualified_id_join_index_rebuild = 6; +} + +// Next tag: 4 +message IcingSearchEngineVersionProto { + // version and max_version are from the original version file. + optional int32 version = 1; + optional int32 max_version = 2; + // Features that are enabled in an icing version at initialization. + repeated IcingSearchEngineFeatureInfoProto enabled_features = 3; +} + // Next tag: 16 message IcingSearchEngineOptions { // Directory to persist files for Icing. Required. @@ -131,9 +181,6 @@ message IcingSearchEngineOptions { // Whether to build the metadata hits used for property existence check, which // is required to support the hasProperty function in advanced query. - // - // TODO(b/309826655): Implement the feature flag derived files rebuild - // mechanism to handle index rebuild, instead of using index's magic value. optional bool build_property_existence_metadata_hits = 15; reserved 2; diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto index fcedeed..46e988e 100644 --- a/proto/icing/proto/logging.proto +++ b/proto/icing/proto/logging.proto @@ -23,7 +23,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Stats of the top-level function IcingSearchEngine::Initialize(). -// Next tag: 14 +// Next tag: 15 message InitializeStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -55,6 +55,10 @@ message InitializeStatsProto { // Any dependencies have changed. DEPENDENCIES_CHANGED = 7; + + // Change detected in Icing's feature flags since last initialization that + // requires recovery. + FEATURE_FLAG_CHANGED = 8; } // Possible recovery causes for document store: @@ -117,10 +121,16 @@ message InitializeStatsProto { // - SCHEMA_CHANGES_OUT_OF_SYNC // - IO_ERROR optional RecoveryCause qualified_id_join_index_restoration_cause = 13; + + // Possible recovery causes for embedding index: + // - INCONSISTENT_WITH_GROUND_TRUTH + // - SCHEMA_CHANGES_OUT_OF_SYNC + // - IO_ERROR + optional RecoveryCause embedding_index_restoration_cause = 14; } // Stats of the top-level function IcingSearchEngine::Put(). -// Next tag: 12 +// Next tag: 13 message PutDocumentStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -166,11 +176,14 @@ message PutDocumentStatsProto { // Time used to index all metadata terms in the document, which can only be // added by PropertyExistenceIndexingHandler currently. optional int32 metadata_term_index_latency_ms = 11; + + // Time used to index all embeddings in the document. + optional int32 embedding_index_latency_ms = 12; } // Stats of the top-level function IcingSearchEngine::Search() and // IcingSearchEngine::GetNextPage(). -// Next tag: 26 +// Next tag: 28 message QueryStatsProto { // TODO(b/305098009): deprecate. Use parent_search_stats instead. // The UTF-8 length of the query string @@ -252,7 +265,7 @@ message QueryStatsProto { optional bool is_join_query = 23; // Stats of the search. Only valid for first page. - // Next tag: 13 + // Next tag: 16 message SearchStats { // The UTF-8 length of the query string optional int32 query_length = 1; @@ -290,6 +303,15 @@ message QueryStatsProto { // Number of hits fetched by integer index before applying any filters. optional int32 num_fetched_hits_integer_index = 12; + + // Time used in Lexer to extract lexer tokens from the query. + optional int32 query_processor_lexer_extract_token_latency_ms = 13; + + // Time used in Parser to consume lexer tokens extracted from the query. + optional int32 query_processor_parser_consume_query_latency_ms = 14; + + // Time used in QueryVisitor to visit and build (nested) DocHitInfoIterator. + optional int32 query_processor_query_visitor_latency_ms = 15; } // Search stats for parent. Only valid for first page. @@ -298,6 +320,12 @@ message QueryStatsProto { // Search stats for child. optional SearchStats child_search_stats = 25; + // Byte size of the lite index hit buffer. + optional int64 lite_index_hit_buffer_byte_size = 26; + + // Byte size of the unsorted tail of the lite index hit buffer. + optional int64 lite_index_hit_buffer_unsorted_byte_size = 27; + reserved 9; } diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto index c716dba..99439bb 100644 --- a/proto/icing/proto/schema.proto +++ b/proto/icing/proto/schema.proto @@ -34,7 +34,7 @@ option objc_class_prefix = "ICNG"; // TODO(cassiewang) Define a sample proto file that can be used by tests and for // documentation. // -// Next tag: 7 +// Next tag: 8 message SchemaTypeConfigProto { // REQUIRED: Named type that uniquely identifies the structured, logical // schema being defined. @@ -43,6 +43,13 @@ message SchemaTypeConfigProto { // in http://schema.org. Eg: DigitalDocument, Message, Person, etc. optional string schema_type = 1; + // OPTIONAL: A natural language description of the SchemaTypeConfigProto. + // + // This string is not used by Icing in any way. It simply exists to allow + // users to store semantic information about the SchemaTypeConfigProto for + // future retrieval. + optional string description = 7; + // List of all properties that are supported by Documents of this type. // An Document should never have properties that are not listed here. // @@ -177,6 +184,26 @@ message IntegerIndexingConfig { optional NumericMatchType.Code numeric_match_type = 1; } +// Describes how a vector property should be indexed. +// Next tag: 3 +message EmbeddingIndexingConfig { + // OPTIONAL: Indicates how the vector contents of this property should be + // matched. + // + // The default value is UNKNOWN. + message EmbeddingIndexingType { + enum Code { + // Contents in this property will not be indexed. Useful if the vector + // property type is not indexable. + UNKNOWN = 0; + + // Contents in this property will be indexed for linear search. + LINEAR_SEARCH = 1; + } + } + optional EmbeddingIndexingType.Code embedding_indexing_type = 1; +} + // Describes how a property can be used to join this document with another // document. See JoinSpecProto (in search.proto) for more details. // Next tag: 3 @@ -208,7 +235,7 @@ message JoinableConfig { // Describes the schema of a single property of Documents that belong to a // specific SchemaTypeConfigProto. These can be considered as a rich, structured // type for each property of Documents accepted by IcingSearchEngine. -// Next tag: 9 +// Next tag: 11 message PropertyConfigProto { // REQUIRED: Name that uniquely identifies a property within an Document of // a specific SchemaTypeConfigProto. @@ -219,6 +246,13 @@ message PropertyConfigProto { // Eg: 'address' for http://schema.org/Place. optional string property_name = 1; + // OPTIONAL: A natural language description of the property. + // + // This string is not used by Icing in any way. It simply exists to allow + // users to store semantic information about the PropertyConfigProto for + // future retrieval. + optional string description = 9; + // REQUIRED: Physical data-types of the contents of the property. message DataType { enum Code { @@ -238,6 +272,9 @@ message PropertyConfigProto { // a hierarchical Document schema. Any property using this data_type // MUST have a valid 'schema_type'. DOCUMENT = 6; + + // A list of floats. Vector type is used for embedding searches. + VECTOR = 7; } } optional DataType.Code data_type = 2; @@ -299,6 +336,10 @@ message PropertyConfigProto { // - The property itself and any upper-level (nested doc) property should // contain at most one element (i.e. Cardinality is OPTIONAL or REQUIRED). optional JoinableConfig joinable_config = 8; + + // OPTIONAL: Describes how vector properties should be indexed. Vector + // properties that do not set the indexing config will not be indexed. + optional EmbeddingIndexingConfig embedding_indexing_config = 10; } // List of all supported types constitutes the schema used by Icing. diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 7f4fb3e..7e145ce 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -27,7 +27,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Client-supplied specifications on what documents to retrieve. -// Next tag: 11 +// Next tag: 13 message SearchSpecProto { // REQUIRED: The "raw" query string that users may type. For example, "cat" // will search for documents with the term cat in it. @@ -112,6 +112,22 @@ message SearchSpecProto { // (TypePropertyMask for the given schema type has empty paths field), no // properties of that schema type will be searched. repeated TypePropertyMask type_property_filters = 10; + + // The vectors to be used in embedding queries. + repeated PropertyProto.VectorProto embedding_query_vectors = 11; + + message EmbeddingQueryMetricType { + enum Code { + UNKNOWN = 0; + COSINE = 1; + DOT_PRODUCT = 2; + EUCLIDEAN = 3; + } + } + + // The default metric type used to calculate the scores for embedding + // queries. + optional EmbeddingQueryMetricType.Code embedding_query_metric_type = 12; } // Client-supplied specifications on what to include/how to format the search diff --git a/proto/icing/proto/storage.proto b/proto/icing/proto/storage.proto index 39dab6b..e0323a1 100644 --- a/proto/icing/proto/storage.proto +++ b/proto/icing/proto/storage.proto @@ -22,6 +22,8 @@ option java_package = "com.google.android.icing.proto"; option java_multiple_files = true; option objc_class_prefix = "ICNG"; +// TODO(b/305098009): fix byte size vs size naming issue. + // Next tag: 10 message NamespaceStorageInfoProto { // Name of the namespace diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index dd08fd1..55c4647 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=587883838) +set(synced_AOSP_CL_number=616925123) |