aboutsummaryrefslogtreecommitdiff
path: root/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
diff options
context:
space:
mode:
Diffstat (limited to 'libgav1/src/dsp/arm/distance_weighted_blend_neon.cc')
-rw-r--r--libgav1/src/dsp/arm/distance_weighted_blend_neon.cc105
1 files changed, 41 insertions, 64 deletions
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
index 7d287c8..6087276 100644
--- a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -36,44 +36,48 @@ constexpr int kInterPostRoundBit = 4;
namespace low_bitdepth {
namespace {
-inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
+inline uint8x8_t ComputeWeightedAverage8(const int16x8_t pred0,
const int16x8_t pred1,
- const int16x4_t weights[2]) {
- // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
- const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0));
- const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0));
- const int32x4_t blended_lo =
- vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1));
- const int32x4_t blended_hi =
- vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1));
-
- return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4),
- vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4));
+ const int16x8_t weight) {
+ // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
+ // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
+ // 8(=kInterPostRoundBit + 4)
+ // The formula is manipulated to avoid lengthening to 32 bits.
+ // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
+ // = (p0 - p1) * w0 + 16 * p1
+ // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
+ const int16x8_t diff = vsubq_s16(pred0, pred1);
+ // (((p0 - p1) * (w0 << 11) << 1) >> 16) + ((16 * p1) >> 4)
+ const int16x8_t weighted_diff = vqdmulhq_s16(diff, weight);
+ // ((p0 - p1) * w0 >> 4) + p1
+ const int16x8_t upscaled_average = vaddq_s16(weighted_diff, pred1);
+ // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
+ return vqrshrun_n_s16(upscaled_average, kInterPostRoundBit);
}
-template <int width, int height>
+template <int width>
inline void DistanceWeightedBlendSmall_NEON(
const int16_t* LIBGAV1_RESTRICT prediction_0,
- const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2],
- void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
+ const int16_t* LIBGAV1_RESTRICT prediction_1, const int height,
+ const int16x8_t weight, void* LIBGAV1_RESTRICT const dest,
+ const ptrdiff_t dest_stride) {
auto* dst = static_cast<uint8_t*>(dest);
constexpr int step = 16 / width;
- for (int y = 0; y < height; y += step) {
+ int y = height;
+ do {
const int16x8_t src_00 = vld1q_s16(prediction_0);
const int16x8_t src_10 = vld1q_s16(prediction_1);
prediction_0 += 8;
prediction_1 += 8;
- const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+ const uint8x8_t result0 = ComputeWeightedAverage8(src_00, src_10, weight);
const int16x8_t src_01 = vld1q_s16(prediction_0);
const int16x8_t src_11 = vld1q_s16(prediction_1);
prediction_0 += 8;
prediction_1 += 8;
- const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
+ const uint8x8_t result1 = ComputeWeightedAverage8(src_01, src_11, weight);
- const uint8x8_t result0 = vqmovun_s16(res0);
- const uint8x8_t result1 = vqmovun_s16(res1);
if (width == 4) {
StoreLo4(dst, result0);
dst += dest_stride;
@@ -90,12 +94,13 @@ inline void DistanceWeightedBlendSmall_NEON(
vst1_u8(dst, result1);
dst += dest_stride;
}
- }
+ y -= step;
+ } while (y != 0);
}
inline void DistanceWeightedBlendLarge_NEON(
const int16_t* LIBGAV1_RESTRICT prediction_0,
- const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2],
+ const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x8_t weight,
const int width, const int height, void* LIBGAV1_RESTRICT const dest,
const ptrdiff_t dest_stride) {
auto* dst = static_cast<uint8_t*>(dest);
@@ -106,16 +111,15 @@ inline void DistanceWeightedBlendLarge_NEON(
do {
const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
- const int16x8_t res_lo =
- ComputeWeightedAverage8(src0_lo, src1_lo, weights);
+ const uint8x8_t res_lo =
+ ComputeWeightedAverage8(src0_lo, src1_lo, weight);
const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
- const int16x8_t res_hi =
- ComputeWeightedAverage8(src0_hi, src1_hi, weights);
+ const uint8x8_t res_hi =
+ ComputeWeightedAverage8(src0_hi, src1_hi, weight);
- const uint8x16_t result =
- vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi));
+ const uint8x16_t result = vcombine_u8(res_lo, res_hi);
vst1q_u8(dst + x, result);
x += 16;
} while (x < width);
@@ -128,52 +132,25 @@ inline void DistanceWeightedBlendLarge_NEON(
inline void DistanceWeightedBlend_NEON(
const void* LIBGAV1_RESTRICT prediction_0,
const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0,
- const uint8_t weight_1, const int width, const int height,
+ const uint8_t /*weight_1*/, const int width, const int height,
void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
- int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
- // TODO(johannkoenig): Investigate the branching. May be fine to call with a
- // variable height.
+ // Upscale the weight for vqdmulh.
+ const int16x8_t weight = vdupq_n_s16(weight_0 << 11);
if (width == 4) {
- if (height == 4) {
- DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest,
- dest_stride);
- } else if (height == 8) {
- DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest,
- dest_stride);
- } else {
- assert(height == 16);
- DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest,
- dest_stride);
- }
+ DistanceWeightedBlendSmall_NEON<4>(pred_0, pred_1, height, weight, dest,
+ dest_stride);
return;
}
if (width == 8) {
- switch (height) {
- case 4:
- DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest,
- dest_stride);
- return;
- case 8:
- DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest,
- dest_stride);
- return;
- case 16:
- DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest,
- dest_stride);
- return;
- default:
- assert(height == 32);
- DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest,
- dest_stride);
-
- return;
- }
+ DistanceWeightedBlendSmall_NEON<8>(pred_0, pred_1, height, weight, dest,
+ dest_stride);
+ return;
}
- DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest,
+ DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weight, width, height, dest,
dest_stride);
}