aboutsummaryrefslogtreecommitdiff
path: root/icing/icing-search-engine_search_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/icing-search-engine_search_test.cc')
-rw-r--r--icing/icing-search-engine_search_test.cc213
1 files changed, 211 insertions, 2 deletions
diff --git a/icing/icing-search-engine_search_test.cc b/icing/icing-search-engine_search_test.cc
index d815f61..a58dbc8 100644
--- a/icing/icing-search-engine_search_test.cc
+++ b/icing/icing-search-engine_search_test.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include <cstdint>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
+#include <vector>
-#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/document-builder.h"
@@ -27,7 +29,7 @@
#include "icing/index/lite/term-id-hit-pair.h"
#include "icing/jni/jni-cache.h"
#include "icing/join/join-processor.h"
-#include "icing/portable/endian.h"
+#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/equals-proto.h"
#include "icing/portable/platform.h"
#include "icing/proto/debug.pb.h"
@@ -49,11 +51,13 @@
#include "icing/result/result-state-manager.h"
#include "icing/schema-builder.h"
#include "icing/testing/common-matchers.h"
+#include "icing/testing/embedding-test-utils.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/clock.h"
#include "icing/util/snippet-helpers.h"
namespace icing {
@@ -63,6 +67,7 @@ namespace {
using ::icing::lib::portable_equals_proto::EqualsProto;
using ::testing::DoubleEq;
+using ::testing::DoubleNear;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Gt;
@@ -120,6 +125,8 @@ class IcingSearchEngineSearchTest
// Non-zero value so we don't override it to be the current time
constexpr int64_t kDefaultCreationTimestampMs = 1575492852000;
+constexpr double kEps = 0.000001;
+
IcingSearchEngineOptions GetDefaultIcingOptions() {
IcingSearchEngineOptions icing_options;
icing_options.set_base_dir(GetTestBaseDir());
@@ -7306,6 +7313,208 @@ TEST_P(IcingSearchEngineSearchTest, HasPropertyQueryNestedDocument) {
EXPECT_THAT(results.results(), IsEmpty());
}
+TEST_P(IcingSearchEngineSearchTest, EmbeddingSearch) {
+ if (GetParam() !=
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ GTEST_SKIP() << "Embedding search is only supported in advanced query.";
+ }
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Email")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("body")
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding1")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("embedding2")
+ .SetDataTypeVector(
+ EMBEDDING_INDEXING_LINEAR_SEARCH)
+ .SetCardinality(CARDINALITY_REPEATED)))
+ .Build();
+ DocumentProto document0 =
+ DocumentBuilder()
+ .SetKey("icing", "uri0")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddStringProperty("body", "foo")
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {0.1, 0.2, 0.3, 0.4, 0.5}))
+ .AddVectorProperty(
+ "embedding2",
+ CreateVector("my_model_v1", {-0.1, -0.2, -0.3, 0.4, 0.5}),
+ CreateVector("my_model_v2", {0.6, 0.7, 0.8}))
+ .Build();
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("icing", "uri1")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(1)
+ .AddVectorProperty(
+ "embedding1",
+ CreateVector("my_model_v1", {-0.1, 0.2, -0.3, -0.4, 0.5}))
+ .AddVectorProperty("embedding2",
+ CreateVector("my_model_v2", {0.6, 0.7, -0.8}))
+ .Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document0).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_term_match_type(TermMatchType::EXACT_ONLY);
+ search_spec.set_embedding_query_metric_type(
+ SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT);
+ search_spec.add_enabled_features(
+ std::string(kListFilterQueryLanguageFeature));
+ search_spec.add_enabled_features(std::string(kEmbeddingSearchFeature));
+ search_spec.set_search_type(GetParam());
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v1", {1, -1, -1, 1, -1});
+ // Add an embedding query with semantic scores:
+ // - document 0: -0.5 (embedding2)
+ // - document 1: -2.1 (embedding2)
+ *search_spec.add_embedding_query_vectors() =
+ CreateVector("my_model_v2", {-1, -1, 1});
+ ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
+ scoring_spec.set_rank_by(
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
+
+ // Match documents that have embeddings with a similarity closer to 0 that is
+ // greater than -1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({}) = -0.2
+ // - document 1: sum({-0.9}) + sum({}) = -0.9
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ SearchResultProto results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query the same as above but with a section restriction, which
+ // still matches document 0 and document 1 but the semantic score 0.3 should
+ // be removed from document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1)
+ // - document 1: -0.9 (embedding1)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ // - document 1: sum({-0.9}) = -0.9
+ search_spec.set_query(
+ "embedding1:semanticSearch(getSearchSpecEmbedding(0), -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9, kEps));
+
+ // Create a query that only matches document 0.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5}) = -0.5
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -1.5)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5, kEps));
+
+ // Create a query that only matches document 1.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query("semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-2.1, kEps));
+
+ // Create a complex query that matches all hits from all documents.
+ //
+ // The matched embeddings for each doc are:
+ // - document 0: -0.5 (embedding1), 0.3 (embedding2), -0.5 (embedding2)
+ // - document 1: -0.9 (embedding1), -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({-0.5, 0.3}) + sum({-0.5}) = -0.7
+ // - document 1: sum({-0.9}) + sum({-2.1}) = -3
+ search_spec.set_query(
+ "semanticSearch(getSearchSpecEmbedding(0)) OR "
+ "semanticSearch(getSearchSpecEmbedding(1))");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(0))) + "
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ EXPECT_THAT(results.results(0).score(), DoubleNear(-0.5 + 0.3 - 0.5, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-0.9 - 2.1, kEps));
+
+ // Create a hybrid query that matches document 0 because of term-based search
+ // and document 1 because of embedding-based search.
+ //
+ // The matched embeddings for each doc are:
+ // - document 1: -2.1 (embedding2)
+ // The scoring expression for each doc will be evaluated as:
+ // - document 0: sum({}) = 0
+ // - document 1: sum({-2.1}) = -2.1
+ search_spec.set_query(
+ "foo OR semanticSearch(getSearchSpecEmbedding(1), -10, -1)");
+ scoring_spec.set_advanced_scoring_expression(
+ "sum(this.matchedSemanticScores(getSearchSpecEmbedding(1)))");
+ results = icing.Search(search_spec, scoring_spec,
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoIsOk());
+ EXPECT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document0));
+ // Document 0 has no matched embedding hit, so its score is 0.
+ EXPECT_THAT(results.results(0).score(), DoubleNear(0, kEps));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document1));
+ EXPECT_THAT(results.results(1).score(), DoubleNear(-2.1, kEps));
+}
+
INSTANTIATE_TEST_SUITE_P(
IcingSearchEngineSearchTest, IcingSearchEngineSearchTest,
testing::Values(