aboutsummaryrefslogtreecommitdiff
path: root/icing/index/embed/embedding-query-results.h
diff options
context:
space:
mode:
Diffstat (limited to 'icing/index/embed/embedding-query-results.h')
-rw-r--r--icing/index/embed/embedding-query-results.h72
1 files changed, 72 insertions, 0 deletions
diff --git a/icing/index/embed/embedding-query-results.h b/icing/index/embed/embedding-query-results.h
new file mode 100644
index 0000000..de85489
--- /dev/null
+++ b/icing/index/embed/embedding-query-results.h
@@ -0,0 +1,72 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+#define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "icing/proto/search.pb.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+// A class to store results generated from embedding queries.
+struct EmbeddingQueryResults {
+ // Maps from DocumentId to the list of matched embedding scores for that
+ // document, which will be used in the advanced scoring language to
+ // determine the results for the "this.matchedSemanticScores(...)" function.
+ using EmbeddingQueryScoreMap =
+ std::unordered_map<DocumentId, std::vector<double>>;
+
+ // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap.
+ std::unordered_map<
+ int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code,
+ EmbeddingQueryScoreMap>>
+ result_scores;
+
+ // Returns the matched scores for the given query_vector_index, metric_type,
+ // and doc_id. Returns nullptr if (query_vector_index, metric_type) does not
+ // exist in the result_scores map.
+ const std::vector<double>* GetMatchedScoresForDocument(
+ int query_vector_index,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ DocumentId doc_id) const {
+ // Check if a mapping exists for the query_vector_index
+ auto outer_it = result_scores.find(query_vector_index);
+ if (outer_it == result_scores.end()) {
+ return nullptr;
+ }
+ // Check if a mapping exists for the metric_type
+ auto inner_it = outer_it->second.find(metric_type);
+ if (inner_it == outer_it->second.end()) {
+ return nullptr;
+ }
+ const EmbeddingQueryScoreMap& score_map = inner_it->second;
+
+ // Check if the doc_id exists in the score_map
+ auto scores_it = score_map.find(doc_id);
+ if (scores_it == score_map.end()) {
+ return nullptr;
+ }
+ return &scores_it->second;
+ }
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_