diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-27 16:37:59 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-27 16:37:59 +0000 |
commit | 0e42c1c9f44ee8b52d7f59503d4b2f610217228d (patch) | |
tree | 69c5b6c2253d521b1c26f1e6b4bb1cfe571c907e | |
parent | 4f7fa20382e5e06e9b56a477478a8ea921ccfc79 (diff) | |
parent | 8f16167c39d77866a213b0ef042058e858255d47 (diff) | |
download | icing-androidx-sharetarget-release.tar.gz |
Snap for 8755081 from 8f16167c39d77866a213b0ef042058e858255d47 to androidx-sharetarget-releaseandroidx-sharetarget-release
Change-Id: I48cc3b5162a457d89cd505e8ceec7c253106fa87
80 files changed, 3288 insertions, 555 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 01ee8eb..8c8e439 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,8 @@ cmake_minimum_required(VERSION 3.10.2) +project(icing) + add_definitions("-DICING_REVERSE_JNI_SEGMENTATION=1") set(VERSION_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/icing/jni.lds") set(CMAKE_SHARED_LINKER_FLAGS @@ -74,7 +76,7 @@ foreach(FILE ${Icing_PROTO_FILES}) "${Icing_PROTO_GEN_DIR}/${FILE_NOEXT}.pb.h" COMMAND ${Protobuf_PROTOC_PATH} --proto_path "${CMAKE_CURRENT_SOURCE_DIR}/proto" - --cpp_out ${Icing_PROTO_GEN_DIR} + --cpp_out "lite:${Icing_PROTO_GEN_DIR}" ${FILE} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/proto/${FILE} @@ -127,4 +129,4 @@ target_include_directories(icing PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(icing PRIVATE ${Icing_PROTO_GEN_DIR}) target_include_directories(icing PRIVATE "${Protobuf_SOURCE_DIR}/src") target_include_directories(icing PRIVATE "${ICU_SOURCE_DIR}/include") -target_link_libraries(icing protobuf::libprotobuf libandroidicu log) +target_link_libraries(icing protobuf::libprotobuf-lite libandroidicu log z) @@ -0,0 +1,3 @@ +adorokhine@google.com +tjbarron@google.com +dsaadati@google.com diff --git a/build.gradle b/build.gradle index 0f60c5e..5b5f3a6 100644 --- a/build.gradle +++ b/build.gradle @@ -14,8 +14,6 @@ * limitations under the License. */ - -import androidx.build.dependencies.DependenciesKt import static androidx.build.SupportConfig.* buildscript { @@ -65,7 +63,7 @@ dependencies { protobuf { protoc { - artifact = DependenciesKt.getDependencyAsString(libs.protobufCompiler) + artifact = libs.protobufCompiler.get() } generateProtoTasks { diff --git a/icing/file/file-backed-proto-log.h b/icing/file/file-backed-proto-log.h index cf16b4f..686b4fb 100644 --- a/icing/file/file-backed-proto-log.h +++ b/icing/file/file-backed-proto-log.h @@ -40,13 +40,13 @@ #include <string_view> #include "icing/text_classifier/lib3/utils/base/statusor.h" -#include <google/protobuf/io/gzip_stream.h> #include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" #include "icing/file/filesystem.h" #include "icing/file/memory-mapped-file.h" #include "icing/legacy/core/icing-string-util.h" +#include "icing/portable/gzip_stream.h" #include "icing/portable/platform.h" #include "icing/portable/zlib.h" #include "icing/util/crc32.h" @@ -292,9 +292,6 @@ class FileBackedProtoLog { static_assert(kMaxProtoSize <= 0x00FFFFFF, "kMaxProtoSize doesn't fit in 3 bytes"); - // Level of compression, BEST_SPEED = 1, BEST_COMPRESSION = 9 - static constexpr int kDeflateCompressionLevel = 3; - // Chunks of the file to mmap at a time, so we don't mmap the entire file. // Only used on 32-bit devices static constexpr int kMmapChunkSize = 4 * 1024 * 1024; // 4MiB @@ -306,9 +303,6 @@ class FileBackedProtoLog { }; template <typename ProtoT> -constexpr uint8_t FileBackedProtoLog<ProtoT>::kProtoMagic; - -template <typename ProtoT> FileBackedProtoLog<ProtoT>::FileBackedProtoLog(const Filesystem* filesystem, const std::string& file_path, std::unique_ptr<Header> header) @@ -582,7 +576,7 @@ libtextclassifier3::StatusOr<ProtoT> FileBackedProtoLog<ProtoT>::ReadProto( // Deserialize proto ProtoT proto; if (header_->compress) { - google::protobuf::io::GzipInputStream decompress_stream(&proto_stream); + protobuf_ports::GzipInputStream decompress_stream(&proto_stream); proto.ParseFromZeroCopyStream(&decompress_stream); } else { proto.ParseFromZeroCopyStream(&proto_stream); diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index 0989935..00bdc7e 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -56,10 +56,9 @@ #ifndef ICING_FILE_FILE_BACKED_VECTOR_H_ #define ICING_FILE_FILE_BACKED_VECTOR_H_ -#include <inttypes.h> -#include <stdint.h> #include <sys/mman.h> +#include <cinttypes> #include <cstdint> #include <memory> #include <string> diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc index b05ce2d..7c02af9 100644 --- a/icing/file/file-backed-vector_test.cc +++ b/icing/file/file-backed-vector_test.cc @@ -14,9 +14,8 @@ #include "icing/file/file-backed-vector.h" -#include <errno.h> - #include <algorithm> +#include <cerrno> #include <cstdint> #include <memory> #include <string_view> diff --git a/icing/file/filesystem.cc b/icing/file/filesystem.cc index 0655cb9..82b8d98 100644 --- a/icing/file/filesystem.cc +++ b/icing/file/filesystem.cc @@ -16,7 +16,6 @@ #include <dirent.h> #include <dlfcn.h> -#include <errno.h> #include <fcntl.h> #include <fnmatch.h> #include <pthread.h> @@ -26,6 +25,7 @@ #include <unistd.h> #include <algorithm> +#include <cerrno> #include <cstdint> #include <unordered_set> diff --git a/icing/file/filesystem.h b/icing/file/filesystem.h index 6bed8e6..ca8c4a8 100644 --- a/icing/file/filesystem.h +++ b/icing/file/filesystem.h @@ -17,11 +17,9 @@ #ifndef ICING_FILE_FILESYSTEM_H_ #define ICING_FILE_FILESYSTEM_H_ -#include <stdint.h> -#include <stdio.h> -#include <string.h> - #include <cstdint> +#include <cstdio> +#include <cstring> #include <memory> #include <string> #include <unordered_set> diff --git a/icing/file/portable-file-backed-proto-log.h b/icing/file/portable-file-backed-proto-log.h index 99b8941..f676dc5 100644 --- a/icing/file/portable-file-backed-proto-log.h +++ b/icing/file/portable-file-backed-proto-log.h @@ -64,7 +64,6 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" -#include <google/protobuf/io/gzip_stream.h> #include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" @@ -72,6 +71,7 @@ #include "icing/file/memory-mapped-file.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/portable/endian.h" +#include "icing/portable/gzip_stream.h" #include "icing/portable/platform.h" #include "icing/portable/zlib.h" #include "icing/util/bit-util.h" @@ -576,9 +576,6 @@ class PortableFileBackedProtoLog { }; template <typename ProtoT> -constexpr uint8_t PortableFileBackedProtoLog<ProtoT>::kProtoMagic; - -template <typename ProtoT> PortableFileBackedProtoLog<ProtoT>::PortableFileBackedProtoLog( const Filesystem* filesystem, const std::string& file_path, std::unique_ptr<Header> header) @@ -733,7 +730,7 @@ PortableFileBackedProtoLog<ProtoT>::InitializeExistingFile( return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to truncate '%s' to size %lld", file_path.data(), static_cast<long long>(header->GetRewindOffset()))); - }; + } data_loss = DataLoss::PARTIAL; } @@ -889,12 +886,11 @@ PortableFileBackedProtoLog<ProtoT>::WriteProto(const ProtoT& proto) { google::protobuf::io::StringOutputStream proto_stream(&proto_str); if (header_->GetCompressFlag()) { - google::protobuf::io::GzipOutputStream::Options options; - options.format = google::protobuf::io::GzipOutputStream::ZLIB; + protobuf_ports::GzipOutputStream::Options options; + options.format = protobuf_ports::GzipOutputStream::ZLIB; options.compression_level = kDeflateCompressionLevel; - google::protobuf::io::GzipOutputStream compressing_stream(&proto_stream, - options); + protobuf_ports::GzipOutputStream compressing_stream(&proto_stream, options); bool success = proto.SerializeToZeroCopyStream(&compressing_stream) && compressing_stream.Close(); @@ -974,7 +970,7 @@ PortableFileBackedProtoLog<ProtoT>::ReadProto(int64_t file_offset) const { // Deserialize proto ProtoT proto; if (header_->GetCompressFlag()) { - google::protobuf::io::GzipInputStream decompress_stream(&proto_stream); + protobuf_ports::GzipInputStream decompress_stream(&proto_stream); proto.ParseFromZeroCopyStream(&decompress_stream); } else { proto.ParseFromZeroCopyStream(&proto_stream); diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 1b7bd89..9aa833b 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -47,6 +47,7 @@ #include "icing/proto/search.pb.h" #include "icing/proto/status.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/suggestion-processor.h" #include "icing/result/projection-tree.h" #include "icing/result/projector.h" #include "icing/result/result-retriever.h" @@ -134,14 +135,24 @@ libtextclassifier3::Status ValidateSearchSpec( return libtextclassifier3::Status::OK; } -IndexProcessor::Options CreateIndexProcessorOptions( - const IcingSearchEngineOptions& options) { - IndexProcessor::Options index_processor_options; - index_processor_options.max_tokens_per_document = - options.max_tokens_per_doc(); - index_processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kSuppressError; - return index_processor_options; +libtextclassifier3::Status ValidateSuggestionSpec( + const SuggestionSpecProto& suggestion_spec, + const PerformanceConfiguration& configuration) { + if (suggestion_spec.prefix().empty()) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("SuggestionSpecProto.prefix is empty!")); + } + if (suggestion_spec.num_to_return() <= 0) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "SuggestionSpecProto.num_to_return must be positive.")); + } + if (suggestion_spec.prefix().size() > configuration.max_query_length) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("SuggestionSpecProto.prefix is longer than the " + "maximum allowed prefix length: ", + std::to_string(configuration.max_query_length))); + } + return libtextclassifier3::Status::OK; } // Document store files are in a standalone subfolder for easier file @@ -799,9 +810,8 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { } DocumentId document_id = document_id_or.ValueOrDie(); - auto index_processor_or = IndexProcessor::Create( - normalizer_.get(), index_.get(), CreateIndexProcessorOptions(options_), - clock_.get()); + auto index_processor_or = + IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -812,6 +822,17 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { auto status = index_processor->IndexDocument(tokenized_document, document_id, put_document_stats); + if (!status.ok()) { + // If we encountered a failure while indexing this document, then mark it as + // deleted. + libtextclassifier3::Status delete_status = + document_store_->Delete(document_id); + if (!delete_status.ok()) { + // This is pretty dire (and, hopefully, unlikely). We can't roll back the + // document that we just added. Wipeout the whole index. + ResetInternal(); + } + } TransformStatus(status, result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -1397,8 +1418,8 @@ SearchResultProto IcingSearchEngine::Search( component_timer = clock_->GetNewTimer(); // Scores but does not rank the results. libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> - scoring_processor_or = - ScoringProcessor::Create(scoring_spec, document_store_.get()); + scoring_processor_or = ScoringProcessor::Create( + scoring_spec, document_store_.get(), schema_store_.get()); if (!scoring_processor_or.ok()) { TransformStatus(scoring_processor_or.status(), result_status); return result_proto; @@ -1709,9 +1730,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {libtextclassifier3::Status::OK, false}; } - auto index_processor_or = IndexProcessor::Create( - normalizer_.get(), index_.get(), CreateIndexProcessorOptions(options_), - clock_.get()); + auto index_processor_or = + IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { return {index_processor_or.status(), true}; } @@ -1789,12 +1809,16 @@ libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() { } ResetResultProto IcingSearchEngine::Reset() { + absl_ports::unique_lock l(&mutex_); + return ResetInternal(); +} + +ResetResultProto IcingSearchEngine::ResetInternal() { ICING_VLOG(1) << "Resetting IcingSearchEngine"; ResetResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); - absl_ports::unique_lock l(&mutex_); initialized_ = false; ResetMembers(); if (!filesystem_->DeleteDirectoryRecursively(options_.base_dir().c_str())) { @@ -1822,5 +1846,62 @@ ResetResultProto IcingSearchEngine::Reset() { return result_proto; } +SuggestionResponse IcingSearchEngine::SearchSuggestions( + const SuggestionSpecProto& suggestion_spec) { + // TODO(b/146008613) Explore ideas to make this function read-only. + absl_ports::unique_lock l(&mutex_); + SuggestionResponse response; + StatusProto* response_status = response.mutable_status(); + if (!initialized_) { + response_status->set_code(StatusProto::FAILED_PRECONDITION); + response_status->set_message("IcingSearchEngine has not been initialized!"); + return response; + } + + libtextclassifier3::Status status = + ValidateSuggestionSpec(suggestion_spec, performance_configuration_); + if (!status.ok()) { + TransformStatus(status, response_status); + return response; + } + + // Create the suggestion processor. + auto suggestion_processor_or = SuggestionProcessor::Create( + index_.get(), language_segmenter_.get(), normalizer_.get()); + if (!suggestion_processor_or.ok()) { + TransformStatus(suggestion_processor_or.status(), response_status); + return response; + } + std::unique_ptr<SuggestionProcessor> suggestion_processor = + std::move(suggestion_processor_or).ValueOrDie(); + + std::vector<NamespaceId> namespace_ids; + namespace_ids.reserve(suggestion_spec.namespace_filters_size()); + for (std::string_view name_space : suggestion_spec.namespace_filters()) { + auto namespace_id_or = document_store_->GetNamespaceId(name_space); + if (!namespace_id_or.ok()) { + continue; + } + namespace_ids.push_back(namespace_id_or.ValueOrDie()); + } + + // Run suggestion based on given SuggestionSpec. + libtextclassifier3::StatusOr<std::vector<TermMetadata>> terms_or = + suggestion_processor->QuerySuggestions(suggestion_spec, namespace_ids); + if (!terms_or.ok()) { + TransformStatus(terms_or.status(), response_status); + return response; + } + + // Convert vector<TermMetaData> into final SuggestionResponse proto. + for (TermMetadata& term : terms_or.ValueOrDie()) { + SuggestionResponse::Suggestion suggestion; + suggestion.set_query(std::move(term.content)); + response.mutable_suggestions()->Add(std::move(suggestion)); + } + response_status->set_code(StatusProto::OK); + return response; +} + } // namespace lib } // namespace icing diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index 65960a3..0a79714 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -302,6 +302,17 @@ class IcingSearchEngine { const ResultSpecProto& result_spec) ICING_LOCKS_EXCLUDED(mutex_); + // Retrieves, scores, ranks and returns the suggested query string according + // to the specs. Results can be empty. + // + // Returns a SuggestionResponse with status: + // OK with results on success + // INVALID_ARGUMENT if any of specs is invalid + // FAILED_PRECONDITION IcingSearchEngine has not been initialized yet + // INTERNAL_ERROR on any other errors + SuggestionResponse SearchSuggestions( + const SuggestionSpecProto& suggestion_spec) ICING_LOCKS_EXCLUDED(mutex_); + // Fetches the next page of results of a previously executed query. Results // can be empty if next-page token is invalid. Invalid next page tokens are // tokens that are either zero or were previously passed to @@ -455,6 +466,10 @@ class IcingSearchEngine { // Resets all members that are created during Initialize. void ResetMembers() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Resets all members that are created during Initialize, deletes all + // underlying files and initializes a fresh index. + ResetResultProto ResetInternal() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Checks for the existence of the init marker file. If the failed init count // exceeds kMaxUnsuccessfulInitAttempts, all data is deleted and the index is // initialized from scratch. The updated count (original failed init count + 1 diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index 6ad4703..b5206cd 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -7439,10 +7439,6 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexingStats) { // No merge should happen. EXPECT_THAT(put_result_proto.put_document_stats().index_merge_latency_ms(), Eq(0)); - // Number of tokens should not exceed. - EXPECT_FALSE(put_result_proto.put_document_stats() - .tokenization_stats() - .exceeded_max_token_num()); // The input document has 2 tokens. EXPECT_THAT(put_result_proto.put_document_stats() .tokenization_stats() @@ -7450,33 +7446,6 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexingStats) { Eq(2)); } -TEST_F(IcingSearchEngineTest, PutDocumentShouldLogWhetherNumTokensExceeds) { - // Create a document with 2 tokens. - DocumentProto document = DocumentBuilder() - .SetKey("icing", "fake_type/0") - .SetSchema("Message") - .AddStringProperty("body", "message body") - .Build(); - - // Create an icing instance with max_tokens_per_doc = 1. - IcingSearchEngineOptions icing_options = GetDefaultIcingOptions(); - icing_options.set_max_tokens_per_doc(1); - IcingSearchEngine icing(icing_options, GetTestJniCache()); - ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); - ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); - - PutResultProto put_result_proto = icing.Put(document); - EXPECT_THAT(put_result_proto.status(), ProtoIsOk()); - // Number of tokens(2) exceeds the max allowed value(1). - EXPECT_TRUE(put_result_proto.put_document_stats() - .tokenization_stats() - .exceeded_max_token_num()); - EXPECT_THAT(put_result_proto.put_document_stats() - .tokenization_stats() - .num_tokens_indexed(), - Eq(1)); -} - TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexMergeLatency) { DocumentProto document1 = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -8044,6 +8013,147 @@ TEST_F(IcingSearchEngineTest, CJKSnippetTest) { EXPECT_THAT(match_proto.exact_match_utf16_length(), Eq(2)); } +TEST_F(IcingSearchEngineTest, PutDocumentIndexFailureDeletion) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + // Testing has shown that adding ~600,000 terms generated this way will + // fill up the hit buffer. + std::vector<std::string> terms = GenerateUniqueTerms(600000); + std::string content = absl_ports::StrJoin(terms, " "); + DocumentProto document = DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Message") + .AddStringProperty("body", "foo " + content) + .Build(); + // We failed to add the document to the index fully. This means that we should + // reject the document from Icing entirely. + ASSERT_THAT(icing.Put(document).status(), + ProtoStatusIs(StatusProto::OUT_OF_SPACE)); + + // Make sure that the document isn't searchable. + SearchSpecProto search_spec; + search_spec.set_query("foo"); + search_spec.set_term_match_type(MATCH_PREFIX); + + SearchResultProto search_results = + icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(search_results.status(), ProtoIsOk()); + ASSERT_THAT(search_results.results(), IsEmpty()); + + // Make sure that the document isn't retrievable. + GetResultProto get_result = + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()); + ASSERT_THAT(get_result.status(), ProtoStatusIs(StatusProto::NOT_FOUND)); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(), + ProtoIsOk()); + + // Creates and inserts 6 documents, and index 6 termSix, 5 termFive, 4 + // termFour, 3 termThree, 2 termTwo and one termOne. + DocumentProto document1 = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty( + "subject", "termOne termTwo termThree termFour termFive termSix") + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", + "termTwo termThree termFour termFive termSix") + .Build(); + DocumentProto document3 = + DocumentBuilder() + .SetKey("namespace", "uri3") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termThree termFour termFive termSix") + .Build(); + DocumentProto document4 = + DocumentBuilder() + .SetKey("namespace", "uri4") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termFour termFive termSix") + .Build(); + DocumentProto document5 = + DocumentBuilder() + .SetKey("namespace", "uri5") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termFive termSix") + .Build(); + DocumentProto document6 = DocumentBuilder() + .SetKey("namespace", "uri6") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termSix") + .Build(); + ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document4).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document5).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document6).status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("t"); + suggestion_spec.set_num_to_return(10); + + // Query all suggestions, and they will be ranked. + SuggestionResponse response = icing.SearchSuggestions(suggestion_spec); + ASSERT_THAT(response.status(), ProtoIsOk()); + ASSERT_THAT(response.suggestions().at(0).query(), "termsix"); + ASSERT_THAT(response.suggestions().at(1).query(), "termfive"); + ASSERT_THAT(response.suggestions().at(2).query(), "termfour"); + ASSERT_THAT(response.suggestions().at(3).query(), "termthree"); + ASSERT_THAT(response.suggestions().at(4).query(), "termtwo"); + ASSERT_THAT(response.suggestions().at(5).query(), "termone"); + + // Query first three suggestions, and they will be ranked. + suggestion_spec.set_num_to_return(3); + response = icing.SearchSuggestions(suggestion_spec); + ASSERT_THAT(response.status(), ProtoIsOk()); + ASSERT_THAT(response.suggestions().at(0).query(), "termsix"); + ASSERT_THAT(response.suggestions().at(1).query(), "termfive"); + ASSERT_THAT(response.suggestions().at(2).query(), "termfour"); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_emptyPrefix) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix(""); + suggestion_spec.set_num_to_return(10); + + ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_NonPositiveNumToReturn) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("prefix"); + suggestion_spec.set_num_to_return(0); + + ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + #ifndef ICING_JNI_TEST // We skip this test case when we're running in a jni_test since the data files // will be stored in the android-instrumented storage location, rather than the diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc index 6d8632f..1aae732 100644 --- a/icing/index/index-processor.cc +++ b/icing/index/index-processor.cc @@ -43,14 +43,13 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> IndexProcessor::Create(const Normalizer* normalizer, Index* index, - const IndexProcessor::Options& options, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(clock); return std::unique_ptr<IndexProcessor>( - new IndexProcessor(normalizer, index, options, clock)); + new IndexProcessor(normalizer, index, clock)); } libtextclassifier3::Status IndexProcessor::IndexDocument( @@ -66,53 +65,34 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( } index_->set_last_added_document_id(document_id); uint32_t num_tokens = 0; - libtextclassifier3::Status overall_status; + libtextclassifier3::Status status; for (const TokenizedSection& section : tokenized_document.sections()) { // TODO(b/152934343): pass real namespace ids in Index::Editor editor = index_->Edit(document_id, section.metadata.id, section.metadata.term_match_type, /*namespace_id=*/0); for (std::string_view token : section.token_sequence) { - if (++num_tokens > options_.max_tokens_per_document) { - // Index all tokens buffered so far. - editor.IndexAllBufferedTerms(); - if (put_document_stats != nullptr) { - put_document_stats->mutable_tokenization_stats() - ->set_exceeded_max_token_num(true); - put_document_stats->mutable_tokenization_stats() - ->set_num_tokens_indexed(options_.max_tokens_per_document); - } - switch (options_.token_limit_behavior) { - case Options::TokenLimitBehavior::kReturnError: - return absl_ports::ResourceExhaustedError( - "Max number of tokens reached!"); - case Options::TokenLimitBehavior::kSuppressError: - return overall_status; - } - } + ++num_tokens; std::string term = normalizer_.NormalizeTerm(token); - // Add this term to Hit buffer. Even if adding this hit fails, we keep - // trying to add more hits because it's possible that future hits could - // still be added successfully. For instance if the lexicon is full, we - // might fail to add a hit for a new term, but should still be able to - // add hits for terms that are already in the index. - auto status = editor.BufferTerm(term.c_str()); - if (overall_status.ok() && !status.ok()) { - // If we've succeeded to add everything so far, set overall_status to - // represent this new failure. If we've already failed, no need to - // update the status - we're already going to return a resource - // exhausted error. - overall_status = status; + // Add this term to Hit buffer. + status = editor.BufferTerm(term.c_str()); + if (!status.ok()) { + // We've encountered a failure. Bail out. We'll mark this doc as deleted + // and signal a failure to the client. + ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: " + << status.error_message(); + break; } } + if (!status.ok()) { + break; + } // Add all the seen terms to the index with their term frequency. - auto status = editor.IndexAllBufferedTerms(); - if (overall_status.ok() && !status.ok()) { - // If we've succeeded so far, set overall_status to - // represent this new failure. If we've already failed, no need to - // update the status - we're already going to return a resource - // exhausted error. - overall_status = status; + status = editor.IndexAllBufferedTerms(); + if (!status.ok()) { + ICING_LOG(WARNING) << "Failed to add hits in lite index due to: " + << status.error_message(); + break; } } @@ -123,9 +103,11 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( num_tokens); } - // Merge if necessary. - if (overall_status.ok() && index_->WantsMerge()) { - ICING_VLOG(1) << "Merging the index at docid " << document_id << "."; + // If we're either successful or we've hit resource exhausted, then attempt a + // merge. + if ((status.ok() || absl_ports::IsResourceExhausted(status)) && + index_->WantsMerge()) { + ICING_LOG(ERROR) << "Merging the index at docid " << document_id << "."; std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer(); libtextclassifier3::Status merge_status = index_->Merge(); @@ -150,7 +132,7 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( } } - return overall_status; + return status; } } // namespace lib diff --git a/icing/index/index-processor.h b/icing/index/index-processor.h index 6b07c98..c4b77b5 100644 --- a/icing/index/index-processor.h +++ b/icing/index/index-processor.h @@ -32,23 +32,6 @@ namespace lib { class IndexProcessor { public: - struct Options { - int32_t max_tokens_per_document; - - // Indicates how a document exceeding max_tokens_per_document should be - // handled. - enum class TokenLimitBehavior { - // When set, the first max_tokens_per_document will be indexed. If the - // token count exceeds max_tokens_per_document, a ResourceExhausted error - // will be returned. - kReturnError, - // When set, the first max_tokens_per_document will be indexed. If the - // token count exceeds max_tokens_per_document, OK will be returned. - kSuppressError, - }; - TokenLimitBehavior token_limit_behavior; - }; - // Factory function to create an IndexProcessor which does not take ownership // of any input components, and all pointers must refer to valid objects that // outlive the created IndexProcessor instance. @@ -57,8 +40,7 @@ class IndexProcessor { // An IndexProcessor on success // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> Create( - const Normalizer* normalizer, Index* index, const Options& options, - const Clock* clock); + const Normalizer* normalizer, Index* index, const Clock* clock); // Add tokenized document to the index, associated with document_id. If the // number of tokens in the document exceeds max_tokens_per_document, then only @@ -84,18 +66,13 @@ class IndexProcessor { PutDocumentStatsProto* put_document_stats = nullptr); private: - IndexProcessor(const Normalizer* normalizer, Index* index, - const Options& options, const Clock* clock) - : normalizer_(*normalizer), - index_(index), - options_(options), - clock_(*clock) {} + IndexProcessor(const Normalizer* normalizer, Index* index, const Clock* clock) + : normalizer_(*normalizer), index_(index), clock_(*clock) {} std::string NormalizeToken(const Token& token); const Normalizer& normalizer_; Index* const index_; - const Options options_; const Clock& clock_; }; diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc index afeac4d..6e072c7 100644 --- a/icing/index/index-processor_benchmark.cc +++ b/icing/index/index-processor_benchmark.cc @@ -168,17 +168,6 @@ void CleanUp(const Filesystem& filesystem, const std::string& index_dir) { filesystem.DeleteDirectoryRecursively(index_dir.c_str()); } -std::unique_ptr<IndexProcessor> CreateIndexProcessor( - const Normalizer* normalizer, Index* index, const Clock* clock) { - IndexProcessor::Options processor_options{}; - processor_options.max_tokens_per_document = 1024 * 1024 * 10; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - return IndexProcessor::Create(normalizer, index, processor_options, clock) - .ValueOrDie(); -} - void BM_IndexDocumentWithOneProperty(benchmark::State& state) { bool run_via_adb = absl::GetFlag(FLAGS_adb); if (!run_via_adb) { @@ -200,9 +189,9 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); - + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithOneProperty(state.range(0)); TokenizedDocument tokenized_document(std::move( TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), @@ -254,8 +243,9 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithTenProperties(state.range(0)); @@ -309,8 +299,9 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithDiacriticLetters(state.range(0)); @@ -364,8 +355,9 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithHiragana(state.range(0)); TokenizedDocument tokenized_document(std::move( diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc index 8a6a9f5..449bc3e 100644 --- a/icing/index/index-processor_test.cc +++ b/icing/index/index-processor_test.cc @@ -27,6 +27,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/helpers/icu/icu-data-file-helper.h" @@ -48,6 +49,7 @@ #include "icing/store/document-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" +#include "icing/testing/random-string.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" @@ -193,15 +195,9 @@ class IndexProcessorTest : public Test { .Build(); ICING_ASSERT_OK(schema_store_->SetSchema(schema)); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), - processor_options, &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); mock_icing_filesystem_ = std::make_unique<IcingMockFilesystem>(); } @@ -232,17 +228,12 @@ std::vector<DocHitInfo> GetHits(std::unique_ptr<DocHitInfoIterator> iterator) { } TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) { - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(), - processor_options, &fake_clock_), + &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(IndexProcessor::Create(normalizer_.get(), /*index=*/nullptr, - processor_options, &fake_clock_), + &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -434,103 +425,68 @@ TEST_F(IndexProcessorTest, DocWithRepeatedProperty) { kDocumentId0, std::vector<SectionId>{kRepeatedSectionId}))); } -TEST_F(IndexProcessorTest, TooManyTokensReturnError) { - // Only allow the first four tokens ("hello", "world", "good", "night") to be - // indexed. - IndexProcessor::Options options; - options.max_tokens_per_document = 4; - options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; +// TODO(b/196771754) This test is disabled on Android because it takes too long +// to generate all of the unique terms and the test times out. Try storing these +// unique terms in a file that the test can read from. +#ifndef __ANDROID__ - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer_.get(), index_.get(), - options, &fake_clock_)); +TEST_F(IndexProcessorTest, HitBufferExhaustedTest) { + // Testing has shown that adding ~600,000 hits will fill up the hit buffer. + std::vector<std::string> unique_terms_ = GenerateUniqueTerms(200000); + std::string content = absl_ports::StrJoin(unique_terms_, " "); DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactProperty), "hello world") - .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") + .AddStringProperty(std::string(kExactProperty), content) + .AddStringProperty(std::string(kPrefixedProperty), content) + .AddStringProperty(std::string(kRepeatedProperty), content) .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), - StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED, + testing::HasSubstr("Hit buffer is full!"))); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - // "night" should have been indexed. - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("night", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); - - // "moon" should not have been. - ICING_ASSERT_OK_AND_ASSIGN(itr, - index_->GetIterator("moon", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); } -TEST_F(IndexProcessorTest, TooManyTokensSuppressError) { - // Only allow the first four tokens ("hello", "world", "good", "night") to be - // indexed. - IndexProcessor::Options options; - options.max_tokens_per_document = 4; - options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kSuppressError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer_.get(), index_.get(), - options, &fake_clock_)); +TEST_F(IndexProcessorTest, LexiconExhaustedTest) { + // Testing has shown that adding ~300,000 terms generated this way will + // fill up the lexicon. + std::vector<std::string> unique_terms_ = GenerateUniqueTerms(300000); + std::string content = absl_ports::StrJoin(unique_terms_, " "); DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactProperty), "hello world") - .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") + .AddStringProperty(std::string(kExactProperty), content) .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), - IsOk()); + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED, + testing::HasSubstr("Unable to add term"))); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - // "night" should have been indexed. - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("night", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); - - // "moon" should not have been. - ICING_ASSERT_OK_AND_ASSIGN(itr, - index_->GetIterator("moon", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); } +#endif // __ANDROID__ + TEST_F(IndexProcessorTest, TooLongTokens) { // Only allow the tokens of length four, truncating "hello", "world" and // "night". - IndexProcessor::Options options; - options.max_tokens_per_document = 1000; - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Normalizer> normalizer, normalizer_factory::Create( /*max_term_byte_size=*/4)); ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer.get(), index_.get(), - options, &fake_clock_)); + index_processor_, + IndexProcessor::Create(normalizer.get(), index_.get(), &fake_clock_)); DocumentProto document = DocumentBuilder() @@ -692,16 +648,6 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { lang_segmenter_, language_segmenter_factory::Create(std::move(segmenter_options))); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); - DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -727,23 +673,13 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { TEST_F(IndexProcessorTest, LexiconFullIndexesSmallerTokensReturnsResourceExhausted) { - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); - // This is the maximum token length that an empty lexicon constructed for a // lite index with merge size of 1MiB can support. constexpr int kMaxTokenLength = 16777217; // Create a string "ppppppp..." with a length that is too large to fit into // the lexicon. std::string enormous_string(kMaxTokenLength + 1, 'p'); - DocumentProto document = + DocumentProto document_one = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) @@ -754,24 +690,10 @@ TEST_F(IndexProcessorTest, ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), - document)); + document_one)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("foo", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kExactSectionId}))); - - ICING_ASSERT_OK_AND_ASSIGN( - itr, - index_->GetIterator("baz", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); } TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { @@ -795,15 +717,9 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { ICING_ASSERT_OK_AND_ASSIGN( index_, Index::Create(options, &filesystem_, &icing_filesystem_)); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); DocumentId doc_id = 0; // Have determined experimentally that indexing 3373 documents with this text // will cause the LiteIndex to fill up. Further indexing will fail unless the @@ -857,15 +773,9 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) { index_, Index::Create(options, &filesystem_, mock_icing_filesystem_.get())); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); // 3. Index one document. This should fit in the LiteIndex without requiring a // merge. diff --git a/icing/index/index.cc b/icing/index/index.cc index db59ad2..1bdab21 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -36,6 +36,7 @@ #include "icing/legacy/index/icing-filesystem.h" #include "icing/proto/term.pb.h" #include "icing/schema/section.h" +#include "icing/scoring/ranker.h" #include "icing/store/document-id.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -89,20 +90,24 @@ bool IsTermInNamespaces( } enum class MergeAction { kTakeLiteTerm, kTakeMainTerm, kMergeTerms }; -std::vector<TermMetadata> MergeTermMetadatas( + +// Merge the TermMetadata from lite index and main index. If the term exists in +// both index, sum up its hit count and push it to the term heap. +// The heap is a min-heap. So that we can avoid some push operation but the time +// complexity is O(NlgK) which N is total number of term and K is num_to_return. +std::vector<TermMetadata> MergeAndRankTermMetadatas( std::vector<TermMetadata> lite_term_metadata_list, std::vector<TermMetadata> main_term_metadata_list, int num_to_return) { - std::vector<TermMetadata> merged_term_metadata_list; - merged_term_metadata_list.reserve( + std::vector<TermMetadata> merged_term_metadata_heap; + merged_term_metadata_heap.reserve( std::min(lite_term_metadata_list.size() + main_term_metadata_list.size(), static_cast<size_t>(num_to_return))); auto lite_term_itr = lite_term_metadata_list.begin(); auto main_term_itr = main_term_metadata_list.begin(); MergeAction merge_action; - while (merged_term_metadata_list.size() < num_to_return && - (lite_term_itr != lite_term_metadata_list.end() || - main_term_itr != main_term_metadata_list.end())) { + while (lite_term_itr != lite_term_metadata_list.end() || + main_term_itr != main_term_metadata_list.end()) { // Get pointers to the next metadatas in each group, if available // Determine how to merge. if (main_term_itr == main_term_metadata_list.end()) { @@ -119,23 +124,32 @@ std::vector<TermMetadata> MergeTermMetadatas( } switch (merge_action) { case MergeAction::kTakeLiteTerm: - merged_term_metadata_list.push_back(std::move(*lite_term_itr)); + PushToTermHeap(std::move(*lite_term_itr), num_to_return, + merged_term_metadata_heap); ++lite_term_itr; break; case MergeAction::kTakeMainTerm: - merged_term_metadata_list.push_back(std::move(*main_term_itr)); + PushToTermHeap(std::move(*main_term_itr), num_to_return, + merged_term_metadata_heap); ++main_term_itr; break; case MergeAction::kMergeTerms: int total_est_hit_count = lite_term_itr->hit_count + main_term_itr->hit_count; - merged_term_metadata_list.emplace_back( - std::move(lite_term_itr->content), total_est_hit_count); + PushToTermHeap(TermMetadata(std::move(lite_term_itr->content), + total_est_hit_count), + num_to_return, merged_term_metadata_heap); ++lite_term_itr; ++main_term_itr; break; } } + // Reverse the list since we pop them from a min heap and we need to return in + // decreasing order. + std::vector<TermMetadata> merged_term_metadata_list = + PopAllTermsFromHeap(merged_term_metadata_heap); + std::reverse(merged_term_metadata_list.begin(), + merged_term_metadata_list.end()); return merged_term_metadata_list; } @@ -214,8 +228,7 @@ Index::GetIterator(const std::string& term, SectionIdMask section_id_mask, libtextclassifier3::StatusOr<std::vector<TermMetadata>> Index::FindLiteTermsByPrefix(const std::string& prefix, - const std::vector<NamespaceId>& namespace_ids, - int num_to_return) { + const std::vector<NamespaceId>& namespace_ids) { // Finds all the terms that start with the given prefix in the lexicon. IcingDynamicTrie::Iterator term_iterator(lite_index_->lexicon(), prefix.c_str()); @@ -224,7 +237,7 @@ Index::FindLiteTermsByPrefix(const std::string& prefix, IcingDynamicTrie::PropertyReadersAll property_reader(lite_index_->lexicon()); std::vector<TermMetadata> term_metadata_list; - while (term_iterator.IsValid() && term_metadata_list.size() < num_to_return) { + while (term_iterator.IsValid()) { uint32_t term_value_index = term_iterator.GetValueIndex(); // Skips the terms that don't exist in the given namespaces. We won't skip @@ -244,13 +257,6 @@ Index::FindLiteTermsByPrefix(const std::string& prefix, term_iterator.Advance(); } - if (term_iterator.IsValid()) { - // We exited the loop above because we hit the num_to_return limit. - ICING_LOG(WARNING) << "Ran into limit of " << num_to_return - << " retrieving suggestions for " << prefix - << ". Some suggestions may not be returned and others " - "may be misranked."; - } return term_metadata_list; } @@ -264,17 +270,15 @@ Index::FindTermsByPrefix(const std::string& prefix, } // Get results from the LiteIndex. - ICING_ASSIGN_OR_RETURN( - std::vector<TermMetadata> lite_term_metadata_list, - FindLiteTermsByPrefix(prefix, namespace_ids, num_to_return)); - + ICING_ASSIGN_OR_RETURN(std::vector<TermMetadata> lite_term_metadata_list, + FindLiteTermsByPrefix(prefix, namespace_ids)); // Append results from the MainIndex. - ICING_ASSIGN_OR_RETURN( - std::vector<TermMetadata> main_term_metadata_list, - main_index_->FindTermsByPrefix(prefix, namespace_ids, num_to_return)); + ICING_ASSIGN_OR_RETURN(std::vector<TermMetadata> main_term_metadata_list, + main_index_->FindTermsByPrefix(prefix, namespace_ids)); - return MergeTermMetadatas(std::move(lite_term_metadata_list), - std::move(main_term_metadata_list), num_to_return); + return MergeAndRankTermMetadatas(std::move(lite_term_metadata_list), + std::move(main_term_metadata_list), + num_to_return); } IndexStorageInfoProto Index::GetStorageInfo() const { diff --git a/icing/index/index.h b/icing/index/index.h index eab5be8..693cf04 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -267,8 +267,7 @@ class Index { filesystem_(filesystem) {} libtextclassifier3::StatusOr<std::vector<TermMetadata>> FindLiteTermsByPrefix( - const std::string& prefix, const std::vector<NamespaceId>& namespace_ids, - int num_to_return); + const std::string& prefix, const std::vector<NamespaceId>& namespace_ids); std::unique_ptr<LiteIndex> lite_index_; std::unique_ptr<MainIndex> main_index_; diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 16593ef..00d5ad6 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -88,6 +88,11 @@ constexpr DocumentId kDocumentId4 = 4; constexpr DocumentId kDocumentId5 = 5; constexpr DocumentId kDocumentId6 = 6; constexpr DocumentId kDocumentId7 = 7; +constexpr DocumentId kDocumentId8 = 8; +constexpr DocumentId kDocumentId9 = 9; +constexpr DocumentId kDocumentId10 = 10; +constexpr DocumentId kDocumentId11 = 11; +constexpr DocumentId kDocumentId12 = 12; constexpr SectionId kSectionId2 = 2; constexpr SectionId kSectionId3 = 3; @@ -1105,11 +1110,10 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCorrectHitCount) { EXPECT_THAT(edit2.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit, 'fool' has 2 hits. - EXPECT_THAT( - index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre(EqualsTermMetadata("foo", 1), - EqualsTermMetadata("fool", 2)))); + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 2), + EqualsTermMetadata("foo", 1)))); ICING_ASSERT_OK(index_->Merge()); @@ -1122,6 +1126,155 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCorrectHitCount) { EqualsTermMetadata("fool", kMinSizePlApproxHits)))); } +TEST_F(IndexTest, FindTermByPrefixShouldReturnInOrder) { + // Push 6 term-six, 5 term-five, 4 term-four, 3 term-three, 2 term-two and one + // term-one into lite index. + Index::Editor edit1 = + index_->Edit(kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit1.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit1.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit2 = + index_->Edit(kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit2.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit2.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit3 = + index_->Edit(kDocumentId3, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit3.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit3.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit4 = + index_->Edit(kDocumentId4, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit4.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit4.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit4.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit4.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit5 = + index_->Edit(kDocumentId5, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit5.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit5.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit5.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit6 = + index_->Edit(kDocumentId6, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit6.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit6.IndexAllBufferedTerms(), IsOk()); + + // verify the order in lite index is correct. + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("term-six", 6), + EqualsTermMetadata("term-five", 5), + EqualsTermMetadata("term-four", 4), + EqualsTermMetadata("term-three", 3), + EqualsTermMetadata("term-two", 2), + EqualsTermMetadata("term-one", 1)))); + + ICING_ASSERT_OK(index_->Merge()); + + // Since most of term has same approx hit count, we don't verify order in the + // main index. + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(UnorderedElementsAre( + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits), + EqualsTermMetadata("term-five", kSecondSmallestPlApproxHits), + EqualsTermMetadata("term-four", kMinSizePlApproxHits), + EqualsTermMetadata("term-three", kMinSizePlApproxHits), + EqualsTermMetadata("term-two", kMinSizePlApproxHits), + EqualsTermMetadata("term-one", kMinSizePlApproxHits)))); + + // keep push terms to the lite index. For term 1-4, since they has same hit + // count kMinSizePlApproxHits, we will push 4 term-one, 3 term-two, 2 + // term-three and one term-four to make them in reverse order. And for term + // 5 & 6, we will push 2 term-five and one term-six. + Index::Editor edit7 = + index_->Edit(kDocumentId7, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit7.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit7.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit8 = + index_->Edit(kDocumentId8, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit8.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit8.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit8.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit8.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit9 = + index_->Edit(kDocumentId9, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit9.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit9.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit9.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit10 = + index_->Edit(kDocumentId10, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit10.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit10.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit11 = + index_->Edit(kDocumentId11, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit11.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit11.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit11.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit12 = + index_->Edit(kDocumentId12, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit12.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit12.IndexAllBufferedTerms(), IsOk()); + + // verify the combination of lite index and main index is in correct order. + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("term-five", + kSecondSmallestPlApproxHits + 2), // 9 + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits + 1), // 8 + EqualsTermMetadata("term-one", kMinSizePlApproxHits + 4), // 7 + EqualsTermMetadata("term-two", kMinSizePlApproxHits + 3), // 6 + EqualsTermMetadata("term-three", kMinSizePlApproxHits + 2), // 5 + EqualsTermMetadata("term-four", kMinSizePlApproxHits + 1)))); // 4 + + // Get the first three terms. + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/3), + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("term-five", + kSecondSmallestPlApproxHits + 2), // 9 + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits + 1), // 8 + EqualsTermMetadata("term-one", kMinSizePlApproxHits + 4)))); // 7 +} + TEST_F(IndexTest, FindTermByPrefixShouldReturnApproximateHitCountForMain) { Index::Editor edit = index_->Edit(kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, @@ -1160,11 +1313,10 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnApproximateHitCountForMain) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit, 'fool' has 8 hits. - EXPECT_THAT( - index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre(EqualsTermMetadata("foo", 1), - EqualsTermMetadata("fool", 8)))); + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 8), + EqualsTermMetadata("foo", 1)))); ICING_ASSERT_OK(index_->Merge()); @@ -1195,9 +1347,9 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCombinedHitCount) { // 1 hit in the lite index. EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre( - EqualsTermMetadata("foo", kMinSizePlApproxHits), - EqualsTermMetadata("fool", kMinSizePlApproxHits + 1)))); + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("fool", kMinSizePlApproxHits + 1), + EqualsTermMetadata("foo", kMinSizePlApproxHits)))); } TEST_F(IndexTest, FindTermByPrefixShouldReturnTermsFromBothIndices) { @@ -1215,11 +1367,11 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnTermsFromBothIndices) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit in the main index, 'fool' has 1 hit in the lite index. - EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre( - EqualsTermMetadata("foo", kMinSizePlApproxHits), - EqualsTermMetadata("fool", 1)))); + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("foo", kMinSizePlApproxHits), + EqualsTermMetadata("fool", 1)))); } TEST_F(IndexTest, GetElementsSize) { diff --git a/icing/index/iterator/doc-hit-info-iterator-and.cc b/icing/index/iterator/doc-hit-info-iterator-and.cc index 39aa969..543e9ef 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and.cc @@ -14,8 +14,7 @@ #include "icing/index/iterator/doc-hit-info-iterator-and.h" -#include <stddef.h> - +#include <cstddef> #include <cstdint> #include <memory> #include <string> diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index fb23934..9e4ac28 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -14,12 +14,11 @@ #include "icing/index/lite/lite-index.h" -#include <inttypes.h> -#include <stddef.h> -#include <stdint.h> #include <sys/mman.h> #include <algorithm> +#include <cinttypes> +#include <cstddef> #include <cstdint> #include <memory> #include <string> diff --git a/icing/index/main/flash-index-storage.cc b/icing/index/main/flash-index-storage.cc index f125b6d..3c52375 100644 --- a/icing/index/main/flash-index-storage.cc +++ b/icing/index/main/flash-index-storage.cc @@ -14,11 +14,11 @@ #include "icing/index/main/flash-index-storage.h" -#include <errno.h> -#include <inttypes.h> #include <sys/types.h> #include <algorithm> +#include <cerrno> +#include <cinttypes> #include <cstdint> #include <memory> #include <unordered_set> diff --git a/icing/index/main/flash-index-storage_test.cc b/icing/index/main/flash-index-storage_test.cc index 7e15524..25fcaad 100644 --- a/icing/index/main/flash-index-storage_test.cc +++ b/icing/index/main/flash-index-storage_test.cc @@ -14,10 +14,10 @@ #include "icing/index/main/flash-index-storage.h" -#include <stdlib.h> #include <unistd.h> #include <algorithm> +#include <cstdlib> #include <limits> #include <utility> #include <vector> diff --git a/icing/index/main/index-block.cc b/icing/index/main/index-block.cc index 4590d06..c6ab345 100644 --- a/icing/index/main/index-block.cc +++ b/icing/index/main/index-block.cc @@ -14,9 +14,8 @@ #include "icing/index/main/index-block.h" -#include <inttypes.h> - #include <algorithm> +#include <cinttypes> #include <limits> #include "icing/text_classifier/lib3/utils/base/statusor.h" diff --git a/icing/index/main/index-block.h b/icing/index/main/index-block.h index edf9a79..5d75a2a 100644 --- a/icing/index/main/index-block.h +++ b/icing/index/main/index-block.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_INDEX_BLOCK_H_ #define ICING_INDEX_MAIN_INDEX_BLOCK_H_ -#include <string.h> #include <sys/mman.h> #include <algorithm> +#include <cstring> #include <limits> #include <memory> #include <string> diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 8ae6b27..b185138 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -217,8 +217,7 @@ bool IsTermInNamespaces( libtextclassifier3::StatusOr<std::vector<TermMetadata>> MainIndex::FindTermsByPrefix(const std::string& prefix, - const std::vector<NamespaceId>& namespace_ids, - int num_to_return) { + const std::vector<NamespaceId>& namespace_ids) { // Finds all the terms that start with the given prefix in the lexicon. IcingDynamicTrie::Iterator term_iterator(*main_lexicon_, prefix.c_str()); @@ -226,7 +225,7 @@ MainIndex::FindTermsByPrefix(const std::string& prefix, IcingDynamicTrie::PropertyReadersAll property_reader(*main_lexicon_); std::vector<TermMetadata> term_metadata_list; - while (term_iterator.IsValid() && term_metadata_list.size() < num_to_return) { + while (term_iterator.IsValid()) { uint32_t term_value_index = term_iterator.GetValueIndex(); // Skips the terms that don't exist in the given namespaces. We won't skip @@ -250,13 +249,6 @@ MainIndex::FindTermsByPrefix(const std::string& prefix, term_iterator.Advance(); } - if (term_iterator.IsValid()) { - // We exited the loop above because we hit the num_to_return limit. - ICING_LOG(WARNING) << "Ran into limit of " << num_to_return - << " retrieving suggestions for " << prefix - << ". Some suggestions may not be returned and others " - "may be misranked."; - } return term_metadata_list; } diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index 43635ca..919a5c5 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -81,8 +81,7 @@ class MainIndex { // A list of TermMetadata on success // INTERNAL_ERROR if failed to access term data. libtextclassifier3::StatusOr<std::vector<TermMetadata>> FindTermsByPrefix( - const std::string& prefix, const std::vector<NamespaceId>& namespace_ids, - int num_to_return); + const std::string& prefix, const std::vector<NamespaceId>& namespace_ids); struct LexiconMergeOutputs { // Maps from main_lexicon tvi for new branching point to the main_lexicon diff --git a/icing/index/main/posting-list-free.h b/icing/index/main/posting-list-free.h index 4f06057..75b99d7 100644 --- a/icing/index/main/posting-list-free.h +++ b/icing/index/main/posting-list-free.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_POSTING_LIST_FREE_H_ #define ICING_INDEX_MAIN_POSTING_LIST_FREE_H_ -#include <string.h> #include <sys/mman.h> #include <cstdint> +#include <cstring> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" diff --git a/icing/index/main/posting-list-used.h b/icing/index/main/posting-list-used.h index 1b2e24e..8944034 100644 --- a/icing/index/main/posting-list-used.h +++ b/icing/index/main/posting-list-used.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_POSTING_LIST_USED_H_ #define ICING_INDEX_MAIN_POSTING_LIST_USED_H_ -#include <string.h> #include <sys/mman.h> #include <algorithm> +#include <cstring> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index ea2bcf7..51d3423 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -420,4 +420,23 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReset( return SerializeProtoToJniByteArray(env, reset_result_proto); } +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( + JNIEnv* env, jclass clazz, jobject object, + jbyteArray suggestion_spec_bytes) { + icing::lib::IcingSearchEngine* icing = + GetIcingSearchEnginePointer(env, object); + + icing::lib::SuggestionSpecProto suggestion_spec_proto; + if (!ParseProtoFromJniByteArray(env, suggestion_spec_bytes, + &suggestion_spec_proto)) { + ICING_LOG(ERROR) << "Failed to parse SuggestionSpecProto in nativeSearch"; + return nullptr; + } + icing::lib::SuggestionResponse suggestionResponse = + icing->SearchSuggestions(suggestion_spec_proto); + + return SerializeProtoToJniByteArray(env, suggestionResponse); +} + } // extern "C" diff --git a/icing/legacy/core/icing-core-types.h b/icing/legacy/core/icing-core-types.h index cc12663..7db8408 100644 --- a/icing/legacy/core/icing-core-types.h +++ b/icing/legacy/core/icing-core-types.h @@ -21,9 +21,8 @@ #ifndef ICING_LEGACY_CORE_ICING_CORE_TYPES_H_ #define ICING_LEGACY_CORE_ICING_CORE_TYPES_H_ -#include <stdint.h> - #include <cstddef> // size_t not defined implicitly for all platforms. +#include <cstdint> #include <vector> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/legacy/core/icing-string-util.cc b/icing/legacy/core/icing-string-util.cc index 2eb64ac..ed06e03 100644 --- a/icing/legacy/core/icing-string-util.cc +++ b/icing/legacy/core/icing-string-util.cc @@ -13,12 +13,11 @@ // limitations under the License. #include "icing/legacy/core/icing-string-util.h" -#include <stdarg.h> -#include <stddef.h> -#include <stdint.h> -#include <stdio.h> - #include <algorithm> +#include <cstdarg> +#include <cstddef> +#include <cstdint> +#include <cstdio> #include <string> #include "icing/legacy/portable/icing-zlib.h" diff --git a/icing/legacy/core/icing-string-util.h b/icing/legacy/core/icing-string-util.h index 767e581..e5e4941 100644 --- a/icing/legacy/core/icing-string-util.h +++ b/icing/legacy/core/icing-string-util.h @@ -15,9 +15,8 @@ #ifndef ICING_LEGACY_CORE_ICING_STRING_UTIL_H_ #define ICING_LEGACY_CORE_ICING_STRING_UTIL_H_ -#include <stdarg.h> -#include <stdint.h> - +#include <cstdarg> +#include <cstdint> #include <string> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/legacy/core/icing-timer.h b/icing/legacy/core/icing-timer.h index 49ba9ad..af38912 100644 --- a/icing/legacy/core/icing-timer.h +++ b/icing/legacy/core/icing-timer.h @@ -16,7 +16,8 @@ #define ICING_LEGACY_CORE_ICING_TIMER_H_ #include <sys/time.h> -#include <time.h> + +#include <ctime> namespace icing { namespace lib { diff --git a/icing/legacy/index/icing-array-storage.cc b/icing/legacy/index/icing-array-storage.cc index b462135..4d2ef67 100644 --- a/icing/legacy/index/icing-array-storage.cc +++ b/icing/legacy/index/icing-array-storage.cc @@ -14,10 +14,10 @@ #include "icing/legacy/index/icing-array-storage.h" -#include <inttypes.h> #include <sys/mman.h> #include <algorithm> +#include <cinttypes> #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/core/icing-timer.h" diff --git a/icing/legacy/index/icing-array-storage.h b/icing/legacy/index/icing-array-storage.h index fad0565..0d93172 100644 --- a/icing/legacy/index/icing-array-storage.h +++ b/icing/legacy/index/icing-array-storage.h @@ -20,8 +20,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_ #define ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_ -#include <stdint.h> - +#include <cstdint> #include <string> #include <vector> diff --git a/icing/legacy/index/icing-bit-util.h b/icing/legacy/index/icing-bit-util.h index 3273a68..d0c3f50 100644 --- a/icing/legacy/index/icing-bit-util.h +++ b/icing/legacy/index/icing-bit-util.h @@ -20,9 +20,8 @@ #ifndef ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_ #define ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_ -#include <stdint.h> -#include <stdio.h> - +#include <cstdint> +#include <cstdio> #include <limits> #include <vector> diff --git a/icing/legacy/index/icing-dynamic-trie.cc b/icing/legacy/index/icing-dynamic-trie.cc index 29843ba..baa043a 100644 --- a/icing/legacy/index/icing-dynamic-trie.cc +++ b/icing/legacy/index/icing-dynamic-trie.cc @@ -62,15 +62,15 @@ #include "icing/legacy/index/icing-dynamic-trie.h" -#include <errno.h> #include <fcntl.h> -#include <inttypes.h> -#include <string.h> #include <sys/mman.h> #include <sys/stat.h> #include <unistd.h> #include <algorithm> +#include <cerrno> +#include <cinttypes> +#include <cstring> #include <memory> #include <utility> diff --git a/icing/legacy/index/icing-dynamic-trie.h b/icing/legacy/index/icing-dynamic-trie.h index 7fe290b..8821799 100644 --- a/icing/legacy/index/icing-dynamic-trie.h +++ b/icing/legacy/index/icing-dynamic-trie.h @@ -35,8 +35,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_ #define ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_ -#include <stdint.h> - +#include <cstdint> #include <memory> #include <string> #include <unordered_map> diff --git a/icing/legacy/index/icing-filesystem.cc b/icing/legacy/index/icing-filesystem.cc index 90e9146..4f5e571 100644 --- a/icing/legacy/index/icing-filesystem.cc +++ b/icing/legacy/index/icing-filesystem.cc @@ -16,7 +16,6 @@ #include <dirent.h> #include <dlfcn.h> -#include <errno.h> #include <fcntl.h> #include <fnmatch.h> #include <pthread.h> @@ -27,6 +26,7 @@ #include <unistd.h> #include <algorithm> +#include <cerrno> #include <unordered_set> #include "icing/absl_ports/str_cat.h" diff --git a/icing/legacy/index/icing-flash-bitmap.h b/icing/legacy/index/icing-flash-bitmap.h index 3b3521a..e3ba0e2 100644 --- a/icing/legacy/index/icing-flash-bitmap.h +++ b/icing/legacy/index/icing-flash-bitmap.h @@ -37,8 +37,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_ #define ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_ -#include <stdint.h> - +#include <cstdint> #include <memory> #include <string> diff --git a/icing/legacy/index/icing-mmapper.cc b/icing/legacy/index/icing-mmapper.cc index 737335c..7946c82 100644 --- a/icing/legacy/index/icing-mmapper.cc +++ b/icing/legacy/index/icing-mmapper.cc @@ -17,10 +17,11 @@ // #include "icing/legacy/index/icing-mmapper.h" -#include <errno.h> -#include <string.h> #include <sys/mman.h> +#include <cerrno> +#include <cstring> + #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/util/logging.h" diff --git a/icing/legacy/index/icing-mock-filesystem.h b/icing/legacy/index/icing-mock-filesystem.h index 75ac62f..122ee7b 100644 --- a/icing/legacy/index/icing-mock-filesystem.h +++ b/icing/legacy/index/icing-mock-filesystem.h @@ -15,16 +15,15 @@ #ifndef ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_ #define ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_ -#include <stdint.h> -#include <stdio.h> -#include <string.h> - +#include <cstdint> +#include <cstdio> +#include <cstring> #include <memory> #include <string> #include <vector> -#include "icing/legacy/index/icing-filesystem.h" #include "gmock/gmock.h" +#include "icing/legacy/index/icing-filesystem.h" namespace icing { namespace lib { diff --git a/icing/legacy/index/icing-storage-file.cc b/icing/legacy/index/icing-storage-file.cc index b27ec67..35a4418 100644 --- a/icing/legacy/index/icing-storage-file.cc +++ b/icing/legacy/index/icing-storage-file.cc @@ -14,9 +14,9 @@ #include "icing/legacy/index/icing-storage-file.h" -#include <inttypes.h> #include <unistd.h> +#include <cinttypes> #include <string> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/portable/endian.h b/icing/portable/endian.h index 595b956..ecebb15 100644 --- a/icing/portable/endian.h +++ b/icing/portable/endian.h @@ -77,7 +77,7 @@ // The following guarantees declaration of the byte swap functions #ifdef COMPILER_MSVC -#include <stdlib.h> // NOLINT(build/include) +#include <cstdlib> // NOLINT(build/include) #define bswap_16(x) _byteswap_ushort(x) #define bswap_32(x) _byteswap_ulong(x) diff --git a/icing/portable/gzip_stream.cc b/icing/portable/gzip_stream.cc new file mode 100644 index 0000000..f00a993 --- /dev/null +++ b/icing/portable/gzip_stream.cc @@ -0,0 +1,313 @@ +// Copyright (C) 2009 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the implementation of classes GzipInputStream and +// GzipOutputStream. It is forked from protobuf because these classes are only +// provided in libprotobuf-full but we would like to link libicing against the +// smaller libprotobuf-lite instead. + +#include "icing/portable/gzip_stream.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { +namespace protobuf_ports { + +static const int kDefaultBufferSize = 65536; + +GzipInputStream::GzipInputStream(ZeroCopyInputStream* sub_stream, Format format, + int buffer_size) + : format_(format), sub_stream_(sub_stream), zerror_(Z_OK), byte_count_(0) { + zcontext_.state = Z_NULL; + zcontext_.zalloc = Z_NULL; + zcontext_.zfree = Z_NULL; + zcontext_.opaque = Z_NULL; + zcontext_.total_out = 0; + zcontext_.next_in = NULL; + zcontext_.avail_in = 0; + zcontext_.total_in = 0; + zcontext_.msg = NULL; + if (buffer_size == -1) { + output_buffer_length_ = kDefaultBufferSize; + } else { + output_buffer_length_ = buffer_size; + } + output_buffer_ = operator new(output_buffer_length_); + zcontext_.next_out = static_cast<Bytef*>(output_buffer_); + zcontext_.avail_out = output_buffer_length_; + output_position_ = output_buffer_; +} +GzipInputStream::~GzipInputStream() { + operator delete(output_buffer_); + zerror_ = inflateEnd(&zcontext_); +} + +static inline int internalInflateInit2(z_stream* zcontext, + GzipInputStream::Format format) { + int windowBitsFormat = 0; + switch (format) { + case GzipInputStream::GZIP: + windowBitsFormat = 16; + break; + case GzipInputStream::AUTO: + windowBitsFormat = 32; + break; + case GzipInputStream::ZLIB: + windowBitsFormat = 0; + break; + } + return inflateInit2(zcontext, /* windowBits */ 15 | windowBitsFormat); +} + +int GzipInputStream::Inflate(int flush) { + if ((zerror_ == Z_OK) && (zcontext_.avail_out == 0)) { + // previous inflate filled output buffer. don't change input params yet. + } else if (zcontext_.avail_in == 0) { + const void* in; + int in_size; + bool first = zcontext_.next_in == NULL; + bool ok = sub_stream_->Next(&in, &in_size); + if (!ok) { + zcontext_.next_out = NULL; + zcontext_.avail_out = 0; + return Z_STREAM_END; + } + zcontext_.next_in = static_cast<Bytef*>(const_cast<void*>(in)); + zcontext_.avail_in = in_size; + if (first) { + int error = internalInflateInit2(&zcontext_, format_); + if (error != Z_OK) { + return error; + } + } + } + zcontext_.next_out = static_cast<Bytef*>(output_buffer_); + zcontext_.avail_out = output_buffer_length_; + output_position_ = output_buffer_; + int error = inflate(&zcontext_, flush); + return error; +} + +void GzipInputStream::DoNextOutput(const void** data, int* size) { + *data = output_position_; + *size = ((uintptr_t)zcontext_.next_out) - ((uintptr_t)output_position_); + output_position_ = zcontext_.next_out; +} + +// implements ZeroCopyInputStream ---------------------------------- +bool GzipInputStream::Next(const void** data, int* size) { + bool ok = (zerror_ == Z_OK) || (zerror_ == Z_STREAM_END) || + (zerror_ == Z_BUF_ERROR); + if ((!ok) || (zcontext_.next_out == NULL)) { + return false; + } + if (zcontext_.next_out != output_position_) { + DoNextOutput(data, size); + return true; + } + if (zerror_ == Z_STREAM_END) { + if (zcontext_.next_out != NULL) { + // sub_stream_ may have concatenated streams to follow + zerror_ = inflateEnd(&zcontext_); + byte_count_ += zcontext_.total_out; + if (zerror_ != Z_OK) { + return false; + } + zerror_ = internalInflateInit2(&zcontext_, format_); + if (zerror_ != Z_OK) { + return false; + } + } else { + *data = NULL; + *size = 0; + return false; + } + } + zerror_ = Inflate(Z_NO_FLUSH); + if ((zerror_ == Z_STREAM_END) && (zcontext_.next_out == NULL)) { + // The underlying stream's Next returned false inside Inflate. + return false; + } + ok = (zerror_ == Z_OK) || (zerror_ == Z_STREAM_END) || + (zerror_ == Z_BUF_ERROR); + if (!ok) { + return false; + } + DoNextOutput(data, size); + return true; +} +void GzipInputStream::BackUp(int count) { + output_position_ = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(output_position_) - count); +} +bool GzipInputStream::Skip(int count) { + const void* data; + int size = 0; + bool ok = Next(&data, &size); + while (ok && (size < count)) { + count -= size; + ok = Next(&data, &size); + } + if (size > count) { + BackUp(size - count); + } + return ok; +} +int64_t GzipInputStream::ByteCount() const { + int64_t ret = byte_count_ + zcontext_.total_out; + if (zcontext_.next_out != NULL && output_position_ != NULL) { + ret += reinterpret_cast<uintptr_t>(zcontext_.next_out) - + reinterpret_cast<uintptr_t>(output_position_); + } + return ret; +} + +// ========================================================================= + +GzipOutputStream::Options::Options() + : format(GZIP), + buffer_size(kDefaultBufferSize), + compression_level(Z_DEFAULT_COMPRESSION), + compression_strategy(Z_DEFAULT_STRATEGY) {} + +GzipOutputStream::GzipOutputStream(ZeroCopyOutputStream* sub_stream) { + Init(sub_stream, Options()); +} + +GzipOutputStream::GzipOutputStream(ZeroCopyOutputStream* sub_stream, + const Options& options) { + Init(sub_stream, options); +} + +void GzipOutputStream::Init(ZeroCopyOutputStream* sub_stream, + const Options& options) { + sub_stream_ = sub_stream; + sub_data_ = NULL; + sub_data_size_ = 0; + + input_buffer_length_ = options.buffer_size; + input_buffer_ = operator new(input_buffer_length_); + + zcontext_.zalloc = Z_NULL; + zcontext_.zfree = Z_NULL; + zcontext_.opaque = Z_NULL; + zcontext_.next_out = NULL; + zcontext_.avail_out = 0; + zcontext_.total_out = 0; + zcontext_.next_in = NULL; + zcontext_.avail_in = 0; + zcontext_.total_in = 0; + zcontext_.msg = NULL; + // default to GZIP format + int windowBitsFormat = 16; + if (options.format == ZLIB) { + windowBitsFormat = 0; + } + zerror_ = + deflateInit2(&zcontext_, options.compression_level, Z_DEFLATED, + /* windowBits */ 15 | windowBitsFormat, + /* memLevel (default) */ 8, options.compression_strategy); +} + +GzipOutputStream::~GzipOutputStream() { + Close(); + operator delete(input_buffer_); +} + +// private +int GzipOutputStream::Deflate(int flush) { + int error = Z_OK; + do { + if ((sub_data_ == NULL) || (zcontext_.avail_out == 0)) { + bool ok = sub_stream_->Next(&sub_data_, &sub_data_size_); + if (!ok) { + sub_data_ = NULL; + sub_data_size_ = 0; + return Z_BUF_ERROR; + } + if (sub_data_size_ <= 0) { + ICING_LOG(FATAL) << "Failed to advance underlying stream"; + } + zcontext_.next_out = static_cast<Bytef*>(sub_data_); + zcontext_.avail_out = sub_data_size_; + } + error = deflate(&zcontext_, flush); + } while (error == Z_OK && zcontext_.avail_out == 0); + if ((flush == Z_FULL_FLUSH) || (flush == Z_FINISH)) { + // Notify lower layer of data. + sub_stream_->BackUp(zcontext_.avail_out); + // We don't own the buffer anymore. + sub_data_ = NULL; + sub_data_size_ = 0; + } + return error; +} + +// implements ZeroCopyOutputStream --------------------------------- +bool GzipOutputStream::Next(void** data, int* size) { + if ((zerror_ != Z_OK) && (zerror_ != Z_BUF_ERROR)) { + return false; + } + if (zcontext_.avail_in != 0) { + zerror_ = Deflate(Z_NO_FLUSH); + if (zerror_ != Z_OK) { + return false; + } + } + if (zcontext_.avail_in == 0) { + // all input was consumed. reset the buffer. + zcontext_.next_in = static_cast<Bytef*>(input_buffer_); + zcontext_.avail_in = input_buffer_length_; + *data = input_buffer_; + *size = input_buffer_length_; + } else { + // The loop in Deflate should consume all avail_in + ICING_LOG(ERROR) << "Deflate left bytes unconsumed"; + } + return true; +} +void GzipOutputStream::BackUp(int count) { + if (zcontext_.avail_in < static_cast<uInt>(count)) { + ICING_LOG(FATAL) << "Not enough data to back up " << count << " bytes"; + } + zcontext_.avail_in -= count; +} +int64_t GzipOutputStream::ByteCount() const { + return zcontext_.total_in + zcontext_.avail_in; +} + +bool GzipOutputStream::Flush() { + zerror_ = Deflate(Z_FULL_FLUSH); + // Return true if the flush succeeded or if it was a no-op. + return (zerror_ == Z_OK) || + (zerror_ == Z_BUF_ERROR && zcontext_.avail_in == 0 && + zcontext_.avail_out != 0); +} + +bool GzipOutputStream::Close() { + if ((zerror_ != Z_OK) && (zerror_ != Z_BUF_ERROR)) { + return false; + } + do { + zerror_ = Deflate(Z_FINISH); + } while (zerror_ == Z_OK); + zerror_ = deflateEnd(&zcontext_); + bool ok = zerror_ == Z_OK; + zerror_ = Z_STREAM_END; + return ok; +} + +} // namespace protobuf_ports +} // namespace lib +} // namespace icing diff --git a/icing/portable/gzip_stream.h b/icing/portable/gzip_stream.h new file mode 100644 index 0000000..602093f --- /dev/null +++ b/icing/portable/gzip_stream.h @@ -0,0 +1,181 @@ +// Copyright (C) 2009 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the definition for classes GzipInputStream and +// GzipOutputStream. It is forked from protobuf because these classes are only +// provided in libprotobuf-full but we would like to link libicing against the +// smaller libprotobuf-lite instead. +// +// GzipInputStream decompresses data from an underlying +// ZeroCopyInputStream and provides the decompressed data as a +// ZeroCopyInputStream. +// +// GzipOutputStream is an ZeroCopyOutputStream that compresses data to +// an underlying ZeroCopyOutputStream. + +#ifndef GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ +#define GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ + +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> +#include "icing/portable/zlib.h" + +namespace icing { +namespace lib { +namespace protobuf_ports { + +// A ZeroCopyInputStream that reads compressed data through zlib +class GzipInputStream : public google::protobuf::io::ZeroCopyInputStream { + public: + // Format key for constructor + enum Format { + // zlib will autodetect gzip header or deflate stream + AUTO = 0, + + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + // buffer_size and format may be -1 for default of 64kB and GZIP format + explicit GzipInputStream( + google::protobuf::io::ZeroCopyInputStream* sub_stream, + Format format = AUTO, int buffer_size = -1); + virtual ~GzipInputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + Format format_; + + google::protobuf::io::ZeroCopyInputStream* sub_stream_; + + z_stream zcontext_; + int zerror_; + + void* output_buffer_; + void* output_position_; + size_t output_buffer_length_; + int64_t byte_count_; + + int Inflate(int flush); + void DoNextOutput(const void** data, int* size); +}; + +class GzipOutputStream : public google::protobuf::io::ZeroCopyOutputStream { + public: + // Format key for constructor + enum Format { + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + struct Options { + // Defaults to GZIP. + Format format; + + // What size buffer to use internally. Defaults to 64kB. + int buffer_size; + + // A number between 0 and 9, where 0 is no compression and 9 is best + // compression. Defaults to Z_DEFAULT_COMPRESSION (see zlib.h). + int compression_level; + + // Defaults to Z_DEFAULT_STRATEGY. Can also be set to Z_FILTERED, + // Z_HUFFMAN_ONLY, or Z_RLE. See the documentation for deflateInit2 in + // zlib.h for definitions of these constants. + int compression_strategy; + + Options(); // Initializes with default values. + }; + + // Create a GzipOutputStream with default options. + explicit GzipOutputStream( + google::protobuf::io::ZeroCopyOutputStream* sub_stream); + + // Create a GzipOutputStream with the given options. + GzipOutputStream( + google::protobuf::io::ZeroCopyOutputStream* sub_stream, + const Options& options); + + virtual ~GzipOutputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // Flushes data written so far to zipped data in the underlying stream. + // It is the caller's responsibility to flush the underlying stream if + // necessary. + // Compression may be less efficient stopping and starting around flushes. + // Returns true if no error. + // + // Please ensure that block size is > 6. Here is an excerpt from the zlib + // doc that explains why: + // + // In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that avail_out + // is greater than six to avoid repeated flush markers due to + // avail_out == 0 on return. + bool Flush(); + + // Writes out all data and closes the gzip stream. + // It is the caller's responsibility to close the underlying stream if + // necessary. + // Returns true if no error. + bool Close(); + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + google::protobuf::io::ZeroCopyOutputStream* sub_stream_; + // Result from calling Next() on sub_stream_ + void* sub_data_; + int sub_data_size_; + + z_stream zcontext_; + int zerror_; + void* input_buffer_; + size_t input_buffer_length_; + + // Shared constructor code. + void Init( + google::protobuf::io::ZeroCopyOutputStream* sub_stream, + const Options& options); + + // Do some compression. + // Takes zlib flush mode. + // Returns zlib error code. + int Deflate(int flush); +}; + +} // namespace protobuf_ports +} // namespace lib +} // namespace icing + +#endif // GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 1f937fd..36c76db 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -182,7 +182,7 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { const Token& token = tokens.at(i); std::unique_ptr<DocHitInfoIterator> result_iterator; - // TODO(cassiewang): Handle negation tokens + // TODO(b/202076890): Handle negation tokens switch (token.type) { case Token::Type::QUERY_LEFT_PARENTHESES: { frames.emplace(ParserStateFrame()); diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc new file mode 100644 index 0000000..9c60810 --- /dev/null +++ b/icing/query/suggestion-processor.cc @@ -0,0 +1,93 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/suggestion-processor.h" + +#include "icing/tokenization/tokenizer-factory.h" +#include "icing/tokenization/tokenizer.h" +#include "icing/transform/normalizer.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> +SuggestionProcessor::Create(Index* index, + const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer) { + ICING_RETURN_ERROR_IF_NULL(index); + ICING_RETURN_ERROR_IF_NULL(language_segmenter); + + return std::unique_ptr<SuggestionProcessor>( + new SuggestionProcessor(index, language_segmenter, normalizer)); +} + +libtextclassifier3::StatusOr<std::vector<TermMetadata>> +SuggestionProcessor::QuerySuggestions( + const icing::lib::SuggestionSpecProto& suggestion_spec, + const std::vector<NamespaceId>& namespace_ids) { + // We use query tokenizer to tokenize the give prefix, and we only use the + // last token to be the suggestion prefix. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<Tokenizer> tokenizer, + tokenizer_factory::CreateIndexingTokenizer( + StringIndexingConfig::TokenizerType::PLAIN, &language_segmenter_)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> iterator, + tokenizer->Tokenize(suggestion_spec.prefix())); + + // If there are previous tokens, they are prepended to the suggestion, + // separated by spaces. + std::string last_token; + int token_start_pos; + while (iterator->Advance()) { + Token token = iterator->GetToken(); + last_token = token.text; + token_start_pos = token.text.data() - suggestion_spec.prefix().c_str(); + } + + // If the position of the last token is not the end of the prefix, it means + // there should be some operator tokens after it and are ignored by the + // tokenizer. + bool is_last_token = token_start_pos + last_token.length() >= + suggestion_spec.prefix().length(); + + if (!is_last_token || last_token.empty()) { + // We don't have a valid last token, return early. + return std::vector<TermMetadata>(); + } + + std::string query_prefix = + suggestion_spec.prefix().substr(0, token_start_pos); + // Run suggestion based on given SuggestionSpec. + // Normalize token text to lowercase since all tokens in the lexicon are + // lowercase. + ICING_ASSIGN_OR_RETURN( + std::vector<TermMetadata> terms, + index_.FindTermsByPrefix(normalizer_.NormalizeTerm(last_token), + namespace_ids, suggestion_spec.num_to_return())); + + for (TermMetadata& term : terms) { + term.content = query_prefix + term.content; + } + return terms; +} + +SuggestionProcessor::SuggestionProcessor( + Index* index, const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer) + : index_(*index), + language_segmenter_(*language_segmenter), + normalizer_(*normalizer) {} + +} // namespace lib +} // namespace icing
\ No newline at end of file diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h new file mode 100644 index 0000000..b10dc84 --- /dev/null +++ b/icing/query/suggestion-processor.h @@ -0,0 +1,68 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_SUGGESTION_PROCESSOR_H_ +#define ICING_QUERY_SUGGESTION_PROCESSOR_H_ + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/index.h" +#include "icing/proto/search.pb.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/transform/normalizer.h" + +namespace icing { +namespace lib { + +// Processes SuggestionSpecProtos and retrieves the specified TermMedaData that +// satisfies the prefix and its restrictions. This also performs ranking, and +// returns TermMetaData ordered by their hit count. +class SuggestionProcessor { + public: + // Factory function to create a SuggestionProcessor which does not take + // ownership of any input components, and all pointers must refer to valid + // objects that outlive the created SuggestionProcessor instance. + // + // Returns: + // An SuggestionProcessor on success + // FAILED_PRECONDITION if any of the pointers is null. + static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> + Create(Index* index, const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer); + + // Query suggestions based on the given SuggestionSpecProto. + // + // Returns: + // On success, + // - One vector that represents the entire TermMetadata + // INTERNAL_ERROR on all other errors + libtextclassifier3::StatusOr<std::vector<TermMetadata>> QuerySuggestions( + const SuggestionSpecProto& suggestion_spec, + const std::vector<NamespaceId>& namespace_ids); + + private: + explicit SuggestionProcessor(Index* index, + const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer); + + // Not const because we could modify/sort the TermMetaData buffer in the lite + // index. + Index& index_; + const LanguageSegmenter& language_segmenter_; + const Normalizer& normalizer_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_SUGGESTION_PROCESSOR_H_ diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc new file mode 100644 index 0000000..5e62277 --- /dev/null +++ b/icing/query/suggestion-processor_test.cc @@ -0,0 +1,324 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/suggestion-processor.h" + +#include "gmock/gmock.h" +#include "icing/helpers/icu/icu-data-file-helper.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::IsEmpty; +using ::testing::Test; + +class SuggestionProcessorTest : public Test { + protected: + SuggestionProcessorTest() + : test_dir_(GetTestTempDir() + "/icing"), + store_dir_(test_dir_ + "/store"), + index_dir_(test_dir_ + "/index") {} + + void SetUp() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(index_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(store_dir_.c_str()); + + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + // If we've specified using the reverse-JNI method for segmentation (i.e. + // not ICU), then we won't have the ICU data file included to set up. + // Technically, we could choose to use reverse-JNI for segmentation AND + // include an ICU data file, but that seems unlikely and our current BUILD + // setup doesn't do this. + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + Index::Options options(index_dir_, + /*index_merge_size=*/1024 * 1024); + ICING_ASSERT_OK_AND_ASSIGN( + index_, Index::Create(options, &filesystem_, &icing_filesystem_)); + + language_segmenter_factory::SegmenterOptions segmenter_options( + ULOC_US, jni_cache_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(segmenter_options)); + + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + } + + libtextclassifier3::Status AddTokenToIndex( + DocumentId document_id, SectionId section_id, + TermMatchType::Code term_match_type, const std::string& token) { + Index::Editor editor = index_->Edit(document_id, section_id, + term_match_type, /*namespace_id=*/0); + auto status = editor.BufferTerm(token.c_str()); + return status.ok() ? editor.IndexAllBufferedTerms() : status; + } + + void TearDown() override { + document_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + Filesystem filesystem_; + const std::string test_dir_; + const std::string store_dir_; + std::unique_ptr<Index> index_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache(); + FakeClock fake_clock_; + + private: + IcingFilesystem icing_filesystem_; + const std::string index_dir_; +}; + +constexpr DocumentId kDocumentId0 = 0; +constexpr SectionId kSectionId2 = 2; + +TEST_F(SuggestionProcessorTest, PrependedPrefixTokenTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix( + "prefix token should be prepended to the suggestion f"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, + "prefix token should be prepended to the suggestion foo"); +} + +TEST_F(SuggestionProcessorTest, NonExistentPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("nonExistTerm"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, PrefixTrailingSpaceTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f "); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, NormalizePrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("F"); + suggestion_spec.set_num_to_return(10); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("fO"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("Fo"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("FO"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); +} + +TEST_F(SuggestionProcessorTest, OrOperatorPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "original"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f OR"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + // Last Operator token will be used to query suggestion + EXPECT_THAT(terms.at(0).content, "f original"); +} + +TEST_F(SuggestionProcessorTest, ParenthesesOperatorPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("{f}"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("[f]"); + ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("(f)"); + ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f:"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("f-"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, InvalidPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "original"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("OR OR - :"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index 3307638..67528ab 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -491,5 +491,10 @@ SchemaStoreStorageInfoProto SchemaStore::GetStorageInfo() const { return storage_info; } +libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*> +SchemaStore::GetSectionMetadata(const std::string& schema_type) const { + return section_manager_->GetMetadataList(schema_type); +} + } // namespace lib } // namespace icing diff --git a/icing/schema/schema-store.h b/icing/schema/schema-store.h index b9be6c0..6b6528d 100644 --- a/icing/schema/schema-store.h +++ b/icing/schema/schema-store.h @@ -246,6 +246,12 @@ class SchemaStore { // INTERNAL_ERROR on compute error libtextclassifier3::StatusOr<Crc32> ComputeChecksum() const; + // Returns: + // - On success, the section metadata list for the specified schema type + // - NOT_FOUND if the schema type is not present in the schema + libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*> + GetSectionMetadata(const std::string& schema_type) const; + // Calculates the StorageInfo for the Schema Store. // // If an IO error occurs while trying to calculate the value for a field, then diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 4822d7f..28d385e 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -26,6 +26,7 @@ #include "icing/store/corpus-associated-scoring-data.h" #include "icing/store/corpus-id.h" #include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" namespace icing { @@ -42,8 +43,11 @@ constexpr float k1_ = 1.2f; constexpr float b_ = 0.7f; // TODO(b/158603900): add tests for Bm25fCalculator -Bm25fCalculator::Bm25fCalculator(const DocumentStore* document_store) - : document_store_(document_store) {} +Bm25fCalculator::Bm25fCalculator( + const DocumentStore* document_store, + std::unique_ptr<SectionWeights> section_weights) + : document_store_(document_store), + section_weights_(std::move(section_weights)) {} // During initialization, Bm25fCalculator iterates through // hit-iterators for each query term to pre-compute n(q_i) for each corpus under @@ -121,9 +125,9 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, // Compute inverse document frequency (IDF) weight for query term in the given // corpus, and cache it in the map. // -// N - n(q_i) + 0.5 -// IDF(q_i) = log(1 + ------------------) -// n(q_i) + 0.5 +// N - n(q_i) + 0.5 +// IDF(q_i) = ln(1 + ------------------) +// n(q_i) + 0.5 // // where N is the number of documents in the corpus, and n(q_i) is the number // of documents in the corpus containing the query term q_i. @@ -149,7 +153,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, uint32_t num_docs = csdata.num_docs(); uint32_t nqi = corpus_nqi_map_[corpus_term_info.value]; float idf = - nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi - 0.5f)) : 0.0f; + nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); ICING_VLOG(1) << IcingStringUtil::StringPrintf( "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, @@ -158,6 +162,11 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, } // Get per corpus average document length and cache the result in the map. +// The average doc length is calculated as: +// +// total_tokens_in_corpus +// Avg Doc Length = ------------------------- +// num_docs_in_corpus + 1 float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { auto iter = corpus_avgdl_map_.find(corpus_id); if (iter != corpus_avgdl_map_.end()) { @@ -191,8 +200,8 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( const DocumentAssociatedScoreData& data) { uint32_t dl = data.length_in_tokens(); float avgdl = GetCorpusAvgDocLength(data.corpus_id()); - float f_q = - ComputeTermFrequencyForMatchedSections(data.corpus_id(), term_match_info); + float f_q = ComputeTermFrequencyForMatchedSections( + data.corpus_id(), term_match_info, hit_info.document_id()); float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); @@ -202,23 +211,41 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( return normalized_tf; } -// Note: once we support section weights, we should update this function to -// compute the weighted term frequency. float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo& term_match_info) const { + CorpusId corpus_id, const TermMatchInfo& term_match_info, + DocumentId document_id) const { float sum = 0.0f; SectionIdMask sections = term_match_info.section_ids_mask; + SchemaTypeId schema_type_id = GetSchemaTypeId(document_id); + while (sections != 0) { SectionId section_id = __builtin_ctz(sections); sections &= ~(1u << section_id); Hit::TermFrequency tf = term_match_info.term_frequencies[section_id]; + double weighted_tf = tf * section_weights_->GetNormalizedSectionWeight( + schema_type_id, section_id); if (tf != Hit::kNoTermFrequency) { - sum += tf; + sum += weighted_tf; } } return sum; } +SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { + auto filter_data_or = document_store_->GetDocumentFilterData(document_id); + if (!filter_data_or.ok()) { + // This should never happen. The only failure case for + // GetDocumentFilterData is if the document_id is outside of the range of + // allocated document_ids, which shouldn't be possible since we're getting + // this document_id from the posting lists. + ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( + "No document filter data for document [%d]", document_id); + return kInvalidSchemaTypeId; + } + DocumentFilterData data = filter_data_or.ValueOrDie(); + return data.schema_type_id(); +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/bm25f-calculator.h b/icing/scoring/bm25f-calculator.h index 91b4f24..05009d8 100644 --- a/icing/scoring/bm25f-calculator.h +++ b/icing/scoring/bm25f-calculator.h @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/legacy/index/icing-bit-util.h" +#include "icing/scoring/section-weights.h" #include "icing/store/corpus-id.h" #include "icing/store/document-store.h" @@ -62,7 +63,8 @@ namespace lib { // see: glossary/bm25 class Bm25fCalculator { public: - explicit Bm25fCalculator(const DocumentStore *document_store_); + explicit Bm25fCalculator(const DocumentStore *document_store_, + std::unique_ptr<SectionWeights> section_weights_); // Precompute and cache statistics relevant to BM25F. // Populates term_id_map_ and corpus_nqi_map_ for use while scoring other @@ -108,18 +110,43 @@ class Bm25fCalculator { } }; + // Returns idf weight for the term and provided corpus. float GetCorpusIdfWeightForTerm(std::string_view term, CorpusId corpus_id); + + // Returns the average document length for the corpus. The average is + // calculated as the sum of tokens in the corpus' documents over the total + // number of documents plus one. float GetCorpusAvgDocLength(CorpusId corpus_id); + + // Returns the normalized term frequency for the term match and document hit. + // This normalizes the term frequency by applying smoothing parameters and + // factoring document length. float ComputedNormalizedTermFrequency( const TermMatchInfo &term_match_info, const DocHitInfo &hit_info, const DocumentAssociatedScoreData &data); + + // Returns the weighted term frequency for the term match and document. For + // each section the term is present, we scale the term frequency by its + // section weight. We return the sum of the weighted term frequencies over all + // sections. float ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo &term_match_info) const; + CorpusId corpus_id, const TermMatchInfo &term_match_info, + DocumentId document_id) const; + // Returns the schema type id for the document by retrieving it from the + // DocumentFilterData. + SchemaTypeId GetSchemaTypeId(DocumentId document_id) const; + + // Clears cached scoring data and prepares the calculator for a new scoring + // run. void Clear(); const DocumentStore *document_store_; // Does not own. + // Used for accessing normalized section weights when computing the weighted + // term frequency. + std::unique_ptr<SectionWeights> section_weights_; + // Map from query term to compact term ID. // Necessary as a key to the other maps. // The use of the string_view as key here means that the query_term_iterators @@ -130,7 +157,6 @@ class Bm25fCalculator { // Necessary to calculate the normalized term frequency. // This information is cached in the DocumentStore::CorpusScoreCache std::unordered_map<CorpusId, float> corpus_avgdl_map_; - // Map from <corpus ID, term ID> to number of documents containing term q_i, // called n(q_i). // Necessary to calculate IDF(q_i) (inverse document frequency). diff --git a/icing/scoring/ranker.cc b/icing/scoring/ranker.cc index fecee82..117f44c 100644 --- a/icing/scoring/ranker.cc +++ b/icing/scoring/ranker.cc @@ -32,6 +32,7 @@ namespace { // Helper function to wrap the heapify algorithm, it heapifies the target // subtree node in place. +// TODO(b/152934343) refactor the heapify function and making it into a class. void Heapify( std::vector<ScoredDocumentHit>* scored_document_hits, int target_subtree_root_index, @@ -71,6 +72,80 @@ void Heapify( } } +// Heapify the given term vector from top to bottom. Call it after add or +// replace an element at the front of the vector. +void HeapifyTermDown(std::vector<TermMetadata>& scored_terms, + int target_subtree_root_index) { + int heap_size = scored_terms.size(); + if (target_subtree_root_index >= heap_size) { + return; + } + + // Initializes subtree root as the current minimum node. + int min = target_subtree_root_index; + // If we represent a heap in an array/vector, indices of left and right + // children can be calculated as such. + const int left = target_subtree_root_index * 2 + 1; + const int right = target_subtree_root_index * 2 + 2; + + // If left child is smaller than current minimum. + if (left < heap_size && + scored_terms.at(left).hit_count < scored_terms.at(min).hit_count) { + min = left; + } + + // If right child is smaller than current minimum. + if (right < heap_size && + scored_terms.at(right).hit_count < scored_terms.at(min).hit_count) { + min = right; + } + + // If the minimum is not the subtree root, swap and continue heapifying the + // lower level subtree. + if (min != target_subtree_root_index) { + std::swap(scored_terms.at(min), + scored_terms.at(target_subtree_root_index)); + HeapifyTermDown(scored_terms, min); + } +} + +// Heapify the given term vector from bottom to top. Call it after add an +// element at the end of the vector. +void HeapifyTermUp(std::vector<TermMetadata>& scored_terms, + int target_subtree_child_index) { + // If we represent a heap in an array/vector, indices of root can be + // calculated as such. + const int root = (target_subtree_child_index + 1) / 2 - 1; + + // If the current child is smaller than the root, swap and continue heapifying + // the upper level subtree + if (root >= 0 && scored_terms.at(target_subtree_child_index).hit_count < + scored_terms.at(root).hit_count) { + std::swap(scored_terms.at(root), + scored_terms.at(target_subtree_child_index)); + HeapifyTermUp(scored_terms, root); + } +} + +TermMetadata PopRootTerm(std::vector<TermMetadata>& scored_terms) { + if (scored_terms.empty()) { + // Return an invalid TermMetadata as a sentinel value. + return TermMetadata(/*content_in=*/"", /*hit_count_in=*/-1); + } + + // Steps to extract root from heap: + // 1. copy out root + TermMetadata root = scored_terms.at(0); + const size_t last_node_index = scored_terms.size() - 1; + // 2. swap root and the last node + std::swap(scored_terms.at(0), scored_terms.at(last_node_index)); + // 3. remove last node + scored_terms.pop_back(); + // 4. heapify root + HeapifyTermDown(scored_terms, /*target_subtree_root_index=*/0); + return root; +} + // Helper function to extract the root from the heap. The heap structure will be // maintained. // @@ -115,6 +190,19 @@ void BuildHeapInPlace( } } +void PushToTermHeap(TermMetadata term, int number_to_return, + std::vector<TermMetadata>& scored_terms_heap) { + if (scored_terms_heap.size() < number_to_return) { + scored_terms_heap.push_back(std::move(term)); + // We insert at end, so we should heapify bottom up. + HeapifyTermUp(scored_terms_heap, scored_terms_heap.size() - 1); + } else if (scored_terms_heap.at(0).hit_count < term.hit_count) { + scored_terms_heap.at(0) = std::move(term); + // We insert at root, so we should heapify top down. + HeapifyTermDown(scored_terms_heap, /*target_subtree_root_index=*/0); + } +} + std::vector<ScoredDocumentHit> PopTopResultsFromHeap( std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results, const ScoredDocumentHitComparator& scored_document_hit_comparator) { @@ -134,5 +222,15 @@ std::vector<ScoredDocumentHit> PopTopResultsFromHeap( return scored_document_hit_result; } +std::vector<TermMetadata> PopAllTermsFromHeap( + std::vector<TermMetadata>& scored_terms_heap) { + std::vector<TermMetadata> top_term_result; + top_term_result.reserve(scored_terms_heap.size()); + while (!scored_terms_heap.empty()) { + top_term_result.push_back(PopRootTerm(scored_terms_heap)); + } + return top_term_result; +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/ranker.h b/icing/scoring/ranker.h index 785c133..81838f3 100644 --- a/icing/scoring/ranker.h +++ b/icing/scoring/ranker.h @@ -17,6 +17,7 @@ #include <vector> +#include "icing/index/term-metadata.h" #include "icing/scoring/scored-document-hit.h" // Provides functionality to get the top N results from an unsorted vector. @@ -39,6 +40,18 @@ std::vector<ScoredDocumentHit> PopTopResultsFromHeap( std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results, const ScoredDocumentHitComparator& scored_document_hit_comparator); +// The heap is a min-heap. So that we can avoid some push operations by +// comparing to the root term, and only pushing if greater than root. The time +// complexity for a single push is O(lgK) which K is the number_to_return. +// REQUIRED: scored_terms_heap is not null. +void PushToTermHeap(TermMetadata term, int number_to_return, + std::vector<TermMetadata>& scored_terms_heap); + +// Return all terms from the given terms heap. And since the heap is a min-heap, +// the output vector will be increasing order. +// REQUIRED: scored_terms_heap is not null. +std::vector<TermMetadata> PopAllTermsFromHeap( + std::vector<TermMetadata>& scored_terms_heap); } // namespace lib } // namespace icing diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index e940e98..cc1d995 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -117,7 +117,8 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -220,7 +221,8 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) { ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -322,7 +324,8 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -390,6 +393,122 @@ BENCHMARK(BM_ScoreAndRankDocumentHitsNoScoring) ->ArgPair(10000, 18000) ->ArgPair(10000, 20000); +void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { + const std::string base_dir = GetTestTempDir() + "/score_and_rank_benchmark"; + const std::string document_store_dir = base_dir + "/document_store"; + const std::string schema_store_dir = base_dir + "/schema_store"; + + // Creates file directories + Filesystem filesystem; + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); + filesystem.CreateDirectoryRecursively(document_store_dir.c_str()); + filesystem.CreateDirectoryRecursively(schema_store_dir.c_str()); + + Clock clock; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SchemaStore> schema_store, + SchemaStore::Create(&filesystem, base_dir, &clock)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem, document_store_dir, &clock, + schema_store.get())); + std::unique_ptr<DocumentStore> document_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK(schema_store->SetSchema(CreateSchemaWithEmailType())); + + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); + + int num_to_score = state.range(0); + int num_of_documents = state.range(1); + + std::mt19937 random_generator; + std::uniform_int_distribution<int> distribution( + 1, std::numeric_limits<int>::max()); + + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + + // Puts documents into document store + std::vector<DocHitInfo> doc_hit_infos; + for (int i = 0; i < num_of_documents; i++) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store->Put(CreateEmailDocument( + /*id=*/i, /*document_score=*/1, + /*creation_timestamp_ms=*/1), + /*num_tokens=*/10)); + DocHitInfo doc_hit = DocHitInfo(document_id, section_id_mask); + // Set five matches for term "foo" for each document hit. + doc_hit.UpdateSection(section_id, /*hit_term_frequency=*/5); + doc_hit_infos.push_back(doc_hit); + } + + ScoredDocumentHitComparator scored_document_hit_comparator( + /*is_descending=*/true); + + for (auto _ : state) { + // Creates a dummy DocHitInfoIterator with results, we need to pause the + // timer here so that the cost of copying test data is not included. + state.PauseTiming(); + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Create a query term iterator that assigns the document hits to term + // "foo". + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + state.ResumeTiming(); + + std::vector<ScoredDocumentHit> scored_document_hits = + scoring_processor->Score(std::move(doc_hit_info_iterator), num_to_score, + &query_term_iterators); + + BuildHeapInPlace(&scored_document_hits, scored_document_hit_comparator); + // Ranks and gets the first page, 20 is a common page size + std::vector<ScoredDocumentHit> results = + PopTopResultsFromHeap(&scored_document_hits, /*num_results=*/20, + scored_document_hit_comparator); + } + + // Clean up + document_store.reset(); + schema_store.reset(); + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); +} +BENCHMARK(BM_ScoreAndRankDocumentHitsByRelevanceScoring) + // num_to_score, num_of_documents in document store + ->ArgPair(1000, 30000) + ->ArgPair(3000, 30000) + ->ArgPair(5000, 30000) + ->ArgPair(7000, 30000) + ->ArgPair(9000, 30000) + ->ArgPair(11000, 30000) + ->ArgPair(13000, 30000) + ->ArgPair(15000, 30000) + ->ArgPair(17000, 30000) + ->ArgPair(19000, 30000) + ->ArgPair(21000, 30000) + ->ArgPair(23000, 30000) + ->ArgPair(25000, 30000) + ->ArgPair(27000, 30000) + ->ArgPair(29000, 30000) + // Starting from this line, we're trying to see if num_of_documents affects + // performance + ->ArgPair(10000, 10000) + ->ArgPair(10000, 12000) + ->ArgPair(10000, 14000) + ->ArgPair(10000, 16000) + ->ArgPair(10000, 18000) + ->ArgPair(10000, 20000); + } // namespace } // namespace lib diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer.cc index a4734b4..5f33e66 100644 --- a/icing/scoring/scorer.cc +++ b/icing/scoring/scorer.cc @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -156,11 +157,12 @@ class NoScorer : public Scorer { }; libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store) { + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); - switch (rank_by) { + switch (scoring_spec.rank_by()) { case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE: return std::make_unique<DocumentScoreScorer>(document_store, default_score); @@ -168,7 +170,12 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( return std::make_unique<DocumentCreationTimestampScorer>(document_store, default_score); case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: { - auto bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store, scoring_spec)); + + auto bm25f_calculator = std::make_unique<Bm25fCalculator>( + document_store, std::move(section_weights)); return std::make_unique<RelevanceScoreScorer>(std::move(bm25f_calculator), default_score); } @@ -183,8 +190,8 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP: [[fallthrough]]; case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: - return std::make_unique<UsageScorer>(document_store, rank_by, - default_score); + return std::make_unique<UsageScorer>( + document_store, scoring_spec.rank_by(), default_score); case ScoringSpecProto::RankingStrategy::NONE: return std::make_unique<NoScorer>(default_score); } diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h index a22db0f..abdd5ca 100644 --- a/icing/scoring/scorer.h +++ b/icing/scoring/scorer.h @@ -43,8 +43,8 @@ class Scorer { // FAILED_PRECONDITION on any null pointer input // INVALID_ARGUMENT if fails to create an instance static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store); // Returns a non-negative score of a document. The score can be a // document-associated score which comes from the DocumentProto directly, an diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 8b89514..f22a31a 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -27,6 +27,7 @@ #include "icing/proto/scoring.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" @@ -91,6 +92,8 @@ class ScorerTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + const FakeClock& fake_clock1() { return fake_clock1_; } const FakeClock& fake_clock2() { return fake_clock2_; } @@ -121,17 +124,37 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScorerTest, CreationWithNullPointerShouldFail) { - EXPECT_THAT(Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, /*document_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +ScoringSpecProto CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::Code ranking_strategy) { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ranking_strategy); + return scoring_spec; +} + +TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), + /*schema_store=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -153,8 +176,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -185,8 +209,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -213,8 +238,9 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -235,8 +261,9 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -259,8 +286,9 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -290,8 +318,9 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { document_store()->Put(test_document2)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -316,16 +345,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -357,16 +389,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -398,16 +433,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -439,19 +477,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -499,19 +540,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -559,19 +603,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -607,8 +654,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/3, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/3, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -618,8 +666,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/111, document_store())); + scorer, + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/111, document_store(), schema_store())); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -643,9 +693,10 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index 24480ef..e36f3bb 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -39,19 +39,20 @@ constexpr double kDefaultScoreInAscendingOrder = libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store) { + const DocumentStore* document_store, + const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); bool is_descending_order = scoring_spec.order_by() == ScoringSpecProto::Order::DESC; ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - Scorer::Create(scoring_spec.rank_by(), + Scorer::Create(scoring_spec, is_descending_order ? kDefaultScoreInDescendingOrder : kDefaultScoreInAscendingOrder, - document_store)); - + document_store, schema_store)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 2289605..e7d09b1 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -40,8 +40,8 @@ class ScoringProcessor { // A ScoringProcessor on success // FAILED_PRECONDITION on any null pointer input static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create( - const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, const DocumentStore* document_store, + const SchemaStore* schema_store); // Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns // a vector of ScoredDocumentHits. The size of results is no more than diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index 125e2a7..7e5cb0f 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -69,11 +69,24 @@ class ScoringProcessorTest : public testing::Test { // Creates a simple email schema SchemaProto test_email_schema = SchemaBuilder() - .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty( - PropertyConfigBuilder() - .SetName("subject") - .SetDataType(TYPE_STRING) - .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema)); } @@ -86,6 +99,8 @@ class ScoringProcessorTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -139,16 +154,46 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScoringProcessorTest, CreationWithNullPointerShouldFail) { +TypePropertyWeights CreateTypePropertyWeights( + std::string schema_type, std::vector<PropertyWeight> property_weights) { + TypePropertyWeights type_property_weights; + type_property_weights.set_schema_type(std::move(schema_type)); + type_property_weights.mutable_property_weights()->Reserve( + property_weights.size()); + + for (PropertyWeight& property_weight : property_weights) { + *type_property_weights.add_property_weights() = std::move(property_weight); + } + + return type_property_weights; +} + +PropertyWeight CreatePropertyWeight(std::string path, double weight) { + PropertyWeight property_weight; + property_weight.set_path(std::move(path)); + property_weight.set_weight(weight); + return property_weight; +} + +TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) { + ScoringSpecProto spec_proto; + EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { ScoringSpecProto spec_proto; - EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr), + EXPECT_THAT(ScoringProcessor::Create(spec_proto, document_store(), + /*schema_store=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScoringProcessorTest, ShouldCreateInstance) { ScoringSpecProto spec_proto; spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); - ICING_EXPECT_OK(ScoringProcessor::Create(spec_proto, document_store())); + ICING_EXPECT_OK( + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); } TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { @@ -163,7 +208,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/5), @@ -189,7 +234,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/-1), @@ -219,7 +264,7 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/2), @@ -251,7 +296,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -306,7 +351,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -316,11 +361,11 @@ TEST_F(ScoringProcessorTest, // the document's length determines the final score. Document shorter than the // average corpus length are slightly boosted. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.255482); + /*score=*/0.187114); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.115927); + /*score=*/0.084904); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.166435); + /*score=*/0.121896); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -375,7 +420,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -384,11 +429,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all contain the query term "foo" exactly once // and they have the same length, they will have the same BM25F scoret. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -448,7 +493,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -457,11 +502,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all have the same length, the score is decided by // the frequency of the query term "foo". ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, - /*score=*/0.309497); + /*score=*/0.226674); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask2, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask3, - /*score=*/0.268599); + /*score=*/0.196720); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -470,6 +515,280 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_HitTermWithZeroFrequency) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + + // Document 1 contains the term "foo" 0 times in the "subject" property + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/0); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask section_id_mask1 = 0b00000001; + + // Since the document hit has zero frequency, expect a score of zero. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, + /*score=*/0.000000); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_SameHitFrequencyDifferentPropertyWeights) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + PropertyWeight subject_property_weight = + CreatePropertyWeight(/*path=*/"subject", /*weight=*/2.0); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight, subject_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. Final scores are computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.053624); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithImplicitPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. This is because the "subject" property is implictly given a + // a weight of 1.0, the default weight value. Final scores are computed with + // smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.094601); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithDefaultPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + *spec_proto.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", {}); + + // Creates a ScoringProcessor with no explicit weights set. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + ScoringSpecProto spec_proto_with_weights; + spec_proto_with_weights.set_rank_by( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", + /*weight=*/1.0); + *spec_proto_with_weights.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", + {body_property_weight}); + + // Creates a ScoringProcessor with default weight set for "body" property. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor_with_weights, + ScoringProcessor::Create(spec_proto_with_weights, document_store(), + schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + // Create a doc hit iterator + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators_scoring_with_weights; + query_term_iterators_scoring_with_weights["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + + // We expect document 1 to have the same score whether a weight is explicitly + // set to 1.0 or implictly scored with the default weight. Final scores are + // computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit(document_id1, body_section_id_mask, + /*score=*/0.208191); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); + + // Restore ownership of doc hit iterator and query term iterator to test. + doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + EXPECT_THAT(scoring_processor_with_weights->Score( + std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); +} + TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -509,7 +828,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -569,7 +888,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -629,7 +948,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -665,7 +984,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/4), ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default), @@ -714,7 +1033,7 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), diff --git a/icing/scoring/section-weights.cc b/icing/scoring/section-weights.cc new file mode 100644 index 0000000..c4afe7f --- /dev/null +++ b/icing/scoring/section-weights.cc @@ -0,0 +1,146 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> +#include <unordered_map> +#include <utility> + +#include "icing/proto/scoring.pb.h" +#include "icing/schema/section.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { + +namespace { + +// Normalizes all weights in the map to be in range (0.0, 1.0], where the max +// weight is normalized to 1.0. +inline void NormalizeSectionWeights( + double max_weight, std::unordered_map<SectionId, double>& section_weights) { + for (auto& raw_weight : section_weights) { + raw_weight.second = raw_weight.second / max_weight; + } +} +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> +SectionWeights::Create(const SchemaStore* schema_store, + const ScoringSpecProto& scoring_spec) { + ICING_RETURN_ERROR_IF_NULL(schema_store); + + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_property_weight_map; + for (const TypePropertyWeights& type_property_weights : + scoring_spec.type_property_weights()) { + std::string_view schema_type = type_property_weights.schema_type(); + auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type); + if (!schema_type_id_or.ok()) { + ICING_LOG(WARNING) << "No schema type id found for schema type: " + << schema_type; + continue; + } + SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie(); + auto section_metadata_list_or = + schema_store->GetSectionMetadata(schema_type.data()); + if (!section_metadata_list_or.ok()) { + ICING_LOG(WARNING) << "No metadata found for schema type: " + << schema_type; + continue; + } + + const std::vector<SectionMetadata>* metadata_list = + section_metadata_list_or.ValueOrDie(); + + std::unordered_map<std::string, double> property_paths_weights; + for (const PropertyWeight& property_weight : + type_property_weights.property_weights()) { + double property_path_weight = property_weight.weight(); + + // Return error on negative and zero weights. + if (property_path_weight <= 0.0) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Property weight for property path \"%s\" is negative or zero. " + "Negative and zero weights are invalid.", + property_weight.path().c_str())); + } + property_paths_weights.insert( + {property_weight.path(), property_path_weight}); + } + NormalizedSectionWeights normalized_section_weights = + ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list); + + schema_property_weight_map.insert( + {schema_type_id, + {/*section_weights*/ std::move( + normalized_section_weights.section_weights), + /*default_weight*/ normalized_section_weights.default_weight}}); + } + // Using `new` to access a non-public constructor. + return std::unique_ptr<SectionWeights>( + new SectionWeights(std::move(schema_property_weight_map))); +} + +double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const { + auto schema_type_map = schema_section_weight_map_.find(schema_type_id); + if (schema_type_map == schema_section_weight_map_.end()) { + // Return default weight if the schema type has no weights specified. + return kDefaultSectionWeight; + } + + auto section_weight = + schema_type_map->second.section_weights.find(section_id); + if (section_weight == schema_type_map->second.section_weights.end()) { + // If there is no entry for SectionId, the weight is implicitly the + // normalized default weight. + return schema_type_map->second.default_weight; + } + return section_weight->second; +} + +inline SectionWeights::NormalizedSectionWeights +SectionWeights::ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list) { + double max_weight = 0.0; + std::unordered_map<SectionId, double> section_weights; + for (const SectionMetadata& section_metadata : metadata_list) { + std::string_view metadata_path = section_metadata.path; + double section_weight = kDefaultSectionWeight; + auto iter = raw_weights.find(metadata_path.data()); + if (iter != raw_weights.end()) { + section_weight = iter->second; + section_weights.insert({section_metadata.id, section_weight}); + } + // Replace max if we see new max weight. + max_weight = std::max(max_weight, section_weight); + } + + NormalizeSectionWeights(max_weight, section_weights); + // Set normalized default weight to 1.0 in case there is no section + // metadata and max_weight is 0.0 (we should not see this case). + double normalized_default_weight = max_weight == 0.0 + ? kDefaultSectionWeight + : kDefaultSectionWeight / max_weight; + SectionWeights::NormalizedSectionWeights normalized_section_weights = + SectionWeights::NormalizedSectionWeights(); + normalized_section_weights.section_weights = std::move(section_weights); + normalized_section_weights.default_weight = normalized_default_weight; + return normalized_section_weights; +} +} // namespace lib +} // namespace icing diff --git a/icing/scoring/section-weights.h b/icing/scoring/section-weights.h new file mode 100644 index 0000000..23a9188 --- /dev/null +++ b/icing/scoring/section-weights.h @@ -0,0 +1,95 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SECTION_WEIGHTS_H_ +#define ICING_SCORING_SECTION_WEIGHTS_H_ + +#include <unordered_map> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/schema/schema-store.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +inline constexpr double kDefaultSectionWeight = 1.0; + +// Provides functions for setting and retrieving section weights for schema +// type properties. Section weights are used to promote and demote term matches +// in sections when scoring results. Section weights are provided by property +// path, and can range from (0, DBL_MAX]. The SectionId is matched to the +// property path by going over the schema type's section metadata. Weights that +// correspond to a valid property path are then normalized against the maxmium +// section weight, and put into map for quick access for scorers. By default, +// a section is given a raw, pre-normalized weight of 1.0. +class SectionWeights { + public: + // SectionWeights instances should not be copied. + SectionWeights(const SectionWeights&) = delete; + SectionWeights& operator=(const SectionWeights&) = delete; + + // Factory function to create a SectionWeights instance. Raw weights are + // provided through the ScoringSpecProto. Provided property paths for weights + // are validated against the schema type's section metadata. If the property + // path doesn't exist, the property weight is ignored. If a weight is 0 or + // negative, an invalid argument error is returned. Raw weights are then + // normalized against the maximum weight for that schema type. + // + // Returns: + // A SectionWeights instance on success + // FAILED_PRECONDITION on any null pointer input + // INVALID_ARGUMENT if a provided weight for a property path is less than or + // equal to 0. + static libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> Create( + const SchemaStore* schema_store, const ScoringSpecProto& scoring_spec); + + // Returns the normalized section weight by SchemaTypeId and SectionId. If + // the SchemaTypeId, or the SectionId for a SchemaTypeId, is not found in the + // normalized weights map, the default weight is returned instead. + double GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const; + + private: + // Holds the normalized section weights for a schema type, as well as the + // normalized default weight for sections that have no weight set. + struct NormalizedSectionWeights { + std::unordered_map<SectionId, double> section_weights; + double default_weight; + }; + + explicit SectionWeights( + const std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map) + : schema_section_weight_map_(std::move(schema_section_weight_map)) {} + + // Creates a map of section ids to normalized weights from the raw property + // path weight map and section metadata and calculates the normalized default + // section weight. + static inline SectionWeights::NormalizedSectionWeights + ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list); + + // A map of (SchemaTypeId -> SectionId -> Normalized Weight), allows for fast + // look up of normalized weights. This is precomputed when creating a + // SectionWeights instance. + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SECTION_WEIGHTS_H_ diff --git a/icing/scoring/section-weights_test.cc b/icing/scoring/section-weights_test.cc new file mode 100644 index 0000000..b90c3d5 --- /dev/null +++ b/icing/scoring/section-weights_test.cc @@ -0,0 +1,386 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/scoring.pb.h" +#include "icing/schema-builder.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { +using ::testing::Eq; + +class SectionWeightsTest : public testing::Test { + protected: + SectionWeightsTest() + : test_dir_(GetTestTempDir() + "/icing"), + schema_store_dir_(test_dir_ + "/schema_store") {} + + void SetUp() override { + // Creates file directories + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + SchemaTypeConfigProto sender_schema = + SchemaTypeConfigBuilder() + .SetType("sender") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaTypeConfigProto email_schema = + SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "sender", /*index_nested_properties=*/true) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaProto schema = + SchemaBuilder().AddType(sender_schema).AddType(email_schema).Build(); + + ICING_ASSERT_OK(schema_store_->SetSchema(schema)); + } + + void TearDown() override { + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SchemaStore *schema_store() { return schema_store_.get(); } + + private: + const std::string test_dir_; + const std::string schema_store_dir_; + Filesystem filesystem_; + FakeClock fake_clock_; + std::unique_ptr<SchemaStore> schema_store_; +}; + +TEST_F(SectionWeightsTest, ShouldNormalizeSinglePropertyWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(5.0); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + // We expect 1.0 as there is only one property in the "sender" schema type + // so it should take the max normalized weight of 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldAcceptMaxWeightValue) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(DBL_MAX); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithNegativeWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_propery_weight = + type_property_weights->add_property_weights(); + body_propery_weight->set_weight(-100.0); + body_propery_weight->set_path("body"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithZeroWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(0.0); + property_weight->set_path("name"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldReturnDefaultIfTypePropertyWeightsNotSet) { + ScoringSpecProto spec_proto; + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(kDefaultSectionWeight)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *nested_property_weight = + type_property_weights->add_property_weights(); + nested_property_weight->set_weight(50.0); + nested_property_weight->set_path("sender.name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.01)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldNormalizeIfAllWeightsBelowOne) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(0.1); + body_property_weight->set_path("body"); + + PropertyWeight *sender_name_weight = + type_property_weights->add_property_weights(); + sender_name_weight->set_weight(0.2); + sender_name_weight->set_path("sender.name"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(0.4); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(1.0 / 4.0)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(2.0 / 4.0)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeightSeparatelyForTypes) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *email_type_property_weights = + spec_proto.add_type_property_weights(); + email_type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + email_type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + email_type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *sender_name_property_weight = + email_type_property_weights->add_property_weights(); + sender_name_property_weight->set_weight(50.0); + sender_name_property_weight->set_path("sender.name"); + + TypePropertyWeights *sender_type_property_weights = + spec_proto.add_type_property_weights(); + sender_type_property_weights->set_schema_type("sender"); + + PropertyWeight *sender_property_weight = + sender_type_property_weights->add_property_weights(); + sender_property_weight->set_weight(25.0); + sender_property_weight->set_path("sender"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // Normalized weight for "sender.name" property (the nested property) + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "name" property for "sender" schema type. As it is + // the only property of the type, it should take the max normalized weight of + // 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSkipNonExistentPathWhenSettingWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + // If this property weight isn't skipped, then the max property weight would + // be set to 100.0 and all weights would be normalized against the max. + PropertyWeight *non_valid_property_weight = + type_property_weights->add_property_weights(); + non_valid_property_weight->set_weight(100.0); + non_valid_property_weight->set_path("sender.organization"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(10.0); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. Because the weight is not explicitly + // set, it is set to the default of 1.0 before being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.1)); + // Normalized weight for "sender.name" property (the nested property). Because + // the weight is not explicitly set, it is set to the default of 1.0 before + // being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.1)); + // Normalized weight for "subject" property. Because the invalid property path + // is skipped when assigning weights, subject takes the max normalized weight + // of 1.0 instead. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/testing/random-string.cc b/icing/testing/random-string.cc new file mode 100644 index 0000000..27f83bc --- /dev/null +++ b/icing/testing/random-string.cc @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/testing/random-string.h" + +namespace icing { +namespace lib { + +std::vector<std::string> GenerateUniqueTerms(int num_terms) { + char before_a = 'a' - 1; + std::string term(1, before_a); + std::vector<std::string> terms; + int current_char = 0; + for (int permutation = 0; permutation < num_terms; ++permutation) { + if (term[current_char] != 'z') { + ++term[current_char]; + } else { + if (current_char < term.length() - 1) { + // The string currently looks something like this "zzzaa" + // 1. Find the first char after this one that isn't + current_char = term.find_first_not_of('z', current_char); + if (current_char != std::string::npos) { + // 2. Increment that character + ++term[current_char]; + + // 3. Set every character prior to current_char to 'a' + term.replace(0, current_char, current_char, 'a'); + } else { + // Every character in this string is a 'z'. We need to grow. + term = std::string(term.length() + 1, 'a'); + } + } else { + term = std::string(term.length() + 1, 'a'); + } + current_char = 0; + } + terms.push_back(term); + } + return terms; +} + +} // namespace lib +} // namespace icing diff --git a/icing/testing/random-string.h b/icing/testing/random-string.h index ac36924..3165bf6 100644 --- a/icing/testing/random-string.h +++ b/icing/testing/random-string.h @@ -36,6 +36,10 @@ std::string RandomString(const std::string_view alphabet, size_t len, return result; } +// Returns a vector containing num_terms unique terms. Terms are created in +// non-random order starting with "a" to "z" to "aa" to "zz", etc. +std::vector<std::string> GenerateUniqueTerms(int num_terms); + } // namespace lib } // namespace icing diff --git a/icing/testing/random-string_test.cc b/icing/testing/random-string_test.cc new file mode 100644 index 0000000..759fec0 --- /dev/null +++ b/icing/testing/random-string_test.cc @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/testing/random-string.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; + +namespace icing { +namespace lib { + +namespace { + +TEST(RandomStringTest, GenerateUniqueTerms) { + EXPECT_THAT(GenerateUniqueTerms(0), IsEmpty()); + EXPECT_THAT(GenerateUniqueTerms(1), ElementsAre("a")); + EXPECT_THAT(GenerateUniqueTerms(4), ElementsAre("a", "b", "c", "d")); + EXPECT_THAT(GenerateUniqueTerms(29), + ElementsAre("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", + "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", + "w", "x", "y", "z", "aa", "ba", "ca")); + EXPECT_THAT(GenerateUniqueTerms(56), + ElementsAre("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", + "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", + "w", "x", "y", "z", "aa", "ba", "ca", "da", "ea", + "fa", "ga", "ha", "ia", "ja", "ka", "la", "ma", "na", + "oa", "pa", "qa", "ra", "sa", "ta", "ua", "va", "wa", + "xa", "ya", "za", "ab", "bb", "cb", "db")); + EXPECT_THAT(GenerateUniqueTerms(56).at(54), Eq("cb")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26), Eq("aa")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27), Eq("aaa")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27 - 6), Eq("uz")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27 + 5), Eq("faa")); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/tokenization/raw-query-tokenizer.cc b/icing/tokenization/raw-query-tokenizer.cc index 205d3a2..2d461ee 100644 --- a/icing/tokenization/raw-query-tokenizer.cc +++ b/icing/tokenization/raw-query-tokenizer.cc @@ -14,9 +14,8 @@ #include "icing/tokenization/raw-query-tokenizer.h" -#include <stddef.h> - #include <cctype> +#include <cstddef> #include <memory> #include <string> #include <string_view> diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc index 6b1cb3a..8e1e563 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc @@ -15,10 +15,10 @@ #include "icing/tokenization/reverse_jni/reverse-jni-break-iterator.h" #include <jni.h> -#include <math.h> #include <cassert> #include <cctype> +#include <cmath> #include <map> #include "icing/jni/jni-cache.h" diff --git a/icing/transform/icu/icu-normalizer.cc b/icing/transform/icu/icu-normalizer.cc index 250d6cf..aceb11d 100644 --- a/icing/transform/icu/icu-normalizer.cc +++ b/icing/transform/icu/icu-normalizer.cc @@ -302,14 +302,16 @@ IcuNormalizer::TermTransformer::FindNormalizedNonLatinMatchEndPosition( int32_t c16_length; int32_t limit; - constexpr int kUtf32CharBufferLength = 3; - UChar32 normalized_buffer[kUtf32CharBufferLength]; - int32_t c32_length; + constexpr int kCharBufferLength = 3 * 4; + char normalized_buffer[kCharBufferLength]; + int32_t c8_length; while (char_itr.utf8_index() < term.length() && normalized_char_itr.utf8_index() < normalized_term.length()) { UChar32 c = char_itr.GetCurrentChar(); - u_strFromUTF32(c16, kUtf16CharBufferLength, &c16_length, &c, - /*srcLength=*/1, &status); + int c_lenth = i18n_utils::GetUtf8Length(c); + u_strFromUTF8(c16, kUtf16CharBufferLength, &c16_length, + term.data() + char_itr.utf8_index(), + /*srcLength=*/c_lenth, &status); if (U_FAILURE(status)) { break; } @@ -322,19 +324,20 @@ IcuNormalizer::TermTransformer::FindNormalizedNonLatinMatchEndPosition( break; } - u_strToUTF32(normalized_buffer, kUtf32CharBufferLength, &c32_length, c16, - c16_length, &status); + u_strToUTF8(normalized_buffer, kCharBufferLength, &c8_length, c16, + c16_length, &status); if (U_FAILURE(status)) { break; } - for (int i = 0; i < c32_length; ++i) { - UChar32 normalized_c = normalized_char_itr.GetCurrentChar(); - if (normalized_buffer[i] != normalized_c) { + for (int i = 0; i < c8_length; ++i) { + if (normalized_buffer[i] != + normalized_term[normalized_char_itr.utf8_index() + i]) { return char_itr; } - normalized_char_itr.AdvanceToUtf32(normalized_char_itr.utf32_index() + 1); } + normalized_char_itr.AdvanceToUtf8(normalized_char_itr.utf8_index() + + c8_length); char_itr.AdvanceToUtf32(char_itr.utf32_index() + 1); } if (U_FAILURE(status)) { diff --git a/icing/transform/map/map-normalizer.cc b/icing/transform/map/map-normalizer.cc index 95aa633..61fce65 100644 --- a/icing/transform/map/map-normalizer.cc +++ b/icing/transform/map/map-normalizer.cc @@ -14,8 +14,7 @@ #include "icing/transform/map/map-normalizer.h" -#include <ctype.h> - +#include <cctype> #include <string> #include <string_view> #include <unordered_map> diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 1f5fb51..95e0c84 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -43,6 +43,8 @@ import com.google.android.icing.proto.SearchSpecProto; import com.google.android.icing.proto.SetSchemaResultProto; import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.StorageInfoResultProto; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.UsageReport; import com.google.protobuf.ExtensionRegistryLite; import com.google.protobuf.InvalidProtocolBufferException; @@ -370,6 +372,26 @@ public class IcingSearchEngine implements Closeable { } @NonNull + public SuggestionResponse searchSuggestions(@NonNull SuggestionSpecProto suggestionSpec) { + byte[] suggestionResponseBytes = nativeSearchSuggestions(this, suggestionSpec.toByteArray()); + if (suggestionResponseBytes == null) { + Log.e(TAG, "Received null suggestionResponseBytes from native."); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing suggestionResponseBytes.", e); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull public DeleteByNamespaceResultProto deleteByNamespace(@NonNull String namespace) { throwIfClosed(); @@ -604,4 +626,7 @@ public class IcingSearchEngine implements Closeable { private static native byte[] nativeGetStorageInfo(IcingSearchEngine instance); private static native byte[] nativeReset(IcingSearchEngine instance); + + private static native byte[] nativeSearchSuggestions( + IcingSearchEngine instance, byte[] suggestionSpecBytes); } diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index 0cee80c..cb28331 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -51,6 +51,8 @@ import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.StorageInfoResultProto; import com.google.android.icing.proto.StringIndexingConfig; import com.google.android.icing.proto.StringIndexingConfig.TokenizerType; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.TermMatchType; import com.google.android.icing.proto.UsageReport; import com.google.android.icing.IcingSearchEngine; @@ -623,6 +625,40 @@ public final class IcingSearchEngineTest { assertThat(match).isEqualTo("𐀂𐀃"); } + @Test + public void testSearchSuggestions() { + assertStatusOk(icingSearchEngine.initialize().getStatus()); + + SchemaTypeConfigProto emailTypeConfig = createEmailTypeConfig(); + SchemaProto schema = SchemaProto.newBuilder().addTypes(emailTypeConfig).build(); + assertThat( + icingSearchEngine + .setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false) + .getStatus() + .getCode()) + .isEqualTo(StatusProto.Code.OK); + + DocumentProto emailDocument1 = + createEmailDocument("namespace", "uri1").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("fo")) + .build(); + DocumentProto emailDocument2 = + createEmailDocument("namespace", "uri2").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("foo")) + .build(); + assertStatusOk(icingSearchEngine.put(emailDocument1).getStatus()); + assertStatusOk(icingSearchEngine.put(emailDocument2).getStatus()); + + SuggestionSpecProto suggestionSpec = + SuggestionSpecProto.newBuilder().setPrefix("f").setNumToReturn(10).build(); + + SuggestionResponse response = icingSearchEngine.searchSuggestions(suggestionSpec); + assertStatusOk(response.getStatus()); + assertThat(response.getSuggestionsList()).hasSize(2); + assertThat(response.getSuggestions(0).getQuery()).isEqualTo("foo"); + assertThat(response.getSuggestions(1).getQuery()).isEqualTo("fo"); + } + private static void assertStatusOk(StatusProto status) { assertWithMessage(status.getMessage()).that(status.getCode()).isEqualTo(StatusProto.Code.OK); } diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto index 7abbf4a..2f1f271 100644 --- a/proto/icing/proto/logging.proto +++ b/proto/icing/proto/logging.proto @@ -118,12 +118,10 @@ message PutDocumentStatsProto { optional int32 document_size = 5; message TokenizationStats { - // Whether the number of tokens to be indexed exceeded the max number of - // tokens per document. - optional bool exceeded_max_token_num = 2; - // Number of tokens added to the index. optional int32 num_tokens_indexed = 1; + + reserved 2; } optional TokenizationStats tokenization_stats = 6; } diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto index 6186fde..a3a64df 100644 --- a/proto/icing/proto/scoring.proto +++ b/proto/icing/proto/scoring.proto @@ -23,7 +23,7 @@ option objc_class_prefix = "ICNG"; // Encapsulates the configurations on how Icing should score and rank the search // results. // TODO(b/170347684): Change all timestamps to seconds. -// Next tag: 3 +// Next tag: 4 message ScoringSpecProto { // OPTIONAL: Indicates how the search results will be ranked. message RankingStrategy { @@ -83,4 +83,41 @@ message ScoringSpecProto { } } optional Order.Code order_by = 2; + + // OPTIONAL: Specifies property weights for RELEVANCE_SCORE scoring strategy. + // Property weights are used for promoting or demoting query term matches in a + // document property. When property weights are provided, the term frequency + // is multiplied by the normalized property weight when computing the + // normalized term frequency component of BM25F. To prefer query term matches + // in the "subject" property over the "body" property of "Email" documents, + // set a higher property weight value for "subject" than "body". By default, + // all properties that are not specified are given a raw, pre-normalized + // weight of 1.0 when scoring. + repeated TypePropertyWeights type_property_weights = 3; +} + +// Next tag: 3 +message TypePropertyWeights { + // Schema type to apply property weights to. + optional string schema_type = 1; + + // Property weights to apply to the schema type. + repeated PropertyWeight property_weights = 2; +} + +// Next tag: 3 +message PropertyWeight { + // Property path to assign property weight to. Property paths must be composed + // only of property names and property separators (the '.' character). + // For example, if an "Email" schema type has string property "subject" and + // document property "sender", which has string property "name", the property + // path for the email's subject would just be "subject" and the property path + // for the sender's name would be "sender.name". If an invalid path is + // specified, the property weight is discarded. + optional string path = 1; + + // Property weight, valid values are positive. Zero and negative weights are + // invalid and will result in an error. By default, a property is given a raw, + // pre-normalized weight of 1.0. + optional double weight = 2; } diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 544995e..c712ab2 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -308,3 +308,37 @@ message GetResultSpecProto { // type will be retrieved. repeated TypePropertyMask type_property_masks = 1; } + +// Next tag: 4 +message SuggestionSpecProto { + // REQUIRED: The "raw" prefix string that users may type. For example, "f" + // will search for suggested query that start with "f" like "foo", "fool". + optional string prefix = 1; + + // OPTIONAL: Only search for suggestions that under the specified namespaces. + // If unset, the suggestion will search over all namespaces. Note that this + // applies to the entire 'prefix'. To issue different suggestions for + // different namespaces, separate RunSuggestion()'s will need to be made. + repeated string namespace_filters = 2; + + // REQUIRED: The number of suggestions to be returned. + optional int32 num_to_return = 3; +} + +// Next tag: 3 +message SuggestionResponse { + message Suggestion { + // The suggested query string for client to search for. + optional string query = 1; + } + + // Status code can be one of: + // OK + // FAILED_PRECONDITION + // INTERNAL + // + // See status.proto for more details. + optional StatusProto status = 1; + + repeated Suggestion suggestions = 2; +} diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index d57de81..7e0431b 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=395331611) +set(synced_AOSP_CL_number=404879391) |