summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-13 23:10:25 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-13 23:10:25 +0000
commitb48b247096f6e71491b8a5e0605bbefd72105aa4 (patch)
tree40aa418887f3eb8c3ad0b5f149e5ef1b46f5fb8e
parent2fca1ae9fd82c79128e10bbf8404d2079211cf53 (diff)
parentcf709f207d4af4c471ea39a0dff18733d35aa86f (diff)
downloadlibpalmrejection-b48b247096f6e71491b8a5e0605bbefd72105aa4.tar.gz
Snap for 9058783 from cf709f207d4af4c471ea39a0dff18733d35aa86f to tm-qpr1-release
Change-Id: Ibe40bb88cb4016e4003878e6f98bc4f498354c56
-rw-r--r--ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc60
1 files changed, 41 insertions, 19 deletions
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
index 49c2a4e..0b301e4 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
@@ -34,6 +34,35 @@ float EuclideanDistance(const gfx::PointF& a, const gfx::PointF& b) {
return (a - b).Length();
}
+bool IsEarlyStageSample(
+ const PalmFilterStroke& stroke,
+ const NeuralStylusPalmDetectionFilterModelConfig& config) {
+ return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
+ config.early_stage_sample_counts.end();
+}
+
+bool HasDecidedStroke(
+ const PalmFilterStroke& stroke,
+ const NeuralStylusPalmDetectionFilterModelConfig& config) {
+ return stroke.samples_seen() >= config.max_sample_count;
+}
+
+bool IsVeryShortStroke(
+ const PalmFilterStroke& stroke,
+ const NeuralStylusPalmDetectionFilterModelConfig& config) {
+ return stroke.samples_seen() < config.min_sample_count;
+}
+
+/**
+ * The provided stroke must be a neighbor stroke rather than a stroke currently
+ * being evaluated. The parameter 'neighbor_min_sample_count' might be different
+ * from the config, depending on the specific usage in the caller.
+ */
+bool HasInsufficientDataAsNeighbor(const PalmFilterStroke& neighbor_stroke,
+ size_t neighbor_min_sample_count) {
+ return neighbor_stroke.samples().size() < neighbor_min_sample_count;
+}
+
} // namespace
NeuralStylusPalmDetectionFilter::NeuralStylusPalmDetectionFilter(
@@ -76,7 +105,7 @@ void NeuralStylusPalmDetectionFilter::FindBiggestNeighborsWithin(
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (neighbor.samples().size() < neighbor_min_sample_count) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
continue;
}
float distance =
@@ -116,7 +145,7 @@ void NeuralStylusPalmDetectionFilter::FindNearestNeighborsWithin(
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (neighbor.samples().size() < neighbor_min_sample_count) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
continue;
}
float distance =
@@ -195,7 +224,7 @@ void NeuralStylusPalmDetectionFilter::Filter(
PalmFilterStroke& stroke = stroke_it->second;
if (end_of_stroke) {
// This is a stroke that hasn't had a decision yet, so we force decide.
- if (stroke.samples().size() < config.max_sample_count) {
+ if (!HasDecidedStroke(stroke, config)) {
slots_to_decide.insert(slot);
}
@@ -215,8 +244,7 @@ void NeuralStylusPalmDetectionFilter::Filter(
// Heuristic delay detection.
if (config.heuristic_delay_start_if_palm && !end_of_stroke &&
- stroke.samples_seen() < config.max_sample_count &&
- IsHeuristicPalmStroke(stroke)) {
+ !HasDecidedStroke(stroke, config) && IsHeuristicPalmStroke(stroke)) {
// A stroke that we _think_ may be a palm, but is too short to decide
// yet. So we mark for delay for now.
is_delay_.set(slot, true);
@@ -224,8 +252,7 @@ void NeuralStylusPalmDetectionFilter::Filter(
// Early stage delay detection that marks suspicious palms for delay.
if (!is_delay_.test(slot) && config.nn_delay_start_if_palm &&
- config.early_stage_sample_counts.find(stroke.samples_seen()) !=
- config.early_stage_sample_counts.end()) {
+ IsEarlyStageSample(stroke, config)) {
VLOG(1) << "About to run a early_stage prediction.";
if (DetectSpuriousStroke(ExtractFeatures(tracking_id),
model_->config().output_threshold)) {
@@ -245,7 +272,7 @@ void NeuralStylusPalmDetectionFilter::Filter(
continue;
}
const auto& stroke = lookup->second;
- if (stroke.samples_seen() < model_->config().min_sample_count) {
+ if (IsVeryShortStroke(stroke, model_->config())) {
// in very short strokes: we use a heuristic.
is_palm_.set(slot, IsHeuristicPalmStroke(stroke));
continue;
@@ -272,23 +299,18 @@ void NeuralStylusPalmDetectionFilter::Filter(
bool NeuralStylusPalmDetectionFilter::ShouldDecideStroke(
const PalmFilterStroke& stroke) const {
const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
- // Perform inference at most every |max_sample_count| samples.
- if (stroke.samples_seen() % config.max_sample_count != 0)
- return false;
-
- // Only inference at start.
- if (stroke.samples_seen() > config.max_sample_count)
- return false;
- return true;
+ // Inference only executed once per stroke
+ return stroke.samples_seen() == config.max_sample_count;
}
bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
const PalmFilterStroke& stroke) const {
- if (stroke.samples().size() >= model_->config().max_sample_count) {
+ const auto& config = model_->config();
+ if (stroke.samples().size() >= config.max_sample_count) {
LOG(DFATAL) << "Should not call this method on long strokes.";
return false;
}
- const auto& config = model_->config();
+
if (config.heuristic_palm_touch_limit > 0.0) {
if (stroke.MaxMajorRadius() >= config.heuristic_palm_touch_limit) {
VLOG(1) << "IsHeuristicPalm: Yes major radius.";
@@ -303,7 +325,7 @@ bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
std::vector<std::pair<float, int>> biggest_strokes;
FindBiggestNeighborsWithin(
1 /* neighbors */, 1 /* neighbor min sample count */,
- model_->config().max_neighbor_distance_in_mm, stroke, &biggest_strokes);
+ config.max_neighbor_distance_in_mm, stroke, &biggest_strokes);
if (!biggest_strokes.empty() &&
strokes_.find(biggest_strokes[0].second)->second.BiggestSize() >=
config.heuristic_palm_area_limit) {