aboutsummaryrefslogtreecommitdiff
path: root/icing/index/embed/doc-hit-info-iterator-embedding.h
diff options
context:
space:
mode:
Diffstat (limited to 'icing/index/embed/doc-hit-info-iterator-embedding.h')
-rw-r--r--icing/index/embed/doc-hit-info-iterator-embedding.h161
1 files changed, 161 insertions, 0 deletions
diff --git a/icing/index/embed/doc-hit-info-iterator-embedding.h b/icing/index/embed/doc-hit-info-iterator-embedding.h
new file mode 100644
index 0000000..21180db
--- /dev/null
+++ b/icing/index/embed/doc-hit-info-iterator-embedding.h
@@ -0,0 +1,161 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/embedding-index.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/embed/embedding-scorer.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+
+namespace icing {
+namespace lib {
+
+class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator {
+ public:
+ // Create a DocHitInfoIterator for iterating through all docs which have an
+ // embedding matched with the provided query with a score in the range of
+ // [score_low, score_high], using the provided metric_type.
+ //
+ // The iterator will store the matched embedding scores in score_map to
+ // prepare for scoring.
+ //
+ // The iterator will handle the section restriction logic internally by the
+ // provided section_restrict_data.
+ //
+ // Returns:
+ // - a DocHitInfoIteratorEmbedding instance on success.
+ // - Any error from posting lists.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<DocHitInfoIteratorEmbedding>>
+ Create(const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ double score_low, double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index);
+
+ libtextclassifier3::Status Advance() override;
+
+ // The iterator will internally handle the section restriction logic by itself
+ // to have better control, so that it is able to filter out embedding hits
+ // from unwanted sections to avoid retrieving unnecessary vectors and
+ // calculate scores for them.
+ bool full_section_restriction_applied() const override { return true; }
+
+ libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
+ return absl_ports::InvalidArgumentError(
+ "Query suggestions for the semanticSearch function are not supported");
+ }
+
+ CallStats GetCallStats() const override {
+ return CallStats(
+ /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
+ /*num_leaf_advance_calls_main_index_in=*/0,
+ /*num_leaf_advance_calls_integer_index_in=*/0,
+ /*num_leaf_advance_calls_no_index_in=*/0,
+ /*num_blocks_inspected_in=*/0);
+ }
+
+ std::string ToString() const override { return "embedding_iterator"; }
+
+ // PopulateMatchedTermsStats is not applicable to embedding search.
+ void PopulateMatchedTermsStats(
+ std::vector<TermMatchInfo>* matched_terms_stats,
+ SectionIdMask filtering_section_mask) const override {}
+
+ private:
+ explicit DocHitInfoIteratorEmbedding(
+ const PropertyProto::VectorProto* query,
+ std::unique_ptr<SectionRestrictData> section_restrict_data,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
+ double score_high,
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
+ const EmbeddingIndex* embedding_index,
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)
+ : query_(*query),
+ section_restrict_data_(std::move(section_restrict_data)),
+ metric_type_(metric_type),
+ embedding_scorer_(std::move(embedding_scorer)),
+ score_low_(score_low),
+ score_high_(score_high),
+ score_map_(*score_map),
+ embedding_index_(*embedding_index),
+ posting_list_accessor_(std::move(posting_list_accessor)),
+ cached_embedding_hits_idx_(0),
+ current_allowed_sections_mask_(kSectionIdMaskAll),
+ no_more_hit_(false),
+ num_advance_calls_(0) {}
+
+ // Advance to the next embedding hit of the current document. If the current
+ // document id is kInvalidDocumentId, the method will advance to the first
+ // embedding hit of the next document and update doc_hit_info_.
+ //
+ // This method also properly updates cached_embedding_hits_,
+ // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
+ // no_more_hit_ to reflect the current state.
+ //
+ // Returns:
+ // - a const pointer to the next embedding hit on success.
+ // - nullptr, if there is no more hit for the current document, or no more
+ // hit in general if the current document id is kInvalidDocumentId.
+ // - Any error from posting lists.
+ libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
+
+ // Query information
+ const PropertyProto::VectorProto& query_; // Does not own
+ std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable.
+
+ // Scoring arguments
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ std::unique_ptr<EmbeddingScorer> embedding_scorer_;
+ double score_low_;
+ double score_high_;
+
+ // Score map
+ EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own
+
+ // Access to embeddings index data
+ const EmbeddingIndex& embedding_index_;
+ std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
+
+ // Cached data from the embeddings index
+ std::vector<EmbeddingHit> cached_embedding_hits_;
+ int cached_embedding_hits_idx_;
+ SectionIdMask current_allowed_sections_mask_;
+ bool no_more_hit_;
+
+ int num_advance_calls_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_