diff options
Diffstat (limited to 'ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc')
-rw-r--r-- | ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc | 133 |
1 files changed, 120 insertions, 13 deletions
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc index 0b301e4..9a8e385 100644 --- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc +++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc @@ -37,20 +37,49 @@ float EuclideanDistance(const gfx::PointF& a, const gfx::PointF& b) { bool IsEarlyStageSample( const PalmFilterStroke& stroke, const NeuralStylusPalmDetectionFilterModelConfig& config) { - return config.early_stage_sample_counts.find(stroke.samples_seen()) != - config.early_stage_sample_counts.end(); + if (!config.resample_period) { + return config.early_stage_sample_counts.find(stroke.samples_seen()) != + config.early_stage_sample_counts.end(); + } + // Duration is not well-defined for sample_count <= 1, so we handle + // it separately. + if (stroke.samples().empty()) { + return false; + } + if (stroke.samples().size() == 1) { + return config.early_stage_sample_counts.find(1) != + config.early_stage_sample_counts.end(); + } + for (const uint32_t sample_count : config.early_stage_sample_counts) { + const base::TimeDelta duration = config.GetEquivalentDuration(sample_count); + // Previous sample must not have passed the 'duration' threshold, but the + // current sample must pass the threshold + if (stroke.LastSampleCrossed(duration)) { + return true; + } + } + return false; } bool HasDecidedStroke( const PalmFilterStroke& stroke, const NeuralStylusPalmDetectionFilterModelConfig& config) { - return stroke.samples_seen() >= config.max_sample_count; + if (!config.resample_period) { + return stroke.samples_seen() >= config.max_sample_count; + } + const base::TimeDelta max_duration = + config.GetEquivalentDuration(config.max_sample_count); + return stroke.Duration() >= max_duration; } bool IsVeryShortStroke( const PalmFilterStroke& stroke, const NeuralStylusPalmDetectionFilterModelConfig& config) { - return stroke.samples_seen() < config.min_sample_count; + if (!config.resample_period) { + return stroke.samples_seen() < config.min_sample_count; + } + return stroke.Duration() < + config.GetEquivalentDuration(config.min_sample_count); } /** @@ -58,9 +87,15 @@ bool IsVeryShortStroke( * being evaluated. The parameter 'neighbor_min_sample_count' might be different * from the config, depending on the specific usage in the caller. */ -bool HasInsufficientDataAsNeighbor(const PalmFilterStroke& neighbor_stroke, - size_t neighbor_min_sample_count) { - return neighbor_stroke.samples().size() < neighbor_min_sample_count; +bool HasInsufficientDataAsNeighbor( + const PalmFilterStroke& neighbor_stroke, + size_t neighbor_min_sample_count, + const NeuralStylusPalmDetectionFilterModelConfig& config) { + if (!config.resample_period) { + return neighbor_stroke.samples().size() < neighbor_min_sample_count; + } + return neighbor_stroke.Duration() < + config.GetEquivalentDuration(neighbor_min_sample_count); } } // namespace @@ -105,7 +140,8 @@ void NeuralStylusPalmDetectionFilter::FindBiggestNeighborsWithin( if (neighbor.tracking_id() == stroke.tracking_id()) { continue; } - if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) { + if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count, + model_->config())) { continue; } float distance = @@ -145,7 +181,8 @@ void NeuralStylusPalmDetectionFilter::FindNearestNeighborsWithin( if (neighbor.tracking_id() == stroke.tracking_id()) { continue; } - if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) { + if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count, + model_->config())) { continue; } float distance = @@ -300,15 +337,29 @@ bool NeuralStylusPalmDetectionFilter::ShouldDecideStroke( const PalmFilterStroke& stroke) const { const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config(); // Inference only executed once per stroke - return stroke.samples_seen() == config.max_sample_count; + if (!config.resample_period) { + return stroke.samples_seen() == config.max_sample_count; + } + return stroke.LastSampleCrossed( + config.GetEquivalentDuration(config.max_sample_count)); } bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke( const PalmFilterStroke& stroke) const { const auto& config = model_->config(); - if (stroke.samples().size() >= config.max_sample_count) { - LOG(DFATAL) << "Should not call this method on long strokes."; - return false; + if (config.resample_period) { + if (stroke.Duration() > + config.GetEquivalentDuration(config.max_sample_count)) { + LOG(DFATAL) + << "Should not call this method on long strokes. Got duration = " + << stroke.Duration(); + return false; + } + } else { + if (stroke.samples().size() >= config.max_sample_count) { + LOG(DFATAL) << "Should not call this method on long strokes."; + return false; + } } if (config.heuristic_palm_touch_limit > 0.0) { @@ -401,6 +452,9 @@ std::vector<float> NeuralStylusPalmDetectionFilter::ExtractFeatures( void NeuralStylusPalmDetectionFilter::AppendFeatures( const PalmFilterStroke& stroke, std::vector<float>* features) const { + if (model_->config().resample_period) { + return AppendResampledFeatures(stroke, features); + } const int size = stroke.samples().size(); for (int i = 0; i < size; ++i) { const PalmFilterSample& sample = stroke.samples()[i]; @@ -435,6 +489,59 @@ void NeuralStylusPalmDetectionFilter::AppendFeatures( features->push_back(samples_seen - model_->config().max_sample_count); } } + +/** + * The flow here is similar to 'AppendFeatures' above, but we rely on the + * timing of the samples rather than on the explicit number / position of + * samples. + */ +void NeuralStylusPalmDetectionFilter::AppendResampledFeatures( + const PalmFilterStroke& stroke, + std::vector<float>* features) const { + size_t sample_count = 0; + const base::TimeTicks& first_time = stroke.samples()[0].time; + const base::TimeDelta& resample_period = *model_->config().resample_period; + const base::TimeDelta max_duration = + model_->config().GetEquivalentDuration(model_->config().max_sample_count); + for (auto time = first_time; (time - first_time) <= max_duration && + time <= stroke.samples().back().time; + time += resample_period) { + sample_count++; + const PalmFilterSample& sample = stroke.GetSampleAt(time); + features->push_back(sample.major_radius); + features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius + : sample.minor_radius); + float distance = 0; + if (time != first_time) { + distance = EuclideanDistance( + stroke.GetSampleAt(time - resample_period).point, sample.point); + } + features->push_back(distance); + features->push_back(sample.edge); + features->push_back(1.0); // existence. + } + const int padding = model_->config().max_sample_count - sample_count; + DCHECK_GE(padding, 0); + + for (int i = 0; i < padding * kFeaturesPerSample; ++i) { + features->push_back(0.0); + } + // "fill proportion." + features->push_back(static_cast<float>(sample_count) / + model_->config().max_sample_count); + features->push_back(EuclideanDistance(stroke.samples().front().point, + stroke.samples().back().point)); + + // Start sequence number. 0 is min. + uint32_t samples_seen = + (stroke.Duration() / (*model_->config().resample_period)) + 1; + if (samples_seen < model_->config().max_sample_count) { + features->push_back(0); + } else { + features->push_back(samples_seen - model_->config().max_sample_count); + } +} + void NeuralStylusPalmDetectionFilter::AppendFeaturesAsNeighbor( const PalmFilterStroke& stroke, float distance, |