aboutsummaryrefslogtreecommitdiff
path: root/networking/arm/chksum_simd.c
blob: 7f69adfc963c375221bf1d661f2b6f37e5fc56c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/*
 * Armv7-A specific checksum implementation using NEON
 *
 * Copyright (c) 2020, Arm Limited.
 * SPDX-License-Identifier: MIT
 */

#include "networking.h"
#include "../chksum_common.h"

#ifndef __ARM_NEON
#pragma GCC target("+simd")
#endif

#include <arm_neon.h>

unsigned short
__chksum_arm_simd(const void *ptr, unsigned int nbytes)
{
    bool swap = (uintptr_t) ptr & 1;
    uint64x1_t vsum = { 0 };

    if (unlikely(nbytes < 40))
    {
	uint64_t sum = slurp_small(ptr, nbytes);
	return fold_and_swap(sum, false);
    }

    /* 8-byte align pointer */
    /* Inline slurp_head-like code since we use NEON here */
    Assert(nbytes >= 8);
    uint32_t off = (uintptr_t) ptr & 7;
    if (likely(off != 0))
    {
	const uint64_t *may_alias ptr64 = align_ptr(ptr, 8);
	uint64x1_t vword64 = vld1_u64(ptr64);
	/* Get rid of bytes 0..off-1 */
	uint64x1_t vmask = vdup_n_u64(ALL_ONES);
	int64x1_t vshiftl = vdup_n_s64(CHAR_BIT * off);
	vmask = vshl_u64(vmask, vshiftl);
	vword64 = vand_u64(vword64, vmask);
	uint32x2_t vtmp = vreinterpret_u32_u64(vword64);
	/* Set accumulator */
	vsum = vpaddl_u32(vtmp);
	/* Update pointer and remaining size */
	ptr = (char *) ptr64 + 8;
	nbytes -= 8 - off;
    }
    Assert(((uintptr_t) ptr & 7) == 0);

    /* Sum groups of 64 bytes */
    uint64x2_t vsum0 = { 0, 0 };
    uint64x2_t vsum1 = { 0, 0 };
    uint64x2_t vsum2 = { 0, 0 };
    uint64x2_t vsum3 = { 0, 0 };
    const uint32_t *may_alias ptr32 = ptr;
    for (uint32_t i = 0; i < nbytes / 64; i++)
    {
	uint32x4_t vtmp0 = vld1q_u32(ptr32);
	uint32x4_t vtmp1 = vld1q_u32(ptr32 + 4);
	uint32x4_t vtmp2 = vld1q_u32(ptr32 + 8);
	uint32x4_t vtmp3 = vld1q_u32(ptr32 + 12);
	vsum0 = vpadalq_u32(vsum0, vtmp0);
	vsum1 = vpadalq_u32(vsum1, vtmp1);
	vsum2 = vpadalq_u32(vsum2, vtmp2);
	vsum3 = vpadalq_u32(vsum3, vtmp3);
	ptr32 += 16;
    }
    nbytes %= 64;

    /* Fold vsum1/vsum2/vsum3 into vsum0 */
    vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum2));
    vsum1 = vpadalq_u32(vsum1, vreinterpretq_u32_u64(vsum3));
    vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum1));

    /* Add any trailing 16-byte groups */
    while (likely(nbytes >= 16))
    {
	uint32x4_t vtmp0 = vld1q_u32(ptr32);
	vsum0 = vpadalq_u32(vsum0, vtmp0);
	ptr32 += 4;
	nbytes -= 16;
    }
    Assert(nbytes < 16);

    /* Fold vsum0 into vsum */
    {
	/* 4xu32 (4x32b) -> 2xu64 (2x33b) */
	vsum0 = vpaddlq_u32(vreinterpretq_u32_u64(vsum0));
	/* 4xu32 (2x(1b+32b)) -> 2xu64 (2x(0b+32b)) */
	vsum0 = vpaddlq_u32(vreinterpretq_u32_u64(vsum0));
	/* 4xu32 (4x32b) -> 2xu64 (2x33b) */
	Assert((vgetq_lane_u64(vsum0, 0) >> 32) == 0);
	Assert((vgetq_lane_u64(vsum0, 1) >> 32) == 0);
	uint32x2_t vtmp = vmovn_u64(vsum0);
	/* Add to accumulator */
	vsum = vpadal_u32(vsum, vtmp);
    }

    /* Add any trailing group of 8 bytes */
    if (nbytes & 8)
    {
	uint32x2_t vtmp = vld1_u32(ptr32);
	/* Add to accumulator */
	vsum = vpadal_u32(vsum, vtmp);
	ptr32 += 2;
	nbytes -= 8;
    }
    Assert(nbytes < 8);

    /* Handle any trailing 1..7 bytes */
    if (likely(nbytes != 0))
    {
	Assert(((uintptr_t) ptr32 & 7) == 0);
	Assert(nbytes < 8);
	uint64x1_t vword64 = vld1_u64((const uint64_t *) ptr32);
	/* Get rid of bytes 7..nbytes */
	uint64x1_t vmask = vdup_n_u64(ALL_ONES);
	int64x1_t vshiftr = vdup_n_s64(-CHAR_BIT * (8 - nbytes));
	vmask = vshl_u64(vmask, vshiftr);/* Shift right */
	vword64 = vand_u64(vword64, vmask);
	/* Fold 64-bit sum to 33 bits */
	vword64 = vpaddl_u32(vreinterpret_u32_u64(vword64));
	/* Add to accumulator */
	vsum = vpadal_u32(vsum, vreinterpret_u32_u64(vword64));
    }

    /* Fold 64-bit vsum to 32 bits */
    vsum = vpaddl_u32(vreinterpret_u32_u64(vsum));
    vsum = vpaddl_u32(vreinterpret_u32_u64(vsum));
    Assert(vget_lane_u32(vreinterpret_u32_u64(vsum), 1) == 0);

    /* Fold 32-bit vsum to 16 bits */
    uint32x2_t vsum32 = vreinterpret_u32_u64(vsum);
    vsum32 = vpaddl_u16(vreinterpret_u16_u32(vsum32));
    vsum32 = vpaddl_u16(vreinterpret_u16_u32(vsum32));
    Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32), 1) == 0);
    Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32), 2) == 0);
    Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32), 3) == 0);

    /* Convert to 16-bit scalar */
    uint16_t sum = vget_lane_u16(vreinterpret_u16_u32(vsum32), 0);

    if (unlikely(swap))/* Odd base pointer is unexpected */
    {
	sum = bswap16(sum);
    }
    return sum;
}