aboutsummaryrefslogtreecommitdiff
path: root/tensorflow/core/kernels/checkpoint_callback_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/checkpoint_callback_manager.cc')
-rw-r--r--tensorflow/core/kernels/checkpoint_callback_manager.cc34
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/checkpoint_callback_manager.cc b/tensorflow/core/kernels/checkpoint_callback_manager.cc
index fb94c19dcda..0e0fae0f91d 100644
--- a/tensorflow/core/kernels/checkpoint_callback_manager.cc
+++ b/tensorflow/core/kernels/checkpoint_callback_manager.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/checkpoint_callback_manager.h"
+#include <regex>
#include <string>
#include <utility>
@@ -24,7 +25,8 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/path.h"
-#include "tensorflow/core/platform/regexp.h"
+// Remove RE2 usage
+// #include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/stringpiece.h"
@@ -38,9 +40,9 @@ const absl::string_view kCheckpointCallbackManagerResourceName =
namespace {
-const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$";
-const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$";
-const absl::string_view kCheckpointDirRegex = "-[0-9]*$";
+const char* kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$";
+const char* kCheckpointTempDirRegex = "-[0-9]*_temp$";
+const char* kCheckpointDirRegex = "-[0-9]*$";
const absl::string_view kCheckpointTempDirSuffix = "_temp";
void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,
@@ -115,17 +117,26 @@ StatusOr<std::pair<std::string, std::string>>
CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix(
absl::string_view prefix) {
for (absl::string_view path = prefix;; path = io::Dirname(path)) {
- absl::string_view basename = io::Basename(path);
+ std::string basename = std::string(io::Basename(path));
// Failed to find checkpoint_id
if (basename.empty()) break;
// Skip known checkpoint file: e.g., part-00000-of-00001
- if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue;
+ // if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue;
+ std::regex checkpoint_file_regex(kCheckpointFileRegex);
+ if (std::regex_search(basename, checkpoint_file_regex)) continue;
// With _temp suffix: e.g., checkpoint-1_temp
- if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) {
- // Trim suffix, "_temp".
+ // if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) {
+ // // Trim suffix, "_temp".
+ // return std::make_pair(
+ // std::string(basename.substr(
+ // 0, basename.length() - kCheckpointTempDirSuffix.length())),
+ // std::string(io::Dirname(path)));
+ // }
+ std::regex checkpoint_temp_dir_regex(kCheckpointTempDirRegex);
+ if (std::regex_search(basename, checkpoint_temp_dir_regex)) {
return std::make_pair(
std::string(basename.substr(
0, basename.length() - kCheckpointTempDirSuffix.length())),
@@ -133,7 +144,12 @@ CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix(
}
// Without _temp suffix: e.g., checkpoint-1
- if (RE2::PartialMatch(basename, kCheckpointDirRegex)) {
+ // if (RE2::PartialMatch(basename, kCheckpointDirRegex)) {
+ // return std::make_pair(std::string(basename),
+ // std::string(io::Dirname(path)));
+ // }
+ std::regex checkpoint_dir_regex(kCheckpointDirRegex);
+ if (std::regex_search(basename, checkpoint_dir_regex)) {
return std::make_pair(std::string(basename),
std::string(io::Dirname(path)));
}