summaryrefslogtreecommitdiff
path: root/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc')
-rw-r--r--ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc133
1 files changed, 120 insertions, 13 deletions
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
index 0b301e4..9a8e385 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
@@ -37,20 +37,49 @@ float EuclideanDistance(const gfx::PointF& a, const gfx::PointF& b) {
bool IsEarlyStageSample(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
- config.early_stage_sample_counts.end();
+ if (!config.resample_period) {
+ return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
+ config.early_stage_sample_counts.end();
+ }
+ // Duration is not well-defined for sample_count <= 1, so we handle
+ // it separately.
+ if (stroke.samples().empty()) {
+ return false;
+ }
+ if (stroke.samples().size() == 1) {
+ return config.early_stage_sample_counts.find(1) !=
+ config.early_stage_sample_counts.end();
+ }
+ for (const uint32_t sample_count : config.early_stage_sample_counts) {
+ const base::TimeDelta duration = config.GetEquivalentDuration(sample_count);
+ // Previous sample must not have passed the 'duration' threshold, but the
+ // current sample must pass the threshold
+ if (stroke.LastSampleCrossed(duration)) {
+ return true;
+ }
+ }
+ return false;
}
bool HasDecidedStroke(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return stroke.samples_seen() >= config.max_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() >= config.max_sample_count;
+ }
+ const base::TimeDelta max_duration =
+ config.GetEquivalentDuration(config.max_sample_count);
+ return stroke.Duration() >= max_duration;
}
bool IsVeryShortStroke(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return stroke.samples_seen() < config.min_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() < config.min_sample_count;
+ }
+ return stroke.Duration() <
+ config.GetEquivalentDuration(config.min_sample_count);
}
/**
@@ -58,9 +87,15 @@ bool IsVeryShortStroke(
* being evaluated. The parameter 'neighbor_min_sample_count' might be different
* from the config, depending on the specific usage in the caller.
*/
-bool HasInsufficientDataAsNeighbor(const PalmFilterStroke& neighbor_stroke,
- size_t neighbor_min_sample_count) {
- return neighbor_stroke.samples().size() < neighbor_min_sample_count;
+bool HasInsufficientDataAsNeighbor(
+ const PalmFilterStroke& neighbor_stroke,
+ size_t neighbor_min_sample_count,
+ const NeuralStylusPalmDetectionFilterModelConfig& config) {
+ if (!config.resample_period) {
+ return neighbor_stroke.samples().size() < neighbor_min_sample_count;
+ }
+ return neighbor_stroke.Duration() <
+ config.GetEquivalentDuration(neighbor_min_sample_count);
}
} // namespace
@@ -105,7 +140,8 @@ void NeuralStylusPalmDetectionFilter::FindBiggestNeighborsWithin(
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
+ model_->config())) {
continue;
}
float distance =
@@ -145,7 +181,8 @@ void NeuralStylusPalmDetectionFilter::FindNearestNeighborsWithin(
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
+ model_->config())) {
continue;
}
float distance =
@@ -300,15 +337,29 @@ bool NeuralStylusPalmDetectionFilter::ShouldDecideStroke(
const PalmFilterStroke& stroke) const {
const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
// Inference only executed once per stroke
- return stroke.samples_seen() == config.max_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() == config.max_sample_count;
+ }
+ return stroke.LastSampleCrossed(
+ config.GetEquivalentDuration(config.max_sample_count));
}
bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
const PalmFilterStroke& stroke) const {
const auto& config = model_->config();
- if (stroke.samples().size() >= config.max_sample_count) {
- LOG(DFATAL) << "Should not call this method on long strokes.";
- return false;
+ if (config.resample_period) {
+ if (stroke.Duration() >
+ config.GetEquivalentDuration(config.max_sample_count)) {
+ LOG(DFATAL)
+ << "Should not call this method on long strokes. Got duration = "
+ << stroke.Duration();
+ return false;
+ }
+ } else {
+ if (stroke.samples().size() >= config.max_sample_count) {
+ LOG(DFATAL) << "Should not call this method on long strokes.";
+ return false;
+ }
}
if (config.heuristic_palm_touch_limit > 0.0) {
@@ -401,6 +452,9 @@ std::vector<float> NeuralStylusPalmDetectionFilter::ExtractFeatures(
void NeuralStylusPalmDetectionFilter::AppendFeatures(
const PalmFilterStroke& stroke,
std::vector<float>* features) const {
+ if (model_->config().resample_period) {
+ return AppendResampledFeatures(stroke, features);
+ }
const int size = stroke.samples().size();
for (int i = 0; i < size; ++i) {
const PalmFilterSample& sample = stroke.samples()[i];
@@ -435,6 +489,59 @@ void NeuralStylusPalmDetectionFilter::AppendFeatures(
features->push_back(samples_seen - model_->config().max_sample_count);
}
}
+
+/**
+ * The flow here is similar to 'AppendFeatures' above, but we rely on the
+ * timing of the samples rather than on the explicit number / position of
+ * samples.
+ */
+void NeuralStylusPalmDetectionFilter::AppendResampledFeatures(
+ const PalmFilterStroke& stroke,
+ std::vector<float>* features) const {
+ size_t sample_count = 0;
+ const base::TimeTicks& first_time = stroke.samples()[0].time;
+ const base::TimeDelta& resample_period = *model_->config().resample_period;
+ const base::TimeDelta max_duration =
+ model_->config().GetEquivalentDuration(model_->config().max_sample_count);
+ for (auto time = first_time; (time - first_time) <= max_duration &&
+ time <= stroke.samples().back().time;
+ time += resample_period) {
+ sample_count++;
+ const PalmFilterSample& sample = stroke.GetSampleAt(time);
+ features->push_back(sample.major_radius);
+ features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
+ : sample.minor_radius);
+ float distance = 0;
+ if (time != first_time) {
+ distance = EuclideanDistance(
+ stroke.GetSampleAt(time - resample_period).point, sample.point);
+ }
+ features->push_back(distance);
+ features->push_back(sample.edge);
+ features->push_back(1.0); // existence.
+ }
+ const int padding = model_->config().max_sample_count - sample_count;
+ DCHECK_GE(padding, 0);
+
+ for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
+ features->push_back(0.0);
+ }
+ // "fill proportion."
+ features->push_back(static_cast<float>(sample_count) /
+ model_->config().max_sample_count);
+ features->push_back(EuclideanDistance(stroke.samples().front().point,
+ stroke.samples().back().point));
+
+ // Start sequence number. 0 is min.
+ uint32_t samples_seen =
+ (stroke.Duration() / (*model_->config().resample_period)) + 1;
+ if (samples_seen < model_->config().max_sample_count) {
+ features->push_back(0);
+ } else {
+ features->push_back(samples_seen - model_->config().max_sample_count);
+ }
+}
+
void NeuralStylusPalmDetectionFilter::AppendFeaturesAsNeighbor(
const PalmFilterStroke& stroke,
float distance,