diff options
Diffstat (limited to 'tensorflow/core/kernels/checkpoint_callback_manager.cc')
-rw-r--r-- | tensorflow/core/kernels/checkpoint_callback_manager.cc | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/checkpoint_callback_manager.cc b/tensorflow/core/kernels/checkpoint_callback_manager.cc index fb94c19dcda..0e0fae0f91d 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager.cc +++ b/tensorflow/core/kernels/checkpoint_callback_manager.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/checkpoint_callback_manager.h" +#include <regex> #include <string> #include <utility> @@ -24,7 +25,8 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/regexp.h" +// Remove RE2 usage +// #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/stringpiece.h" @@ -38,9 +40,9 @@ const absl::string_view kCheckpointCallbackManagerResourceName = namespace { -const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$"; -const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$"; -const absl::string_view kCheckpointDirRegex = "-[0-9]*$"; +const char* kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$"; +const char* kCheckpointTempDirRegex = "-[0-9]*_temp$"; +const char* kCheckpointDirRegex = "-[0-9]*$"; const absl::string_view kCheckpointTempDirSuffix = "_temp"; void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id, @@ -115,17 +117,26 @@ StatusOr<std::pair<std::string, std::string>> CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix( absl::string_view prefix) { for (absl::string_view path = prefix;; path = io::Dirname(path)) { - absl::string_view basename = io::Basename(path); + std::string basename = std::string(io::Basename(path)); // Failed to find checkpoint_id if (basename.empty()) break; // Skip known checkpoint file: e.g., part-00000-of-00001 - if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue; + // if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue; + std::regex checkpoint_file_regex(kCheckpointFileRegex); + if (std::regex_search(basename, checkpoint_file_regex)) continue; // With _temp suffix: e.g., checkpoint-1_temp - if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) { - // Trim suffix, "_temp". + // if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) { + // // Trim suffix, "_temp". + // return std::make_pair( + // std::string(basename.substr( + // 0, basename.length() - kCheckpointTempDirSuffix.length())), + // std::string(io::Dirname(path))); + // } + std::regex checkpoint_temp_dir_regex(kCheckpointTempDirRegex); + if (std::regex_search(basename, checkpoint_temp_dir_regex)) { return std::make_pair( std::string(basename.substr( 0, basename.length() - kCheckpointTempDirSuffix.length())), @@ -133,7 +144,12 @@ CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix( } // Without _temp suffix: e.g., checkpoint-1 - if (RE2::PartialMatch(basename, kCheckpointDirRegex)) { + // if (RE2::PartialMatch(basename, kCheckpointDirRegex)) { + // return std::make_pair(std::string(basename), + // std::string(io::Dirname(path))); + // } + std::regex checkpoint_dir_regex(kCheckpointDirRegex); + if (std::regex_search(basename, checkpoint_dir_regex)) { return std::make_pair(std::string(basename), std::string(io::Dirname(path))); } |