aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2020-02-19 03:10:42 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2020-02-19 03:10:42 +0000
commit8b23c3bfd1d8d5cfb431576ad1fb798c44df4d64 (patch)
treed1ca88d7bef60a4ac8e70481861ad0e024555f9f
parent33e4796a20c318d3a7e9be15650a34a7d0d85539 (diff)
parent5fa0858b57c36722f0ab2606c606e927b5a40448 (diff)
downloadXNNPACK-android11-d1-s1-release.tar.gz
Change-Id: Ibf363d121c5ecd3b06809b46550e5e09ffedd166
-rw-r--r--Android.bp3
-rw-r--r--BUILD.bazel91
-rw-r--r--CMakeLists.txt5
-rw-r--r--METADATA17
-rw-r--r--README.md6
-rw-r--r--bench/f32-gemm-e2e.cc6
-rw-r--r--bench/f32-igemm.cc20
-rw-r--r--build_defs.bzl33
-rw-r--r--include/xnnpack.h74
-rwxr-xr-xscripts/generate-f32-igemm.sh4
-rw-r--r--src/f32-dwconv/gen/up16x25-avx512f-acc2.c52
-rw-r--r--src/f32-dwconv/gen/up16x25-avx512f.c52
-rw-r--r--src/f32-dwconv/gen/up16x4-avx512f-acc2.c10
-rw-r--r--src/f32-dwconv/gen/up16x4-avx512f.c10
-rw-r--r--src/f32-dwconv/gen/up16x9-avx512f-acc2.c20
-rw-r--r--src/f32-dwconv/gen/up16x9-avx512f.c20
-rw-r--r--src/f32-dwconv/gen/up32x25-avx512f-acc2.c52
-rw-r--r--src/f32-dwconv/gen/up32x25-avx512f.c52
-rw-r--r--src/f32-dwconv/gen/up32x4-avx512f-acc2.c10
-rw-r--r--src/f32-dwconv/gen/up32x4-avx512f.c10
-rw-r--r--src/f32-dwconv/gen/up32x9-avx512f-acc2.c20
-rw-r--r--src/f32-dwconv/gen/up32x9-avx512f.c20
-rw-r--r--src/f32-dwconv/up-avx512.c.in4
-rw-r--r--src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in15
-rw-r--r--src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in13
-rw-r--r--src/f32-gemm/4x8-aarch32-neon-ld64.S3
-rw-r--r--src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in18
-rw-r--r--src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in101
-rw-r--r--src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in40
-rw-r--r--src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in18
-rw-r--r--src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in45
-rw-r--r--src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in4
-rw-r--r--src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S15
-rw-r--r--src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S13
-rw-r--r--src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S18
-rw-r--r--src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S89
-rw-r--r--src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S40
-rw-r--r--src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S18
-rw-r--r--src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S45
-rw-r--r--src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S4
-rw-r--r--src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S15
-rw-r--r--src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S13
-rw-r--r--src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S18
-rw-r--r--src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S89
-rw-r--r--src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S40
-rw-r--r--src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S18
-rw-r--r--src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S45
-rw-r--r--src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S4
-rw-r--r--src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S26
-rw-r--r--src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S29
-rw-r--r--src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S30
-rw-r--r--src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in393
-rw-r--r--src/f32-igemm/4x8-aarch32-neon-ld64.S248
-rw-r--r--src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S26
-rw-r--r--src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S369
-rw-r--r--src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S389
-rw-r--r--src/init.c6
-rw-r--r--src/runtime.c168
-rw-r--r--src/subgraph.c213
-rw-r--r--src/xnnpack/igemm.h4
-rw-r--r--src/xnnpack/subgraph.h5
-rw-r--r--test/f32-igemm.cc1466
-rw-r--r--test/f32-igemm.yaml9
-rw-r--r--third_party/cpuinfo.BUILD99
64 files changed, 4202 insertions, 610 deletions
diff --git a/Android.bp b/Android.bp
index 61f5fe67c..1343ca93a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -1244,6 +1244,9 @@ AARCH32_ASM_UKERNELS = [
"src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S",
"src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S",
"src/f32-gemm/4x8-aarch32-neon-ld64.S",
+ "src/f32-igemm/4x8-aarch32-neon-ld64.S",
+ "src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S",
+ "src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S",
]
AARCH64_ASM_UKERNELS = [
diff --git a/BUILD.bazel b/BUILD.bazel
index 095b25967..a6e6c1f48 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1373,6 +1373,9 @@ AARCH32_ASM_UKERNELS = [
"src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S",
"src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S",
"src/f32-gemm/4x8-aarch32-neon-ld64.S",
+ "src/f32-igemm/4x8-aarch32-neon-ld64.S",
+ "src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S",
+ "src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S",
]
AARCH64_ASM_UKERNELS = [
@@ -3244,3 +3247,91 @@ config_setting(
"cpu": "asmjs",
},
)
+
+config_setting(
+ name = "ios_armv7",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_armv7",
+ },
+)
+
+config_setting(
+ name = "ios_arm64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_arm64",
+ },
+)
+
+config_setting(
+ name = "ios_arm64e",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_arm64e",
+ },
+)
+
+config_setting(
+ name = "ios_x86",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_i386",
+ },
+)
+
+config_setting(
+ name = "ios_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_x86_64",
+ },
+)
+
+config_setting(
+ name = "watchos_armv7k",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_armv7k",
+ },
+)
+
+config_setting(
+ name = "watchos_arm64_32",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_arm64_32",
+ },
+)
+
+config_setting(
+ name = "watchos_x86",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_i386",
+ },
+)
+
+config_setting(
+ name = "watchos_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_x86_64",
+ },
+)
+
+config_setting(
+ name = "tvos_arm64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "tvos_arm64",
+ },
+)
+
+config_setting(
+ name = "tvos_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "tvos_x86_64",
+ },
+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 46ec94f9e..cd5d59e63 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1357,7 +1357,10 @@ SET(XNNPACK_AARCH32_ASM_MICROKERNEL_SRCS
src/f32-gemm/4x8-aarch32-neon-cortex-a53.S
src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S
src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S
- src/f32-gemm/4x8-aarch32-neon-ld64.S)
+ src/f32-gemm/4x8-aarch32-neon-ld64.S
+ src/f32-igemm/4x8-aarch32-neon-ld64.S
+ src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S
+ src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S)
SET(XNNPACK_AARCH64_ASM_MICROKERNEL_SRCS
src/f32-dwconv/up4x9-aarch64-neonfma-cortex-a55.S
diff --git a/METADATA b/METADATA
index e8e914bda..a515390e0 100644
--- a/METADATA
+++ b/METADATA
@@ -1,12 +1,5 @@
name: "XNNPACK"
-description:
- "XNNPACK is a highly optimized library of floating-point neural network "
- "inference operators for ARM, WebAssembly, and x86 (SSE2 level) platforms. "
- "XNNPACK is not intended for direct use by deep learning practitioners and "
- "researchers; instead it provides low-level performance primitives for "
- "accelerating high-level machine learning frameworks, such as MediaPipe, "
- "TensorFlow Lite, and TensorFlow.js."
-
+description: "XNNPACK is a highly optimized library of floating-point neural network inference operators for ARM, WebAssembly, and x86 (SSE2 level) platforms. XNNPACK is not intended for direct use by deep learning practitioners and researchers; instead it provides low-level performance primitives for accelerating high-level machine learning frameworks, such as MediaPipe, TensorFlow Lite, and TensorFlow.js."
third_party {
url {
type: HOMEPAGE
@@ -16,7 +9,11 @@ third_party {
type: GIT
value: "https://github.com/google/XNNPACK"
}
- version: "98ca635f1b6d84c81539ccafe721650fb174da67"
- last_upgrade_date { year: 2020 month: 2 day: 3 }
+ version: "1498d1d4d0430480dfe5c4538049b4f789d29134"
license_type: NOTICE
+ last_upgrade_date {
+ year: 2020
+ month: 2
+ day: 11
+ }
}
diff --git a/README.md b/README.md
index eba09d9bb..b52194565 100644
--- a/README.md
+++ b/README.md
@@ -4,11 +4,11 @@ XNNPACK is a highly optimized library of floating-point neural network inference
## Supported Architectures
-- ARM64 on Android and Linux
-- ARMv7 (with NEON) on Android and Linux
+- ARM64 on Android, Linux, and iOS (including WatchOS and tvOS)
+- ARMv7 (with NEON) on Android, Linux, and iOS (including WatchOS)
- WebAssembly MVP
- WebAssembly SIMD (experimental)
-- x86 and x86-64 (up to AVX512) on Android, Linux, and macOS
+- x86 and x86-64 (up to AVX512) on Android, Linux, macOS, and iOS simulator
## Operator Coverage
diff --git a/bench/f32-gemm-e2e.cc b/bench/f32-gemm-e2e.cc
index c83da0b0e..9661eb13c 100644
--- a/bench/f32-gemm-e2e.cc
+++ b/bench/f32-gemm-e2e.cc
@@ -258,7 +258,7 @@ static void GEMMEnd2EndBenchmark(
static void f32_gemm_4x8__aarch32_neon_ld64(benchmark::State& state, models::ExecutionPlanFactory model) {
GEMMEnd2EndBenchmark(state, model,
xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64,
- xnn_f32_igemm_ukernel_4x8__neon_lane_ld64,
+ xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64,
xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */,
@@ -276,7 +276,7 @@ static void GEMMEnd2EndBenchmark(
static void f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, models::ExecutionPlanFactory model) {
GEMMEnd2EndBenchmark(state, model,
xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75,
- xnn_f32_igemm_ukernel_4x8__neon_lane_ld64,
+ xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75,
xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */,
@@ -285,7 +285,7 @@ static void GEMMEnd2EndBenchmark(
static void f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, models::ExecutionPlanFactory model) {
GEMMEnd2EndBenchmark(state, model,
xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
- xnn_f32_igemm_ukernel_4x8__neon_lane_ld64,
+ xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */,
diff --git a/bench/f32-igemm.cc b/bench/f32-igemm.cc
index ca2cf162b..064506538 100644
--- a/bench/f32-igemm.cc
+++ b/bench/f32-igemm.cc
@@ -258,6 +258,14 @@ static void IGEMMBenchmark(benchmark::State& state,
BENCHMARK_CONV(f32_igemm_8x8s4__neonfma)
#endif
+#if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY
+ static void f32_igemm_4x8__aarch32_neon_ld64(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64, 4, 8, 1, 1);
+ }
+
+ BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_ld64)
+#endif /* XNN_ARCH_ARM */
+
#if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
static void f32_igemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53, 1, 12, 1, 1);
@@ -366,6 +374,18 @@ static void IGEMMBenchmark(benchmark::State& state,
BENCHMARK_CONV(f32_igemm_6x8__neonfma_lane_ld128)
#endif /* XNN_ARCH_ARM64 */
+#if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY
+ static void f32_igemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, 4, 8, 1, 1);
+ }
+ static void f32_igemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, const char* net) {
+ IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75, 4, 8, 1, 1);
+ }
+
+ BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_pld_cortex_a75)
+ BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_cortex_a75)
+#endif /* XNN_ARCH_ARM */
+
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
static void f32_igemm_1x8__sse_load1(benchmark::State& state, const char* net) {
IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__sse_load1, 1, 8, 1, 1);
diff --git a/build_defs.bzl b/build_defs.bzl
index db1d303ea..9112ab27b 100644
--- a/build_defs.bzl
+++ b/build_defs.bzl
@@ -113,6 +113,17 @@ def xnnpack_cc_library(
":android_arm64": aarch64_srcs,
":android_x86": x86_srcs,
":android_x86_64": x86_srcs,
+ ":ios_armv7": aarch32_srcs,
+ ":ios_arm64": aarch64_srcs,
+ ":ios_arm64e": aarch64_srcs,
+ ":ios_x86": x86_srcs,
+ ":ios_x86_64": x86_srcs,
+ ":watchos_armv7k": aarch32_srcs,
+ ":watchos_arm64_32": aarch64_srcs,
+ ":watchos_x86": x86_srcs,
+ ":watchos_x86_64": x86_srcs,
+ ":tvos_arm64": aarch64_srcs,
+ ":tvos_x86_64": x86_srcs,
":emscripten_asmjs": asmjs_srcs,
":emscripten_wasm": wasm_srcs,
":emscripten_wasmsimd": wasmsimd_srcs,
@@ -129,6 +140,17 @@ def xnnpack_cc_library(
":android_arm64": aarch64_copts,
":android_x86": x86_copts,
":android_x86_64": x86_copts,
+ ":ios_armv7": aarch32_copts,
+ ":ios_arm64": aarch64_copts,
+ ":ios_arm64e": aarch64_copts,
+ ":ios_x86": x86_copts,
+ ":ios_x86_64": x86_copts,
+ ":watchos_armv7k": aarch32_copts,
+ ":watchos_arm64_32": aarch64_copts,
+ ":watchos_x86": x86_copts,
+ ":watchos_x86_64": x86_copts,
+ ":tvos_arm64": aarch64_copts,
+ ":tvos_x86_64": x86_copts,
":emscripten_asmjs": asmjs_copts,
":emscripten_wasm": wasm_copts,
":emscripten_wasmsimd": wasmsimd_copts,
@@ -180,6 +202,17 @@ def xnnpack_aggregate_library(
":android_arm64": aarch64_deps,
":android_x86": x86_deps,
":android_x86_64": x86_deps,
+ ":ios_armv7": aarch32_deps,
+ ":ios_arm64": aarch64_deps,
+ ":ios_arm64e": aarch64_deps,
+ ":ios_x86": x86_deps,
+ ":ios_x86_64": x86_deps,
+ ":watchos_armv7k": aarch32_deps,
+ ":watchos_arm64_32": aarch64_deps,
+ ":watchos_x86": x86_deps,
+ ":watchos_x86_64": x86_deps,
+ ":tvos_arm64": aarch64_deps,
+ ":tvos_x86_64": x86_deps,
":emscripten_wasm": wasm_deps,
":emscripten_wasmsimd": wasmsimd_deps,
":emscripten_asmjs": [],
diff --git a/include/xnnpack.h b/include/xnnpack.h
index ea171f060..b7bef08ce 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -349,6 +349,80 @@ enum xnn_status xnn_define_multiply2(
uint32_t output_id,
uint32_t flags);
+/// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph.
+///
+/// @param subgraph - a Subgraph object that will own the created Node.
+/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
+/// with [N, H, W, channels] dimensions
+/// @param slope_id - Value ID for the bias tensor. The bias tensor must be a 1D tensor defined in the @a subgraph with
+/// [channels] dimensions.
+/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
+/// with [N, H, W, channels] dimensions.
+/// @param flags - binary features of the PReLU Node. No supported flags are currently defined.
+enum xnn_status xnn_define_prelu(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t slope_id,
+ uint32_t output_id,
+ uint32_t flags);
+
+/// Define a Clamp Node and add it to a Subgraph.
+///
+/// @param subgraph - a Subgraph object that will own the created Node.
+/// @param output_min - lower bound for clipping output values.
+/// @param output_max - upper bound for clipping output values.
+/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
+/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
+/// shape must match the shape of the input tensor.
+/// @param flags - binary features of the Clamp Node. No supported flags are currently defined.
+enum xnn_status xnn_define_clamp(
+ xnn_subgraph_t subgraph,
+ float output_min,
+ float output_max,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags);
+
+/// Define a HardSwish Node and add it to a Subgraph.
+///
+/// @param subgraph - a Subgraph object that will own the created Node.
+/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
+/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
+/// shape must match the shape of the input tensor.
+/// @param flags - binary features of the HardSwish Node. No supported flags are currently defined.
+enum xnn_status xnn_define_hardswish(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags);
+
+/// Define a Sigmoid Node and add it to a Subgraph.
+///
+/// @param subgraph - a Subgraph object that will own the created Node.
+/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
+/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
+/// shape must match the shape of the input tensor.
+/// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined.
+enum xnn_status xnn_define_sigmoid(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags);
+
+/// Define a SoftMax Node and add it to a Subgraph.
+///
+/// @param subgraph - a Subgraph object that will own the created Node.
+/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at
+/// least one dimension.
+/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
+/// shape must match the shape of the input tensor.
+/// @param flags - binary features of the SoftMax Node. No supported flags are currently defined.
+enum xnn_status xnn_define_softmax(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags);
+
/// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values.
typedef struct xnn_runtime* xnn_runtime_t;
diff --git a/scripts/generate-f32-igemm.sh b/scripts/generate-f32-igemm.sh
index 763e34fe4..ceed44310 100755
--- a/scripts/generate-f32-igemm.sh
+++ b/scripts/generate-f32-igemm.sh
@@ -28,6 +28,10 @@ tools/xngen src/f32-igemm/5x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFE
tools/xngen src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFETCH=0 -o src/f32-igemm/gen/6x8-aarch64-neonfma-cortex-a57.S
tools/xngen src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFETCH=1 -o src/f32-igemm/gen/6x8-aarch64-neonfma-cortex-a75.S
+############################### AArch32 assembly ##############################
+tools/xngen src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in -D INC=0 -D PREFETCH=0 -o src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S
+tools/xngen src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in -D INC=0 -D PREFETCH=1 -o src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S
+
################################### ARM NEON ##################################
### LD64 micro-kernels
tools/xngen src/f32-igemm/neon-ld64.c.in -D MR=1 -D NR=8 -D FMA=0 -D DUP=0 -o src/f32-igemm/gen/1x8-neon-lane-ld64.c
diff --git a/src/f32-dwconv/gen/up16x25-avx512f-acc2.c b/src/f32-dwconv/gen/up16x25-avx512f-acc2.c
index 72d39a8f2..1d0d4b696 100644
--- a/src/f32-dwconv/gen/up16x25-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up16x25-avx512f-acc2.c
@@ -256,106 +256,106 @@ void xnn_f32_dwconv_ukernel_up16x25__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
- const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
- const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176);
+ const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
- const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
- const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208);
+ const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
- const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
- const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240);
+ const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
- const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
- const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272);
+ const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
- const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
- const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304);
+ const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
- const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320);
+ const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
- const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336);
+ const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
- const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352);
+ const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
- const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368);
+ const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
- const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384);
+ const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
- const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400);
+ const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up16x25-avx512f.c b/src/f32-dwconv/gen/up16x25-avx512f.c
index 2c7b8d452..140e387e2 100644
--- a/src/f32-dwconv/gen/up16x25-avx512f.c
+++ b/src/f32-dwconv/gen/up16x25-avx512f.c
@@ -254,106 +254,106 @@ void xnn_f32_dwconv_ukernel_up16x25__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
- const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
- const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176);
+ const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
- const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
- const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208);
+ const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
- const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
- const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240);
+ const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
- const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
- const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272);
+ const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
- const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
- const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304);
+ const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
- const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320);
+ const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
- const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336);
+ const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
- const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352);
+ const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
- const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368);
+ const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
- const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384);
+ const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
- const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400);
+ const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/gen/up16x4-avx512f-acc2.c b/src/f32-dwconv/gen/up16x4-avx512f-acc2.c
index 36f05a2e3..c9e6586c2 100644
--- a/src/f32-dwconv/gen/up16x4-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up16x4-avx512f-acc2.c
@@ -88,22 +88,22 @@ void xnn_f32_dwconv_ukernel_up16x4__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up16x4-avx512f.c b/src/f32-dwconv/gen/up16x4-avx512f.c
index 79a5c498a..0a625551c 100644
--- a/src/f32-dwconv/gen/up16x4-avx512f.c
+++ b/src/f32-dwconv/gen/up16x4-avx512f.c
@@ -86,22 +86,22 @@ void xnn_f32_dwconv_ukernel_up16x4__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/gen/up16x9-avx512f-acc2.c b/src/f32-dwconv/gen/up16x9-avx512f-acc2.c
index 63848915a..17587f5d0 100644
--- a/src/f32-dwconv/gen/up16x9-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up16x9-avx512f-acc2.c
@@ -128,42 +128,42 @@ void xnn_f32_dwconv_ukernel_up16x9__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up16x9-avx512f.c b/src/f32-dwconv/gen/up16x9-avx512f.c
index eaab45c51..4a7606ffa 100644
--- a/src/f32-dwconv/gen/up16x9-avx512f.c
+++ b/src/f32-dwconv/gen/up16x9-avx512f.c
@@ -126,42 +126,42 @@ void xnn_f32_dwconv_ukernel_up16x9__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/gen/up32x25-avx512f-acc2.c b/src/f32-dwconv/gen/up32x25-avx512f-acc2.c
index 9d5d8271f..a666d1e84 100644
--- a/src/f32-dwconv/gen/up32x25-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up32x25-avx512f-acc2.c
@@ -500,106 +500,106 @@ void xnn_f32_dwconv_ukernel_up32x25__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
- const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 320);
+ const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
- const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 352);
+ const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
- const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 384);
+ const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
- const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 416);
+ const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 416);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
- const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 448);
+ const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 448);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
- const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 480);
+ const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 480);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
- const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 512);
+ const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 512);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
- const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 544);
+ const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 544);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
- const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 576);
+ const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 576);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
- const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 608);
+ const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 608);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
- const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 640);
+ const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 640);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
- const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 672);
+ const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 672);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
- const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 704);
+ const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 704);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
- const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 736);
+ const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 736);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
- const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 768);
+ const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 768);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
- const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 800);
+ const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 800);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up32x25-avx512f.c b/src/f32-dwconv/gen/up32x25-avx512f.c
index 64ad2b6cb..5b896a31b 100644
--- a/src/f32-dwconv/gen/up32x25-avx512f.c
+++ b/src/f32-dwconv/gen/up32x25-avx512f.c
@@ -495,106 +495,106 @@ void xnn_f32_dwconv_ukernel_up32x25__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
- const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 320);
+ const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
- const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 352);
+ const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
- const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 384);
+ const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
- const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 416);
+ const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 416);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
- const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 448);
+ const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 448);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
- const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 480);
+ const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 480);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
- const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 512);
+ const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 512);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
- const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 544);
+ const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 544);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
- const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 576);
+ const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 576);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
- const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 608);
+ const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 608);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
- const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 640);
+ const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 640);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
- const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 672);
+ const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 672);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
- const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 704);
+ const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 704);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
- const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 736);
+ const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 736);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
- const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 768);
+ const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 768);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
- const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 800);
+ const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 800);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/gen/up32x4-avx512f-acc2.c b/src/f32-dwconv/gen/up32x4-avx512f-acc2.c
index 67bad500d..d8aaa0862 100644
--- a/src/f32-dwconv/gen/up32x4-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up32x4-avx512f-acc2.c
@@ -143,22 +143,22 @@ void xnn_f32_dwconv_ukernel_up32x4__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up32x4-avx512f.c b/src/f32-dwconv/gen/up32x4-avx512f.c
index 126e982a5..dc5b4a88a 100644
--- a/src/f32-dwconv/gen/up32x4-avx512f.c
+++ b/src/f32-dwconv/gen/up32x4-avx512f.c
@@ -138,22 +138,22 @@ void xnn_f32_dwconv_ukernel_up32x4__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/gen/up32x9-avx512f-acc2.c b/src/f32-dwconv/gen/up32x9-avx512f-acc2.c
index 9978c55f8..cec1e0790 100644
--- a/src/f32-dwconv/gen/up32x9-avx512f-acc2.c
+++ b/src/f32-dwconv/gen/up32x9-avx512f-acc2.c
@@ -228,42 +228,42 @@ void xnn_f32_dwconv_ukernel_up32x9__avx512f_acc2(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
__m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
// Add up all accumulators to vacc0123456789ABCDEFp0
diff --git a/src/f32-dwconv/gen/up32x9-avx512f.c b/src/f32-dwconv/gen/up32x9-avx512f.c
index 62fec2eaa..3cb060086 100644
--- a/src/f32-dwconv/gen/up32x9-avx512f.c
+++ b/src/f32-dwconv/gen/up32x9-avx512f.c
@@ -223,42 +223,42 @@ void xnn_f32_dwconv_ukernel_up32x9__avx512f(
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
+ __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
- const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32);
+ const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
- const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64);
+ const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
- const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96);
+ const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
- const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128);
+ const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
- const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160);
+ const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
- const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192);
+ const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
- const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224);
+ const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
- const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256);
+ const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
- const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288);
+ const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
diff --git a/src/f32-dwconv/up-avx512.c.in b/src/f32-dwconv/up-avx512.c.in
index 4d5a8db74..e9cff30bb 100644
--- a/src/f32-dwconv/up-avx512.c.in
+++ b/src/f32-dwconv/up-avx512.c.in
@@ -117,11 +117,11 @@ void xnn_f32_dwconv_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx512f${"" if ACC
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
- __m512 vacc${ABC[0:16]}p0 = _mm512_load_ps(w);
+ __m512 vacc${ABC[0:16]}p0 = _mm512_maskz_loadu_ps(vmask, w);
$for K in range(KERNEL_TILE):
const __m512 vi${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, i${K});
- const __m512 vk${K}x${ABC[0:16]} = _mm512_load_ps(w + ${(K + 1) * CHANNEL_TILE});
+ const __m512 vk${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, w + ${(K + 1) * CHANNEL_TILE});
$if 1 <= K < ACCUMULATORS:
__m512 vacc${ABC[0:16]}p${K} = _mm512_mul_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
$else:
diff --git a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
index 4ff2dc336..8880dceb3 100644
--- a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
+++ b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
@@ -67,7 +67,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - loads for first group of 6 fma
@@ -261,6 +261,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma
# BLOCK 4
INS v19.d[1], x9
FMLA v20.4s, v17.4s, v1.s[1]
+ TST x0, 15
# BLOCK 5
FMLA v21.4s, v18.4s, v1.s[1]
@@ -269,11 +270,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma
FMLA v22.4s, v19.4s, v1.s[1]
# BLOCK 7
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -290,12 +288,13 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma
ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
SUB x3, x3, x2 // a0 -= kc
-
B.HI 0b
-
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 1 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
index a20699867..d0bb021ed 100644
--- a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
+++ b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
@@ -144,7 +144,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
SUBS x0, x0, 16
@@ -416,6 +416,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma
FMLA v27.4s, v18.4s, v3.s[1]
FMLA v30.4s, v18.4s, v3.s[3]
FMLA v22.4s, v19.4s, v2.s[1]
+ TST x0, 15
# BLOCK 7
FMLA v25.4s, v19.4s, v2.s[3]
@@ -423,11 +424,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma
ADD x5, x5, 96
FMLA v31.4s, v19.4s, v3.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -488,6 +486,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 4 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/4x8-aarch32-neon-ld64.S b/src/f32-gemm/4x8-aarch32-neon-ld64.S
index cc711f9ae..72c2d5c8e 100644
--- a/src/f32-gemm/4x8-aarch32-neon-ld64.S
+++ b/src/f32-gemm/4x8-aarch32-neon-ld64.S
@@ -86,7 +86,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64
VMOV q13, q9
VMOV q14, q8
VMOV q15, q9
- BLO 3f // less than 2 channels?
+ BLO 8f // less than 2 channels?
// Main loop - 2 floats of A (8 bytes)
2:
@@ -116,7 +116,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64
VMLA.F32 q15, q7, d3[1]
BHS 2b
-3:
// Is there a remainder?- 1 floats of A (4 bytes)
TST r5, 4
BNE 8f
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in
index 37860d04a..2c5108b4c 100644
--- a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in
+++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in
@@ -142,7 +142,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -346,6 +346,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
FMLA v22.4s, v14.4s, v3.s[3]
FMLA v24.4s, v14.4s, v4.s[1]
FMLA v26.4s, v14.4s, v4.s[3]
+ TST x0, 15
// BLOCK 4
FMLA v21.4s, v15.4s, v3.s[1]
@@ -356,11 +357,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
// BLOCK 5
FMLA v27.4s, v15.4s, v4.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
+
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -410,8 +409,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -441,7 +443,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
index dba8f7d47..aab2dbdaf 100644
--- a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
@@ -88,7 +88,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
1:
@@ -135,52 +135,10 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
FMLA v31.4s, v27.4s, v3.s[3]
B.HS 1b
- # Remainder- 2 floats of A (8 bytes)
-2:
- TBZ x0, 3, 3f
+ TST x0, 15
+ B.NE 5f
- LDR d0, [x3], 8
- LDP q20, q21, [x5], 32
- LDR d1, [x11], 8
- LDR d2, [x12], 8
- LDR d3, [x4], 8
- FMLA v16.4s, v20.4s, v0.s[0]
- FMLA v17.4s, v21.4s, v0.s[0]
- FMLA v18.4s, v20.4s, v1.s[0]
- FMLA v19.4s, v21.4s, v1.s[0]
- LDP q22, q23, [x5], 32
- FMLA v28.4s, v20.4s, v2.s[0]
- FMLA v29.4s, v21.4s, v2.s[0]
- FMLA v30.4s, v20.4s, v3.s[0]
- FMLA v31.4s, v21.4s, v3.s[0]
- FMLA v16.4s, v22.4s, v0.s[1]
- FMLA v17.4s, v23.4s, v0.s[1]
- FMLA v18.4s, v22.4s, v1.s[1]
- FMLA v19.4s, v23.4s, v1.s[1]
- FMLA v28.4s, v22.4s, v2.s[1]
- FMLA v29.4s, v23.4s, v2.s[1]
- FMLA v30.4s, v22.4s, v3.s[1]
- FMLA v31.4s, v23.4s, v3.s[1]
-
- # Remainder- 1 float of A (4 bytes)
-3:
- TBZ x0, 2, 6f
-
- LDR s0, [x3], 4
- LDP q20, q21, [x5], 32
- LDR s1, [x11], 4
- LDR s2, [x12], 4
- LDR s3, [x4], 4
- FMLA v16.4s, v20.4s, v0.s[0]
- FMLA v17.4s, v21.4s, v0.s[0]
- FMLA v18.4s, v20.4s, v1.s[0]
- FMLA v19.4s, v21.4s, v1.s[0]
- FMLA v28.4s, v20.4s, v2.s[0]
- FMLA v29.4s, v21.4s, v2.s[0]
- FMLA v30.4s, v20.4s, v3.s[0]
- FMLA v31.4s, v21.4s, v3.s[0]
-
-6:
+4:
# Clamp
FMIN v16.4s, v16.4s, v4.4s
SUBS x1, x1, 8
@@ -223,9 +181,58 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
SUB x4, x4, x2 // a3 -= kc
B.HI 0b
-
RET
+ # Remainder- 2 floats of A (8 bytes)
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+6:
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
+
+
# Store odd width
7:
TBZ x1, 2, 8f
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
index e2bedbcd8..8147e90c4 100644
--- a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
@@ -88,10 +88,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
# Is there at least 2 floats (8 bytes)?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 5f
# Main loop - 2 floats of A (8 bytes)
-
1:
LDR d0, [x3], 8
LDP q20, q21, [x5], 32
@@ -117,25 +116,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
FMLA v30.4s, v22.4s, v3.s[1]
FMLA v31.4s, v23.4s, v3.s[1]
B.HS 1b
-2:
- # Remainder- 1 floats of A (4 bytes)
- TBZ x0, 2, 6f
- LDR s0, [x3], 4
- LDP q20, q21, [x5], 32
- LDR s1, [x11], 4
- LDR s2, [x12], 4
- LDR s3 , [x4], 4
- FMLA v16.4s, v20.4s, v0.s[0]
- FMLA v17.4s, v21.4s, v0.s[0]
- FMLA v18.4s, v20.4s, v1.s[0]
- FMLA v19.4s, v21.4s, v1.s[0]
- FMLA v28.4s, v20.4s, v2.s[0]
- FMLA v29.4s, v21.4s, v2.s[0]
- FMLA v30.4s, v20.4s, v3.s[0]
- FMLA v31.4s, v21.4s, v3.s[0]
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
-6:
+4:
# Clamp
FMIN v16.4s, v16.4s, v4.4s
SUBS x1, x1, 8
@@ -181,6 +166,23 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_
RET
+ # Remainder- 1 float of A (4 bytes)
+5:
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
+
# Store odd width
7:
TBZ x1, 2, 8f
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in
index 9ac56da97..fa7056c90 100644
--- a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in
@@ -167,7 +167,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -430,6 +430,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v21.4s, v15.4s, v3.s[1]
FMLA v23.4s, v15.4s, v3.s[3]
FMLA v25.4s, v15.4s, v4.s[1]
+ TST x0, 15
// BLOCK 7
FMLA v27.4s, v15.4s, v4.s[3]
@@ -437,11 +438,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v31.4s, v15.4s, v5.s[3]
ADD x5, x5, 64
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -507,8 +505,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -548,8 +549,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
index b3e3f5526..50b74f330 100644
--- a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
@@ -149,7 +149,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
# 48 FMA + 6 ld128 A + 4 LDP B
@@ -218,12 +218,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v31.4s, v19.4s, v5.s[3]
B.HS 1b
-2:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 4f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 5f
-3:
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ TST x0, 15
+ B.NE 5f
+
+4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
SUBS x1, x1, 8
@@ -252,7 +251,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMAX v31.4s, v31.4s, v7.4s
# Store full 6 x 8
- B.LO 6f
+ B.LO 7f
$if INC:
ST1 {v30.16b, v31.16b}, [x7], x14
@@ -282,10 +281,12 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
-4:
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
LDR d0, [x3], 8
LDP q16, q17, [x5], 32
@@ -320,10 +321,12 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v27.4s, v19.4s, v3.s[1]
FMLA v29.4s, v19.4s, v4.s[1]
FMLA v31.4s, v19.4s, v5.s[1]
- TBZ x0, 2, 3b
-5:
- # Remainder- 1 floats of A (4 bytes)
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+6:
LDR s0, [x3], 4
LDP q16, q17, [x5], 32
LDR s1, [x9], 4
@@ -343,11 +346,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v27.4s, v17.4s, v3.s[0]
FMLA v29.4s, v17.4s, v4.s[0]
FMLA v31.4s, v17.4s, v5.s[0]
- B 3b
+ B 4b
# Store odd width
-6:
- TBZ x1, 2, 7f
+7:
+ TBZ x1, 2, 8f
$if INC:
STR q30, [x7], 16
MOV v30.16b, v31.16b
@@ -375,8 +378,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
STR q30, [x7], 16
MOV v30.16b, v31.16b
-7:
- TBZ x1, 1, 8f
+8:
+ TBZ x1, 1, 9f
$if INC:
STR d30, [x7], 8
DUP d30, v30.d[1]
@@ -404,8 +407,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
STR d30, [x7], 8
DUP d30, v30.d[1]
-8:
- TBZ x1, 0, 9f
+9:
+ TBZ x1, 0, 10f
$if INC:
STR s30, [x7]
STR s28, [x13]
@@ -420,7 +423,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
STR s26, [x18]
STR s28, [x13]
STR s30, [x7]
-9:
+10:
RET
END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld128
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
index 64a1be3d4..daae4a001 100644
--- a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
@@ -149,7 +149,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
# Is there at least 2 floats (8 bytes) for main loop?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 4f
# Main loop - 2 floats of A (8 bytes)
# 24 FMA + 6 LD64 A + 2 LDP B
@@ -190,7 +190,6 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
FMLA v31.4s, v19.4s, v5.s[1]
B.HS 1b
-2:
# Is there a remainder?- 1 floats of A (4 bytes)
TBNZ x0, 2, 4f
3:
@@ -252,7 +251,6 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
4:
diff --git a/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S
index fc9bf22a5..cb9cac776 100644
--- a/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S
@@ -60,7 +60,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - loads for first group of 6 fma
@@ -254,6 +254,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
# BLOCK 4
INS v19.d[1], x9
FMLA v20.4s, v17.4s, v1.s[1]
+ TST x0, 15
# BLOCK 5
FMLA v21.4s, v18.4s, v1.s[1]
@@ -262,11 +263,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
FMLA v22.4s, v19.4s, v1.s[1]
# BLOCK 7
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -283,12 +281,13 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
SUB x3, x3, x2 // a0 -= kc
-
B.HI 0b
-
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 1 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S
index ba8e20400..897ad35fb 100644
--- a/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S
@@ -114,7 +114,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
SUBS x0, x0, 16
@@ -386,6 +386,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
FMLA v27.4s, v18.4s, v3.s[1]
FMLA v30.4s, v18.4s, v3.s[3]
FMLA v22.4s, v19.4s, v2.s[1]
+ TST x0, 15
# BLOCK 7
FMLA v25.4s, v19.4s, v2.s[3]
@@ -393,11 +394,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
ADD x5, x5, 96
FMLA v31.4s, v19.4s, v3.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -448,6 +446,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 4 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S
index 9b447aa95..b0ef94fa2 100644
--- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S
@@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -321,6 +321,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53
FMLA v22.4s, v14.4s, v3.s[3]
FMLA v24.4s, v14.4s, v4.s[1]
FMLA v26.4s, v14.4s, v4.s[3]
+ TST x0, 15
// BLOCK 4
FMLA v21.4s, v15.4s, v3.s[1]
@@ -331,11 +332,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53
// BLOCK 5
FMLA v27.4s, v15.4s, v4.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
+
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -375,8 +374,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -406,7 +408,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S
index 198884e76..87d7ddbc0 100644
--- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S
+++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S
@@ -75,7 +75,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
1:
@@ -122,10 +122,50 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
FMLA v31.4s, v27.4s, v3.s[3]
B.HS 1b
+ TST x0, 15
+ B.NE 5f
+
+4:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ SUBS x1, x1, 8
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ B.LO 7f
+
+ ST1 {v30.16b, v31.16b}, [x7], x14
+ SUB x3, x3, x2 // a0 -= kc
+ ST1 {v28.16b, v29.16b}, [x10], x14
+ SUB x11, x11, x2 // a1 -= kc
+ ST1 {v18.16b, v19.16b}, [x9], x14
+ SUB x12, x12, x2 // a2 -= kc
+ ST1 {v16.16b, v17.16b}, [x6], x14
+ SUB x4, x4, x2 // a3 -= kc
+
+ B.HI 0b
+ RET
+
# Remainder- 2 floats of A (8 bytes)
-2:
- TBZ x0, 3, 3f
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+ # Remainder- 2 floats of A (8 bytes)
LDR d0, [x3], 8
LDP q20, q21, [x5], 32
LDR d1, [x11], 8
@@ -149,10 +189,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
FMLA v30.4s, v22.4s, v3.s[1]
FMLA v31.4s, v23.4s, v3.s[1]
- # Remainder- 1 float of A (4 bytes)
-3:
- TBZ x0, 2, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+ # Remainder- 1 float of A (4 bytes)
+6:
LDR s0, [x3], 4
LDP q20, q21, [x5], 32
LDR s1, [x11], 4
@@ -166,42 +207,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
FMLA v29.4s, v21.4s, v2.s[0]
FMLA v30.4s, v20.4s, v3.s[0]
FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
-6:
- # Clamp
- FMIN v16.4s, v16.4s, v4.4s
- SUBS x1, x1, 8
- FMIN v17.4s, v17.4s, v4.4s
- FMIN v18.4s, v18.4s, v4.4s
- FMIN v19.4s, v19.4s, v4.4s
- FMIN v28.4s, v28.4s, v4.4s
- FMIN v29.4s, v29.4s, v4.4s
- FMIN v30.4s, v30.4s, v4.4s
- FMIN v31.4s, v31.4s, v4.4s
- FMAX v16.4s, v16.4s, v5.4s
- FMAX v17.4s, v17.4s, v5.4s
- FMAX v18.4s, v18.4s, v5.4s
- FMAX v19.4s, v19.4s, v5.4s
- FMAX v28.4s, v28.4s, v5.4s
- FMAX v29.4s, v29.4s, v5.4s
- FMAX v30.4s, v30.4s, v5.4s
- FMAX v31.4s, v31.4s, v5.4s
-
- # Store full 4 x 8
- B.LO 7f
-
- ST1 {v30.16b, v31.16b}, [x7], x14
- SUB x3, x3, x2 // a0 -= kc
- ST1 {v28.16b, v29.16b}, [x10], x14
- SUB x11, x11, x2 // a1 -= kc
- ST1 {v18.16b, v19.16b}, [x9], x14
- SUB x12, x12, x2 // a2 -= kc
- ST1 {v16.16b, v17.16b}, [x6], x14
- SUB x4, x4, x2 // a3 -= kc
-
- B.HI 0b
-
- RET
# Store odd width
7:
diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S
index 0f7a645c7..6d146e1a7 100644
--- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S
+++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S
@@ -75,10 +75,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64
# Is there at least 2 floats (8 bytes)?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 5f
# Main loop - 2 floats of A (8 bytes)
-
1:
LDR d0, [x3], 8
LDP q20, q21, [x5], 32
@@ -104,25 +103,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64
FMLA v30.4s, v22.4s, v3.s[1]
FMLA v31.4s, v23.4s, v3.s[1]
B.HS 1b
-2:
- # Remainder- 1 floats of A (4 bytes)
- TBZ x0, 2, 6f
- LDR s0, [x3], 4
- LDP q20, q21, [x5], 32
- LDR s1, [x11], 4
- LDR s2, [x12], 4
- LDR s3 , [x4], 4
- FMLA v16.4s, v20.4s, v0.s[0]
- FMLA v17.4s, v21.4s, v0.s[0]
- FMLA v18.4s, v20.4s, v1.s[0]
- FMLA v19.4s, v21.4s, v1.s[0]
- FMLA v28.4s, v20.4s, v2.s[0]
- FMLA v29.4s, v21.4s, v2.s[0]
- FMLA v30.4s, v20.4s, v3.s[0]
- FMLA v31.4s, v21.4s, v3.s[0]
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
-6:
+4:
# Clamp
FMIN v16.4s, v16.4s, v4.4s
SUBS x1, x1, 8
@@ -158,6 +143,23 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64
RET
+ # Remainder- 1 float of A (4 bytes)
+5:
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
+
# Store odd width
7:
TBZ x1, 2, 8f
diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S
index 736ac6744..0b5cb6e8e 100644
--- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S
@@ -134,7 +134,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -397,6 +397,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v21.4s, v15.4s, v3.s[1]
FMLA v23.4s, v15.4s, v3.s[3]
FMLA v25.4s, v15.4s, v4.s[1]
+ TST x0, 15
// BLOCK 7
FMLA v27.4s, v15.4s, v4.s[3]
@@ -404,11 +405,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v31.4s, v15.4s, v5.s[3]
ADD x5, x5, 64
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -460,8 +458,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -501,8 +502,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S
index 559a71a35..2fdf4f153 100644
--- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S
+++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S
@@ -122,7 +122,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
# 48 FMA + 6 ld128 A + 4 LDP B
@@ -191,12 +191,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
FMLA v31.4s, v19.4s, v5.s[3]
B.HS 1b
-2:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 4f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 5f
-3:
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ TST x0, 15
+ B.NE 5f
+
+4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
SUBS x1, x1, 8
@@ -225,7 +224,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
FMAX v31.4s, v31.4s, v7.4s
# Store full 6 x 8
- B.LO 6f
+ B.LO 7f
ST1 {v30.16b, v31.16b}, [x7], x14
SUB x3, x3, x2 // a0 -= kc
@@ -241,10 +240,12 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
-4:
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
LDR d0, [x3], 8
LDP q16, q17, [x5], 32
@@ -279,10 +280,12 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
FMLA v27.4s, v19.4s, v3.s[1]
FMLA v29.4s, v19.4s, v4.s[1]
FMLA v31.4s, v19.4s, v5.s[1]
- TBZ x0, 2, 3b
-5:
- # Remainder- 1 floats of A (4 bytes)
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+6:
LDR s0, [x3], 4
LDP q16, q17, [x5], 32
LDR s1, [x9], 4
@@ -302,11 +305,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
FMLA v27.4s, v17.4s, v3.s[0]
FMLA v29.4s, v17.4s, v4.s[0]
FMLA v31.4s, v17.4s, v5.s[0]
- B 3b
+ B 4b
# Store odd width
-6:
- TBZ x1, 2, 7f
+7:
+ TBZ x1, 2, 8f
STR q30, [x7], 16
MOV v30.16b, v31.16b
STR q28, [x13], 16
@@ -320,8 +323,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
STR q20, [x6], 16
MOV v20.16b, v21.16b
-7:
- TBZ x1, 1, 8f
+8:
+ TBZ x1, 1, 9f
STR d30, [x7], 8
DUP d30, v30.d[1]
STR d28, [x13], 8
@@ -335,15 +338,15 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
STR d20, [x6], 8
DUP d20, v20.d[1]
-8:
- TBZ x1, 0, 9f
+9:
+ TBZ x1, 0, 10f
STR s30, [x7]
STR s28, [x13]
STR s26, [x18]
STR s24, [x17]
STR s22, [x16]
STR s20, [x6]
-9:
+10:
RET
END_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S
index b66f5c70a..eaa68f7d7 100644
--- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S
+++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S
@@ -122,7 +122,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64
# Is there at least 2 floats (8 bytes) for main loop?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 4f
# Main loop - 2 floats of A (8 bytes)
# 24 FMA + 6 LD64 A + 2 LDP B
@@ -163,7 +163,6 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64
FMLA v31.4s, v19.4s, v5.s[1]
B.HS 1b
-2:
# Is there a remainder?- 1 floats of A (4 bytes)
TBNZ x0, 2, 4f
3:
@@ -211,7 +210,6 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
4:
diff --git a/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S
index c6eb6f37e..f1730103d 100644
--- a/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S
@@ -57,7 +57,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - loads for first group of 6 fma
@@ -251,6 +251,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
# BLOCK 4
INS v19.d[1], x9
FMLA v20.4s, v17.4s, v1.s[1]
+ TST x0, 15
# BLOCK 5
FMLA v21.4s, v18.4s, v1.s[1]
@@ -259,11 +260,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
FMLA v22.4s, v19.4s, v1.s[1]
# BLOCK 7
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -280,12 +278,13 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
SUB x3, x3, x2 // a0 -= kc
-
B.HI 0b
-
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 1 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S
index 15669a9a7..257ab4bb0 100644
--- a/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S
@@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
SUBS x0, x0, 16
@@ -389,6 +389,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
FMLA v27.4s, v18.4s, v3.s[1]
FMLA v30.4s, v18.4s, v3.s[3]
FMLA v22.4s, v19.4s, v2.s[1]
+ TST x0, 15
# BLOCK 7
FMLA v25.4s, v19.4s, v2.s[3]
@@ -396,11 +397,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
ADD x5, x5, 96
FMLA v31.4s, v19.4s, v3.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 5f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
@@ -451,6 +449,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
RET
5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder - 2 floats of A (8 bytes)
# Read first block of 4 A.
LDR d0, [x3], 8 // a0
diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S
index a0674a2c0..6b689e902 100644
--- a/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S
@@ -119,7 +119,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -323,6 +323,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53
FMLA v22.4s, v14.4s, v3.s[3]
FMLA v24.4s, v14.4s, v4.s[1]
FMLA v26.4s, v14.4s, v4.s[3]
+ TST x0, 15
// BLOCK 4
FMLA v21.4s, v15.4s, v3.s[1]
@@ -333,11 +334,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53
// BLOCK 5
FMLA v27.4s, v15.4s, v4.s[3]
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
+
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -377,8 +376,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -408,7 +410,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S
index c2f496745..aa7e7a7e6 100644
--- a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S
+++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S
@@ -75,7 +75,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
1:
@@ -122,10 +122,50 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
FMLA v31.4s, v27.4s, v3.s[3]
B.HS 1b
+ TST x0, 15
+ B.NE 5f
+
+4:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ SUBS x1, x1, 8
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ B.LO 7f
+
+ ST1 {v16.16b, v17.16b}, [x6], x14
+ SUB x3, x3, x2 // a0 -= kc
+ ST1 {v18.16b, v19.16b}, [x9], x14
+ SUB x11, x11, x2 // a1 -= kc
+ ST1 {v28.16b, v29.16b}, [x10], x14
+ SUB x12, x12, x2 // a2 -= kc
+ ST1 {v30.16b, v31.16b}, [x7], x14
+ SUB x4, x4, x2 // a3 -= kc
+
+ B.HI 0b
+ RET
+
# Remainder- 2 floats of A (8 bytes)
-2:
- TBZ x0, 3, 3f
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+ # Remainder- 2 floats of A (8 bytes)
LDR d0, [x3], 8
LDP q20, q21, [x5], 32
LDR d1, [x11], 8
@@ -149,10 +189,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
FMLA v30.4s, v22.4s, v3.s[1]
FMLA v31.4s, v23.4s, v3.s[1]
- # Remainder- 1 float of A (4 bytes)
-3:
- TBZ x0, 2, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+ # Remainder- 1 float of A (4 bytes)
+6:
LDR s0, [x3], 4
LDP q20, q21, [x5], 32
LDR s1, [x11], 4
@@ -166,42 +207,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
FMLA v29.4s, v21.4s, v2.s[0]
FMLA v30.4s, v20.4s, v3.s[0]
FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
-6:
- # Clamp
- FMIN v16.4s, v16.4s, v4.4s
- SUBS x1, x1, 8
- FMIN v17.4s, v17.4s, v4.4s
- FMIN v18.4s, v18.4s, v4.4s
- FMIN v19.4s, v19.4s, v4.4s
- FMIN v28.4s, v28.4s, v4.4s
- FMIN v29.4s, v29.4s, v4.4s
- FMIN v30.4s, v30.4s, v4.4s
- FMIN v31.4s, v31.4s, v4.4s
- FMAX v16.4s, v16.4s, v5.4s
- FMAX v17.4s, v17.4s, v5.4s
- FMAX v18.4s, v18.4s, v5.4s
- FMAX v19.4s, v19.4s, v5.4s
- FMAX v28.4s, v28.4s, v5.4s
- FMAX v29.4s, v29.4s, v5.4s
- FMAX v30.4s, v30.4s, v5.4s
- FMAX v31.4s, v31.4s, v5.4s
-
- # Store full 4 x 8
- B.LO 7f
-
- ST1 {v16.16b, v17.16b}, [x6], x14
- SUB x3, x3, x2 // a0 -= kc
- ST1 {v18.16b, v19.16b}, [x9], x14
- SUB x11, x11, x2 // a1 -= kc
- ST1 {v28.16b, v29.16b}, [x10], x14
- SUB x12, x12, x2 // a2 -= kc
- ST1 {v30.16b, v31.16b}, [x7], x14
- SUB x4, x4, x2 // a3 -= kc
-
- B.HI 0b
-
- RET
# Store odd width
7:
diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S
index 1770a0290..548450f64 100644
--- a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S
+++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S
@@ -75,10 +75,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64
# Is there at least 2 floats (8 bytes)?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 5f
# Main loop - 2 floats of A (8 bytes)
-
1:
LDR d0, [x3], 8
LDP q20, q21, [x5], 32
@@ -104,25 +103,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64
FMLA v30.4s, v22.4s, v3.s[1]
FMLA v31.4s, v23.4s, v3.s[1]
B.HS 1b
-2:
- # Remainder- 1 floats of A (4 bytes)
- TBZ x0, 2, 6f
- LDR s0, [x3], 4
- LDP q20, q21, [x5], 32
- LDR s1, [x11], 4
- LDR s2, [x12], 4
- LDR s3 , [x4], 4
- FMLA v16.4s, v20.4s, v0.s[0]
- FMLA v17.4s, v21.4s, v0.s[0]
- FMLA v18.4s, v20.4s, v1.s[0]
- FMLA v19.4s, v21.4s, v1.s[0]
- FMLA v28.4s, v20.4s, v2.s[0]
- FMLA v29.4s, v21.4s, v2.s[0]
- FMLA v30.4s, v20.4s, v3.s[0]
- FMLA v31.4s, v21.4s, v3.s[0]
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
-6:
+4:
# Clamp
FMIN v16.4s, v16.4s, v4.4s
SUBS x1, x1, 8
@@ -158,6 +143,23 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64
RET
+ # Remainder- 1 float of A (4 bytes)
+5:
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ B 4b
+
# Store odd width
7:
TBZ x1, 2, 8f
diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S
index 3851823a6..1e80a6010 100644
--- a/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S
@@ -138,7 +138,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 3f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x3], 8 // a0
@@ -401,6 +401,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v21.4s, v15.4s, v3.s[1]
FMLA v23.4s, v15.4s, v3.s[3]
FMLA v25.4s, v15.4s, v4.s[1]
+ TST x0, 15
// BLOCK 7
FMLA v27.4s, v15.4s, v4.s[3]
@@ -408,11 +409,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v31.4s, v15.4s, v5.s[3]
ADD x5, x5, 64
-3:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
@@ -464,8 +462,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53
LDP d12, d13, [sp], 32
RET
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
-6:
LDR d0, [x3], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x9], 8
@@ -505,8 +506,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53
# Is there a remainder?- 1 floats of A (4 bytes)
TBZ x0, 2, 4b
-
-7:
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x3], 4
LDR q16, [x5], 16
diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S
index 48ac7af57..7909ba23f 100644
--- a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S
+++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S
@@ -126,7 +126,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
# Is there at least 4 floats (16 bytes)?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 2f
+ B.LO 5f
# Main loop - 4 floats of A (16 bytes)
# 48 FMA + 6 ld128 A + 4 LDP B
@@ -195,12 +195,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
FMLA v31.4s, v19.4s, v5.s[3]
B.HS 1b
-2:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 4f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 5f
-3:
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ TST x0, 15
+ B.NE 5f
+
+4:
# Clamp
FMIN v20.4s, v20.4s, v6.4s
SUBS x1, x1, 8
@@ -229,7 +228,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
FMAX v31.4s, v31.4s, v7.4s
# Store full 6 x 8
- B.LO 6f
+ B.LO 7f
ST1 {v20.16b, v21.16b}, [x6], x14
SUB x3, x3, x2 // a0 -= kc
@@ -245,10 +244,12 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
-4:
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
# Remainder- 2 floats of A (8 bytes)
LDR d0, [x3], 8
LDP q16, q17, [x5], 32
@@ -283,10 +284,12 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
FMLA v27.4s, v19.4s, v3.s[1]
FMLA v29.4s, v19.4s, v4.s[1]
FMLA v31.4s, v19.4s, v5.s[1]
- TBZ x0, 2, 3b
-5:
- # Remainder- 1 floats of A (4 bytes)
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+6:
LDR s0, [x3], 4
LDP q16, q17, [x5], 32
LDR s1, [x9], 4
@@ -306,11 +309,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
FMLA v27.4s, v17.4s, v3.s[0]
FMLA v29.4s, v17.4s, v4.s[0]
FMLA v31.4s, v17.4s, v5.s[0]
- B 3b
+ B 4b
# Store odd width
-6:
- TBZ x1, 2, 7f
+7:
+ TBZ x1, 2, 8f
STR q20, [x6], 16
MOV v20.16b, v21.16b
STR q22, [x16], 16
@@ -324,8 +327,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
STR q30, [x7], 16
MOV v30.16b, v31.16b
-7:
- TBZ x1, 1, 8f
+8:
+ TBZ x1, 1, 9f
STR d20, [x6], 8
DUP d20, v20.d[1]
STR d22, [x16], 8
@@ -339,15 +342,15 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
STR d30, [x7], 8
DUP d30, v30.d[1]
-8:
- TBZ x1, 0, 9f
+9:
+ TBZ x1, 0, 10f
STR s20, [x6]
STR s22, [x16]
STR s24, [x17]
STR s26, [x18]
STR s28, [x13]
STR s30, [x7]
-9:
+10:
RET
END_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S
index 4f869950b..d9460005f 100644
--- a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S
+++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S
@@ -126,7 +126,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64
# Is there at least 2 floats (8 bytes) for main loop?
SUBS x0, x2, 8 // k = kc - 8
- B.LO 2f
+ B.LO 4f
# Main loop - 2 floats of A (8 bytes)
# 24 FMA + 6 LD64 A + 2 LDP B
@@ -167,7 +167,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64
FMLA v31.4s, v19.4s, v5.s[1]
B.HS 1b
-2:
# Is there a remainder?- 1 floats of A (4 bytes)
TBNZ x0, 2, 4f
3:
@@ -215,7 +214,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64
SUB x4, x4, x2 // a5 -= kc
B.HI 0b
-
RET
4:
diff --git a/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S
index 1f3edf73c..5f24466fc 100644
--- a/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S
@@ -71,7 +71,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 4f
+ B.LO 5f
# Prologue - loads for first group of 6 fma
@@ -267,6 +267,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
# BLOCK 4
INS v19.d[1], x7
FMLA v20.4s, v17.4s, v1.s[1]
+ TST x0, 15
# BLOCK 5
FMLA v21.4s, v18.4s, v1.s[1]
@@ -275,14 +276,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
FMLA v22.4s, v19.4s, v1.s[1]
# BLOCK 7
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
4:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
-
-5:
# ks loop
SUBS x9, x9, 8 // ks -= MR * sizeof(void*)
B.NE 1b
@@ -304,12 +301,13 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
# nc loop
B.HI 0b
-
RET
-6:
- # Remainder - 2 floats of A (8 bytes)
- # Read first block of 1 A.
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
LDR d0, [x8], 8 // a0
LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48
@@ -324,8 +322,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
FMLA v21.4s, v6.4s, v0.s[1]
FMLA v22.4s, v7.4s, v0.s[1]
- TBZ x0, 2, 5b
-7:
+ TBZ x0, 2, 4b
+6:
# Remainder - 1 float of A (4 bytes)
LDR s0, [x8], 4 // a0
LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
@@ -333,7 +331,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
FMLA v20.4s, v2.4s, v0.s[0]
FMLA v21.4s, v3.4s, v0.s[0]
FMLA v22.4s, v4.4s, v0.s[0]
- B 5b
+ B 4b
8:
ADD x1, x1, 12
diff --git a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S
index f43e79e37..8d6e51d09 100644
--- a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S
@@ -63,7 +63,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53
# Is there at least 8 floats (32 bytes) for prologue + epilogue?
SUBS x0, x2, 32 // k = kc - 32 // k = kc
- B.LO 4f
+ B.LO 5f
# 16 prologue
# Read first block of A and B.
@@ -148,18 +148,13 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53
FMLA v19.4s, v23.4s, v1.s[1]
FMLA v16.4s, v24.4s, v1.s[2]
FMLA v17.4s, v25.4s, v1.s[2]
+ TST x0, 31
FMLA v18.4s, v26.4s, v1.s[3]
FMLA v19.4s, v27.4s, v1.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ B.NE 5f
4:
- # Is there a remainder?- 4 floats of A (16 bytes)
- TBNZ x0, 4, 6f
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 7f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 9f
-
-5:
# ks loop
SUBS x9, x9, 8 // ks -= MR * sizeof(void*)
B.NE 1b
@@ -185,7 +180,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53
RET
-6:
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 4, 6f
+
# Remainder- 4 floats of A (16 bytes)
LDR q20, [x5], 16
LDR q21, [x5], 16
@@ -205,8 +203,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53
FMLA v18.4s, v26.4s, v0.s[3]
FMLA v19.4s, v27.4s, v0.s[3]
- TBZ x0, 3, 8f
-7:
+6:
+ TBZ x0, 3, 7f
# Remainder- 2 floats of A (8 bytes)
LDR q20, [x5], 16
LDR q21, [x5], 16
@@ -217,16 +215,15 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53
LDR q23, [x5], 16
FMLA v18.4s, v22.4s, v0.s[1]
FMLA v19.4s, v23.4s, v0.s[1]
-8:
- TBZ x0, 2, 5b
-9:
+7:
+ TBZ x0, 2, 4b
# Remainder- 1 float of A (4 bytes)
LDR q20, [x5], 16
LDR q21, [x5], 16
LDR s0, [x8], 4
FMLA v16.4s, v20.4s, v0.s[0]
FMLA v17.4s, v21.4s, v0.s[0]
- B 5b
+ B 4b
10:
# Store odd channels
diff --git a/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S
index a25dce0af..e0675a8cf 100644
--- a/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S
@@ -136,7 +136,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
PRFM PLDL1KEEP, [x15, 64]
PRFM PLDL1KEEP, [x16, 0]
PRFM PLDL1KEEP, [x16, 64]
- B.LO 4f
+ B.LO 5f
SUBS x0, x0, 16 // 4 floats for main loop
@@ -408,19 +408,18 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
FMLA v27.4s, v18.4s, v3.s[1]
FMLA v30.4s, v18.4s, v3.s[3]
FMLA v22.4s, v19.4s, v2.s[1]
+ TST x0, 15
# BLOCK 7
FMLA v25.4s, v19.4s, v2.s[3]
FMLA v28.4s, v19.4s, v3.s[1]
ADD x5, x5, 96
FMLA v31.4s, v19.4s, v3.s[3]
-4:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
-5:
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
+
+4:
# ks loop
SUBS x9, x9, 32 // ks -= MR * sizeof(void*)
B.NE 1b
@@ -470,9 +469,11 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
LDP d8, d9, [sp], 48
RET
-6:
- # Remainder - 2 floats of A (8 bytes)
- # Read first block of 4 A.
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
LDR d0, [x13], 8 // a0
LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
LDR d1, [x14], 8 // a1
@@ -508,9 +509,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
FMLA v28.4s, v11.4s, v2.s[1]
FMLA v31.4s, v11.4s, v3.s[1]
- TBZ x0, 2, 5b
-7:
- # Remainder - 1 float of A (4 bytes)
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBZ x0, 2, 4b
+6:
+ # Remainder- 1 floats of A (4 bytes)
LDR s0, [x13], 4 // a0
LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
LDR s1, [x14], 4 // a1
@@ -529,7 +531,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
FMLA v25.4s, v8.4s, v1.s[0]
FMLA v28.4s, v8.4s, v2.s[0]
FMLA v31.4s, v8.4s, v3.s[0]
- B 5b
+ B 4b
8:
ADD x1, x1, 12
diff --git a/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in b/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in
new file mode 100644
index 000000000..d862d608a
--- /dev/null
+++ b/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in
@@ -0,0 +1,393 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+.syntax unified
+
+// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75(
+// size_t mr, r0
+// size_t nc, r1
+// size_t kc, r2 -> r5 -> sp + 64
+// size_t ks, r3 -> sp + 68 -> r14
+// const float**restrict a, sp + 108 -> r2
+// const void*restrict w, sp + 112 -> r9
+// uint8_t*restrict c, sp + 116 -> r11
+// size_t cm_stride, sp + 120 -> (r6)
+// size_t cn_stride, sp + 124 -> (r7)
+// size_t a_offset, sp + 128 -> (r5)
+// const float* zero, sp + 132 -> (r7)
+// output_params*params, sp + 136 -> (r5)
+
+// inner loop registers
+
+// A0 r3 d0
+// A1 r12 d1
+// A2 r10 d2
+// A3 r0 d3
+
+// B r9 d8, d9, d10, d11
+// B d12, d13, d14, d15
+
+// C0 r11 d16-d17 q8 d18-d19 q9
+// C1 r4 d20-d21 q10 d22-d23 q11
+// C2 r8 d24-d25 q12 d26-d27 q13
+// C3 r6 d28-d29 q14 d30-d31 q15
+
+// Clamp (r5) d4 d5 d6 d7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75
+ .arm
+#ifndef __APPLE__
+ .arch armv7-a
+ .fpu neon
+#endif
+ // Push 108 bytes
+ // r2 will be reloaded in outer loop. r3 is ks
+ PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44
+ VPUSH {d8-d15} // +64 = 108
+
+ LDR r2, [sp, 108] // a
+ LDR r9, [sp, 112] // w
+ LDR r11, [sp, 116] // c
+ LDR r6, [sp, 120] // cm_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Clamp C pointers
+ CMP r0, 2 // if mr >= 2
+ ADD r4, r11, r6 // c1 = c0 + cm_stride
+ MOVLO r4, r11 // c1
+ // if mr > 2
+ ADD r8, r4, r6 // c2 = c1 + cm_stride
+ MOVLS r8, r4 // c2
+ CMP r0, 4 // if mr >=4
+ ADD r6, r8, r6 // c3 = c2 + cm_stride
+ MOVLO r6, r8 // c3
+
+ .p2align 3
+0:
+ # Load initial bias from w into accumulators
+ VLDM r9!, {d16-d19} // Bias
+ VMOV q10, q8
+ VMOV q11, q9
+ VMOV q12, q8
+ VMOV q13, q9
+ VMOV q14, q8
+ VMOV q15, q9
+
+ $if PREFETCH:
+ PLD [r9, 0] // Prefetch B
+ PLD [r9, 64]
+ PLD [r9, 128]
+ PLD [r9, 192]
+ PLD [r9, 256]
+ PLD [r9, 320]
+
+1:
+ # Load next 4 A pointers
+ LDR r3, [r2, 0]
+ LDR r12, [r2, 4]
+ LDR r10, [r2, 8]
+ LDR r0, [r2, 12]
+ ADD r2, r2, 16
+
+ // Add a_offset
+ LDR r5, [sp, 128] // a_offset
+ LDR r7, [sp, 132] // zero
+ CMP r3, r7 // if a0 == zero
+ ADD r3, r3, r5 // a0 += a_offset
+ MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset
+ CMP r12, r7 // if a1 == zero
+ ADD r12, r12, r5 // a1 += a_offset
+ MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset
+ CMP r10, r7 // if a2 == zero
+ ADD r10, r10, r5 // a2 += a_offset
+ MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset
+ CMP r0, r7 // if a3 == zero
+ ADD r0, r0, r5 // a3 += a_offset
+ LDR r5, [sp, 64] // kc
+ MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset
+
+ $if PREFETCH:
+ PLD [r3, 0] // Prefetch A
+ PLD [r3, 64]
+ PLD [r12, 0]
+ PLD [r12, 64]
+ PLD [r10, 0]
+ PLD [r10, 64]
+ PLD [r0, 0]
+ PLD [r0, 64]
+
+ SUBS r5, r5, 16 // kc - 16
+ BLO 4f // less than 4 channels?
+
+ // Prologue
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ SUBS r5, r5, 16
+ BLO 3f // less than 4 channels? skip main loop
+
+ .p2align 3
+
+ // Main loop - 4 floats of A (16 bytes)
+2:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ $if PREFETCH:
+ PLD [r3, 128] // Prefetch A0
+ VMLA.F32 q12, q4, d6[0]
+ VLD1.32 {d0}, [r3]! // A0
+ VMLA.F32 q14, q4, d7[0]
+ $if PREFETCH:
+ PLD [r12, 128] // Prefetch A1
+ VMLA.F32 q9, q5, d4[0]
+ VLD1.32 {d1}, [r12]! // A1
+ VMLA.F32 q11, q5, d5[0]
+ $if PREFETCH:
+ PLD [r10, 128] // Prefetch A2
+ VMLA.F32 q13, q5, d6[0]
+ VLD1.32 {d2}, [r10]! // A2
+ VMLA.F32 q15, q5, d7[0]
+ $if PREFETCH:
+ PLD [r0, 128] // Prefetch A3
+ VMLA.F32 q8, q6, d4[1]
+ VLD1.32 {d3}, [ r0]! // A3
+ VMLA.F32 q10, q6, d5[1]
+ $if PREFETCH:
+ PLD [r9, 384] // Prefetch B
+ VMLA.F32 q12, q6, d6[1]
+ $if PREFETCH:
+ PLD [r9, 448] // Prefetch B
+ VMLA.F32 q14, q6, d7[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ SUBS r5, r5, 16
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+ BHS 2b
+
+ // Epilogue
+3:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ VMLA.F32 q12, q4, d6[0]
+ VMLA.F32 q14, q4, d7[0]
+ VMLA.F32 q9, q5, d4[0]
+ VMLA.F32 q11, q5, d5[0]
+ VMLA.F32 q13, q5, d6[0]
+ VMLA.F32 q15, q5, d7[0]
+ VMLA.F32 q8, q6, d4[1]
+ VMLA.F32 q10, q6, d5[1]
+ VMLA.F32 q12, q6, d6[1]
+ VMLA.F32 q14, q6, d7[1]
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+
+4:
+ // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
+ TST r5, 12
+ BNE 7f
+
+ .p2align 3
+5:
+ # ks loop
+ SUBS r14, r14, 16 // ks -= MR * sizeof(void*)
+ BNE 1b
+
+ // Load params pointer
+ LDR r5, [sp, 136] // clamping_params
+ LDR r7, [sp, 124] // cn_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Load clamping_params values
+ VLD1.32 {d4[],d5[]}, [r5]!
+ SUBS r1, r1, 8
+ VLD1.32 {d6[],d7[]}, [r5]
+
+ // Clamp
+ VMIN.F32 q8, q8, q2
+ VMIN.F32 q9, q9, q2
+ VMIN.F32 q10, q10, q2
+ VMIN.F32 q11, q11, q2
+ VMIN.F32 q12, q12, q2
+ VMIN.F32 q13, q13, q2
+ VMIN.F32 q14, q14, q2
+ VMIN.F32 q15, q15, q2
+ VMAX.F32 q8, q8, q3
+ VMAX.F32 q9, q9, q3
+ VMAX.F32 q10, q10, q3
+ VMAX.F32 q11, q11, q3
+ VMAX.F32 q12, q12, q3
+ VMAX.F32 q13, q13, q3
+ VMAX.F32 q14, q14, q3
+ VMAX.F32 q15, q15, q3
+
+ // Store full 4 x 8
+ BLO 10f
+ VST1.32 {d28-d31}, [r6], r7
+ VST1.32 {d24-d27}, [r8], r7
+ VST1.32 {d20-d23}, [r4], r7
+ VST1.32 {d16-d19}, [r11], r7
+ SUB r2, r2, r14 // a -= ks
+ BHI 0b
+
+6:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+ .p2align 3
+7:
+ // Is there a remainder?- 2 floats of A (8 bytes)
+ TST r5, 8
+ BEQ 8f
+
+ // Remainder - 2 floats of A (8 bytes)
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VMLA.F32 q15, q7, d3[1]
+8:
+ // Is there a remainder?- 1 floats of A (4 bytes)
+ TST r5, 4
+ BEQ 5b
+
+9:
+ // Remainder- 1 floats of A (4 bytes)
+ VLDM r3!, {s0} // A0
+ VLDM r9!, {d8-d11} // B0
+ VLDM r12!, {s2} // A1
+ VLDM r10!, {s4} // A2
+ VLDM r0!, {s6} // A3
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ B 5b
+
+ // Store odd width
+10:
+ TST r1, 4
+ BEQ 11f
+ VST1.32 {d28-d29}, [r6]!
+ VMOV q14, q15
+ VST1.32 {d24-d25}, [r8]!
+ VMOV q12, q13
+ VST1.32 {d20-d21}, [r4]!
+ VMOV q10, q11
+ VST1.32 {d16-d17}, [r11]!
+ VMOV q8, q9
+
+11:
+ TST r1, 2
+ BEQ 12f
+ VST1.32 {d28}, [r6]!
+ VMOV d28, d29
+ VST1.32 {d24}, [r8]!
+ VMOV d24, d25
+ VST1.32 {d20}, [r4]!
+ VMOV d20, d21
+ VST1.32 {d16}, [r11]!
+ VMOV d16, d17
+
+12:
+ TST r1, 1
+ BEQ 13f
+ VST1.32 {d28[0]}, [r6]!
+ VST1.32 {d24[0]}, [r8]!
+ VST1.32 {d20[0]}, [r4]!
+ VST1.32 {d16[0]}, [r11]!
+
+13:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
+
+
+
diff --git a/src/f32-igemm/4x8-aarch32-neon-ld64.S b/src/f32-igemm/4x8-aarch32-neon-ld64.S
new file mode 100644
index 000000000..3e0a77f8d
--- /dev/null
+++ b/src/f32-igemm/4x8-aarch32-neon-ld64.S
@@ -0,0 +1,248 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+.syntax unified
+
+// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64(
+// size_t mr, r0
+// size_t nc, r1
+// size_t kc, r2 -> r5 -> sp + 68
+// size_t ks, r3 -> sp + 72 -> r14
+// const float**restrict a, sp + 112 -> r2
+// const void*restrict w, sp + 116 -> r9
+// uint8_t*restrict c, sp + 120 -> r11
+// size_t cm_stride, sp + 124 -> (r6)
+// size_t cn_stride, sp + 128 -> (r7)
+// size_t a_offset, sp + 132 -> (r5)
+// const float* zero, sp + 136 -> (r7)
+// output_params*params, sp + 140 -> (r5)
+
+// inner loop registers
+
+// A0 r3 d0
+// A1 r12 d1
+// A2 r10 d2
+// A3 r0 d3
+
+// B r9 d8, d9, d10, d11
+// B d12, d13, d14, d15
+
+// C0 r11 d16-d17 q8 d18-d19 q9
+// C1 r4 d20-d21 q10 d22-d23 q11
+// C2 r8 d24-d25 q12 d26-d27 q13
+// C3 r6 d28-d29 q14 d30-d31 q15
+
+// Clamp (r5) d4 d5 d6 d7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64
+ .arm
+#ifndef __APPLE__
+ .arch armv7-a
+ .fpu neon
+#endif
+ // Push 112 bytes
+ // r2 will be reloaded in outer loop. r3 is ks
+ PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44
+ SUB sp, sp, 4 // 4
+ VPUSH {d8-d15} // +64 = 112
+
+ MOV r14, r3 // p = ks
+ LDR r2, [sp, 112] // a
+ LDR r9, [sp, 116] // w
+ LDR r11, [sp, 120] // c
+ LDR r6, [sp, 124] // cm_stride
+ LDR r5, [sp, 140] // clamping_params
+
+ // Clamp C pointers
+ CMP r0, 2 // if mr >= 2
+ ADD r4, r11, r6 // c1 = c0 + cm_stride
+ MOVLO r4, r11 // c1
+ // if mr > 2
+ ADD r8, r4, r6 // c2 = c1 + cm_stride
+ MOVLS r8, r4 // c2
+ CMP r0, 4 // if mr >=4
+ ADD r6, r8, r6 // c3 = c2 + cm_stride
+ MOVLO r6, r8 // c3
+
+ // Load clamping_params values
+ VLD1.32 {d4[], d5[]}, [r5]!
+ VLD1.32 {d6[], d7[]}, [r5]
+
+0:
+ # Load initial bias from w into accumulators
+ VLDM r9!, {d16-d19} // Bias
+ VMOV q10, q8
+ VMOV q11, q9
+ VMOV q12, q8
+ VMOV q13, q9
+ VMOV q14, q8
+ VMOV q15, q9
+
+1:
+ # Load next 4 A pointers
+ LDR r3, [r2, 0]
+ LDR r12, [r2, 4]
+ LDR r10, [r2, 8]
+ LDR r0, [r2, 12]
+ ADD r2, r2, 16
+
+ // Add a_offset
+ LDR r5, [sp, 132] // a_offset
+ LDR r7, [sp, 136] // zero
+ CMP r3, r7 // if a0 == zero
+ ADD r3, r3, r5 // a0 += a_offset
+ MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset
+ CMP r12, r7 // if a1 == zero
+ ADD r12, r12, r5 // a1 += a_offset
+ MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset
+ CMP r10, r7 // if a2 == zero
+ ADD r10, r10, r5 // a2 += a_offset
+ MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset
+ CMP r0, r7 // if a3 == zero
+ ADD r0, r0, r5 // a3 += a_offset
+ LDR r5, [sp, 68] // kc
+ MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset
+
+ SUBS r5, r5, 8 // kc - 8
+ BLO 8f // less than 2 channels?
+
+ // Main loop - 2 floats of A (8 bytes)
+2:
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q13, q7, d2[1]
+ SUBS r5, r5, 8
+ VMLA.F32 q14, q6, d3[1]
+ VMLA.F32 q15, q7, d3[1]
+ BHS 2b
+
+ // Is there a remainder?- 1 floats of A (4 bytes)
+ TST r5, 4
+ BNE 8f
+
+4:
+ # ks loop
+ SUBS r14, r14, 16 // ks -= MR * sizeof(void*)
+ BNE 1b
+
+ LDR r7, [sp, 128] // cn_stride
+ LDR r14, [sp, 72] // p = ks
+
+ // Clamp
+ VMIN.F32 q8, q8, q2
+ SUBS r1, r1, 8
+ VMIN.F32 q9, q9, q2
+ VMIN.F32 q10, q10, q2
+ VMIN.F32 q11, q11, q2
+ VMIN.F32 q12, q12, q2
+ VMIN.F32 q13, q13, q2
+ VMIN.F32 q14, q14, q2
+ VMIN.F32 q15, q15, q2
+ VMAX.F32 q8, q8, q3
+ VMAX.F32 q9, q9, q3
+ VMAX.F32 q10, q10, q3
+ VMAX.F32 q11, q11, q3
+ VMAX.F32 q12, q12, q3
+ VMAX.F32 q13, q13, q3
+ VMAX.F32 q14, q14, q3
+ VMAX.F32 q15, q15, q3
+
+ // Store full 4 x 8
+ BLO 10f
+ VST1.32 {d28-d31}, [r6], r7
+ VST1.32 {d24-d27}, [r8], r7
+ VST1.32 {d20-d23}, [r4], r7
+ VST1.32 {d16-d19}, [r11], r7
+ SUB r2, r2, r14 // a -= ks
+ BHI 0b
+
+6:
+ VPOP {d8-d15}
+ ADD sp, sp, 12 // skip pad, r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+8:
+ // Remainder- 1 floats of A (4 bytes)
+ VLDM r3!, {s0} // A0
+ VLDM r9!, {d8-d11} // B0
+ VLDM r12!, {s2} // A1
+ VLDM r10!, {s4} // A2
+ VLDM r0!, {s6} // A3
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ B 4b
+
+ // Store odd width
+10:
+ TST r1, 4
+ BEQ 11f
+ VST1.32 {d28-d29}, [r6]!
+ VMOV q14, q15
+ VST1.32 {d24-d25}, [r8]!
+ VMOV q12, q13
+ VST1.32 {d20-d21}, [r4]!
+ VMOV q10, q11
+ VST1.32 {d16-d17}, [r11]!
+ VMOV q8, q9
+
+11:
+ TST r1, 2
+ BEQ 12f
+ VST1.32 {d28}, [r6]!
+ VMOV d28, d29
+ VST1.32 {d24}, [r8]!
+ VMOV d24, d25
+ VST1.32 {d20}, [r4]!
+ VMOV d20, d21
+ VST1.32 {d16}, [r11]!
+ VMOV d16, d17
+
+12:
+ TST r1, 1
+ BEQ 13f
+ VST1.32 {d28[0]}, [r6]!
+ VST1.32 {d24[0]}, [r8]!
+ VST1.32 {d20[0]}, [r4]!
+ VST1.32 {d16[0]}, [r11]!
+
+13:
+ VPOP {d8-d15}
+ ADD sp, sp, 12 // skip pad, r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
+
+
+
diff --git a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S
index c4ea8a041..3094b16b0 100644
--- a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S
+++ b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S
@@ -145,7 +145,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
# Is there at least 4 floats (16 bytes) for prologue + epilogue?
SUBS x0, x2, 16 // k = kc - 16
- B.LO 4f
+ B.LO 5f
# Prologue - First group loads, no FMA
LDR d0, [x14], 8 // a0
@@ -408,6 +408,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v21.4s, v15.4s, v3.s[1]
FMLA v23.4s, v15.4s, v3.s[3]
FMLA v25.4s, v15.4s, v4.s[1]
+ TST x0, 15
// BLOCK 7
FMLA v27.4s, v15.4s, v4.s[3]
@@ -415,13 +416,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v31.4s, v15.4s, v5.s[3]
ADD x5, x5, 64
-4:
- # Is there a remainder?- 2 floats of A (8 bytes)
- TBNZ x0, 3, 6f
- # Is there a remainder?- 1 floats of A (4 bytes)
- TBNZ x0, 2, 7f
+ # Is there a remainder?- 2 floats of A (8 bytes) or less
+ B.NE 5f
-5:
+4:
# ks loop
SUBS x9, x9, 48 // ks -= MR * sizeof(void*)
B.NE 1b
@@ -482,9 +480,11 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
LDP d12, d13, [sp], 80
RET
- # Remainder - 2 floats of A (8 bytes)
- # 24 FMA + 6 LD64 A + 2 LDP B
-6:
+5:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
LDR d0, [x14], 8
LDR q16, [x5], 16
LD1 {v0.d}[1], [x15], 8
@@ -522,8 +522,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v31.4s, v19.4s, v2.s[3]
# Is there a remainder?- 1 floats of A (4 bytes)
- TBZ x0, 2, 5b
-7:
+ TBZ x0, 2, 4b
+6:
# Remainder- 1 floats of A (4 bytes)
LDR s0, [x14], 4
LDR q16, [x5], 16
@@ -546,7 +546,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53
FMLA v27.4s, v17.4s, v1.s[2]
FMLA v29.4s, v17.4s, v2.s[0]
FMLA v31.4s, v17.4s, v2.s[2]
- B 5b
+ B 4b
# Store odd width
8:
diff --git a/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S b/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S
new file mode 100644
index 000000000..0e2f07fba
--- /dev/null
+++ b/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S
@@ -0,0 +1,369 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+.syntax unified
+
+// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75(
+// size_t mr, r0
+// size_t nc, r1
+// size_t kc, r2 -> r5 -> sp + 64
+// size_t ks, r3 -> sp + 68 -> r14
+// const float**restrict a, sp + 108 -> r2
+// const void*restrict w, sp + 112 -> r9
+// uint8_t*restrict c, sp + 116 -> r11
+// size_t cm_stride, sp + 120 -> (r6)
+// size_t cn_stride, sp + 124 -> (r7)
+// size_t a_offset, sp + 128 -> (r5)
+// const float* zero, sp + 132 -> (r7)
+// output_params*params, sp + 136 -> (r5)
+
+// inner loop registers
+
+// A0 r3 d0
+// A1 r12 d1
+// A2 r10 d2
+// A3 r0 d3
+
+// B r9 d8, d9, d10, d11
+// B d12, d13, d14, d15
+
+// C0 r11 d16-d17 q8 d18-d19 q9
+// C1 r4 d20-d21 q10 d22-d23 q11
+// C2 r8 d24-d25 q12 d26-d27 q13
+// C3 r6 d28-d29 q14 d30-d31 q15
+
+// Clamp (r5) d4 d5 d6 d7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75
+ .arm
+#ifndef __APPLE__
+ .arch armv7-a
+ .fpu neon
+#endif
+ // Push 108 bytes
+ // r2 will be reloaded in outer loop. r3 is ks
+ PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44
+ VPUSH {d8-d15} // +64 = 108
+
+ LDR r2, [sp, 108] // a
+ LDR r9, [sp, 112] // w
+ LDR r11, [sp, 116] // c
+ LDR r6, [sp, 120] // cm_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Clamp C pointers
+ CMP r0, 2 // if mr >= 2
+ ADD r4, r11, r6 // c1 = c0 + cm_stride
+ MOVLO r4, r11 // c1
+ // if mr > 2
+ ADD r8, r4, r6 // c2 = c1 + cm_stride
+ MOVLS r8, r4 // c2
+ CMP r0, 4 // if mr >=4
+ ADD r6, r8, r6 // c3 = c2 + cm_stride
+ MOVLO r6, r8 // c3
+
+ .p2align 3
+0:
+ # Load initial bias from w into accumulators
+ VLDM r9!, {d16-d19} // Bias
+ VMOV q10, q8
+ VMOV q11, q9
+ VMOV q12, q8
+ VMOV q13, q9
+ VMOV q14, q8
+ VMOV q15, q9
+
+
+1:
+ # Load next 4 A pointers
+ LDR r3, [r2, 0]
+ LDR r12, [r2, 4]
+ LDR r10, [r2, 8]
+ LDR r0, [r2, 12]
+ ADD r2, r2, 16
+
+ // Add a_offset
+ LDR r5, [sp, 128] // a_offset
+ LDR r7, [sp, 132] // zero
+ CMP r3, r7 // if a0 == zero
+ ADD r3, r3, r5 // a0 += a_offset
+ MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset
+ CMP r12, r7 // if a1 == zero
+ ADD r12, r12, r5 // a1 += a_offset
+ MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset
+ CMP r10, r7 // if a2 == zero
+ ADD r10, r10, r5 // a2 += a_offset
+ MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset
+ CMP r0, r7 // if a3 == zero
+ ADD r0, r0, r5 // a3 += a_offset
+ LDR r5, [sp, 64] // kc
+ MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset
+
+
+ SUBS r5, r5, 16 // kc - 16
+ BLO 4f // less than 4 channels?
+
+ // Prologue
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ SUBS r5, r5, 16
+ BLO 3f // less than 4 channels? skip main loop
+
+ .p2align 3
+
+ // Main loop - 4 floats of A (16 bytes)
+2:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ VMLA.F32 q12, q4, d6[0]
+ VLD1.32 {d0}, [r3]! // A0
+ VMLA.F32 q14, q4, d7[0]
+ VMLA.F32 q9, q5, d4[0]
+ VLD1.32 {d1}, [r12]! // A1
+ VMLA.F32 q11, q5, d5[0]
+ VMLA.F32 q13, q5, d6[0]
+ VLD1.32 {d2}, [r10]! // A2
+ VMLA.F32 q15, q5, d7[0]
+ VMLA.F32 q8, q6, d4[1]
+ VLD1.32 {d3}, [ r0]! // A3
+ VMLA.F32 q10, q6, d5[1]
+ VMLA.F32 q12, q6, d6[1]
+ VMLA.F32 q14, q6, d7[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ SUBS r5, r5, 16
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+ BHS 2b
+
+ // Epilogue
+3:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ VMLA.F32 q12, q4, d6[0]
+ VMLA.F32 q14, q4, d7[0]
+ VMLA.F32 q9, q5, d4[0]
+ VMLA.F32 q11, q5, d5[0]
+ VMLA.F32 q13, q5, d6[0]
+ VMLA.F32 q15, q5, d7[0]
+ VMLA.F32 q8, q6, d4[1]
+ VMLA.F32 q10, q6, d5[1]
+ VMLA.F32 q12, q6, d6[1]
+ VMLA.F32 q14, q6, d7[1]
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+
+4:
+ // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
+ TST r5, 12
+ BNE 7f
+
+ .p2align 3
+5:
+ # ks loop
+ SUBS r14, r14, 16 // ks -= MR * sizeof(void*)
+ BNE 1b
+
+ // Load params pointer
+ LDR r5, [sp, 136] // clamping_params
+ LDR r7, [sp, 124] // cn_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Load clamping_params values
+ VLD1.32 {d4[],d5[]}, [r5]!
+ SUBS r1, r1, 8
+ VLD1.32 {d6[],d7[]}, [r5]
+
+ // Clamp
+ VMIN.F32 q8, q8, q2
+ VMIN.F32 q9, q9, q2
+ VMIN.F32 q10, q10, q2
+ VMIN.F32 q11, q11, q2
+ VMIN.F32 q12, q12, q2
+ VMIN.F32 q13, q13, q2
+ VMIN.F32 q14, q14, q2
+ VMIN.F32 q15, q15, q2
+ VMAX.F32 q8, q8, q3
+ VMAX.F32 q9, q9, q3
+ VMAX.F32 q10, q10, q3
+ VMAX.F32 q11, q11, q3
+ VMAX.F32 q12, q12, q3
+ VMAX.F32 q13, q13, q3
+ VMAX.F32 q14, q14, q3
+ VMAX.F32 q15, q15, q3
+
+ // Store full 4 x 8
+ BLO 10f
+ VST1.32 {d28-d31}, [r6], r7
+ VST1.32 {d24-d27}, [r8], r7
+ VST1.32 {d20-d23}, [r4], r7
+ VST1.32 {d16-d19}, [r11], r7
+ SUB r2, r2, r14 // a -= ks
+ BHI 0b
+
+6:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+ .p2align 3
+7:
+ // Is there a remainder?- 2 floats of A (8 bytes)
+ TST r5, 8
+ BEQ 8f
+
+ // Remainder - 2 floats of A (8 bytes)
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VMLA.F32 q15, q7, d3[1]
+8:
+ // Is there a remainder?- 1 floats of A (4 bytes)
+ TST r5, 4
+ BEQ 5b
+
+9:
+ // Remainder- 1 floats of A (4 bytes)
+ VLDM r3!, {s0} // A0
+ VLDM r9!, {d8-d11} // B0
+ VLDM r12!, {s2} // A1
+ VLDM r10!, {s4} // A2
+ VLDM r0!, {s6} // A3
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ B 5b
+
+ // Store odd width
+10:
+ TST r1, 4
+ BEQ 11f
+ VST1.32 {d28-d29}, [r6]!
+ VMOV q14, q15
+ VST1.32 {d24-d25}, [r8]!
+ VMOV q12, q13
+ VST1.32 {d20-d21}, [r4]!
+ VMOV q10, q11
+ VST1.32 {d16-d17}, [r11]!
+ VMOV q8, q9
+
+11:
+ TST r1, 2
+ BEQ 12f
+ VST1.32 {d28}, [r6]!
+ VMOV d28, d29
+ VST1.32 {d24}, [r8]!
+ VMOV d24, d25
+ VST1.32 {d20}, [r4]!
+ VMOV d20, d21
+ VST1.32 {d16}, [r11]!
+ VMOV d16, d17
+
+12:
+ TST r1, 1
+ BEQ 13f
+ VST1.32 {d28[0]}, [r6]!
+ VST1.32 {d24[0]}, [r8]!
+ VST1.32 {d20[0]}, [r4]!
+ VST1.32 {d16[0]}, [r11]!
+
+13:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
+
+
+
diff --git a/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S b/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S
new file mode 100644
index 000000000..0051bad2f
--- /dev/null
+++ b/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S
@@ -0,0 +1,389 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+.syntax unified
+
+// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75(
+// size_t mr, r0
+// size_t nc, r1
+// size_t kc, r2 -> r5 -> sp + 64
+// size_t ks, r3 -> sp + 68 -> r14
+// const float**restrict a, sp + 108 -> r2
+// const void*restrict w, sp + 112 -> r9
+// uint8_t*restrict c, sp + 116 -> r11
+// size_t cm_stride, sp + 120 -> (r6)
+// size_t cn_stride, sp + 124 -> (r7)
+// size_t a_offset, sp + 128 -> (r5)
+// const float* zero, sp + 132 -> (r7)
+// output_params*params, sp + 136 -> (r5)
+
+// inner loop registers
+
+// A0 r3 d0
+// A1 r12 d1
+// A2 r10 d2
+// A3 r0 d3
+
+// B r9 d8, d9, d10, d11
+// B d12, d13, d14, d15
+
+// C0 r11 d16-d17 q8 d18-d19 q9
+// C1 r4 d20-d21 q10 d22-d23 q11
+// C2 r8 d24-d25 q12 d26-d27 q13
+// C3 r6 d28-d29 q14 d30-d31 q15
+
+// Clamp (r5) d4 d5 d6 d7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75
+ .arm
+#ifndef __APPLE__
+ .arch armv7-a
+ .fpu neon
+#endif
+ // Push 108 bytes
+ // r2 will be reloaded in outer loop. r3 is ks
+ PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44
+ VPUSH {d8-d15} // +64 = 108
+
+ LDR r2, [sp, 108] // a
+ LDR r9, [sp, 112] // w
+ LDR r11, [sp, 116] // c
+ LDR r6, [sp, 120] // cm_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Clamp C pointers
+ CMP r0, 2 // if mr >= 2
+ ADD r4, r11, r6 // c1 = c0 + cm_stride
+ MOVLO r4, r11 // c1
+ // if mr > 2
+ ADD r8, r4, r6 // c2 = c1 + cm_stride
+ MOVLS r8, r4 // c2
+ CMP r0, 4 // if mr >=4
+ ADD r6, r8, r6 // c3 = c2 + cm_stride
+ MOVLO r6, r8 // c3
+
+ .p2align 3
+0:
+ # Load initial bias from w into accumulators
+ VLDM r9!, {d16-d19} // Bias
+ VMOV q10, q8
+ VMOV q11, q9
+ VMOV q12, q8
+ VMOV q13, q9
+ VMOV q14, q8
+ VMOV q15, q9
+
+ PLD [r9, 0] // Prefetch B
+ PLD [r9, 64]
+ PLD [r9, 128]
+ PLD [r9, 192]
+ PLD [r9, 256]
+ PLD [r9, 320]
+
+1:
+ # Load next 4 A pointers
+ LDR r3, [r2, 0]
+ LDR r12, [r2, 4]
+ LDR r10, [r2, 8]
+ LDR r0, [r2, 12]
+ ADD r2, r2, 16
+
+ // Add a_offset
+ LDR r5, [sp, 128] // a_offset
+ LDR r7, [sp, 132] // zero
+ CMP r3, r7 // if a0 == zero
+ ADD r3, r3, r5 // a0 += a_offset
+ MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset
+ CMP r12, r7 // if a1 == zero
+ ADD r12, r12, r5 // a1 += a_offset
+ MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset
+ CMP r10, r7 // if a2 == zero
+ ADD r10, r10, r5 // a2 += a_offset
+ MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset
+ CMP r0, r7 // if a3 == zero
+ ADD r0, r0, r5 // a3 += a_offset
+ LDR r5, [sp, 64] // kc
+ MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset
+
+ PLD [r3, 0] // Prefetch A
+ PLD [r3, 64]
+ PLD [r12, 0]
+ PLD [r12, 64]
+ PLD [r10, 0]
+ PLD [r10, 64]
+ PLD [r0, 0]
+ PLD [r0, 64]
+
+ SUBS r5, r5, 16 // kc - 16
+ BLO 4f // less than 4 channels?
+
+ // Prologue
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ SUBS r5, r5, 16
+ BLO 3f // less than 4 channels? skip main loop
+
+ .p2align 3
+
+ // Main loop - 4 floats of A (16 bytes)
+2:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ PLD [r3, 128] // Prefetch A0
+ VMLA.F32 q12, q4, d6[0]
+ VLD1.32 {d0}, [r3]! // A0
+ VMLA.F32 q14, q4, d7[0]
+ PLD [r12, 128] // Prefetch A1
+ VMLA.F32 q9, q5, d4[0]
+ VLD1.32 {d1}, [r12]! // A1
+ VMLA.F32 q11, q5, d5[0]
+ PLD [r10, 128] // Prefetch A2
+ VMLA.F32 q13, q5, d6[0]
+ VLD1.32 {d2}, [r10]! // A2
+ VMLA.F32 q15, q5, d7[0]
+ PLD [r0, 128] // Prefetch A3
+ VMLA.F32 q8, q6, d4[1]
+ VLD1.32 {d3}, [ r0]! // A3
+ VMLA.F32 q10, q6, d5[1]
+ PLD [r9, 384] // Prefetch B
+ VMLA.F32 q12, q6, d6[1]
+ PLD [r9, 448] // Prefetch B
+ VMLA.F32 q14, q6, d7[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ SUBS r5, r5, 16
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+ BHS 2b
+
+ // Epilogue
+3:
+ VMLA.F32 q8, q4, d0[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VLD1.32 {d4}, [r3]! // A0
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q9, q5, d0[0]
+ VLD1.32 {d5}, [r12]! // A1
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q15, q5, d3[0]
+ VLD1.32 {d6}, [r10]! // A2
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VLD1.32 {d7}, [ r0]! // A3
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VLDM r9!, {d8-d11} // B0
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q15, q7, d3[1]
+
+ VMLA.F32 q8, q4, d4[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q10, q4, d5[0]
+ VMLA.F32 q12, q4, d6[0]
+ VMLA.F32 q14, q4, d7[0]
+ VMLA.F32 q9, q5, d4[0]
+ VMLA.F32 q11, q5, d5[0]
+ VMLA.F32 q13, q5, d6[0]
+ VMLA.F32 q15, q5, d7[0]
+ VMLA.F32 q8, q6, d4[1]
+ VMLA.F32 q10, q6, d5[1]
+ VMLA.F32 q12, q6, d6[1]
+ VMLA.F32 q14, q6, d7[1]
+ VMLA.F32 q9, q7, d4[1]
+ VMLA.F32 q11, q7, d5[1]
+ VMLA.F32 q13, q7, d6[1]
+ VMLA.F32 q15, q7, d7[1]
+
+4:
+ // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
+ TST r5, 12
+ BNE 7f
+
+ .p2align 3
+5:
+ # ks loop
+ SUBS r14, r14, 16 // ks -= MR * sizeof(void*)
+ BNE 1b
+
+ // Load params pointer
+ LDR r5, [sp, 136] // clamping_params
+ LDR r7, [sp, 124] // cn_stride
+ LDR r14, [sp, 68] // p = ks
+
+ // Load clamping_params values
+ VLD1.32 {d4[],d5[]}, [r5]!
+ SUBS r1, r1, 8
+ VLD1.32 {d6[],d7[]}, [r5]
+
+ // Clamp
+ VMIN.F32 q8, q8, q2
+ VMIN.F32 q9, q9, q2
+ VMIN.F32 q10, q10, q2
+ VMIN.F32 q11, q11, q2
+ VMIN.F32 q12, q12, q2
+ VMIN.F32 q13, q13, q2
+ VMIN.F32 q14, q14, q2
+ VMIN.F32 q15, q15, q2
+ VMAX.F32 q8, q8, q3
+ VMAX.F32 q9, q9, q3
+ VMAX.F32 q10, q10, q3
+ VMAX.F32 q11, q11, q3
+ VMAX.F32 q12, q12, q3
+ VMAX.F32 q13, q13, q3
+ VMAX.F32 q14, q14, q3
+ VMAX.F32 q15, q15, q3
+
+ // Store full 4 x 8
+ BLO 10f
+ VST1.32 {d28-d31}, [r6], r7
+ VST1.32 {d24-d27}, [r8], r7
+ VST1.32 {d20-d23}, [r4], r7
+ VST1.32 {d16-d19}, [r11], r7
+ SUB r2, r2, r14 // a -= ks
+ BHI 0b
+
+6:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+ .p2align 3
+7:
+ // Is there a remainder?- 2 floats of A (8 bytes)
+ TST r5, 8
+ BEQ 8f
+
+ // Remainder - 2 floats of A (8 bytes)
+ VLD1.32 {d0}, [r3]! // A0
+ VLDM r9!, {d8-d11} // B0
+ VLD1.32 {d1}, [r12]! // A1
+ VLD1.32 {d2}, [r10]! // A2
+ VLD1.32 {d3}, [ r0]! // A3
+
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VLDM r9!, {d12-d15} // B1
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ VMLA.F32 q8, q6, d0[1]
+ VMLA.F32 q9, q7, d0[1]
+ VMLA.F32 q10, q6, d1[1]
+ VMLA.F32 q11, q7, d1[1]
+ VMLA.F32 q12, q6, d2[1]
+ VMLA.F32 q13, q7, d2[1]
+ VMLA.F32 q14, q6, d3[1]
+ VMLA.F32 q15, q7, d3[1]
+8:
+ // Is there a remainder?- 1 floats of A (4 bytes)
+ TST r5, 4
+ BEQ 5b
+
+9:
+ // Remainder- 1 floats of A (4 bytes)
+ VLDM r3!, {s0} // A0
+ VLDM r9!, {d8-d11} // B0
+ VLDM r12!, {s2} // A1
+ VLDM r10!, {s4} // A2
+ VLDM r0!, {s6} // A3
+ VMLA.F32 q8, q4, d0[0]
+ VMLA.F32 q9, q5, d0[0]
+ VMLA.F32 q10, q4, d1[0]
+ VMLA.F32 q11, q5, d1[0]
+ VMLA.F32 q12, q4, d2[0]
+ VMLA.F32 q13, q5, d2[0]
+ VMLA.F32 q14, q4, d3[0]
+ VMLA.F32 q15, q5, d3[0]
+ B 5b
+
+ // Store odd width
+10:
+ TST r1, 4
+ BEQ 11f
+ VST1.32 {d28-d29}, [r6]!
+ VMOV q14, q15
+ VST1.32 {d24-d25}, [r8]!
+ VMOV q12, q13
+ VST1.32 {d20-d21}, [r4]!
+ VMOV q10, q11
+ VST1.32 {d16-d17}, [r11]!
+ VMOV q8, q9
+
+11:
+ TST r1, 2
+ BEQ 12f
+ VST1.32 {d28}, [r6]!
+ VMOV d28, d29
+ VST1.32 {d24}, [r8]!
+ VMOV d24, d25
+ VST1.32 {d20}, [r4]!
+ VMOV d20, d21
+ VST1.32 {d16}, [r11]!
+ VMOV d16, d17
+
+12:
+ TST r1, 1
+ BEQ 13f
+ VST1.32 {d28[0]}, [r6]!
+ VST1.32 {d24[0]}, [r8]!
+ VST1.32 {d20[0]}, [r4]!
+ VST1.32 {d16[0]}, [r11]!
+
+13:
+ VPOP {d8-d15}
+ ADD sp, sp, 8 // skip r2, r3
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
+
+
+
diff --git a/src/init.c b/src/init.c
index 43d0abf8c..9f35c5422 100644
--- a/src/init.c
+++ b/src/init.c
@@ -138,7 +138,7 @@ static void init(void) {
case cpuinfo_uarch_cortex_a55:
xnn_params.f32.gemm = (struct gemm_parameters) {
.gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64,
.gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
.igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
.mr = 4,
@@ -151,7 +151,7 @@ static void init(void) {
case cpuinfo_uarch_cortex_a73:
xnn_params.f32.gemm = (struct gemm_parameters) {
.gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
.gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
.igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
.mr = 4,
@@ -162,7 +162,7 @@ static void init(void) {
default:
xnn_params.f32.gemm = (struct gemm_parameters) {
.gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75,
.gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
.igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
.mr = 4,
diff --git a/src/runtime.c b/src/runtime.c
index b210b34bb..f0e262f01 100644
--- a/src/runtime.c
+++ b/src/runtime.c
@@ -68,10 +68,10 @@ enum xnn_status xnn_create_runtime_v2(
if (status != xnn_status_success) {
goto error;
}
- runtime->ops[i].shape1.num_dims = subgraph->values[node->inputs.raw[0]].shape.num_dims;
- runtime->ops[i].shape2.num_dims = subgraph->values[node->inputs.raw[1]].shape.num_dims;
- memcpy(runtime->ops[i].shape1.dim, subgraph->values[node->inputs.raw[0]].shape.dim, subgraph->values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t));
- memcpy(runtime->ops[i].shape2.dim, subgraph->values[node->inputs.raw[1]].shape.dim, subgraph->values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t));
+ runtime->ops[i].shape1.num_dims = values[node->inputs.raw[0]].shape.num_dims;
+ runtime->ops[i].shape2.num_dims = values[node->inputs.raw[1]].shape.num_dims;
+ memcpy(runtime->ops[i].shape1.dim, values[node->inputs.raw[0]].shape.dim, values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t));
+ memcpy(runtime->ops[i].shape2.dim, values[node->inputs.raw[1]].shape.dim, values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t));
runtime->ops[i].inputs[0] = node->inputs.raw[0];
runtime->ops[i].inputs[1] = node->inputs.raw[1];
runtime->ops[i].outputs[0] = node->outputs.raw[0];
@@ -102,9 +102,28 @@ enum xnn_status xnn_create_runtime_v2(
if (status != xnn_status_success) {
goto error;
}
- runtime->ops[i].batch_size = subgraph->values[node->inputs.raw[0]].shape.dim[0];
- runtime->ops[i].input_height = subgraph->values[node->inputs.raw[0]].shape.dim[1];
- runtime->ops[i].input_width = subgraph->values[node->inputs.raw[0]].shape.dim[2];
+ runtime->ops[i].batch_size = values[node->inputs.raw[0]].shape.dim[0];
+ runtime->ops[i].input_height = values[node->inputs.raw[0]].shape.dim[1];
+ runtime->ops[i].input_width = values[node->inputs.raw[0]].shape.dim[2];
+ runtime->ops[i].inputs[0] = node->inputs.raw[0];
+ runtime->ops[i].outputs[0] = node->outputs.raw[0];
+ break;
+ case xnn_node_type_clamp:
+ status = xnn_create_clamp_nc_f32(
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags,
+ &runtime->ops[i].op);
+ if (status != xnn_status_success) {
+ goto error;
+ }
+ runtime->ops[i].batch_size = 1;
+ for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) {
+ runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i];
+ }
runtime->ops[i].inputs[0] = node->inputs.raw[0];
runtime->ops[i].outputs[0] = node->outputs.raw[0];
break;
@@ -134,9 +153,26 @@ enum xnn_status xnn_create_runtime_v2(
if (status != xnn_status_success) {
goto error;
}
- runtime->ops[i].batch_size = subgraph->values[node->inputs.raw[0]].shape.dim[0];
- runtime->ops[i].input_height = subgraph->values[node->inputs.raw[0]].shape.dim[1];
- runtime->ops[i].input_width = subgraph->values[node->inputs.raw[0]].shape.dim[2];
+ runtime->ops[i].batch_size = values[node->inputs.raw[0]].shape.dim[0];
+ runtime->ops[i].input_height = values[node->inputs.raw[0]].shape.dim[1];
+ runtime->ops[i].input_width = values[node->inputs.raw[0]].shape.dim[2];
+ runtime->ops[i].inputs[0] = node->inputs.raw[0];
+ runtime->ops[i].outputs[0] = node->outputs.raw[0];
+ break;
+ case xnn_node_type_hardswish:
+ status = xnn_create_hardswish_nc_f32(
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */,
+ node->flags,
+ &runtime->ops[i].op);
+ if (status != xnn_status_success) {
+ goto error;
+ }
+ runtime->ops[i].batch_size = 1;
+ for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) {
+ runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i];
+ }
runtime->ops[i].inputs[0] = node->inputs.raw[0];
runtime->ops[i].outputs[0] = node->outputs.raw[0];
break;
@@ -149,14 +185,68 @@ enum xnn_status xnn_create_runtime_v2(
if (status != xnn_status_success) {
goto error;
}
- runtime->ops[i].shape1.num_dims = subgraph->values[node->inputs.raw[0]].shape.num_dims;
- runtime->ops[i].shape2.num_dims = subgraph->values[node->inputs.raw[1]].shape.num_dims;
- memcpy(runtime->ops[i].shape1.dim, subgraph->values[node->inputs.raw[0]].shape.dim, subgraph->values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t));
- memcpy(runtime->ops[i].shape2.dim, subgraph->values[node->inputs.raw[1]].shape.dim, subgraph->values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t));
+ runtime->ops[i].shape1.num_dims = values[node->inputs.raw[0]].shape.num_dims;
+ runtime->ops[i].shape2.num_dims = values[node->inputs.raw[1]].shape.num_dims;
+ memcpy(runtime->ops[i].shape1.dim, values[node->inputs.raw[0]].shape.dim, values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t));
+ memcpy(runtime->ops[i].shape2.dim, values[node->inputs.raw[1]].shape.dim, values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t));
runtime->ops[i].inputs[0] = node->inputs.raw[0];
runtime->ops[i].inputs[1] = node->inputs.raw[1];
runtime->ops[i].outputs[0] = node->outputs.raw[0];
break;
+ case xnn_node_type_prelu:
+ status = xnn_create_prelu_nc_f32(
+ values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* channels */,
+ values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* output stride */,
+ values[node->inputs.raw[1]].data /* negative slope */,
+ -INFINITY,
+ +INFINITY,
+ node->flags,
+ &runtime->ops[i].op);
+ if (status != xnn_status_success) {
+ goto error;
+ }
+ runtime->ops[i].batch_size = 1;
+ for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) {
+ runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i];
+ }
+ runtime->ops[i].inputs[0] = node->inputs.raw[0];
+ runtime->ops[i].outputs[0] = node->outputs.raw[0];
+ break;
+ case xnn_node_type_sigmoid:
+ status = xnn_create_sigmoid_nc_f32(
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */,
+ node->flags,
+ &runtime->ops[i].op);
+ if (status != xnn_status_success) {
+ goto error;
+ }
+ runtime->ops[i].batch_size = 1;
+ for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) {
+ runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i];
+ }
+ runtime->ops[i].inputs[0] = node->inputs.raw[0];
+ runtime->ops[i].outputs[0] = node->outputs.raw[0];
+ break;
+ case xnn_node_type_softmax:
+ status = xnn_create_softmax_nc_f32(
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */,
+ node->flags,
+ &runtime->ops[i].op);
+ if (status != xnn_status_success) {
+ goto error;
+ }
+ runtime->ops[i].batch_size = 1;
+ for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) {
+ runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i];
+ }
+ runtime->ops[i].inputs[0] = node->inputs.raw[0];
+ runtime->ops[i].outputs[0] = node->outputs.raw[0];
+ break;
case xnn_node_type_invalid:
xnn_log_fatal("unexpected node type %d in node #%zu", node->type, i);
XNN_UNREACHABLE;
@@ -282,6 +372,26 @@ enum xnn_status xnn_setup_runtime(
runtime->blobs[op->outputs[0]].data,
runtime->threadpool);
break;
+ case xnn_operator_type_clamp_nc_f32:
+ assert(runtime->blobs[op->inputs[0]].data != NULL);
+ assert(runtime->blobs[op->outputs[0]].data != NULL);
+ status = xnn_setup_clamp_nc_f32(
+ op->op,
+ op->batch_size,
+ runtime->blobs[op->inputs[0]].data,
+ runtime->blobs[op->outputs[0]].data,
+ runtime->threadpool);
+ break;
+ case xnn_operator_type_hardswish_nc_f32:
+ assert(runtime->blobs[op->inputs[0]].data != NULL);
+ assert(runtime->blobs[op->outputs[0]].data != NULL);
+ status = xnn_setup_hardswish_nc_f32(
+ op->op,
+ op->batch_size,
+ runtime->blobs[op->inputs[0]].data,
+ runtime->blobs[op->outputs[0]].data,
+ runtime->threadpool);
+ break;
case xnn_operator_type_multiply_nd_f32:
assert(runtime->blobs[op->inputs[0]].data != NULL);
assert(runtime->blobs[op->inputs[1]].data != NULL);
@@ -297,6 +407,36 @@ enum xnn_status xnn_setup_runtime(
runtime->blobs[op->outputs[0]].data,
runtime->threadpool);
break;
+ case xnn_operator_type_prelu_nc_f32:
+ assert(runtime->blobs[op->inputs[0]].data != NULL);
+ assert(runtime->blobs[op->outputs[0]].data != NULL);
+ status = xnn_setup_prelu_nc_f32(
+ op->op,
+ op->batch_size,
+ runtime->blobs[op->inputs[0]].data,
+ runtime->blobs[op->outputs[0]].data,
+ runtime->threadpool);
+ break;
+ case xnn_operator_type_sigmoid_nc_f32:
+ assert(runtime->blobs[op->inputs[0]].data != NULL);
+ assert(runtime->blobs[op->outputs[0]].data != NULL);
+ status = xnn_setup_sigmoid_nc_f32(
+ op->op,
+ op->batch_size,
+ runtime->blobs[op->inputs[0]].data,
+ runtime->blobs[op->outputs[0]].data,
+ runtime->threadpool);
+ break;
+ case xnn_operator_type_softmax_nc_f32:
+ assert(runtime->blobs[op->inputs[0]].data != NULL);
+ assert(runtime->blobs[op->outputs[0]].data != NULL);
+ status = xnn_setup_softmax_nc_f32(
+ op->op,
+ op->batch_size,
+ runtime->blobs[op->inputs[0]].data,
+ runtime->blobs[op->outputs[0]].data,
+ runtime->threadpool);
+ break;
default:
xnn_log_fatal("unexpected operator type %d in operator #%zu", op->op->type, i);
XNN_UNREACHABLE;
diff --git a/src/subgraph.c b/src/subgraph.c
index 714195f65..d90d62589 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -554,6 +554,219 @@ enum xnn_status xnn_define_multiply2(
return xnn_status_success;
}
+enum xnn_status xnn_define_prelu(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t slope_id,
+ uint32_t output_id,
+ uint32_t flags)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to define PReLU operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define PReLU operator with input ID #%" PRIu32 ": invalid Value ID",
+ input_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (slope_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define PReLU operator with slope ID #%" PRIu32 ": invalid Value ID",
+ slope_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define PReLU operator with output ID #%" PRIu32 ": invalid Value ID",
+ output_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ struct xnn_node* node = xnn_subgraph_new_node(subgraph);
+ if (node == NULL) {
+ return xnn_status_out_of_memory;
+ }
+
+ node->type = xnn_node_type_prelu;
+ node->num_inputs = 2;
+ node->inputs.raw[0] = input_id;
+ node->inputs.raw[1] = slope_id;
+ node->num_outputs = 1;
+ node->outputs.raw[0] = output_id;
+ node->flags = flags;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_define_clamp(
+ xnn_subgraph_t subgraph,
+ float output_min,
+ float output_max,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to define Clamp operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define Clamp operator with input ID #%" PRIu32 ": invalid Value ID",
+ input_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define Clamp operator with output ID #%" PRIu32 ": invalid Value ID",
+ output_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ struct xnn_node* node = xnn_subgraph_new_node(subgraph);
+ if (node == NULL) {
+ return xnn_status_out_of_memory;
+ }
+
+ node->type = xnn_node_type_clamp;
+ node->activation.output_min = output_min;
+ node->activation.output_max = output_max;
+ node->num_inputs = 1;
+ node->inputs.raw[0] = input_id;
+ node->num_outputs = 1;
+ node->outputs.raw[0] = output_id;
+ node->flags = flags;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_define_hardswish(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to define HardSwish operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define HardSwish operator with input ID #%" PRIu32 ": invalid Value ID",
+ input_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define HardSwish operator with output ID #%" PRIu32 ": invalid Value ID",
+ output_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ struct xnn_node* node = xnn_subgraph_new_node(subgraph);
+ if (node == NULL) {
+ return xnn_status_out_of_memory;
+ }
+
+ node->type = xnn_node_type_hardswish;
+ node->num_inputs = 1;
+ node->inputs.raw[0] = input_id;
+ node->num_outputs = 1;
+ node->outputs.raw[0] = output_id;
+ node->flags = flags;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_define_sigmoid(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to define Sigmoid operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define Sigmoid operator with input ID #%" PRIu32 ": invalid Value ID",
+ input_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define Sigmoid operator with output ID #%" PRIu32 ": invalid Value ID",
+ output_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ struct xnn_node* node = xnn_subgraph_new_node(subgraph);
+ if (node == NULL) {
+ return xnn_status_out_of_memory;
+ }
+
+ node->type = xnn_node_type_sigmoid;
+ node->num_inputs = 1;
+ node->inputs.raw[0] = input_id;
+ node->num_outputs = 1;
+ node->outputs.raw[0] = output_id;
+ node->flags = flags;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_define_softmax(
+ xnn_subgraph_t subgraph,
+ uint32_t input_id,
+ uint32_t output_id,
+ uint32_t flags)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to define SoftMax operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define SoftMax operator with input ID #%" PRIu32 ": invalid Value ID",
+ input_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_id >= subgraph->num_values) {
+ xnn_log_error(
+ "failed to define SoftMax operator with output ID #%" PRIu32 ": invalid Value ID",
+ output_id);
+ return xnn_status_invalid_parameter;
+ }
+
+ struct xnn_node* node = xnn_subgraph_new_node(subgraph);
+ if (node == NULL) {
+ return xnn_status_out_of_memory;
+ }
+
+ node->type = xnn_node_type_softmax;
+ node->num_inputs = 1;
+ node->inputs.raw[0] = input_id;
+ node->num_outputs = 1;
+ node->outputs.raw[0] = output_id;
+ node->flags = flags;
+
+ return xnn_status_success;
+}
+
enum xnn_status xnn_delete_subgraph(
xnn_subgraph_t subgraph)
{
diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h
index 1de690495..d2bf06b53 100644
--- a/src/xnnpack/igemm.h
+++ b/src/xnnpack/igemm.h
@@ -88,6 +88,10 @@ DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_co
DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53)
DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75)
+
DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__sse_load1)
DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__sse_load1)
diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h
index de1a9d3f6..7f9342efe 100644
--- a/src/xnnpack/subgraph.h
+++ b/src/xnnpack/subgraph.h
@@ -58,9 +58,14 @@ struct xnn_blob {
enum xnn_node_type {
xnn_node_type_invalid = 0,
xnn_node_type_add2,
+ xnn_node_type_clamp,
xnn_node_type_convolution_2d,
xnn_node_type_depthwise_convolution_2d,
+ xnn_node_type_hardswish,
xnn_node_type_multiply2,
+ xnn_node_type_prelu,
+ xnn_node_type_sigmoid,
+ xnn_node_type_softmax,
};
struct xnn_node {
diff --git a/test/f32-igemm.cc b/test/f32-igemm.cc
index 597b8fcfa..54eee542e 100644
--- a/test/f32-igemm.cc
+++ b/test/f32-igemm.cc
@@ -3016,6 +3016,1472 @@
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
+#if XNN_ARCH_ARM
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(2)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(2)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(2)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(2)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(2)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_lt_2) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 2; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_lt_2_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 2; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_gt_2) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 3; k < 4; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_gt_2_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 3; k < 4; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_div_2) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 4; k <= 20; k += 2) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_div_2_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 4; k <= 20; k += 2) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 10; k += 3) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 10; k += 3) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(43)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 10; k += 3) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(43)
+ .zero_index(mz)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(2)
+ .qmin(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(2)
+ .qmax(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(2)
+ .cm_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64);
+ }
+#endif // XNN_ARCH_ARM
+
+
+#if XNN_ARCH_ARM
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_8) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(8)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_lt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_lt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 9; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_gt_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 9; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_div_4) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 12; k <= 40; k += 4) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_div_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 12; k <= 40; k += 4) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .zero_index(mz)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .qmin(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .qmax(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .cm_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75);
+ }
+#endif // XNN_ARCH_ARM
+
+
+#if XNN_ARCH_ARM
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(4)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_8) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(8)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_lt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_lt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 9; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_gt_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 9; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_div_4) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 12; k <= 40; k += 4) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_div_4_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 12; k <= 40; k += 4) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 20; k += 5) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .zero_index(mz)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+ }
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .qmin(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .qmax(128)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+
+ TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(1)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(4)
+ .cm_stride(11)
+ .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75);
+ }
+#endif // XNN_ARCH_ARM
+
+
#if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
TEST(F32_IGEMM_5X8__AARCH64_NEONFMA_CORTEX_A57, k_eq_8) {
TEST_REQUIRES_ARM_NEON_FMA;
diff --git a/test/f32-igemm.yaml b/test/f32-igemm.yaml
index 5563dee19..a8d5f1fd0 100644
--- a/test/f32-igemm.yaml
+++ b/test/f32-igemm.yaml
@@ -26,6 +26,15 @@
k-block: 8
pipelined: true
assembly: true
+- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64
+ k-block: 2
+ pipelined: false
+- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75
+ k-block: 4
+ pipelined: true
+- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75
+ k-block: 4
+ pipelined: true
- name: xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a57
k-block: 8
pipelined: true
diff --git a/third_party/cpuinfo.BUILD b/third_party/cpuinfo.BUILD
index ad8a07000..58ab81717 100644
--- a/third_party/cpuinfo.BUILD
+++ b/third_party/cpuinfo.BUILD
@@ -107,6 +107,17 @@ cc_library(
":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS,
":android_x86": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS,
":android_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS,
+ ":ios_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":ios_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":ios_armv7": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":ios_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":ios_arm64e": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":watchos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
":emscripten_wasm": COMMON_SRCS + EMSCRIPTEN_SRCS,
}),
copts = C99OPTS + [
@@ -201,6 +212,94 @@ config_setting(
)
config_setting(
+ name = "ios_armv7",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_armv7",
+ },
+)
+
+config_setting(
+ name = "ios_arm64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_arm64",
+ },
+)
+
+config_setting(
+ name = "ios_arm64e",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_arm64e",
+ },
+)
+
+config_setting(
+ name = "ios_x86",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_i386",
+ },
+)
+
+config_setting(
+ name = "ios_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "ios_x86_64",
+ },
+)
+
+config_setting(
+ name = "watchos_armv7k",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_armv7k",
+ },
+)
+
+config_setting(
+ name = "watchos_arm64_32",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_arm64_32",
+ },
+)
+
+config_setting(
+ name = "watchos_x86",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_i386",
+ },
+)
+
+config_setting(
+ name = "watchos_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "watchos_x86_64",
+ },
+)
+
+config_setting(
+ name = "tvos_arm64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "tvos_arm64",
+ },
+)
+
+config_setting(
+ name = "tvos_x86_64",
+ values = {
+ "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "cpu": "tvos_x86_64",
+ },
+)
+
+config_setting(
name = "emscripten_wasm",
values = {
"cpu": "wasm",