aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Freilich <sfreilich@google.com>2023-03-21 07:58:56 -0700
committerCopybara-Service <copybara-worker@google.com>2023-03-21 07:59:29 -0700
commitab17d1788b46499c5d3cb51061aed2dd1b119b80 (patch)
tree324ae0b5296110943a649be1223bf01f0c8dab14
parent41ef7406b8f6aa14844b7dda1fb00328cd723d1c (diff)
downloadink-stroke-modeler-ab17d1788b46499c5d3cb51061aed2dd1b119b80.tar.gz
Simplify initialization of predictor
Just use std::holds_alternative instead of std::visit followed by something convoluted with if constexpr. PiperOrigin-RevId: 518272650
-rw-r--r--ink_stroke_modeler/BUILD.bazel1
-rw-r--r--ink_stroke_modeler/CMakeLists.txt2
-rw-r--r--ink_stroke_modeler/stroke_modeler.cc41
3 files changed, 16 insertions, 28 deletions
diff --git a/ink_stroke_modeler/BUILD.bazel b/ink_stroke_modeler/BUILD.bazel
index 1172f96..0896246 100644
--- a/ink_stroke_modeler/BUILD.bazel
+++ b/ink_stroke_modeler/BUILD.bazel
@@ -76,7 +76,6 @@ cc_library(
"//ink_stroke_modeler/internal/prediction:input_predictor",
"//ink_stroke_modeler/internal/prediction:kalman_predictor",
"//ink_stroke_modeler/internal/prediction:stroke_end_predictor",
- "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
diff --git a/ink_stroke_modeler/CMakeLists.txt b/ink_stroke_modeler/CMakeLists.txt
index c8712e2..9392be9 100644
--- a/ink_stroke_modeler/CMakeLists.txt
+++ b/ink_stroke_modeler/CMakeLists.txt
@@ -79,11 +79,9 @@ ink_cc_library(
InkStrokeModeler::input_predictor
InkStrokeModeler::kalman_predictor
InkStrokeModeler::stroke_end_predictor
- absl::core_headers
absl::status
absl::statusor
absl::strings
- absl::variant
)
ink_cc_test(
diff --git a/ink_stroke_modeler/stroke_modeler.cc b/ink_stroke_modeler/stroke_modeler.cc
index e5494f0..ae6a28d 100644
--- a/ink_stroke_modeler/stroke_modeler.cc
+++ b/ink_stroke_modeler/stroke_modeler.cc
@@ -17,11 +17,9 @@
#include <iterator>
#include <memory>
#include <optional>
-#include <type_traits>
#include <variant>
#include <vector>
-#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
@@ -67,9 +65,6 @@ absl::StatusOr<int> GetNumberOfSteps(Time start_time, Time end_time,
return n_steps;
}
-template <typename>
-ABSL_ATTRIBUTE_UNUSED inline constexpr bool kAlwaysFalse = false;
-
} // namespace
absl::Status StrokeModeler::Reset(
@@ -85,26 +80,22 @@ absl::Status StrokeModeler::Reset(
stroke_model_params_ = stroke_model_params;
ResetInternal();
- std::visit(
- [this](auto &&params) {
- using ParamType = std::decay_t<decltype(params)>;
- if constexpr (std::is_same_v<ParamType, KalmanPredictorParams>) {
- predictor_ = std::make_unique<KalmanPredictor>(
- params, stroke_model_params_->sampling_params);
- } else if constexpr (std::is_same_v<ParamType,
- StrokeEndPredictorParams>) {
- predictor_ = std::make_unique<StrokeEndPredictor>(
- stroke_model_params_->position_modeler_params,
- stroke_model_params_->sampling_params);
- } else if constexpr (std::is_same_v<ParamType,
- DisabledPredictorParams>) {
- predictor_ = nullptr;
- } else {
- static_assert(kAlwaysFalse<ParamType>,
- "Unknown prediction parameter type");
- }
- },
- stroke_model_params_->prediction_params);
+ const PredictionParams &prediction_params =
+ stroke_model_params_->prediction_params;
+ static_assert(std::variant_size_v<PredictionParams> == 3);
+ if (std::holds_alternative<KalmanPredictorParams>(prediction_params)) {
+ predictor_ = std::make_unique<KalmanPredictor>(
+ std::get<KalmanPredictorParams>(prediction_params),
+ stroke_model_params_->sampling_params);
+ } else if (std::holds_alternative<StrokeEndPredictorParams>(
+ prediction_params)) {
+ predictor_ = std::make_unique<StrokeEndPredictor>(
+ stroke_model_params_->position_modeler_params,
+ stroke_model_params_->sampling_params);
+ } else if (std::holds_alternative<DisabledPredictorParams>(
+ prediction_params)) {
+ predictor_ = nullptr;
+ }
return absl::OkStatus();
}