aboutsummaryrefslogtreecommitdiff
path: root/icing/index/embed/embedding-index.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/index/embed/embedding-index.cc')
-rw-r--r--icing/index/embed/embedding-index.cc440
1 files changed, 440 insertions, 0 deletions
diff --git a/icing/index/embed/embedding-index.cc b/icing/index/embed/embedding-index.cc
new file mode 100644
index 0000000..2e70f0e
--- /dev/null
+++ b/icing/index/embed/embedding-index.cc
@@ -0,0 +1,440 @@
+// Copyright (C) 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/embed/embedding-index.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/file/destructible-directory.h"
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/embed/embedding-hit.h"
+#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
+#include "icing/index/hit/hit.h"
+#include "icing/store/document-id.h"
+#include "icing/store/dynamic-trie-key-mapper.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+#include "icing/util/encode-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+constexpr uint32_t kEmbeddingHitListMapperMaxSize =
+ 128 * 1024 * 1024; // 128 MiB;
+
+// The maximum length returned by encode_util::EncodeIntToCString is 5 for
+// uint32_t.
+constexpr uint32_t kEncodedDimensionLength = 5;
+
+std::string GetMetadataFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/metadata");
+}
+
+std::string GetFlashIndexStorageFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/flash_index_storage");
+}
+
+std::string GetEmbeddingHitListMapperPath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper");
+}
+
+std::string GetEmbeddingVectorsFilePath(std::string_view working_path) {
+ return absl_ports::StrCat(working_path, "/embedding_vectors");
+}
+
+// An injective function that maps the ordered pair (dimension, model_signature)
+// to a string, which is used to form a key for embedding_posting_list_mapper_.
+std::string GetPostingListKey(uint32_t dimension,
+ std::string_view model_signature) {
+ std::string encoded_dimension_str =
+ encode_util::EncodeIntToCString(dimension);
+ // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes.
+ while (encoded_dimension_str.size() < kEncodedDimensionLength) {
+ // C string cannot contain 0 bytes, so we append it using 1, just like what
+ // we do in encode_util::EncodeIntToCString.
+ //
+ // The reason that this works is because DecodeIntToString decodes a byte
+ // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded
+ // dimension that is less than 5 bytes, it means that the dimension contains
+ // unencoded leading 0x00. So here we're explicitly encoding those bytes as
+ // 0x01.
+ encoded_dimension_str.push_back(1);
+ }
+ return absl_ports::StrCat(encoded_dimension_str, model_signature);
+}
+
+std::string GetPostingListKey(const PropertyProto::VectorProto& vector) {
+ return GetPostingListKey(vector.values_size(), vector.model_signature());
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>>
+EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) {
+ ICING_RETURN_ERROR_IF_NULL(filesystem);
+
+ std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>(
+ new EmbeddingIndex(*filesystem, std::move(working_path)));
+ ICING_RETURN_IF_ERROR(index->Initialize());
+ return index;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Initialize() {
+ bool is_new = false;
+ if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) {
+ // Create working directory.
+ if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) {
+ return absl_ports::InternalError(
+ absl_ports::StrCat("Failed to create directory: ", working_path_));
+ }
+ is_new = true;
+ }
+
+ ICING_ASSIGN_OR_RETURN(
+ MemoryMappedFile metadata_mmapped_file,
+ MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_),
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
+ /*max_file_size=*/kMetadataFileSize,
+ /*pre_mapping_file_offset=*/0,
+ /*pre_mapping_mmap_size=*/kMetadataFileSize));
+ metadata_mmapped_file_ =
+ std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file));
+
+ ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(
+ GetFlashIndexStorageFilePath(working_path_),
+ &filesystem_, posting_list_hit_serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_posting_list_mapper_,
+ DynamicTrieKeyMapper<PostingListIdentifier>::Create(
+ filesystem_, GetEmbeddingHitListMapperPath(working_path_),
+ kEmbeddingHitListMapperMaxSize));
+
+ ICING_ASSIGN_OR_RETURN(
+ embedding_vectors_,
+ FileBackedVector<float>::Create(
+ filesystem_, GetEmbeddingVectorsFilePath(working_path_),
+ MemoryMappedFile::READ_WRITE_AUTO_SYNC));
+
+ if (is_new) {
+ ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary(
+ /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize));
+ info().magic = Info::kMagic;
+ info().last_added_document_id = kInvalidDocumentId;
+ ICING_RETURN_IF_ERROR(InitializeNewStorage());
+ } else {
+ if (metadata_mmapped_file_->available_size() != kMetadataFileSize) {
+ return absl_ports::FailedPreconditionError(
+ "Incorrect metadata file size");
+ }
+ if (info().magic != Info::kMagic) {
+ return absl_ports::FailedPreconditionError("Incorrect magic value");
+ }
+ ICING_RETURN_IF_ERROR(InitializeExistingStorage());
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Clear() {
+ pending_embedding_hits_.clear();
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+ if (filesystem_.DirectoryExists(working_path_.c_str())) {
+ ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_));
+ }
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
+EmbeddingIndex::GetAccessor(uint32_t dimension,
+ std::string_view model_signature) const {
+ if (dimension == 0) {
+ return absl_ports::InvalidArgumentError("Dimension is 0");
+ }
+ std::string key = GetPostingListKey(dimension, model_signature);
+ ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id,
+ embedding_posting_list_mapper_->Get(key));
+ return PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id);
+}
+
+libtextclassifier3::Status EmbeddingIndex::BufferEmbedding(
+ const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) {
+ if (vector.values_size() == 0) {
+ return absl_ports::InvalidArgumentError("Vector dimension is 0");
+ }
+
+ uint32_t location = embedding_vectors_->num_elements();
+ uint32_t dimension = vector.values_size();
+ std::string key = GetPostingListKey(vector);
+
+ // Buffer the embedding hit.
+ pending_embedding_hits_.push_back(
+ {std::move(key), EmbeddingHit(basic_hit, location)});
+
+ // Put vector
+ ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr,
+ embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension);
+
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() {
+ std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end());
+ auto iter_curr_key = pending_embedding_hits_.rbegin();
+ while (iter_curr_key != pending_embedding_hits_.rend()) {
+ // In order to batch putting embedding hits with the same key (dimension,
+ // model_signature) to the same posting list, we find the range
+ // [iter_curr_key, iter_next_key) of embedding hits with the same key and
+ // put them into their corresponding posting list together.
+ auto iter_next_key = iter_curr_key;
+ while (iter_next_key != pending_embedding_hits_.rend() &&
+ iter_next_key->first == iter_curr_key->first) {
+ iter_next_key++;
+ }
+
+ const std::string& key = iter_curr_key->first;
+ libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or =
+ embedding_posting_list_mapper_->Get(key);
+ std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
+ if (posting_list_id_or.ok()) {
+ // Existing posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ posting_list_id_or.ValueOrDie()));
+ } else if (absl_ports::IsNotFound(posting_list_id_or.status())) {
+ // New posting list.
+ ICING_ASSIGN_OR_RETURN(
+ pl_accessor,
+ PostingListEmbeddingHitAccessor::Create(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get()));
+ } else {
+ // Errors
+ return std::move(posting_list_id_or).status();
+ }
+
+ // Adding the embedding hits.
+ for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) {
+ ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second));
+ }
+
+ // Finalize this posting list and add the posting list id in
+ // embedding_posting_list_mapper_.
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id));
+
+ // Advance to the next key.
+ iter_curr_key = iter_next_key;
+ }
+ pending_embedding_hits_.clear();
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::TransferIndex(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ EmbeddingIndex* new_index) const {
+ std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr =
+ embedding_posting_list_mapper_->GetIterator();
+ while (itr->Advance()) {
+ std::string_view key = itr->GetKey();
+ // This should never happen unless there is an inconsistency, or the index
+ // is corrupted.
+ if (key.size() < kEncodedDimensionLength) {
+ return absl_ports::InternalError(
+ "Got invalid key from embedding posting list mapper.");
+ }
+ uint32_t dimension = encode_util::DecodeIntFromCString(
+ std::string_view(key.begin(), kEncodedDimensionLength));
+
+ // Transfer hits
+ std::vector<EmbeddingHit> new_hits;
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor,
+ PostingListEmbeddingHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_hit_serializer_.get(),
+ /*existing_posting_list_id=*/itr->GetValue()));
+ while (true) {
+ ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
+ old_pl_accessor->GetNextHitsBatch());
+ if (batch.empty()) {
+ break;
+ }
+ for (EmbeddingHit& old_hit : batch) {
+ // Safety checks to add robustness to the codebase, so to make sure
+ // that we never access invalid memory, in case that hit from the
+ // posting list is corrupted.
+ ICING_ASSIGN_OR_RETURN(const float* old_vector,
+ GetEmbeddingVector(old_hit, dimension));
+ if (old_hit.basic_hit().document_id() < 0 ||
+ old_hit.basic_hit().document_id() >=
+ document_id_old_to_new.size()) {
+ return absl_ports::InternalError(
+ "Embedding hit document id is out of bound. The provided map is "
+ "too small, or the index may have been corrupted.");
+ }
+
+ // Construct transferred hit
+ DocumentId new_document_id =
+ document_id_old_to_new[old_hit.basic_hit().document_id()];
+ if (new_document_id == kInvalidDocumentId) {
+ continue;
+ }
+ uint32_t new_location = new_index->embedding_vectors_->num_elements();
+ new_hits.push_back(EmbeddingHit(
+ BasicHit(old_hit.basic_hit().section_id(), new_document_id),
+ new_location));
+
+ // Copy the embedding vector of the hit to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ FileBackedVector<float>::MutableArrayView mutable_arr,
+ new_index->embedding_vectors_->Allocate(dimension));
+ mutable_arr.SetArray(/*idx=*/0, old_vector, dimension);
+ }
+ }
+ // No hit needs to be added to the new index.
+ if (new_hits.empty()) {
+ return libtextclassifier3::Status::OK;
+ }
+ // Add transferred hits to the new index.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum,
+ PostingListEmbeddingHitAccessor::Create(
+ new_index->flash_index_storage_.get(),
+ new_index->posting_list_hit_serializer_.get()));
+ for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend();
+ ++new_hit_itr) {
+ ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr));
+ }
+ PostingListEmbeddingHitAccessor::FinalizeResult result =
+ std::move(*hit_accum).Finalize();
+ if (!result.id.is_valid()) {
+ return absl_ports::InternalError("Failed to finalize posting list");
+ }
+ ICING_RETURN_IF_ERROR(
+ new_index->embedding_posting_list_mapper_->Put(key, result.id));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status EmbeddingIndex::Optimize(
+ const std::vector<DocumentId>& document_id_old_to_new,
+ DocumentId new_last_added_document_id) {
+ // This is just for completeness, but this should never be necessary, since we
+ // should never have pending hits at the time when Optimize is run.
+ ICING_RETURN_IF_ERROR(CommitBufferToIndex());
+
+ std::string temporary_index_working_path = working_path_ + "_temp";
+ if (!filesystem_.DeleteDirectoryRecursively(
+ temporary_index_working_path.c_str())) {
+ ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path;
+ return absl_ports::InternalError(
+ "Unable to delete temp directory to prepare to build new index.");
+ }
+
+ DestructibleDirectory temporary_index_dir(
+ &filesystem_, std::move(temporary_index_working_path));
+ if (!temporary_index_dir.is_valid()) {
+ return absl_ports::InternalError(
+ "Unable to create temp directory to build new index.");
+ }
+
+ {
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<EmbeddingIndex> new_index,
+ EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir()));
+ ICING_RETURN_IF_ERROR(
+ TransferIndex(document_id_old_to_new, new_index.get()));
+ new_index->set_last_added_document_id(new_last_added_document_id);
+ ICING_RETURN_IF_ERROR(new_index->PersistToDisk());
+ }
+
+ // Destruct current storage instances to safely swap directories.
+ metadata_mmapped_file_.reset();
+ flash_index_storage_.reset();
+ embedding_posting_list_mapper_.reset();
+ embedding_vectors_.reset();
+
+ if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(),
+ working_path_.c_str())) {
+ return absl_ports::InternalError(
+ "Unable to apply new index due to failed swap!");
+ }
+
+ // Reinitialize the index.
+ is_initialized_ = false;
+ return Initialize();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) {
+ return metadata_mmapped_file_->PersistToDisk();
+}
+
+libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) {
+ if (!flash_index_storage_->PersistToDisk()) {
+ return absl_ports::InternalError("Fail to persist flash index to disk");
+ }
+ ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk());
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum(
+ bool force) {
+ return info().ComputeChecksum();
+}
+
+libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum(
+ bool force) {
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc,
+ embedding_posting_list_mapper_->ComputeChecksum());
+ ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc,
+ embedding_vectors_->ComputeChecksum());
+ return Crc32(embedding_posting_list_mapper_crc.Get() ^
+ embedding_vectors_crc.Get());
+}
+
+} // namespace lib
+} // namespace icing