diff options
author | android-build-team Robot <android-build-team-robot@google.com> | 2020-02-19 03:10:42 +0000 |
---|---|---|
committer | android-build-team Robot <android-build-team-robot@google.com> | 2020-02-19 03:10:42 +0000 |
commit | 8b23c3bfd1d8d5cfb431576ad1fb798c44df4d64 (patch) | |
tree | d1ca88d7bef60a4ac8e70481861ad0e024555f9f | |
parent | 33e4796a20c318d3a7e9be15650a34a7d0d85539 (diff) | |
parent | 5fa0858b57c36722f0ab2606c606e927b5a40448 (diff) | |
download | XNNPACK-android11-d1-s1-release.tar.gz |
Snap for 6217125 from 5fa0858b57c36722f0ab2606c606e927b5a40448 to rvc-d1-releaseandroid-11.0.0_r9android-11.0.0_r8android-11.0.0_r7android-11.0.0_r15android-11.0.0_r14android-11.0.0_r13android-11.0.0_r12android-11.0.0_r11android-11.0.0_r10android11-d1-s7-releaseandroid11-d1-s6-releaseandroid11-d1-s5-releaseandroid11-d1-s1-releaseandroid11-d1-release
Change-Id: Ibf363d121c5ecd3b06809b46550e5e09ffedd166
64 files changed, 4202 insertions, 610 deletions
diff --git a/Android.bp b/Android.bp index 61f5fe67c..1343ca93a 100644 --- a/Android.bp +++ b/Android.bp @@ -1244,6 +1244,9 @@ AARCH32_ASM_UKERNELS = [ "src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S", "src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S", "src/f32-gemm/4x8-aarch32-neon-ld64.S", + "src/f32-igemm/4x8-aarch32-neon-ld64.S", + "src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S", + "src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S", ] AARCH64_ASM_UKERNELS = [ diff --git a/BUILD.bazel b/BUILD.bazel index 095b25967..a6e6c1f48 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1373,6 +1373,9 @@ AARCH32_ASM_UKERNELS = [ "src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S", "src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S", "src/f32-gemm/4x8-aarch32-neon-ld64.S", + "src/f32-igemm/4x8-aarch32-neon-ld64.S", + "src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S", + "src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S", ] AARCH64_ASM_UKERNELS = [ @@ -3244,3 +3247,91 @@ config_setting( "cpu": "asmjs", }, ) + +config_setting( + name = "ios_armv7", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "ios_arm64e", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_arm64e", + }, +) + +config_setting( + name = "ios_x86", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_i386", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "watchos_armv7k", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_armv7k", + }, +) + +config_setting( + name = "watchos_arm64_32", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_arm64_32", + }, +) + +config_setting( + name = "watchos_x86", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_i386", + }, +) + +config_setting( + name = "watchos_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_x86_64", + }, +) + +config_setting( + name = "tvos_arm64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "tvos_arm64", + }, +) + +config_setting( + name = "tvos_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "tvos_x86_64", + }, +) diff --git a/CMakeLists.txt b/CMakeLists.txt index 46ec94f9e..cd5d59e63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1357,7 +1357,10 @@ SET(XNNPACK_AARCH32_ASM_MICROKERNEL_SRCS src/f32-gemm/4x8-aarch32-neon-cortex-a53.S src/f32-gemm/gen/4x8-aarch32-neon-cortex-a75.S src/f32-gemm/gen/4x8-aarch32-neon-pld-cortex-a75.S - src/f32-gemm/4x8-aarch32-neon-ld64.S) + src/f32-gemm/4x8-aarch32-neon-ld64.S + src/f32-igemm/4x8-aarch32-neon-ld64.S + src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S + src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S) SET(XNNPACK_AARCH64_ASM_MICROKERNEL_SRCS src/f32-dwconv/up4x9-aarch64-neonfma-cortex-a55.S @@ -1,12 +1,5 @@ name: "XNNPACK" -description: - "XNNPACK is a highly optimized library of floating-point neural network " - "inference operators for ARM, WebAssembly, and x86 (SSE2 level) platforms. " - "XNNPACK is not intended for direct use by deep learning practitioners and " - "researchers; instead it provides low-level performance primitives for " - "accelerating high-level machine learning frameworks, such as MediaPipe, " - "TensorFlow Lite, and TensorFlow.js." - +description: "XNNPACK is a highly optimized library of floating-point neural network inference operators for ARM, WebAssembly, and x86 (SSE2 level) platforms. XNNPACK is not intended for direct use by deep learning practitioners and researchers; instead it provides low-level performance primitives for accelerating high-level machine learning frameworks, such as MediaPipe, TensorFlow Lite, and TensorFlow.js." third_party { url { type: HOMEPAGE @@ -16,7 +9,11 @@ third_party { type: GIT value: "https://github.com/google/XNNPACK" } - version: "98ca635f1b6d84c81539ccafe721650fb174da67" - last_upgrade_date { year: 2020 month: 2 day: 3 } + version: "1498d1d4d0430480dfe5c4538049b4f789d29134" license_type: NOTICE + last_upgrade_date { + year: 2020 + month: 2 + day: 11 + } } @@ -4,11 +4,11 @@ XNNPACK is a highly optimized library of floating-point neural network inference ## Supported Architectures -- ARM64 on Android and Linux -- ARMv7 (with NEON) on Android and Linux +- ARM64 on Android, Linux, and iOS (including WatchOS and tvOS) +- ARMv7 (with NEON) on Android, Linux, and iOS (including WatchOS) - WebAssembly MVP - WebAssembly SIMD (experimental) -- x86 and x86-64 (up to AVX512) on Android, Linux, and macOS +- x86 and x86-64 (up to AVX512) on Android, Linux, macOS, and iOS simulator ## Operator Coverage diff --git a/bench/f32-gemm-e2e.cc b/bench/f32-gemm-e2e.cc index c83da0b0e..9661eb13c 100644 --- a/bench/f32-gemm-e2e.cc +++ b/bench/f32-gemm-e2e.cc @@ -258,7 +258,7 @@ static void GEMMEnd2EndBenchmark( static void f32_gemm_4x8__aarch32_neon_ld64(benchmark::State& state, models::ExecutionPlanFactory model) { GEMMEnd2EndBenchmark(state, model, xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64, - xnn_f32_igemm_ukernel_4x8__neon_lane_ld64, + xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64, xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, 4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */, @@ -276,7 +276,7 @@ static void GEMMEnd2EndBenchmark( static void f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, models::ExecutionPlanFactory model) { GEMMEnd2EndBenchmark(state, model, xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75, - xnn_f32_igemm_ukernel_4x8__neon_lane_ld64, + xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75, xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, 4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */, @@ -285,7 +285,7 @@ static void GEMMEnd2EndBenchmark( static void f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, models::ExecutionPlanFactory model) { GEMMEnd2EndBenchmark(state, model, xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, - xnn_f32_igemm_ukernel_4x8__neon_lane_ld64, + xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, 4 /* mr */, 8 /* nr */, 0 /* log2_kr */, 0 /* log2_sr */, diff --git a/bench/f32-igemm.cc b/bench/f32-igemm.cc index ca2cf162b..064506538 100644 --- a/bench/f32-igemm.cc +++ b/bench/f32-igemm.cc @@ -258,6 +258,14 @@ static void IGEMMBenchmark(benchmark::State& state, BENCHMARK_CONV(f32_igemm_8x8s4__neonfma) #endif +#if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY + static void f32_igemm_4x8__aarch32_neon_ld64(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64, 4, 8, 1, 1); + } + + BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_ld64) +#endif /* XNN_ARCH_ARM */ + #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY static void f32_igemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) { IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53, 1, 12, 1, 1); @@ -366,6 +374,18 @@ static void IGEMMBenchmark(benchmark::State& state, BENCHMARK_CONV(f32_igemm_6x8__neonfma_lane_ld128) #endif /* XNN_ARCH_ARM64 */ +#if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY + static void f32_igemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, 4, 8, 1, 1); + } + static void f32_igemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75, 4, 8, 1, 1); + } + + BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_pld_cortex_a75) + BENCHMARK_CONV(f32_igemm_4x8__aarch32_neon_cortex_a75) +#endif /* XNN_ARCH_ARM */ + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 static void f32_igemm_1x8__sse_load1(benchmark::State& state, const char* net) { IGEMMBenchmark(state, xnn_f32_igemm_ukernel_1x8__sse_load1, 1, 8, 1, 1); diff --git a/build_defs.bzl b/build_defs.bzl index db1d303ea..9112ab27b 100644 --- a/build_defs.bzl +++ b/build_defs.bzl @@ -113,6 +113,17 @@ def xnnpack_cc_library( ":android_arm64": aarch64_srcs, ":android_x86": x86_srcs, ":android_x86_64": x86_srcs, + ":ios_armv7": aarch32_srcs, + ":ios_arm64": aarch64_srcs, + ":ios_arm64e": aarch64_srcs, + ":ios_x86": x86_srcs, + ":ios_x86_64": x86_srcs, + ":watchos_armv7k": aarch32_srcs, + ":watchos_arm64_32": aarch64_srcs, + ":watchos_x86": x86_srcs, + ":watchos_x86_64": x86_srcs, + ":tvos_arm64": aarch64_srcs, + ":tvos_x86_64": x86_srcs, ":emscripten_asmjs": asmjs_srcs, ":emscripten_wasm": wasm_srcs, ":emscripten_wasmsimd": wasmsimd_srcs, @@ -129,6 +140,17 @@ def xnnpack_cc_library( ":android_arm64": aarch64_copts, ":android_x86": x86_copts, ":android_x86_64": x86_copts, + ":ios_armv7": aarch32_copts, + ":ios_arm64": aarch64_copts, + ":ios_arm64e": aarch64_copts, + ":ios_x86": x86_copts, + ":ios_x86_64": x86_copts, + ":watchos_armv7k": aarch32_copts, + ":watchos_arm64_32": aarch64_copts, + ":watchos_x86": x86_copts, + ":watchos_x86_64": x86_copts, + ":tvos_arm64": aarch64_copts, + ":tvos_x86_64": x86_copts, ":emscripten_asmjs": asmjs_copts, ":emscripten_wasm": wasm_copts, ":emscripten_wasmsimd": wasmsimd_copts, @@ -180,6 +202,17 @@ def xnnpack_aggregate_library( ":android_arm64": aarch64_deps, ":android_x86": x86_deps, ":android_x86_64": x86_deps, + ":ios_armv7": aarch32_deps, + ":ios_arm64": aarch64_deps, + ":ios_arm64e": aarch64_deps, + ":ios_x86": x86_deps, + ":ios_x86_64": x86_deps, + ":watchos_armv7k": aarch32_deps, + ":watchos_arm64_32": aarch64_deps, + ":watchos_x86": x86_deps, + ":watchos_x86_64": x86_deps, + ":tvos_arm64": aarch64_deps, + ":tvos_x86_64": x86_deps, ":emscripten_wasm": wasm_deps, ":emscripten_wasmsimd": wasmsimd_deps, ":emscripten_asmjs": [], diff --git a/include/xnnpack.h b/include/xnnpack.h index ea171f060..b7bef08ce 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -349,6 +349,80 @@ enum xnn_status xnn_define_multiply2( uint32_t output_id, uint32_t flags); +/// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, H, W, channels] dimensions +/// @param slope_id - Value ID for the bias tensor. The bias tensor must be a 1D tensor defined in the @a subgraph with +/// [channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, H, W, channels] dimensions. +/// @param flags - binary features of the PReLU Node. No supported flags are currently defined. +enum xnn_status xnn_define_prelu( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t slope_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Clamp Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Clamp Node. No supported flags are currently defined. +enum xnn_status xnn_define_clamp( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a HardSwish Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the HardSwish Node. No supported flags are currently defined. +enum xnn_status xnn_define_hardswish( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Sigmoid Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined. +enum xnn_status xnn_define_sigmoid( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a SoftMax Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at +/// least one dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the SoftMax Node. No supported flags are currently defined. +enum xnn_status xnn_define_softmax( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + /// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values. typedef struct xnn_runtime* xnn_runtime_t; diff --git a/scripts/generate-f32-igemm.sh b/scripts/generate-f32-igemm.sh index 763e34fe4..ceed44310 100755 --- a/scripts/generate-f32-igemm.sh +++ b/scripts/generate-f32-igemm.sh @@ -28,6 +28,10 @@ tools/xngen src/f32-igemm/5x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFE tools/xngen src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFETCH=0 -o src/f32-igemm/gen/6x8-aarch64-neonfma-cortex-a57.S tools/xngen src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S.in -D INC=0 -D PREFETCH=1 -o src/f32-igemm/gen/6x8-aarch64-neonfma-cortex-a75.S +############################### AArch32 assembly ############################## +tools/xngen src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in -D INC=0 -D PREFETCH=0 -o src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S +tools/xngen src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in -D INC=0 -D PREFETCH=1 -o src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S + ################################### ARM NEON ################################## ### LD64 micro-kernels tools/xngen src/f32-igemm/neon-ld64.c.in -D MR=1 -D NR=8 -D FMA=0 -D DUP=0 -o src/f32-igemm/gen/1x8-neon-lane-ld64.c diff --git a/src/f32-dwconv/gen/up16x25-avx512f-acc2.c b/src/f32-dwconv/gen/up16x25-avx512f-acc2.c index 72d39a8f2..1d0d4b696 100644 --- a/src/f32-dwconv/gen/up16x25-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up16x25-avx512f-acc2.c @@ -256,106 +256,106 @@ void xnn_f32_dwconv_ukernel_up16x25__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9); - const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10); - const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176); + const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11); - const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12); - const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208); + const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13); - const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14); - const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240); + const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15); - const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16); - const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272); + const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17); - const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18); - const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304); + const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19); - const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320); + const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20); - const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336); + const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21); - const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352); + const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22); - const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368); + const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23); - const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384); + const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24); - const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400); + const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up16x25-avx512f.c b/src/f32-dwconv/gen/up16x25-avx512f.c index 2c7b8d452..140e387e2 100644 --- a/src/f32-dwconv/gen/up16x25-avx512f.c +++ b/src/f32-dwconv/gen/up16x25-avx512f.c @@ -254,106 +254,106 @@ void xnn_f32_dwconv_ukernel_up16x25__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9); - const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10); - const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176); + const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11); - const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12); - const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208); + const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13); - const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14); - const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240); + const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15); - const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16); - const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272); + const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17); - const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18); - const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304); + const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19); - const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320); + const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20); - const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336); + const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21); - const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352); + const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22); - const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368); + const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23); - const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384); + const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24); - const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400); + const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/gen/up16x4-avx512f-acc2.c b/src/f32-dwconv/gen/up16x4-avx512f-acc2.c index 36f05a2e3..c9e6586c2 100644 --- a/src/f32-dwconv/gen/up16x4-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up16x4-avx512f-acc2.c @@ -88,22 +88,22 @@ void xnn_f32_dwconv_ukernel_up16x4__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up16x4-avx512f.c b/src/f32-dwconv/gen/up16x4-avx512f.c index 79a5c498a..0a625551c 100644 --- a/src/f32-dwconv/gen/up16x4-avx512f.c +++ b/src/f32-dwconv/gen/up16x4-avx512f.c @@ -86,22 +86,22 @@ void xnn_f32_dwconv_ukernel_up16x4__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/gen/up16x9-avx512f-acc2.c b/src/f32-dwconv/gen/up16x9-avx512f-acc2.c index 63848915a..17587f5d0 100644 --- a/src/f32-dwconv/gen/up16x9-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up16x9-avx512f-acc2.c @@ -128,42 +128,42 @@ void xnn_f32_dwconv_ukernel_up16x9__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up16x9-avx512f.c b/src/f32-dwconv/gen/up16x9-avx512f.c index eaab45c51..4a7606ffa 100644 --- a/src/f32-dwconv/gen/up16x9-avx512f.c +++ b/src/f32-dwconv/gen/up16x9-avx512f.c @@ -126,42 +126,42 @@ void xnn_f32_dwconv_ukernel_up16x9__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/gen/up32x25-avx512f-acc2.c b/src/f32-dwconv/gen/up32x25-avx512f-acc2.c index 9d5d8271f..a666d1e84 100644 --- a/src/f32-dwconv/gen/up32x25-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up32x25-avx512f-acc2.c @@ -500,106 +500,106 @@ void xnn_f32_dwconv_ukernel_up32x25__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9); - const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 320); + const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10); - const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 352); + const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11); - const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 384); + const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12); - const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 416); + const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 416); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13); - const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 448); + const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 448); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14); - const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 480); + const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 480); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15); - const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 512); + const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 512); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16); - const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 544); + const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 544); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17); - const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 576); + const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 576); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18); - const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 608); + const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 608); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19); - const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 640); + const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 640); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20); - const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 672); + const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 672); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21); - const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 704); + const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 704); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22); - const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 736); + const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 736); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23); - const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 768); + const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 768); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24); - const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 800); + const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 800); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up32x25-avx512f.c b/src/f32-dwconv/gen/up32x25-avx512f.c index 64ad2b6cb..5b896a31b 100644 --- a/src/f32-dwconv/gen/up32x25-avx512f.c +++ b/src/f32-dwconv/gen/up32x25-avx512f.c @@ -495,106 +495,106 @@ void xnn_f32_dwconv_ukernel_up32x25__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9); - const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 320); + const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10); - const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 352); + const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11); - const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 384); + const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12); - const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 416); + const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 416); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13); - const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 448); + const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 448); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14); - const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 480); + const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 480); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15); - const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 512); + const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 512); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16); - const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 544); + const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 544); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17); - const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 576); + const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 576); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18); - const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 608); + const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 608); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19); - const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 640); + const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 640); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20); - const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 672); + const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 672); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21); - const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 704); + const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 704); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22); - const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 736); + const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 736); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23); - const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 768); + const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 768); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24); - const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 800); + const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 800); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/gen/up32x4-avx512f-acc2.c b/src/f32-dwconv/gen/up32x4-avx512f-acc2.c index 67bad500d..d8aaa0862 100644 --- a/src/f32-dwconv/gen/up32x4-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up32x4-avx512f-acc2.c @@ -143,22 +143,22 @@ void xnn_f32_dwconv_ukernel_up32x4__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up32x4-avx512f.c b/src/f32-dwconv/gen/up32x4-avx512f.c index 126e982a5..dc5b4a88a 100644 --- a/src/f32-dwconv/gen/up32x4-avx512f.c +++ b/src/f32-dwconv/gen/up32x4-avx512f.c @@ -138,22 +138,22 @@ void xnn_f32_dwconv_ukernel_up32x4__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/gen/up32x9-avx512f-acc2.c b/src/f32-dwconv/gen/up32x9-avx512f-acc2.c index 9978c55f8..cec1e0790 100644 --- a/src/f32-dwconv/gen/up32x9-avx512f-acc2.c +++ b/src/f32-dwconv/gen/up32x9-avx512f-acc2.c @@ -228,42 +228,42 @@ void xnn_f32_dwconv_ukernel_up32x9__avx512f_acc2( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); __m512 vacc0123456789ABCDEFp1 = _mm512_mul_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp1 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp1); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); // Add up all accumulators to vacc0123456789ABCDEFp0 diff --git a/src/f32-dwconv/gen/up32x9-avx512f.c b/src/f32-dwconv/gen/up32x9-avx512f.c index 62fec2eaa..3cb060086 100644 --- a/src/f32-dwconv/gen/up32x9-avx512f.c +++ b/src/f32-dwconv/gen/up32x9-avx512f.c @@ -223,42 +223,42 @@ void xnn_f32_dwconv_ukernel_up32x9__avx512f( // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w); + __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w); const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0); - const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 32); + const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1); - const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 64); + const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2); - const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 96); + const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3); - const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 128); + const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4); - const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 160); + const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5); - const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 192); + const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6); - const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 224); + const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7); - const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 256); + const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0); const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8); - const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 288); + const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288); vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0); diff --git a/src/f32-dwconv/up-avx512.c.in b/src/f32-dwconv/up-avx512.c.in index 4d5a8db74..e9cff30bb 100644 --- a/src/f32-dwconv/up-avx512.c.in +++ b/src/f32-dwconv/up-avx512.c.in @@ -117,11 +117,11 @@ void xnn_f32_dwconv_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx512f${"" if ACC // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1))); - __m512 vacc${ABC[0:16]}p0 = _mm512_load_ps(w); + __m512 vacc${ABC[0:16]}p0 = _mm512_maskz_loadu_ps(vmask, w); $for K in range(KERNEL_TILE): const __m512 vi${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, i${K}); - const __m512 vk${K}x${ABC[0:16]} = _mm512_load_ps(w + ${(K + 1) * CHANNEL_TILE}); + const __m512 vk${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, w + ${(K + 1) * CHANNEL_TILE}); $if 1 <= K < ACCUMULATORS: __m512 vacc${ABC[0:16]}p${K} = _mm512_mul_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}); $else: diff --git a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in index 4ff2dc336..8880dceb3 100644 --- a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in +++ b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in @@ -67,7 +67,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - loads for first group of 6 fma @@ -261,6 +261,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma # BLOCK 4 INS v19.d[1], x9 FMLA v20.4s, v17.4s, v1.s[1] + TST x0, 15 # BLOCK 5 FMLA v21.4s, v18.4s, v1.s[1] @@ -269,11 +270,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma FMLA v22.4s, v19.4s, v1.s[1] # BLOCK 7 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -290,12 +288,13 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14 SUB x3, x3, x2 // a0 -= kc - B.HI 0b - RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 1 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in index a20699867..d0bb021ed 100644 --- a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in +++ b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in @@ -144,7 +144,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f SUBS x0, x0, 16 @@ -416,6 +416,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma FMLA v27.4s, v18.4s, v3.s[1] FMLA v30.4s, v18.4s, v3.s[3] FMLA v22.4s, v19.4s, v2.s[1] + TST x0, 15 # BLOCK 7 FMLA v25.4s, v19.4s, v2.s[3] @@ -423,11 +424,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma ADD x5, x5, 96 FMLA v31.4s, v19.4s, v3.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -488,6 +486,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 4 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/4x8-aarch32-neon-ld64.S b/src/f32-gemm/4x8-aarch32-neon-ld64.S index cc711f9ae..72c2d5c8e 100644 --- a/src/f32-gemm/4x8-aarch32-neon-ld64.S +++ b/src/f32-gemm/4x8-aarch32-neon-ld64.S @@ -86,7 +86,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64 VMOV q13, q9 VMOV q14, q8 VMOV q15, q9 - BLO 3f // less than 2 channels? + BLO 8f // less than 2 channels? // Main loop - 2 floats of A (8 bytes) 2: @@ -116,7 +116,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64 VMLA.F32 q15, q7, d3[1] BHS 2b -3: // Is there a remainder?- 1 floats of A (4 bytes) TST r5, 4 BNE 8f diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in index 37860d04a..2c5108b4c 100644 --- a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in +++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a53.S.in @@ -142,7 +142,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -346,6 +346,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ FMLA v22.4s, v14.4s, v3.s[3] FMLA v24.4s, v14.4s, v4.s[1] FMLA v26.4s, v14.4s, v4.s[3] + TST x0, 15 // BLOCK 4 FMLA v21.4s, v15.4s, v3.s[1] @@ -356,11 +357,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ // BLOCK 5 FMLA v27.4s, v15.4s, v4.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f + 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -410,8 +409,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -441,7 +443,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in index dba8f7d47..aab2dbdaf 100644 --- a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in +++ b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in @@ -88,7 +88,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) 1: @@ -135,52 +135,10 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ FMLA v31.4s, v27.4s, v3.s[3] B.HS 1b - # Remainder- 2 floats of A (8 bytes) -2: - TBZ x0, 3, 3f + TST x0, 15 + B.NE 5f - LDR d0, [x3], 8 - LDP q20, q21, [x5], 32 - LDR d1, [x11], 8 - LDR d2, [x12], 8 - LDR d3, [x4], 8 - FMLA v16.4s, v20.4s, v0.s[0] - FMLA v17.4s, v21.4s, v0.s[0] - FMLA v18.4s, v20.4s, v1.s[0] - FMLA v19.4s, v21.4s, v1.s[0] - LDP q22, q23, [x5], 32 - FMLA v28.4s, v20.4s, v2.s[0] - FMLA v29.4s, v21.4s, v2.s[0] - FMLA v30.4s, v20.4s, v3.s[0] - FMLA v31.4s, v21.4s, v3.s[0] - FMLA v16.4s, v22.4s, v0.s[1] - FMLA v17.4s, v23.4s, v0.s[1] - FMLA v18.4s, v22.4s, v1.s[1] - FMLA v19.4s, v23.4s, v1.s[1] - FMLA v28.4s, v22.4s, v2.s[1] - FMLA v29.4s, v23.4s, v2.s[1] - FMLA v30.4s, v22.4s, v3.s[1] - FMLA v31.4s, v23.4s, v3.s[1] - - # Remainder- 1 float of A (4 bytes) -3: - TBZ x0, 2, 6f - - LDR s0, [x3], 4 - LDP q20, q21, [x5], 32 - LDR s1, [x11], 4 - LDR s2, [x12], 4 - LDR s3, [x4], 4 - FMLA v16.4s, v20.4s, v0.s[0] - FMLA v17.4s, v21.4s, v0.s[0] - FMLA v18.4s, v20.4s, v1.s[0] - FMLA v19.4s, v21.4s, v1.s[0] - FMLA v28.4s, v20.4s, v2.s[0] - FMLA v29.4s, v21.4s, v2.s[0] - FMLA v30.4s, v20.4s, v3.s[0] - FMLA v31.4s, v21.4s, v3.s[0] - -6: +4: # Clamp FMIN v16.4s, v16.4s, v4.4s SUBS x1, x1, 8 @@ -223,9 +181,58 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ SUB x4, x4, x2 // a3 -= kc B.HI 0b - RET + # Remainder- 2 floats of A (8 bytes) +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + + # Remainder- 2 floats of A (8 bytes) + LDR d0, [x3], 8 + LDP q20, q21, [x5], 32 + LDR d1, [x11], 8 + LDR d2, [x12], 8 + LDR d3, [x4], 8 + FMLA v16.4s, v20.4s, v0.s[0] + FMLA v17.4s, v21.4s, v0.s[0] + FMLA v18.4s, v20.4s, v1.s[0] + FMLA v19.4s, v21.4s, v1.s[0] + LDP q22, q23, [x5], 32 + FMLA v28.4s, v20.4s, v2.s[0] + FMLA v29.4s, v21.4s, v2.s[0] + FMLA v30.4s, v20.4s, v3.s[0] + FMLA v31.4s, v21.4s, v3.s[0] + FMLA v16.4s, v22.4s, v0.s[1] + FMLA v17.4s, v23.4s, v0.s[1] + FMLA v18.4s, v22.4s, v1.s[1] + FMLA v19.4s, v23.4s, v1.s[1] + FMLA v28.4s, v22.4s, v2.s[1] + FMLA v29.4s, v23.4s, v2.s[1] + FMLA v30.4s, v22.4s, v3.s[1] + FMLA v31.4s, v23.4s, v3.s[1] + + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + + # Remainder- 1 float of A (4 bytes) +6: + LDR s0, [x3], 4 + LDP q20, q21, [x5], 32 + LDR s1, [x11], 4 + LDR s2, [x12], 4 + LDR s3, [x4], 4 + FMLA v16.4s, v20.4s, v0.s[0] + FMLA v17.4s, v21.4s, v0.s[0] + FMLA v18.4s, v20.4s, v1.s[0] + FMLA v19.4s, v21.4s, v1.s[0] + FMLA v28.4s, v20.4s, v2.s[0] + FMLA v29.4s, v21.4s, v2.s[0] + FMLA v30.4s, v20.4s, v3.s[0] + FMLA v31.4s, v21.4s, v3.s[0] + B 4b + + # Store odd width 7: TBZ x1, 2, 8f diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in index e2bedbcd8..8147e90c4 100644 --- a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in +++ b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in @@ -88,10 +88,9 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ # Is there at least 2 floats (8 bytes)? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 5f # Main loop - 2 floats of A (8 bytes) - 1: LDR d0, [x3], 8 LDP q20, q21, [x5], 32 @@ -117,25 +116,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ FMLA v30.4s, v22.4s, v3.s[1] FMLA v31.4s, v23.4s, v3.s[1] B.HS 1b -2: - # Remainder- 1 floats of A (4 bytes) - TBZ x0, 2, 6f - LDR s0, [x3], 4 - LDP q20, q21, [x5], 32 - LDR s1, [x11], 4 - LDR s2, [x12], 4 - LDR s3 , [x4], 4 - FMLA v16.4s, v20.4s, v0.s[0] - FMLA v17.4s, v21.4s, v0.s[0] - FMLA v18.4s, v20.4s, v1.s[0] - FMLA v19.4s, v21.4s, v1.s[0] - FMLA v28.4s, v20.4s, v2.s[0] - FMLA v29.4s, v21.4s, v2.s[0] - FMLA v30.4s, v20.4s, v3.s[0] - FMLA v31.4s, v21.4s, v3.s[0] + # Is there a remainder?- 1 floats of A (4 bytes) + TBNZ x0, 2, 5f -6: +4: # Clamp FMIN v16.4s, v16.4s, v4.4s SUBS x1, x1, 8 @@ -181,6 +166,23 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ RET + # Remainder- 1 float of A (4 bytes) +5: + LDR s0, [x3], 4 + LDP q20, q21, [x5], 32 + LDR s1, [x11], 4 + LDR s2, [x12], 4 + LDR s3 , [x4], 4 + FMLA v16.4s, v20.4s, v0.s[0] + FMLA v17.4s, v21.4s, v0.s[0] + FMLA v18.4s, v20.4s, v1.s[0] + FMLA v19.4s, v21.4s, v1.s[0] + FMLA v28.4s, v20.4s, v2.s[0] + FMLA v29.4s, v21.4s, v2.s[0] + FMLA v30.4s, v20.4s, v3.s[0] + FMLA v31.4s, v21.4s, v3.s[0] + B 4b + # Store odd width 7: TBZ x1, 2, 8f diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in index 9ac56da97..fa7056c90 100644 --- a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in +++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a53.S.in @@ -167,7 +167,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -430,6 +430,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v21.4s, v15.4s, v3.s[1] FMLA v23.4s, v15.4s, v3.s[3] FMLA v25.4s, v15.4s, v4.s[1] + TST x0, 15 // BLOCK 7 FMLA v27.4s, v15.4s, v4.s[3] @@ -437,11 +438,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v31.4s, v15.4s, v5.s[3] ADD x5, x5, 64 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -507,8 +505,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -548,8 +549,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b - -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in index b3e3f5526..50b74f330 100644 --- a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in +++ b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in @@ -149,7 +149,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) # 48 FMA + 6 ld128 A + 4 LDP B @@ -218,12 +218,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v31.4s, v19.4s, v5.s[3] B.HS 1b -2: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 4f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 5f -3: + # Is there a remainder?- 2 floats of A (8 bytes) or less + TST x0, 15 + B.NE 5f + +4: # Clamp FMIN v20.4s, v20.4s, v6.4s SUBS x1, x1, 8 @@ -252,7 +251,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMAX v31.4s, v31.4s, v7.4s # Store full 6 x 8 - B.LO 6f + B.LO 7f $if INC: ST1 {v30.16b, v31.16b}, [x7], x14 @@ -282,10 +281,12 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET -4: +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) LDR d0, [x3], 8 LDP q16, q17, [x5], 32 @@ -320,10 +321,12 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v27.4s, v19.4s, v3.s[1] FMLA v29.4s, v19.4s, v4.s[1] FMLA v31.4s, v19.4s, v5.s[1] - TBZ x0, 2, 3b -5: - # Remainder- 1 floats of A (4 bytes) + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + + # Remainder- 1 float of A (4 bytes) +6: LDR s0, [x3], 4 LDP q16, q17, [x5], 32 LDR s1, [x9], 4 @@ -343,11 +346,11 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v27.4s, v17.4s, v3.s[0] FMLA v29.4s, v17.4s, v4.s[0] FMLA v31.4s, v17.4s, v5.s[0] - B 3b + B 4b # Store odd width -6: - TBZ x1, 2, 7f +7: + TBZ x1, 2, 8f $if INC: STR q30, [x7], 16 MOV v30.16b, v31.16b @@ -375,8 +378,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ STR q30, [x7], 16 MOV v30.16b, v31.16b -7: - TBZ x1, 1, 8f +8: + TBZ x1, 1, 9f $if INC: STR d30, [x7], 8 DUP d30, v30.d[1] @@ -404,8 +407,8 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ STR d30, [x7], 8 DUP d30, v30.d[1] -8: - TBZ x1, 0, 9f +9: + TBZ x1, 0, 10f $if INC: STR s30, [x7] STR s28, [x13] @@ -420,7 +423,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ STR s26, [x18] STR s28, [x13] STR s30, [x7] -9: +10: RET END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld128 diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in index 64a1be3d4..daae4a001 100644 --- a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in +++ b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in @@ -149,7 +149,7 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ # Is there at least 2 floats (8 bytes) for main loop? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 4f # Main loop - 2 floats of A (8 bytes) # 24 FMA + 6 LD64 A + 2 LDP B @@ -190,7 +190,6 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ FMLA v31.4s, v19.4s, v5.s[1] B.HS 1b -2: # Is there a remainder?- 1 floats of A (4 bytes) TBNZ x0, 2, 4f 3: @@ -252,7 +251,6 @@ BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET 4: diff --git a/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S index fc9bf22a5..cb9cac776 100644 --- a/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen-inc/1x12-aarch64-neonfma-cortex-a53.S @@ -60,7 +60,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - loads for first group of 6 fma @@ -254,6 +254,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53 # BLOCK 4 INS v19.d[1], x9 FMLA v20.4s, v17.4s, v1.s[1] + TST x0, 15 # BLOCK 5 FMLA v21.4s, v18.4s, v1.s[1] @@ -262,11 +263,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53 FMLA v22.4s, v19.4s, v1.s[1] # BLOCK 7 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -283,12 +281,13 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53 ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14 SUB x3, x3, x2 // a0 -= kc - B.HI 0b - RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 1 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S index ba8e20400..897ad35fb 100644 --- a/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen-inc/4x12-aarch64-neonfma-cortex-a53.S @@ -114,7 +114,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f SUBS x0, x0, 16 @@ -386,6 +386,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53 FMLA v27.4s, v18.4s, v3.s[1] FMLA v30.4s, v18.4s, v3.s[3] FMLA v22.4s, v19.4s, v2.s[1] + TST x0, 15 # BLOCK 7 FMLA v25.4s, v19.4s, v2.s[3] @@ -393,11 +394,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53 ADD x5, x5, 96 FMLA v31.4s, v19.4s, v3.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -448,6 +446,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53 RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 4 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S index 9b447aa95..b0ef94fa2 100644 --- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-cortex-a53.S @@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -321,6 +321,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53 FMLA v22.4s, v14.4s, v3.s[3] FMLA v24.4s, v14.4s, v4.s[1] FMLA v26.4s, v14.4s, v4.s[3] + TST x0, 15 // BLOCK 4 FMLA v21.4s, v15.4s, v3.s[1] @@ -331,11 +332,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53 // BLOCK 5 FMLA v27.4s, v15.4s, v4.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f + 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -375,8 +374,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53 LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -406,7 +408,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a53 # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S index 198884e76..87d7ddbc0 100644 --- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S +++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld128.S @@ -75,7 +75,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) 1: @@ -122,10 +122,50 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128 FMLA v31.4s, v27.4s, v3.s[3] B.HS 1b + TST x0, 15 + B.NE 5f + +4: + # Clamp + FMIN v16.4s, v16.4s, v4.4s + SUBS x1, x1, 8 + FMIN v17.4s, v17.4s, v4.4s + FMIN v18.4s, v18.4s, v4.4s + FMIN v19.4s, v19.4s, v4.4s + FMIN v28.4s, v28.4s, v4.4s + FMIN v29.4s, v29.4s, v4.4s + FMIN v30.4s, v30.4s, v4.4s + FMIN v31.4s, v31.4s, v4.4s + FMAX v16.4s, v16.4s, v5.4s + FMAX v17.4s, v17.4s, v5.4s + FMAX v18.4s, v18.4s, v5.4s + FMAX v19.4s, v19.4s, v5.4s + FMAX v28.4s, v28.4s, v5.4s + FMAX v29.4s, v29.4s, v5.4s + FMAX v30.4s, v30.4s, v5.4s + FMAX v31.4s, v31.4s, v5.4s + + # Store full 4 x 8 + B.LO 7f + + ST1 {v30.16b, v31.16b}, [x7], x14 + SUB x3, x3, x2 // a0 -= kc + ST1 {v28.16b, v29.16b}, [x10], x14 + SUB x11, x11, x2 // a1 -= kc + ST1 {v18.16b, v19.16b}, [x9], x14 + SUB x12, x12, x2 // a2 -= kc + ST1 {v16.16b, v17.16b}, [x6], x14 + SUB x4, x4, x2 // a3 -= kc + + B.HI 0b + RET + # Remainder- 2 floats of A (8 bytes) -2: - TBZ x0, 3, 3f +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) LDR d0, [x3], 8 LDP q20, q21, [x5], 32 LDR d1, [x11], 8 @@ -149,10 +189,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128 FMLA v30.4s, v22.4s, v3.s[1] FMLA v31.4s, v23.4s, v3.s[1] - # Remainder- 1 float of A (4 bytes) -3: - TBZ x0, 2, 6f + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + # Remainder- 1 float of A (4 bytes) +6: LDR s0, [x3], 4 LDP q20, q21, [x5], 32 LDR s1, [x11], 4 @@ -166,42 +207,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128 FMLA v29.4s, v21.4s, v2.s[0] FMLA v30.4s, v20.4s, v3.s[0] FMLA v31.4s, v21.4s, v3.s[0] + B 4b -6: - # Clamp - FMIN v16.4s, v16.4s, v4.4s - SUBS x1, x1, 8 - FMIN v17.4s, v17.4s, v4.4s - FMIN v18.4s, v18.4s, v4.4s - FMIN v19.4s, v19.4s, v4.4s - FMIN v28.4s, v28.4s, v4.4s - FMIN v29.4s, v29.4s, v4.4s - FMIN v30.4s, v30.4s, v4.4s - FMIN v31.4s, v31.4s, v4.4s - FMAX v16.4s, v16.4s, v5.4s - FMAX v17.4s, v17.4s, v5.4s - FMAX v18.4s, v18.4s, v5.4s - FMAX v19.4s, v19.4s, v5.4s - FMAX v28.4s, v28.4s, v5.4s - FMAX v29.4s, v29.4s, v5.4s - FMAX v30.4s, v30.4s, v5.4s - FMAX v31.4s, v31.4s, v5.4s - - # Store full 4 x 8 - B.LO 7f - - ST1 {v30.16b, v31.16b}, [x7], x14 - SUB x3, x3, x2 // a0 -= kc - ST1 {v28.16b, v29.16b}, [x10], x14 - SUB x11, x11, x2 // a1 -= kc - ST1 {v18.16b, v19.16b}, [x9], x14 - SUB x12, x12, x2 // a2 -= kc - ST1 {v16.16b, v17.16b}, [x6], x14 - SUB x4, x4, x2 // a3 -= kc - - B.HI 0b - - RET # Store odd width 7: diff --git a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S index 0f7a645c7..6d146e1a7 100644 --- a/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S +++ b/src/f32-gemm/gen-inc/4x8-aarch64-neonfma-ld64.S @@ -75,10 +75,9 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64 # Is there at least 2 floats (8 bytes)? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 5f # Main loop - 2 floats of A (8 bytes) - 1: LDR d0, [x3], 8 LDP q20, q21, [x5], 32 @@ -104,25 +103,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64 FMLA v30.4s, v22.4s, v3.s[1] FMLA v31.4s, v23.4s, v3.s[1] B.HS 1b -2: - # Remainder- 1 floats of A (4 bytes) - TBZ x0, 2, 6f - LDR s0, [x3], 4 - LDP q20, q21, [x5], 32 - LDR s1, [x11], 4 - LDR s2, [x12], 4 - LDR s3 , [x4], 4 - FMLA v16.4s, v20.4s, v0.s[0] - FMLA v17.4s, v21.4s, v0.s[0] - FMLA v18.4s, v20.4s, v1.s[0] - FMLA v19.4s, v21.4s, v1.s[0] - FMLA v28.4s, v20.4s, v2.s[0] - FMLA v29.4s, v21.4s, v2.s[0] - FMLA v30.4s, v20.4s, v3.s[0] - FMLA v31.4s, v21.4s, v3.s[0] + # Is there a remainder?- 1 floats of A (4 bytes) + TBNZ x0, 2, 5f -6: +4: # Clamp FMIN v16.4s, v16.4s, v4.4s SUBS x1, x1, 8 @@ -158,6 +143,23 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64 RET + # Remainder- 1 float of A (4 bytes) +5: + LDR s0, [x3], 4 + LDP q20, q21, [x5], 32 + LDR s1, [x11], 4 + LDR s2, [x12], 4 + LDR s3 , [x4], 4 + FMLA v16.4s, v20.4s, v0.s[0] + FMLA v17.4s, v21.4s, v0.s[0] + FMLA v18.4s, v20.4s, v1.s[0] + FMLA v19.4s, v21.4s, v1.s[0] + FMLA v28.4s, v20.4s, v2.s[0] + FMLA v29.4s, v21.4s, v2.s[0] + FMLA v30.4s, v20.4s, v3.s[0] + FMLA v31.4s, v21.4s, v3.s[0] + B 4b + # Store odd width 7: TBZ x1, 2, 8f diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S index 736ac6744..0b5cb6e8e 100644 --- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-cortex-a53.S @@ -134,7 +134,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -397,6 +397,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v21.4s, v15.4s, v3.s[1] FMLA v23.4s, v15.4s, v3.s[3] FMLA v25.4s, v15.4s, v4.s[1] + TST x0, 15 // BLOCK 7 FMLA v27.4s, v15.4s, v4.s[3] @@ -404,11 +405,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v31.4s, v15.4s, v5.s[3] ADD x5, x5, 64 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -460,8 +458,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53 LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -501,8 +502,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a53 # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b - -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S index 559a71a35..2fdf4f153 100644 --- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S +++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld128.S @@ -122,7 +122,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) # 48 FMA + 6 ld128 A + 4 LDP B @@ -191,12 +191,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 FMLA v31.4s, v19.4s, v5.s[3] B.HS 1b -2: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 4f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 5f -3: + # Is there a remainder?- 2 floats of A (8 bytes) or less + TST x0, 15 + B.NE 5f + +4: # Clamp FMIN v20.4s, v20.4s, v6.4s SUBS x1, x1, 8 @@ -225,7 +224,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 FMAX v31.4s, v31.4s, v7.4s # Store full 6 x 8 - B.LO 6f + B.LO 7f ST1 {v30.16b, v31.16b}, [x7], x14 SUB x3, x3, x2 // a0 -= kc @@ -241,10 +240,12 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET -4: +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) LDR d0, [x3], 8 LDP q16, q17, [x5], 32 @@ -279,10 +280,12 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 FMLA v27.4s, v19.4s, v3.s[1] FMLA v29.4s, v19.4s, v4.s[1] FMLA v31.4s, v19.4s, v5.s[1] - TBZ x0, 2, 3b -5: - # Remainder- 1 floats of A (4 bytes) + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + + # Remainder- 1 float of A (4 bytes) +6: LDR s0, [x3], 4 LDP q16, q17, [x5], 32 LDR s1, [x9], 4 @@ -302,11 +305,11 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 FMLA v27.4s, v17.4s, v3.s[0] FMLA v29.4s, v17.4s, v4.s[0] FMLA v31.4s, v17.4s, v5.s[0] - B 3b + B 4b # Store odd width -6: - TBZ x1, 2, 7f +7: + TBZ x1, 2, 8f STR q30, [x7], 16 MOV v30.16b, v31.16b STR q28, [x13], 16 @@ -320,8 +323,8 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 STR q20, [x6], 16 MOV v20.16b, v21.16b -7: - TBZ x1, 1, 8f +8: + TBZ x1, 1, 9f STR d30, [x7], 8 DUP d30, v30.d[1] STR d28, [x13], 8 @@ -335,15 +338,15 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 STR d20, [x6], 8 DUP d20, v20.d[1] -8: - TBZ x1, 0, 9f +9: + TBZ x1, 0, 10f STR s30, [x7] STR s28, [x13] STR s26, [x18] STR s24, [x17] STR s22, [x16] STR s20, [x6] -9: +10: RET END_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128 diff --git a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S index b66f5c70a..eaa68f7d7 100644 --- a/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S +++ b/src/f32-gemm/gen-inc/6x8-aarch64-neonfma-ld64.S @@ -122,7 +122,7 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64 # Is there at least 2 floats (8 bytes) for main loop? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 4f # Main loop - 2 floats of A (8 bytes) # 24 FMA + 6 LD64 A + 2 LDP B @@ -163,7 +163,6 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64 FMLA v31.4s, v19.4s, v5.s[1] B.HS 1b -2: # Is there a remainder?- 1 floats of A (4 bytes) TBNZ x0, 2, 4f 3: @@ -211,7 +210,6 @@ BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64 SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET 4: diff --git a/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S index c6eb6f37e..f1730103d 100644 --- a/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen/1x12-aarch64-neonfma-cortex-a53.S @@ -57,7 +57,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - loads for first group of 6 fma @@ -251,6 +251,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53 # BLOCK 4 INS v19.d[1], x9 FMLA v20.4s, v17.4s, v1.s[1] + TST x0, 15 # BLOCK 5 FMLA v21.4s, v18.4s, v1.s[1] @@ -259,11 +260,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53 FMLA v22.4s, v19.4s, v1.s[1] # BLOCK 7 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -280,12 +278,13 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53 ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14 SUB x3, x3, x2 // a0 -= kc - B.HI 0b - RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 1 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S index 15669a9a7..257ab4bb0 100644 --- a/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen/4x12-aarch64-neonfma-cortex-a53.S @@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f SUBS x0, x0, 16 @@ -389,6 +389,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53 FMLA v27.4s, v18.4s, v3.s[1] FMLA v30.4s, v18.4s, v3.s[3] FMLA v22.4s, v19.4s, v2.s[1] + TST x0, 15 # BLOCK 7 FMLA v25.4s, v19.4s, v2.s[3] @@ -396,11 +397,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53 ADD x5, x5, 96 FMLA v31.4s, v19.4s, v3.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 5f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 6f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp @@ -451,6 +449,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53 RET 5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder - 2 floats of A (8 bytes) # Read first block of 4 A. LDR d0, [x3], 8 // a0 diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S index a0674a2c0..6b689e902 100644 --- a/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-cortex-a53.S @@ -119,7 +119,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -323,6 +323,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53 FMLA v22.4s, v14.4s, v3.s[3] FMLA v24.4s, v14.4s, v4.s[1] FMLA v26.4s, v14.4s, v4.s[3] + TST x0, 15 // BLOCK 4 FMLA v21.4s, v15.4s, v3.s[1] @@ -333,11 +334,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53 // BLOCK 5 FMLA v27.4s, v15.4s, v4.s[3] -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f + 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -377,8 +376,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53 LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -408,7 +410,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53 # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S index c2f496745..aa7e7a7e6 100644 --- a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S +++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld128.S @@ -75,7 +75,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) 1: @@ -122,10 +122,50 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128 FMLA v31.4s, v27.4s, v3.s[3] B.HS 1b + TST x0, 15 + B.NE 5f + +4: + # Clamp + FMIN v16.4s, v16.4s, v4.4s + SUBS x1, x1, 8 + FMIN v17.4s, v17.4s, v4.4s + FMIN v18.4s, v18.4s, v4.4s + FMIN v19.4s, v19.4s, v4.4s + FMIN v28.4s, v28.4s, v4.4s + FMIN v29.4s, v29.4s, v4.4s + FMIN v30.4s, v30.4s, v4.4s + FMIN v31.4s, v31.4s, v4.4s + FMAX v16.4s, v16.4s, v5.4s + FMAX v17.4s, v17.4s, v5.4s + FMAX v18.4s, v18.4s, v5.4s + FMAX v19.4s, v19.4s, v5.4s + FMAX v28.4s, v28.4s, v5.4s + FMAX v29.4s, v29.4s, v5.4s + FMAX v30.4s, v30.4s, v5.4s + FMAX v31.4s, v31.4s, v5.4s + + # Store full 4 x 8 + B.LO 7f + + ST1 {v16.16b, v17.16b}, [x6], x14 + SUB x3, x3, x2 // a0 -= kc + ST1 {v18.16b, v19.16b}, [x9], x14 + SUB x11, x11, x2 // a1 -= kc + ST1 {v28.16b, v29.16b}, [x10], x14 + SUB x12, x12, x2 // a2 -= kc + ST1 {v30.16b, v31.16b}, [x7], x14 + SUB x4, x4, x2 // a3 -= kc + + B.HI 0b + RET + # Remainder- 2 floats of A (8 bytes) -2: - TBZ x0, 3, 3f +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) LDR d0, [x3], 8 LDP q20, q21, [x5], 32 LDR d1, [x11], 8 @@ -149,10 +189,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128 FMLA v30.4s, v22.4s, v3.s[1] FMLA v31.4s, v23.4s, v3.s[1] - # Remainder- 1 float of A (4 bytes) -3: - TBZ x0, 2, 6f + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + # Remainder- 1 float of A (4 bytes) +6: LDR s0, [x3], 4 LDP q20, q21, [x5], 32 LDR s1, [x11], 4 @@ -166,42 +207,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128 FMLA v29.4s, v21.4s, v2.s[0] FMLA v30.4s, v20.4s, v3.s[0] FMLA v31.4s, v21.4s, v3.s[0] + B 4b -6: - # Clamp - FMIN v16.4s, v16.4s, v4.4s - SUBS x1, x1, 8 - FMIN v17.4s, v17.4s, v4.4s - FMIN v18.4s, v18.4s, v4.4s - FMIN v19.4s, v19.4s, v4.4s - FMIN v28.4s, v28.4s, v4.4s - FMIN v29.4s, v29.4s, v4.4s - FMIN v30.4s, v30.4s, v4.4s - FMIN v31.4s, v31.4s, v4.4s - FMAX v16.4s, v16.4s, v5.4s - FMAX v17.4s, v17.4s, v5.4s - FMAX v18.4s, v18.4s, v5.4s - FMAX v19.4s, v19.4s, v5.4s - FMAX v28.4s, v28.4s, v5.4s - FMAX v29.4s, v29.4s, v5.4s - FMAX v30.4s, v30.4s, v5.4s - FMAX v31.4s, v31.4s, v5.4s - - # Store full 4 x 8 - B.LO 7f - - ST1 {v16.16b, v17.16b}, [x6], x14 - SUB x3, x3, x2 // a0 -= kc - ST1 {v18.16b, v19.16b}, [x9], x14 - SUB x11, x11, x2 // a1 -= kc - ST1 {v28.16b, v29.16b}, [x10], x14 - SUB x12, x12, x2 // a2 -= kc - ST1 {v30.16b, v31.16b}, [x7], x14 - SUB x4, x4, x2 // a3 -= kc - - B.HI 0b - - RET # Store odd width 7: diff --git a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S index 1770a0290..548450f64 100644 --- a/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S +++ b/src/f32-gemm/gen/4x8-aarch64-neonfma-ld64.S @@ -75,10 +75,9 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64 # Is there at least 2 floats (8 bytes)? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 5f # Main loop - 2 floats of A (8 bytes) - 1: LDR d0, [x3], 8 LDP q20, q21, [x5], 32 @@ -104,25 +103,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64 FMLA v30.4s, v22.4s, v3.s[1] FMLA v31.4s, v23.4s, v3.s[1] B.HS 1b -2: - # Remainder- 1 floats of A (4 bytes) - TBZ x0, 2, 6f - LDR s0, [x3], 4 - LDP q20, q21, [x5], 32 - LDR s1, [x11], 4 - LDR s2, [x12], 4 - LDR s3 , [x4], 4 - FMLA v16.4s, v20.4s, v0.s[0] - FMLA v17.4s, v21.4s, v0.s[0] - FMLA v18.4s, v20.4s, v1.s[0] - FMLA v19.4s, v21.4s, v1.s[0] - FMLA v28.4s, v20.4s, v2.s[0] - FMLA v29.4s, v21.4s, v2.s[0] - FMLA v30.4s, v20.4s, v3.s[0] - FMLA v31.4s, v21.4s, v3.s[0] + # Is there a remainder?- 1 floats of A (4 bytes) + TBNZ x0, 2, 5f -6: +4: # Clamp FMIN v16.4s, v16.4s, v4.4s SUBS x1, x1, 8 @@ -158,6 +143,23 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64 RET + # Remainder- 1 float of A (4 bytes) +5: + LDR s0, [x3], 4 + LDP q20, q21, [x5], 32 + LDR s1, [x11], 4 + LDR s2, [x12], 4 + LDR s3 , [x4], 4 + FMLA v16.4s, v20.4s, v0.s[0] + FMLA v17.4s, v21.4s, v0.s[0] + FMLA v18.4s, v20.4s, v1.s[0] + FMLA v19.4s, v21.4s, v1.s[0] + FMLA v28.4s, v20.4s, v2.s[0] + FMLA v29.4s, v21.4s, v2.s[0] + FMLA v30.4s, v20.4s, v3.s[0] + FMLA v31.4s, v21.4s, v3.s[0] + B 4b + # Store odd width 7: TBZ x1, 2, 8f diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S index 3851823a6..1e80a6010 100644 --- a/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-cortex-a53.S @@ -138,7 +138,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 3f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x3], 8 // a0 @@ -401,6 +401,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v21.4s, v15.4s, v3.s[1] FMLA v23.4s, v15.4s, v3.s[3] FMLA v25.4s, v15.4s, v4.s[1] + TST x0, 15 // BLOCK 7 FMLA v27.4s, v15.4s, v4.s[3] @@ -408,11 +409,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v31.4s, v15.4s, v5.s[3] ADD x5, x5, 64 -3: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: # Clamp FMIN v20.4s, v20.4s, v6.4s @@ -464,8 +462,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53 LDP d12, d13, [sp], 32 RET +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) -6: LDR d0, [x3], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x9], 8 @@ -505,8 +506,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53 # Is there a remainder?- 1 floats of A (4 bytes) TBZ x0, 2, 4b - -7: +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x3], 4 LDR q16, [x5], 16 diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S index 48ac7af57..7909ba23f 100644 --- a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S +++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld128.S @@ -126,7 +126,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 # Is there at least 4 floats (16 bytes)? SUBS x0, x2, 16 // k = kc - 16 - B.LO 2f + B.LO 5f # Main loop - 4 floats of A (16 bytes) # 48 FMA + 6 ld128 A + 4 LDP B @@ -195,12 +195,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 FMLA v31.4s, v19.4s, v5.s[3] B.HS 1b -2: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 4f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 5f -3: + # Is there a remainder?- 2 floats of A (8 bytes) or less + TST x0, 15 + B.NE 5f + +4: # Clamp FMIN v20.4s, v20.4s, v6.4s SUBS x1, x1, 8 @@ -229,7 +228,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 FMAX v31.4s, v31.4s, v7.4s # Store full 6 x 8 - B.LO 6f + B.LO 7f ST1 {v20.16b, v21.16b}, [x6], x14 SUB x3, x3, x2 // a0 -= kc @@ -245,10 +244,12 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET -4: +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + # Remainder- 2 floats of A (8 bytes) LDR d0, [x3], 8 LDP q16, q17, [x5], 32 @@ -283,10 +284,12 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 FMLA v27.4s, v19.4s, v3.s[1] FMLA v29.4s, v19.4s, v4.s[1] FMLA v31.4s, v19.4s, v5.s[1] - TBZ x0, 2, 3b -5: - # Remainder- 1 floats of A (4 bytes) + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b + + # Remainder- 1 float of A (4 bytes) +6: LDR s0, [x3], 4 LDP q16, q17, [x5], 32 LDR s1, [x9], 4 @@ -306,11 +309,11 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 FMLA v27.4s, v17.4s, v3.s[0] FMLA v29.4s, v17.4s, v4.s[0] FMLA v31.4s, v17.4s, v5.s[0] - B 3b + B 4b # Store odd width -6: - TBZ x1, 2, 7f +7: + TBZ x1, 2, 8f STR q20, [x6], 16 MOV v20.16b, v21.16b STR q22, [x16], 16 @@ -324,8 +327,8 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 STR q30, [x7], 16 MOV v30.16b, v31.16b -7: - TBZ x1, 1, 8f +8: + TBZ x1, 1, 9f STR d20, [x6], 8 DUP d20, v20.d[1] STR d22, [x16], 8 @@ -339,15 +342,15 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 STR d30, [x7], 8 DUP d30, v30.d[1] -8: - TBZ x1, 0, 9f +9: + TBZ x1, 0, 10f STR s20, [x6] STR s22, [x16] STR s24, [x17] STR s26, [x18] STR s28, [x13] STR s30, [x7] -9: +10: RET END_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128 diff --git a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S index 4f869950b..d9460005f 100644 --- a/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S +++ b/src/f32-gemm/gen/6x8-aarch64-neonfma-ld64.S @@ -126,7 +126,7 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64 # Is there at least 2 floats (8 bytes) for main loop? SUBS x0, x2, 8 // k = kc - 8 - B.LO 2f + B.LO 4f # Main loop - 2 floats of A (8 bytes) # 24 FMA + 6 LD64 A + 2 LDP B @@ -167,7 +167,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64 FMLA v31.4s, v19.4s, v5.s[1] B.HS 1b -2: # Is there a remainder?- 1 floats of A (4 bytes) TBNZ x0, 2, 4f 3: @@ -215,7 +214,6 @@ BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64 SUB x4, x4, x2 // a5 -= kc B.HI 0b - RET 4: diff --git a/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S index 1f3edf73c..5f24466fc 100644 --- a/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S @@ -71,7 +71,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 4f + B.LO 5f # Prologue - loads for first group of 6 fma @@ -267,6 +267,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 # BLOCK 4 INS v19.d[1], x7 FMLA v20.4s, v17.4s, v1.s[1] + TST x0, 15 # BLOCK 5 FMLA v21.4s, v18.4s, v1.s[1] @@ -275,14 +276,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 FMLA v22.4s, v19.4s, v1.s[1] # BLOCK 7 + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f 4: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f - -5: # ks loop SUBS x9, x9, 8 // ks -= MR * sizeof(void*) B.NE 1b @@ -304,12 +301,13 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 # nc loop B.HI 0b - RET -6: - # Remainder - 2 floats of A (8 bytes) - # Read first block of 1 A. +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + + # Remainder- 2 floats of A (8 bytes) LDR d0, [x8], 8 // a0 LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48 LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48 @@ -324,8 +322,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 FMLA v21.4s, v6.4s, v0.s[1] FMLA v22.4s, v7.4s, v0.s[1] - TBZ x0, 2, 5b -7: + TBZ x0, 2, 4b +6: # Remainder - 1 float of A (4 bytes) LDR s0, [x8], 4 // a0 LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48 @@ -333,7 +331,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53 FMLA v20.4s, v2.4s, v0.s[0] FMLA v21.4s, v3.4s, v0.s[0] FMLA v22.4s, v4.4s, v0.s[0] - B 5b + B 4b 8: ADD x1, x1, 12 diff --git a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S index f43e79e37..8d6e51d09 100644 --- a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a53.S @@ -63,7 +63,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53 # Is there at least 8 floats (32 bytes) for prologue + epilogue? SUBS x0, x2, 32 // k = kc - 32 // k = kc - B.LO 4f + B.LO 5f # 16 prologue # Read first block of A and B. @@ -148,18 +148,13 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53 FMLA v19.4s, v23.4s, v1.s[1] FMLA v16.4s, v24.4s, v1.s[2] FMLA v17.4s, v25.4s, v1.s[2] + TST x0, 31 FMLA v18.4s, v26.4s, v1.s[3] FMLA v19.4s, v27.4s, v1.s[3] + # Is there a remainder?- 4 floats of A (16 bytes) or less + B.NE 5f 4: - # Is there a remainder?- 4 floats of A (16 bytes) - TBNZ x0, 4, 6f - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 7f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 9f - -5: # ks loop SUBS x9, x9, 8 // ks -= MR * sizeof(void*) B.NE 1b @@ -185,7 +180,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53 RET -6: +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 4, 6f + # Remainder- 4 floats of A (16 bytes) LDR q20, [x5], 16 LDR q21, [x5], 16 @@ -205,8 +203,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53 FMLA v18.4s, v26.4s, v0.s[3] FMLA v19.4s, v27.4s, v0.s[3] - TBZ x0, 3, 8f -7: +6: + TBZ x0, 3, 7f # Remainder- 2 floats of A (8 bytes) LDR q20, [x5], 16 LDR q21, [x5], 16 @@ -217,16 +215,15 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53 LDR q23, [x5], 16 FMLA v18.4s, v22.4s, v0.s[1] FMLA v19.4s, v23.4s, v0.s[1] -8: - TBZ x0, 2, 5b -9: +7: + TBZ x0, 2, 4b # Remainder- 1 float of A (4 bytes) LDR q20, [x5], 16 LDR q21, [x5], 16 LDR s0, [x8], 4 FMLA v16.4s, v20.4s, v0.s[0] FMLA v17.4s, v21.4s, v0.s[0] - B 5b + B 4b 10: # Store odd channels diff --git a/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S index a25dce0af..e0675a8cf 100644 --- a/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S +++ b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S @@ -136,7 +136,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53 PRFM PLDL1KEEP, [x15, 64] PRFM PLDL1KEEP, [x16, 0] PRFM PLDL1KEEP, [x16, 64] - B.LO 4f + B.LO 5f SUBS x0, x0, 16 // 4 floats for main loop @@ -408,19 +408,18 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53 FMLA v27.4s, v18.4s, v3.s[1] FMLA v30.4s, v18.4s, v3.s[3] FMLA v22.4s, v19.4s, v2.s[1] + TST x0, 15 # BLOCK 7 FMLA v25.4s, v19.4s, v2.s[3] FMLA v28.4s, v19.4s, v3.s[1] ADD x5, x5, 96 FMLA v31.4s, v19.4s, v3.s[3] -4: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f -5: + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f + +4: # ks loop SUBS x9, x9, 32 // ks -= MR * sizeof(void*) B.NE 1b @@ -470,9 +469,11 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53 LDP d8, d9, [sp], 48 RET -6: - # Remainder - 2 floats of A (8 bytes) - # Read first block of 4 A. +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + + # Remainder- 2 floats of A (8 bytes) LDR d0, [x13], 8 // a0 LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48 LDR d1, [x14], 8 // a1 @@ -508,9 +509,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53 FMLA v28.4s, v11.4s, v2.s[1] FMLA v31.4s, v11.4s, v3.s[1] - TBZ x0, 2, 5b -7: - # Remainder - 1 float of A (4 bytes) + # Is there a remainder?- 1 floats of A (4 bytes) + TBZ x0, 2, 4b +6: + # Remainder- 1 floats of A (4 bytes) LDR s0, [x13], 4 // a0 LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48 LDR s1, [x14], 4 // a1 @@ -529,7 +531,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53 FMLA v25.4s, v8.4s, v1.s[0] FMLA v28.4s, v8.4s, v2.s[0] FMLA v31.4s, v8.4s, v3.s[0] - B 5b + B 4b 8: ADD x1, x1, 12 diff --git a/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in b/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in new file mode 100644 index 000000000..d862d608a --- /dev/null +++ b/src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in @@ -0,0 +1,393 @@ +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <xnnpack/assembly.h> + +.syntax unified + +// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75( +// size_t mr, r0 +// size_t nc, r1 +// size_t kc, r2 -> r5 -> sp + 64 +// size_t ks, r3 -> sp + 68 -> r14 +// const float**restrict a, sp + 108 -> r2 +// const void*restrict w, sp + 112 -> r9 +// uint8_t*restrict c, sp + 116 -> r11 +// size_t cm_stride, sp + 120 -> (r6) +// size_t cn_stride, sp + 124 -> (r7) +// size_t a_offset, sp + 128 -> (r5) +// const float* zero, sp + 132 -> (r7) +// output_params*params, sp + 136 -> (r5) + +// inner loop registers + +// A0 r3 d0 +// A1 r12 d1 +// A2 r10 d2 +// A3 r0 d3 + +// B r9 d8, d9, d10, d11 +// B d12, d13, d14, d15 + +// C0 r11 d16-d17 q8 d18-d19 q9 +// C1 r4 d20-d21 q10 d22-d23 q11 +// C2 r8 d24-d25 q12 d26-d27 q13 +// C3 r6 d28-d29 q14 d30-d31 q15 + +// Clamp (r5) d4 d5 d6 d7 + +BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75 + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + // Push 108 bytes + // r2 will be reloaded in outer loop. r3 is ks + PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44 + VPUSH {d8-d15} // +64 = 108 + + LDR r2, [sp, 108] // a + LDR r9, [sp, 112] // w + LDR r11, [sp, 116] // c + LDR r6, [sp, 120] // cm_stride + LDR r14, [sp, 68] // p = ks + + // Clamp C pointers + CMP r0, 2 // if mr >= 2 + ADD r4, r11, r6 // c1 = c0 + cm_stride + MOVLO r4, r11 // c1 + // if mr > 2 + ADD r8, r4, r6 // c2 = c1 + cm_stride + MOVLS r8, r4 // c2 + CMP r0, 4 // if mr >=4 + ADD r6, r8, r6 // c3 = c2 + cm_stride + MOVLO r6, r8 // c3 + + .p2align 3 +0: + # Load initial bias from w into accumulators + VLDM r9!, {d16-d19} // Bias + VMOV q10, q8 + VMOV q11, q9 + VMOV q12, q8 + VMOV q13, q9 + VMOV q14, q8 + VMOV q15, q9 + + $if PREFETCH: + PLD [r9, 0] // Prefetch B + PLD [r9, 64] + PLD [r9, 128] + PLD [r9, 192] + PLD [r9, 256] + PLD [r9, 320] + +1: + # Load next 4 A pointers + LDR r3, [r2, 0] + LDR r12, [r2, 4] + LDR r10, [r2, 8] + LDR r0, [r2, 12] + ADD r2, r2, 16 + + // Add a_offset + LDR r5, [sp, 128] // a_offset + LDR r7, [sp, 132] // zero + CMP r3, r7 // if a0 == zero + ADD r3, r3, r5 // a0 += a_offset + MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset + CMP r12, r7 // if a1 == zero + ADD r12, r12, r5 // a1 += a_offset + MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset + CMP r10, r7 // if a2 == zero + ADD r10, r10, r5 // a2 += a_offset + MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset + CMP r0, r7 // if a3 == zero + ADD r0, r0, r5 // a3 += a_offset + LDR r5, [sp, 64] // kc + MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset + + $if PREFETCH: + PLD [r3, 0] // Prefetch A + PLD [r3, 64] + PLD [r12, 0] + PLD [r12, 64] + PLD [r10, 0] + PLD [r10, 64] + PLD [r0, 0] + PLD [r0, 64] + + SUBS r5, r5, 16 // kc - 16 + BLO 4f // less than 4 channels? + + // Prologue + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + SUBS r5, r5, 16 + BLO 3f // less than 4 channels? skip main loop + + .p2align 3 + + // Main loop - 4 floats of A (16 bytes) +2: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + $if PREFETCH: + PLD [r3, 128] // Prefetch A0 + VMLA.F32 q12, q4, d6[0] + VLD1.32 {d0}, [r3]! // A0 + VMLA.F32 q14, q4, d7[0] + $if PREFETCH: + PLD [r12, 128] // Prefetch A1 + VMLA.F32 q9, q5, d4[0] + VLD1.32 {d1}, [r12]! // A1 + VMLA.F32 q11, q5, d5[0] + $if PREFETCH: + PLD [r10, 128] // Prefetch A2 + VMLA.F32 q13, q5, d6[0] + VLD1.32 {d2}, [r10]! // A2 + VMLA.F32 q15, q5, d7[0] + $if PREFETCH: + PLD [r0, 128] // Prefetch A3 + VMLA.F32 q8, q6, d4[1] + VLD1.32 {d3}, [ r0]! // A3 + VMLA.F32 q10, q6, d5[1] + $if PREFETCH: + PLD [r9, 384] // Prefetch B + VMLA.F32 q12, q6, d6[1] + $if PREFETCH: + PLD [r9, 448] // Prefetch B + VMLA.F32 q14, q6, d7[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + SUBS r5, r5, 16 + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + BHS 2b + + // Epilogue +3: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + VMLA.F32 q12, q4, d6[0] + VMLA.F32 q14, q4, d7[0] + VMLA.F32 q9, q5, d4[0] + VMLA.F32 q11, q5, d5[0] + VMLA.F32 q13, q5, d6[0] + VMLA.F32 q15, q5, d7[0] + VMLA.F32 q8, q6, d4[1] + VMLA.F32 q10, q6, d5[1] + VMLA.F32 q12, q6, d6[1] + VMLA.F32 q14, q6, d7[1] + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + +4: + // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes) + TST r5, 12 + BNE 7f + + .p2align 3 +5: + # ks loop + SUBS r14, r14, 16 // ks -= MR * sizeof(void*) + BNE 1b + + // Load params pointer + LDR r5, [sp, 136] // clamping_params + LDR r7, [sp, 124] // cn_stride + LDR r14, [sp, 68] // p = ks + + // Load clamping_params values + VLD1.32 {d4[],d5[]}, [r5]! + SUBS r1, r1, 8 + VLD1.32 {d6[],d7[]}, [r5] + + // Clamp + VMIN.F32 q8, q8, q2 + VMIN.F32 q9, q9, q2 + VMIN.F32 q10, q10, q2 + VMIN.F32 q11, q11, q2 + VMIN.F32 q12, q12, q2 + VMIN.F32 q13, q13, q2 + VMIN.F32 q14, q14, q2 + VMIN.F32 q15, q15, q2 + VMAX.F32 q8, q8, q3 + VMAX.F32 q9, q9, q3 + VMAX.F32 q10, q10, q3 + VMAX.F32 q11, q11, q3 + VMAX.F32 q12, q12, q3 + VMAX.F32 q13, q13, q3 + VMAX.F32 q14, q14, q3 + VMAX.F32 q15, q15, q3 + + // Store full 4 x 8 + BLO 10f + VST1.32 {d28-d31}, [r6], r7 + VST1.32 {d24-d27}, [r8], r7 + VST1.32 {d20-d23}, [r4], r7 + VST1.32 {d16-d19}, [r11], r7 + SUB r2, r2, r14 // a -= ks + BHI 0b + +6: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + + .p2align 3 +7: + // Is there a remainder?- 2 floats of A (8 bytes) + TST r5, 8 + BEQ 8f + + // Remainder - 2 floats of A (8 bytes) + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q10, q6, d1[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q14, q6, d3[1] + VMLA.F32 q15, q7, d3[1] +8: + // Is there a remainder?- 1 floats of A (4 bytes) + TST r5, 4 + BEQ 5b + +9: + // Remainder- 1 floats of A (4 bytes) + VLDM r3!, {s0} // A0 + VLDM r9!, {d8-d11} // B0 + VLDM r12!, {s2} // A1 + VLDM r10!, {s4} // A2 + VLDM r0!, {s6} // A3 + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + B 5b + + // Store odd width +10: + TST r1, 4 + BEQ 11f + VST1.32 {d28-d29}, [r6]! + VMOV q14, q15 + VST1.32 {d24-d25}, [r8]! + VMOV q12, q13 + VST1.32 {d20-d21}, [r4]! + VMOV q10, q11 + VST1.32 {d16-d17}, [r11]! + VMOV q8, q9 + +11: + TST r1, 2 + BEQ 12f + VST1.32 {d28}, [r6]! + VMOV d28, d29 + VST1.32 {d24}, [r8]! + VMOV d24, d25 + VST1.32 {d20}, [r4]! + VMOV d20, d21 + VST1.32 {d16}, [r11]! + VMOV d16, d17 + +12: + TST r1, 1 + BEQ 13f + VST1.32 {d28[0]}, [r6]! + VST1.32 {d24[0]}, [r8]! + VST1.32 {d20[0]}, [r4]! + VST1.32 {d16[0]}, [r11]! + +13: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + +END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_${"pld_" if PREFETCH else ""}cortex_a75 + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif + + + diff --git a/src/f32-igemm/4x8-aarch32-neon-ld64.S b/src/f32-igemm/4x8-aarch32-neon-ld64.S new file mode 100644 index 000000000..3e0a77f8d --- /dev/null +++ b/src/f32-igemm/4x8-aarch32-neon-ld64.S @@ -0,0 +1,248 @@ +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <xnnpack/assembly.h> + +.syntax unified + +// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64( +// size_t mr, r0 +// size_t nc, r1 +// size_t kc, r2 -> r5 -> sp + 68 +// size_t ks, r3 -> sp + 72 -> r14 +// const float**restrict a, sp + 112 -> r2 +// const void*restrict w, sp + 116 -> r9 +// uint8_t*restrict c, sp + 120 -> r11 +// size_t cm_stride, sp + 124 -> (r6) +// size_t cn_stride, sp + 128 -> (r7) +// size_t a_offset, sp + 132 -> (r5) +// const float* zero, sp + 136 -> (r7) +// output_params*params, sp + 140 -> (r5) + +// inner loop registers + +// A0 r3 d0 +// A1 r12 d1 +// A2 r10 d2 +// A3 r0 d3 + +// B r9 d8, d9, d10, d11 +// B d12, d13, d14, d15 + +// C0 r11 d16-d17 q8 d18-d19 q9 +// C1 r4 d20-d21 q10 d22-d23 q11 +// C2 r8 d24-d25 q12 d26-d27 q13 +// C3 r6 d28-d29 q14 d30-d31 q15 + +// Clamp (r5) d4 d5 d6 d7 + +BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64 + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + // Push 112 bytes + // r2 will be reloaded in outer loop. r3 is ks + PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44 + SUB sp, sp, 4 // 4 + VPUSH {d8-d15} // +64 = 112 + + MOV r14, r3 // p = ks + LDR r2, [sp, 112] // a + LDR r9, [sp, 116] // w + LDR r11, [sp, 120] // c + LDR r6, [sp, 124] // cm_stride + LDR r5, [sp, 140] // clamping_params + + // Clamp C pointers + CMP r0, 2 // if mr >= 2 + ADD r4, r11, r6 // c1 = c0 + cm_stride + MOVLO r4, r11 // c1 + // if mr > 2 + ADD r8, r4, r6 // c2 = c1 + cm_stride + MOVLS r8, r4 // c2 + CMP r0, 4 // if mr >=4 + ADD r6, r8, r6 // c3 = c2 + cm_stride + MOVLO r6, r8 // c3 + + // Load clamping_params values + VLD1.32 {d4[], d5[]}, [r5]! + VLD1.32 {d6[], d7[]}, [r5] + +0: + # Load initial bias from w into accumulators + VLDM r9!, {d16-d19} // Bias + VMOV q10, q8 + VMOV q11, q9 + VMOV q12, q8 + VMOV q13, q9 + VMOV q14, q8 + VMOV q15, q9 + +1: + # Load next 4 A pointers + LDR r3, [r2, 0] + LDR r12, [r2, 4] + LDR r10, [r2, 8] + LDR r0, [r2, 12] + ADD r2, r2, 16 + + // Add a_offset + LDR r5, [sp, 132] // a_offset + LDR r7, [sp, 136] // zero + CMP r3, r7 // if a0 == zero + ADD r3, r3, r5 // a0 += a_offset + MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset + CMP r12, r7 // if a1 == zero + ADD r12, r12, r5 // a1 += a_offset + MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset + CMP r10, r7 // if a2 == zero + ADD r10, r10, r5 // a2 += a_offset + MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset + CMP r0, r7 // if a3 == zero + ADD r0, r0, r5 // a3 += a_offset + LDR r5, [sp, 68] // kc + MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset + + SUBS r5, r5, 8 // kc - 8 + BLO 8f // less than 2 channels? + + // Main loop - 2 floats of A (8 bytes) +2: + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q10, q6, d1[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q13, q7, d2[1] + SUBS r5, r5, 8 + VMLA.F32 q14, q6, d3[1] + VMLA.F32 q15, q7, d3[1] + BHS 2b + + // Is there a remainder?- 1 floats of A (4 bytes) + TST r5, 4 + BNE 8f + +4: + # ks loop + SUBS r14, r14, 16 // ks -= MR * sizeof(void*) + BNE 1b + + LDR r7, [sp, 128] // cn_stride + LDR r14, [sp, 72] // p = ks + + // Clamp + VMIN.F32 q8, q8, q2 + SUBS r1, r1, 8 + VMIN.F32 q9, q9, q2 + VMIN.F32 q10, q10, q2 + VMIN.F32 q11, q11, q2 + VMIN.F32 q12, q12, q2 + VMIN.F32 q13, q13, q2 + VMIN.F32 q14, q14, q2 + VMIN.F32 q15, q15, q2 + VMAX.F32 q8, q8, q3 + VMAX.F32 q9, q9, q3 + VMAX.F32 q10, q10, q3 + VMAX.F32 q11, q11, q3 + VMAX.F32 q12, q12, q3 + VMAX.F32 q13, q13, q3 + VMAX.F32 q14, q14, q3 + VMAX.F32 q15, q15, q3 + + // Store full 4 x 8 + BLO 10f + VST1.32 {d28-d31}, [r6], r7 + VST1.32 {d24-d27}, [r8], r7 + VST1.32 {d20-d23}, [r4], r7 + VST1.32 {d16-d19}, [r11], r7 + SUB r2, r2, r14 // a -= ks + BHI 0b + +6: + VPOP {d8-d15} + ADD sp, sp, 12 // skip pad, r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + +8: + // Remainder- 1 floats of A (4 bytes) + VLDM r3!, {s0} // A0 + VLDM r9!, {d8-d11} // B0 + VLDM r12!, {s2} // A1 + VLDM r10!, {s4} // A2 + VLDM r0!, {s6} // A3 + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + B 4b + + // Store odd width +10: + TST r1, 4 + BEQ 11f + VST1.32 {d28-d29}, [r6]! + VMOV q14, q15 + VST1.32 {d24-d25}, [r8]! + VMOV q12, q13 + VST1.32 {d20-d21}, [r4]! + VMOV q10, q11 + VST1.32 {d16-d17}, [r11]! + VMOV q8, q9 + +11: + TST r1, 2 + BEQ 12f + VST1.32 {d28}, [r6]! + VMOV d28, d29 + VST1.32 {d24}, [r8]! + VMOV d24, d25 + VST1.32 {d20}, [r4]! + VMOV d20, d21 + VST1.32 {d16}, [r11]! + VMOV d16, d17 + +12: + TST r1, 1 + BEQ 13f + VST1.32 {d28[0]}, [r6]! + VST1.32 {d24[0]}, [r8]! + VST1.32 {d20[0]}, [r4]! + VST1.32 {d16[0]}, [r11]! + +13: + VPOP {d8-d15} + ADD sp, sp, 12 // skip pad, r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + +END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64 + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif + + + diff --git a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S index c4ea8a041..3094b16b0 100644 --- a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S +++ b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a53.S @@ -145,7 +145,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 # Is there at least 4 floats (16 bytes) for prologue + epilogue? SUBS x0, x2, 16 // k = kc - 16 - B.LO 4f + B.LO 5f # Prologue - First group loads, no FMA LDR d0, [x14], 8 // a0 @@ -408,6 +408,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v21.4s, v15.4s, v3.s[1] FMLA v23.4s, v15.4s, v3.s[3] FMLA v25.4s, v15.4s, v4.s[1] + TST x0, 15 // BLOCK 7 FMLA v27.4s, v15.4s, v4.s[3] @@ -415,13 +416,10 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v31.4s, v15.4s, v5.s[3] ADD x5, x5, 64 -4: - # Is there a remainder?- 2 floats of A (8 bytes) - TBNZ x0, 3, 6f - # Is there a remainder?- 1 floats of A (4 bytes) - TBNZ x0, 2, 7f + # Is there a remainder?- 2 floats of A (8 bytes) or less + B.NE 5f -5: +4: # ks loop SUBS x9, x9, 48 // ks -= MR * sizeof(void*) B.NE 1b @@ -482,9 +480,11 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 LDP d12, d13, [sp], 80 RET - # Remainder - 2 floats of A (8 bytes) - # 24 FMA + 6 LD64 A + 2 LDP B -6: +5: + # Is there a remainder?- 2 floats of A (8 bytes) + TBZ x0, 3, 6f + + # Remainder- 2 floats of A (8 bytes) LDR d0, [x14], 8 LDR q16, [x5], 16 LD1 {v0.d}[1], [x15], 8 @@ -522,8 +522,8 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v31.4s, v19.4s, v2.s[3] # Is there a remainder?- 1 floats of A (4 bytes) - TBZ x0, 2, 5b -7: + TBZ x0, 2, 4b +6: # Remainder- 1 floats of A (4 bytes) LDR s0, [x14], 4 LDR q16, [x5], 16 @@ -546,7 +546,7 @@ BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53 FMLA v27.4s, v17.4s, v1.s[2] FMLA v29.4s, v17.4s, v2.s[0] FMLA v31.4s, v17.4s, v2.s[2] - B 5b + B 4b # Store odd width 8: diff --git a/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S b/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S new file mode 100644 index 000000000..0e2f07fba --- /dev/null +++ b/src/f32-igemm/gen/4x8-aarch32-neon-cortex-a75.S @@ -0,0 +1,369 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <xnnpack/assembly.h> + +.syntax unified + +// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75( +// size_t mr, r0 +// size_t nc, r1 +// size_t kc, r2 -> r5 -> sp + 64 +// size_t ks, r3 -> sp + 68 -> r14 +// const float**restrict a, sp + 108 -> r2 +// const void*restrict w, sp + 112 -> r9 +// uint8_t*restrict c, sp + 116 -> r11 +// size_t cm_stride, sp + 120 -> (r6) +// size_t cn_stride, sp + 124 -> (r7) +// size_t a_offset, sp + 128 -> (r5) +// const float* zero, sp + 132 -> (r7) +// output_params*params, sp + 136 -> (r5) + +// inner loop registers + +// A0 r3 d0 +// A1 r12 d1 +// A2 r10 d2 +// A3 r0 d3 + +// B r9 d8, d9, d10, d11 +// B d12, d13, d14, d15 + +// C0 r11 d16-d17 q8 d18-d19 q9 +// C1 r4 d20-d21 q10 d22-d23 q11 +// C2 r8 d24-d25 q12 d26-d27 q13 +// C3 r6 d28-d29 q14 d30-d31 q15 + +// Clamp (r5) d4 d5 d6 d7 + +BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75 + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + // Push 108 bytes + // r2 will be reloaded in outer loop. r3 is ks + PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44 + VPUSH {d8-d15} // +64 = 108 + + LDR r2, [sp, 108] // a + LDR r9, [sp, 112] // w + LDR r11, [sp, 116] // c + LDR r6, [sp, 120] // cm_stride + LDR r14, [sp, 68] // p = ks + + // Clamp C pointers + CMP r0, 2 // if mr >= 2 + ADD r4, r11, r6 // c1 = c0 + cm_stride + MOVLO r4, r11 // c1 + // if mr > 2 + ADD r8, r4, r6 // c2 = c1 + cm_stride + MOVLS r8, r4 // c2 + CMP r0, 4 // if mr >=4 + ADD r6, r8, r6 // c3 = c2 + cm_stride + MOVLO r6, r8 // c3 + + .p2align 3 +0: + # Load initial bias from w into accumulators + VLDM r9!, {d16-d19} // Bias + VMOV q10, q8 + VMOV q11, q9 + VMOV q12, q8 + VMOV q13, q9 + VMOV q14, q8 + VMOV q15, q9 + + +1: + # Load next 4 A pointers + LDR r3, [r2, 0] + LDR r12, [r2, 4] + LDR r10, [r2, 8] + LDR r0, [r2, 12] + ADD r2, r2, 16 + + // Add a_offset + LDR r5, [sp, 128] // a_offset + LDR r7, [sp, 132] // zero + CMP r3, r7 // if a0 == zero + ADD r3, r3, r5 // a0 += a_offset + MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset + CMP r12, r7 // if a1 == zero + ADD r12, r12, r5 // a1 += a_offset + MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset + CMP r10, r7 // if a2 == zero + ADD r10, r10, r5 // a2 += a_offset + MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset + CMP r0, r7 // if a3 == zero + ADD r0, r0, r5 // a3 += a_offset + LDR r5, [sp, 64] // kc + MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset + + + SUBS r5, r5, 16 // kc - 16 + BLO 4f // less than 4 channels? + + // Prologue + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + SUBS r5, r5, 16 + BLO 3f // less than 4 channels? skip main loop + + .p2align 3 + + // Main loop - 4 floats of A (16 bytes) +2: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + VMLA.F32 q12, q4, d6[0] + VLD1.32 {d0}, [r3]! // A0 + VMLA.F32 q14, q4, d7[0] + VMLA.F32 q9, q5, d4[0] + VLD1.32 {d1}, [r12]! // A1 + VMLA.F32 q11, q5, d5[0] + VMLA.F32 q13, q5, d6[0] + VLD1.32 {d2}, [r10]! // A2 + VMLA.F32 q15, q5, d7[0] + VMLA.F32 q8, q6, d4[1] + VLD1.32 {d3}, [ r0]! // A3 + VMLA.F32 q10, q6, d5[1] + VMLA.F32 q12, q6, d6[1] + VMLA.F32 q14, q6, d7[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + SUBS r5, r5, 16 + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + BHS 2b + + // Epilogue +3: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + VMLA.F32 q12, q4, d6[0] + VMLA.F32 q14, q4, d7[0] + VMLA.F32 q9, q5, d4[0] + VMLA.F32 q11, q5, d5[0] + VMLA.F32 q13, q5, d6[0] + VMLA.F32 q15, q5, d7[0] + VMLA.F32 q8, q6, d4[1] + VMLA.F32 q10, q6, d5[1] + VMLA.F32 q12, q6, d6[1] + VMLA.F32 q14, q6, d7[1] + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + +4: + // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes) + TST r5, 12 + BNE 7f + + .p2align 3 +5: + # ks loop + SUBS r14, r14, 16 // ks -= MR * sizeof(void*) + BNE 1b + + // Load params pointer + LDR r5, [sp, 136] // clamping_params + LDR r7, [sp, 124] // cn_stride + LDR r14, [sp, 68] // p = ks + + // Load clamping_params values + VLD1.32 {d4[],d5[]}, [r5]! + SUBS r1, r1, 8 + VLD1.32 {d6[],d7[]}, [r5] + + // Clamp + VMIN.F32 q8, q8, q2 + VMIN.F32 q9, q9, q2 + VMIN.F32 q10, q10, q2 + VMIN.F32 q11, q11, q2 + VMIN.F32 q12, q12, q2 + VMIN.F32 q13, q13, q2 + VMIN.F32 q14, q14, q2 + VMIN.F32 q15, q15, q2 + VMAX.F32 q8, q8, q3 + VMAX.F32 q9, q9, q3 + VMAX.F32 q10, q10, q3 + VMAX.F32 q11, q11, q3 + VMAX.F32 q12, q12, q3 + VMAX.F32 q13, q13, q3 + VMAX.F32 q14, q14, q3 + VMAX.F32 q15, q15, q3 + + // Store full 4 x 8 + BLO 10f + VST1.32 {d28-d31}, [r6], r7 + VST1.32 {d24-d27}, [r8], r7 + VST1.32 {d20-d23}, [r4], r7 + VST1.32 {d16-d19}, [r11], r7 + SUB r2, r2, r14 // a -= ks + BHI 0b + +6: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + + .p2align 3 +7: + // Is there a remainder?- 2 floats of A (8 bytes) + TST r5, 8 + BEQ 8f + + // Remainder - 2 floats of A (8 bytes) + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q10, q6, d1[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q14, q6, d3[1] + VMLA.F32 q15, q7, d3[1] +8: + // Is there a remainder?- 1 floats of A (4 bytes) + TST r5, 4 + BEQ 5b + +9: + // Remainder- 1 floats of A (4 bytes) + VLDM r3!, {s0} // A0 + VLDM r9!, {d8-d11} // B0 + VLDM r12!, {s2} // A1 + VLDM r10!, {s4} // A2 + VLDM r0!, {s6} // A3 + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + B 5b + + // Store odd width +10: + TST r1, 4 + BEQ 11f + VST1.32 {d28-d29}, [r6]! + VMOV q14, q15 + VST1.32 {d24-d25}, [r8]! + VMOV q12, q13 + VST1.32 {d20-d21}, [r4]! + VMOV q10, q11 + VST1.32 {d16-d17}, [r11]! + VMOV q8, q9 + +11: + TST r1, 2 + BEQ 12f + VST1.32 {d28}, [r6]! + VMOV d28, d29 + VST1.32 {d24}, [r8]! + VMOV d24, d25 + VST1.32 {d20}, [r4]! + VMOV d20, d21 + VST1.32 {d16}, [r11]! + VMOV d16, d17 + +12: + TST r1, 1 + BEQ 13f + VST1.32 {d28[0]}, [r6]! + VST1.32 {d24[0]}, [r8]! + VST1.32 {d20[0]}, [r4]! + VST1.32 {d16[0]}, [r11]! + +13: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + +END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75 + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif + + + diff --git a/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S b/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S new file mode 100644 index 000000000..0051bad2f --- /dev/null +++ b/src/f32-igemm/gen/4x8-aarch32-neon-pld-cortex-a75.S @@ -0,0 +1,389 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-igemm/4x8-aarch32-neon-cortex-a75.S.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <xnnpack/assembly.h> + +.syntax unified + +// void xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75( +// size_t mr, r0 +// size_t nc, r1 +// size_t kc, r2 -> r5 -> sp + 64 +// size_t ks, r3 -> sp + 68 -> r14 +// const float**restrict a, sp + 108 -> r2 +// const void*restrict w, sp + 112 -> r9 +// uint8_t*restrict c, sp + 116 -> r11 +// size_t cm_stride, sp + 120 -> (r6) +// size_t cn_stride, sp + 124 -> (r7) +// size_t a_offset, sp + 128 -> (r5) +// const float* zero, sp + 132 -> (r7) +// output_params*params, sp + 136 -> (r5) + +// inner loop registers + +// A0 r3 d0 +// A1 r12 d1 +// A2 r10 d2 +// A3 r0 d3 + +// B r9 d8, d9, d10, d11 +// B d12, d13, d14, d15 + +// C0 r11 d16-d17 q8 d18-d19 q9 +// C1 r4 d20-d21 q10 d22-d23 q11 +// C2 r8 d24-d25 q12 d26-d27 q13 +// C3 r6 d28-d29 q14 d30-d31 q15 + +// Clamp (r5) d4 d5 d6 d7 + +BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75 + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + // Push 108 bytes + // r2 will be reloaded in outer loop. r3 is ks + PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r14} // +44 + VPUSH {d8-d15} // +64 = 108 + + LDR r2, [sp, 108] // a + LDR r9, [sp, 112] // w + LDR r11, [sp, 116] // c + LDR r6, [sp, 120] // cm_stride + LDR r14, [sp, 68] // p = ks + + // Clamp C pointers + CMP r0, 2 // if mr >= 2 + ADD r4, r11, r6 // c1 = c0 + cm_stride + MOVLO r4, r11 // c1 + // if mr > 2 + ADD r8, r4, r6 // c2 = c1 + cm_stride + MOVLS r8, r4 // c2 + CMP r0, 4 // if mr >=4 + ADD r6, r8, r6 // c3 = c2 + cm_stride + MOVLO r6, r8 // c3 + + .p2align 3 +0: + # Load initial bias from w into accumulators + VLDM r9!, {d16-d19} // Bias + VMOV q10, q8 + VMOV q11, q9 + VMOV q12, q8 + VMOV q13, q9 + VMOV q14, q8 + VMOV q15, q9 + + PLD [r9, 0] // Prefetch B + PLD [r9, 64] + PLD [r9, 128] + PLD [r9, 192] + PLD [r9, 256] + PLD [r9, 320] + +1: + # Load next 4 A pointers + LDR r3, [r2, 0] + LDR r12, [r2, 4] + LDR r10, [r2, 8] + LDR r0, [r2, 12] + ADD r2, r2, 16 + + // Add a_offset + LDR r5, [sp, 128] // a_offset + LDR r7, [sp, 132] // zero + CMP r3, r7 // if a0 == zero + ADD r3, r3, r5 // a0 += a_offset + MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset + CMP r12, r7 // if a1 == zero + ADD r12, r12, r5 // a1 += a_offset + MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset + CMP r10, r7 // if a2 == zero + ADD r10, r10, r5 // a2 += a_offset + MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset + CMP r0, r7 // if a3 == zero + ADD r0, r0, r5 // a3 += a_offset + LDR r5, [sp, 64] // kc + MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset + + PLD [r3, 0] // Prefetch A + PLD [r3, 64] + PLD [r12, 0] + PLD [r12, 64] + PLD [r10, 0] + PLD [r10, 64] + PLD [r0, 0] + PLD [r0, 64] + + SUBS r5, r5, 16 // kc - 16 + BLO 4f // less than 4 channels? + + // Prologue + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + SUBS r5, r5, 16 + BLO 3f // less than 4 channels? skip main loop + + .p2align 3 + + // Main loop - 4 floats of A (16 bytes) +2: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + PLD [r3, 128] // Prefetch A0 + VMLA.F32 q12, q4, d6[0] + VLD1.32 {d0}, [r3]! // A0 + VMLA.F32 q14, q4, d7[0] + PLD [r12, 128] // Prefetch A1 + VMLA.F32 q9, q5, d4[0] + VLD1.32 {d1}, [r12]! // A1 + VMLA.F32 q11, q5, d5[0] + PLD [r10, 128] // Prefetch A2 + VMLA.F32 q13, q5, d6[0] + VLD1.32 {d2}, [r10]! // A2 + VMLA.F32 q15, q5, d7[0] + PLD [r0, 128] // Prefetch A3 + VMLA.F32 q8, q6, d4[1] + VLD1.32 {d3}, [ r0]! // A3 + VMLA.F32 q10, q6, d5[1] + PLD [r9, 384] // Prefetch B + VMLA.F32 q12, q6, d6[1] + PLD [r9, 448] // Prefetch B + VMLA.F32 q14, q6, d7[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + SUBS r5, r5, 16 + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + BHS 2b + + // Epilogue +3: + VMLA.F32 q8, q4, d0[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q12, q4, d2[0] + VLD1.32 {d4}, [r3]! // A0 + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q9, q5, d0[0] + VLD1.32 {d5}, [r12]! // A1 + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q15, q5, d3[0] + VLD1.32 {d6}, [r10]! // A2 + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q10, q6, d1[1] + VLD1.32 {d7}, [ r0]! // A3 + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q14, q6, d3[1] + VLDM r9!, {d8-d11} // B0 + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q15, q7, d3[1] + + VMLA.F32 q8, q4, d4[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q10, q4, d5[0] + VMLA.F32 q12, q4, d6[0] + VMLA.F32 q14, q4, d7[0] + VMLA.F32 q9, q5, d4[0] + VMLA.F32 q11, q5, d5[0] + VMLA.F32 q13, q5, d6[0] + VMLA.F32 q15, q5, d7[0] + VMLA.F32 q8, q6, d4[1] + VMLA.F32 q10, q6, d5[1] + VMLA.F32 q12, q6, d6[1] + VMLA.F32 q14, q6, d7[1] + VMLA.F32 q9, q7, d4[1] + VMLA.F32 q11, q7, d5[1] + VMLA.F32 q13, q7, d6[1] + VMLA.F32 q15, q7, d7[1] + +4: + // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes) + TST r5, 12 + BNE 7f + + .p2align 3 +5: + # ks loop + SUBS r14, r14, 16 // ks -= MR * sizeof(void*) + BNE 1b + + // Load params pointer + LDR r5, [sp, 136] // clamping_params + LDR r7, [sp, 124] // cn_stride + LDR r14, [sp, 68] // p = ks + + // Load clamping_params values + VLD1.32 {d4[],d5[]}, [r5]! + SUBS r1, r1, 8 + VLD1.32 {d6[],d7[]}, [r5] + + // Clamp + VMIN.F32 q8, q8, q2 + VMIN.F32 q9, q9, q2 + VMIN.F32 q10, q10, q2 + VMIN.F32 q11, q11, q2 + VMIN.F32 q12, q12, q2 + VMIN.F32 q13, q13, q2 + VMIN.F32 q14, q14, q2 + VMIN.F32 q15, q15, q2 + VMAX.F32 q8, q8, q3 + VMAX.F32 q9, q9, q3 + VMAX.F32 q10, q10, q3 + VMAX.F32 q11, q11, q3 + VMAX.F32 q12, q12, q3 + VMAX.F32 q13, q13, q3 + VMAX.F32 q14, q14, q3 + VMAX.F32 q15, q15, q3 + + // Store full 4 x 8 + BLO 10f + VST1.32 {d28-d31}, [r6], r7 + VST1.32 {d24-d27}, [r8], r7 + VST1.32 {d20-d23}, [r4], r7 + VST1.32 {d16-d19}, [r11], r7 + SUB r2, r2, r14 // a -= ks + BHI 0b + +6: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + + .p2align 3 +7: + // Is there a remainder?- 2 floats of A (8 bytes) + TST r5, 8 + BEQ 8f + + // Remainder - 2 floats of A (8 bytes) + VLD1.32 {d0}, [r3]! // A0 + VLDM r9!, {d8-d11} // B0 + VLD1.32 {d1}, [r12]! // A1 + VLD1.32 {d2}, [r10]! // A2 + VLD1.32 {d3}, [ r0]! // A3 + + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VLDM r9!, {d12-d15} // B1 + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + VMLA.F32 q8, q6, d0[1] + VMLA.F32 q9, q7, d0[1] + VMLA.F32 q10, q6, d1[1] + VMLA.F32 q11, q7, d1[1] + VMLA.F32 q12, q6, d2[1] + VMLA.F32 q13, q7, d2[1] + VMLA.F32 q14, q6, d3[1] + VMLA.F32 q15, q7, d3[1] +8: + // Is there a remainder?- 1 floats of A (4 bytes) + TST r5, 4 + BEQ 5b + +9: + // Remainder- 1 floats of A (4 bytes) + VLDM r3!, {s0} // A0 + VLDM r9!, {d8-d11} // B0 + VLDM r12!, {s2} // A1 + VLDM r10!, {s4} // A2 + VLDM r0!, {s6} // A3 + VMLA.F32 q8, q4, d0[0] + VMLA.F32 q9, q5, d0[0] + VMLA.F32 q10, q4, d1[0] + VMLA.F32 q11, q5, d1[0] + VMLA.F32 q12, q4, d2[0] + VMLA.F32 q13, q5, d2[0] + VMLA.F32 q14, q4, d3[0] + VMLA.F32 q15, q5, d3[0] + B 5b + + // Store odd width +10: + TST r1, 4 + BEQ 11f + VST1.32 {d28-d29}, [r6]! + VMOV q14, q15 + VST1.32 {d24-d25}, [r8]! + VMOV q12, q13 + VST1.32 {d20-d21}, [r4]! + VMOV q10, q11 + VST1.32 {d16-d17}, [r11]! + VMOV q8, q9 + +11: + TST r1, 2 + BEQ 12f + VST1.32 {d28}, [r6]! + VMOV d28, d29 + VST1.32 {d24}, [r8]! + VMOV d24, d25 + VST1.32 {d20}, [r4]! + VMOV d20, d21 + VST1.32 {d16}, [r11]! + VMOV d16, d17 + +12: + TST r1, 1 + BEQ 13f + VST1.32 {d28[0]}, [r6]! + VST1.32 {d24[0]}, [r8]! + VST1.32 {d20[0]}, [r4]! + VST1.32 {d16[0]}, [r11]! + +13: + VPOP {d8-d15} + ADD sp, sp, 8 // skip r2, r3 + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} + +END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75 + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif + + + diff --git a/src/init.c b/src/init.c index 43d0abf8c..9f35c5422 100644 --- a/src/init.c +++ b/src/init.c @@ -138,7 +138,7 @@ static void init(void) { case cpuinfo_uarch_cortex_a55: xnn_params.f32.gemm = (struct gemm_parameters) { .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53, - .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128, + .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64, .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, .mr = 4, @@ -151,7 +151,7 @@ static void init(void) { case cpuinfo_uarch_cortex_a73: xnn_params.f32.gemm = (struct gemm_parameters) { .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, - .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128, + .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, .mr = 4, @@ -162,7 +162,7 @@ static void init(void) { default: xnn_params.f32.gemm = (struct gemm_parameters) { .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75, - .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128, + .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75, .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64, .mr = 4, diff --git a/src/runtime.c b/src/runtime.c index b210b34bb..f0e262f01 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -68,10 +68,10 @@ enum xnn_status xnn_create_runtime_v2( if (status != xnn_status_success) { goto error; } - runtime->ops[i].shape1.num_dims = subgraph->values[node->inputs.raw[0]].shape.num_dims; - runtime->ops[i].shape2.num_dims = subgraph->values[node->inputs.raw[1]].shape.num_dims; - memcpy(runtime->ops[i].shape1.dim, subgraph->values[node->inputs.raw[0]].shape.dim, subgraph->values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t)); - memcpy(runtime->ops[i].shape2.dim, subgraph->values[node->inputs.raw[1]].shape.dim, subgraph->values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t)); + runtime->ops[i].shape1.num_dims = values[node->inputs.raw[0]].shape.num_dims; + runtime->ops[i].shape2.num_dims = values[node->inputs.raw[1]].shape.num_dims; + memcpy(runtime->ops[i].shape1.dim, values[node->inputs.raw[0]].shape.dim, values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t)); + memcpy(runtime->ops[i].shape2.dim, values[node->inputs.raw[1]].shape.dim, values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t)); runtime->ops[i].inputs[0] = node->inputs.raw[0]; runtime->ops[i].inputs[1] = node->inputs.raw[1]; runtime->ops[i].outputs[0] = node->outputs.raw[0]; @@ -102,9 +102,28 @@ enum xnn_status xnn_create_runtime_v2( if (status != xnn_status_success) { goto error; } - runtime->ops[i].batch_size = subgraph->values[node->inputs.raw[0]].shape.dim[0]; - runtime->ops[i].input_height = subgraph->values[node->inputs.raw[0]].shape.dim[1]; - runtime->ops[i].input_width = subgraph->values[node->inputs.raw[0]].shape.dim[2]; + runtime->ops[i].batch_size = values[node->inputs.raw[0]].shape.dim[0]; + runtime->ops[i].input_height = values[node->inputs.raw[0]].shape.dim[1]; + runtime->ops[i].input_width = values[node->inputs.raw[0]].shape.dim[2]; + runtime->ops[i].inputs[0] = node->inputs.raw[0]; + runtime->ops[i].outputs[0] = node->outputs.raw[0]; + break; + case xnn_node_type_clamp: + status = xnn_create_clamp_nc_f32( + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */, + node->activation.output_min, + node->activation.output_max, + node->flags, + &runtime->ops[i].op); + if (status != xnn_status_success) { + goto error; + } + runtime->ops[i].batch_size = 1; + for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) { + runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i]; + } runtime->ops[i].inputs[0] = node->inputs.raw[0]; runtime->ops[i].outputs[0] = node->outputs.raw[0]; break; @@ -134,9 +153,26 @@ enum xnn_status xnn_create_runtime_v2( if (status != xnn_status_success) { goto error; } - runtime->ops[i].batch_size = subgraph->values[node->inputs.raw[0]].shape.dim[0]; - runtime->ops[i].input_height = subgraph->values[node->inputs.raw[0]].shape.dim[1]; - runtime->ops[i].input_width = subgraph->values[node->inputs.raw[0]].shape.dim[2]; + runtime->ops[i].batch_size = values[node->inputs.raw[0]].shape.dim[0]; + runtime->ops[i].input_height = values[node->inputs.raw[0]].shape.dim[1]; + runtime->ops[i].input_width = values[node->inputs.raw[0]].shape.dim[2]; + runtime->ops[i].inputs[0] = node->inputs.raw[0]; + runtime->ops[i].outputs[0] = node->outputs.raw[0]; + break; + case xnn_node_type_hardswish: + status = xnn_create_hardswish_nc_f32( + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */, + node->flags, + &runtime->ops[i].op); + if (status != xnn_status_success) { + goto error; + } + runtime->ops[i].batch_size = 1; + for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) { + runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i]; + } runtime->ops[i].inputs[0] = node->inputs.raw[0]; runtime->ops[i].outputs[0] = node->outputs.raw[0]; break; @@ -149,14 +185,68 @@ enum xnn_status xnn_create_runtime_v2( if (status != xnn_status_success) { goto error; } - runtime->ops[i].shape1.num_dims = subgraph->values[node->inputs.raw[0]].shape.num_dims; - runtime->ops[i].shape2.num_dims = subgraph->values[node->inputs.raw[1]].shape.num_dims; - memcpy(runtime->ops[i].shape1.dim, subgraph->values[node->inputs.raw[0]].shape.dim, subgraph->values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t)); - memcpy(runtime->ops[i].shape2.dim, subgraph->values[node->inputs.raw[1]].shape.dim, subgraph->values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t)); + runtime->ops[i].shape1.num_dims = values[node->inputs.raw[0]].shape.num_dims; + runtime->ops[i].shape2.num_dims = values[node->inputs.raw[1]].shape.num_dims; + memcpy(runtime->ops[i].shape1.dim, values[node->inputs.raw[0]].shape.dim, values[node->inputs.raw[0]].shape.num_dims * sizeof(size_t)); + memcpy(runtime->ops[i].shape2.dim, values[node->inputs.raw[1]].shape.dim, values[node->inputs.raw[1]].shape.num_dims * sizeof(size_t)); runtime->ops[i].inputs[0] = node->inputs.raw[0]; runtime->ops[i].inputs[1] = node->inputs.raw[1]; runtime->ops[i].outputs[0] = node->outputs.raw[0]; break; + case xnn_node_type_prelu: + status = xnn_create_prelu_nc_f32( + values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* channels */, + values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* input stride */, + values[node->inputs.raw[1]].shape.dim[values[node->inputs.raw[1]].shape.num_dims - 1] /* output stride */, + values[node->inputs.raw[1]].data /* negative slope */, + -INFINITY, + +INFINITY, + node->flags, + &runtime->ops[i].op); + if (status != xnn_status_success) { + goto error; + } + runtime->ops[i].batch_size = 1; + for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) { + runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i]; + } + runtime->ops[i].inputs[0] = node->inputs.raw[0]; + runtime->ops[i].outputs[0] = node->outputs.raw[0]; + break; + case xnn_node_type_sigmoid: + status = xnn_create_sigmoid_nc_f32( + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */, + node->flags, + &runtime->ops[i].op); + if (status != xnn_status_success) { + goto error; + } + runtime->ops[i].batch_size = 1; + for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) { + runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i]; + } + runtime->ops[i].inputs[0] = node->inputs.raw[0]; + runtime->ops[i].outputs[0] = node->outputs.raw[0]; + break; + case xnn_node_type_softmax: + status = xnn_create_softmax_nc_f32( + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* channels */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* input stride */, + values[node->inputs.raw[0]].shape.dim[values[node->inputs.raw[0]].shape.num_dims - 1] /* output stride */, + node->flags, + &runtime->ops[i].op); + if (status != xnn_status_success) { + goto error; + } + runtime->ops[i].batch_size = 1; + for (size_t i = 0; i + 1 < values[node->inputs.raw[0]].shape.num_dims; i++) { + runtime->ops[i].batch_size *= values[node->inputs.raw[0]].shape.dim[i]; + } + runtime->ops[i].inputs[0] = node->inputs.raw[0]; + runtime->ops[i].outputs[0] = node->outputs.raw[0]; + break; case xnn_node_type_invalid: xnn_log_fatal("unexpected node type %d in node #%zu", node->type, i); XNN_UNREACHABLE; @@ -282,6 +372,26 @@ enum xnn_status xnn_setup_runtime( runtime->blobs[op->outputs[0]].data, runtime->threadpool); break; + case xnn_operator_type_clamp_nc_f32: + assert(runtime->blobs[op->inputs[0]].data != NULL); + assert(runtime->blobs[op->outputs[0]].data != NULL); + status = xnn_setup_clamp_nc_f32( + op->op, + op->batch_size, + runtime->blobs[op->inputs[0]].data, + runtime->blobs[op->outputs[0]].data, + runtime->threadpool); + break; + case xnn_operator_type_hardswish_nc_f32: + assert(runtime->blobs[op->inputs[0]].data != NULL); + assert(runtime->blobs[op->outputs[0]].data != NULL); + status = xnn_setup_hardswish_nc_f32( + op->op, + op->batch_size, + runtime->blobs[op->inputs[0]].data, + runtime->blobs[op->outputs[0]].data, + runtime->threadpool); + break; case xnn_operator_type_multiply_nd_f32: assert(runtime->blobs[op->inputs[0]].data != NULL); assert(runtime->blobs[op->inputs[1]].data != NULL); @@ -297,6 +407,36 @@ enum xnn_status xnn_setup_runtime( runtime->blobs[op->outputs[0]].data, runtime->threadpool); break; + case xnn_operator_type_prelu_nc_f32: + assert(runtime->blobs[op->inputs[0]].data != NULL); + assert(runtime->blobs[op->outputs[0]].data != NULL); + status = xnn_setup_prelu_nc_f32( + op->op, + op->batch_size, + runtime->blobs[op->inputs[0]].data, + runtime->blobs[op->outputs[0]].data, + runtime->threadpool); + break; + case xnn_operator_type_sigmoid_nc_f32: + assert(runtime->blobs[op->inputs[0]].data != NULL); + assert(runtime->blobs[op->outputs[0]].data != NULL); + status = xnn_setup_sigmoid_nc_f32( + op->op, + op->batch_size, + runtime->blobs[op->inputs[0]].data, + runtime->blobs[op->outputs[0]].data, + runtime->threadpool); + break; + case xnn_operator_type_softmax_nc_f32: + assert(runtime->blobs[op->inputs[0]].data != NULL); + assert(runtime->blobs[op->outputs[0]].data != NULL); + status = xnn_setup_softmax_nc_f32( + op->op, + op->batch_size, + runtime->blobs[op->inputs[0]].data, + runtime->blobs[op->outputs[0]].data, + runtime->threadpool); + break; default: xnn_log_fatal("unexpected operator type %d in operator #%zu", op->op->type, i); XNN_UNREACHABLE; diff --git a/src/subgraph.c b/src/subgraph.c index 714195f65..d90d62589 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -554,6 +554,219 @@ enum xnn_status xnn_define_multiply2( return xnn_status_success; } +enum xnn_status xnn_define_prelu( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t slope_id, + uint32_t output_id, + uint32_t flags) +{ + if (!xnn_params.initialized) { + xnn_log_error("failed to define PReLU operator: XNNPACK is not initialized"); + return xnn_status_uninitialized; + } + + if (input_id >= subgraph->num_values) { + xnn_log_error( + "failed to define PReLU operator with input ID #%" PRIu32 ": invalid Value ID", + input_id); + return xnn_status_invalid_parameter; + } + + if (slope_id >= subgraph->num_values) { + xnn_log_error( + "failed to define PReLU operator with slope ID #%" PRIu32 ": invalid Value ID", + slope_id); + return xnn_status_invalid_parameter; + } + + if (output_id >= subgraph->num_values) { + xnn_log_error( + "failed to define PReLU operator with output ID #%" PRIu32 ": invalid Value ID", + output_id); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_prelu; + node->num_inputs = 2; + node->inputs.raw[0] = input_id; + node->inputs.raw[1] = slope_id; + node->num_outputs = 1; + node->outputs.raw[0] = output_id; + node->flags = flags; + + return xnn_status_success; +} + +enum xnn_status xnn_define_clamp( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags) +{ + if (!xnn_params.initialized) { + xnn_log_error("failed to define Clamp operator: XNNPACK is not initialized"); + return xnn_status_uninitialized; + } + + if (input_id >= subgraph->num_values) { + xnn_log_error( + "failed to define Clamp operator with input ID #%" PRIu32 ": invalid Value ID", + input_id); + return xnn_status_invalid_parameter; + } + + if (output_id >= subgraph->num_values) { + xnn_log_error( + "failed to define Clamp operator with output ID #%" PRIu32 ": invalid Value ID", + output_id); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_clamp; + node->activation.output_min = output_min; + node->activation.output_max = output_max; + node->num_inputs = 1; + node->inputs.raw[0] = input_id; + node->num_outputs = 1; + node->outputs.raw[0] = output_id; + node->flags = flags; + + return xnn_status_success; +} + +enum xnn_status xnn_define_hardswish( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags) +{ + if (!xnn_params.initialized) { + xnn_log_error("failed to define HardSwish operator: XNNPACK is not initialized"); + return xnn_status_uninitialized; + } + + if (input_id >= subgraph->num_values) { + xnn_log_error( + "failed to define HardSwish operator with input ID #%" PRIu32 ": invalid Value ID", + input_id); + return xnn_status_invalid_parameter; + } + + if (output_id >= subgraph->num_values) { + xnn_log_error( + "failed to define HardSwish operator with output ID #%" PRIu32 ": invalid Value ID", + output_id); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_hardswish; + node->num_inputs = 1; + node->inputs.raw[0] = input_id; + node->num_outputs = 1; + node->outputs.raw[0] = output_id; + node->flags = flags; + + return xnn_status_success; +} + +enum xnn_status xnn_define_sigmoid( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags) +{ + if (!xnn_params.initialized) { + xnn_log_error("failed to define Sigmoid operator: XNNPACK is not initialized"); + return xnn_status_uninitialized; + } + + if (input_id >= subgraph->num_values) { + xnn_log_error( + "failed to define Sigmoid operator with input ID #%" PRIu32 ": invalid Value ID", + input_id); + return xnn_status_invalid_parameter; + } + + if (output_id >= subgraph->num_values) { + xnn_log_error( + "failed to define Sigmoid operator with output ID #%" PRIu32 ": invalid Value ID", + output_id); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_sigmoid; + node->num_inputs = 1; + node->inputs.raw[0] = input_id; + node->num_outputs = 1; + node->outputs.raw[0] = output_id; + node->flags = flags; + + return xnn_status_success; +} + +enum xnn_status xnn_define_softmax( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags) +{ + if (!xnn_params.initialized) { + xnn_log_error("failed to define SoftMax operator: XNNPACK is not initialized"); + return xnn_status_uninitialized; + } + + if (input_id >= subgraph->num_values) { + xnn_log_error( + "failed to define SoftMax operator with input ID #%" PRIu32 ": invalid Value ID", + input_id); + return xnn_status_invalid_parameter; + } + + if (output_id >= subgraph->num_values) { + xnn_log_error( + "failed to define SoftMax operator with output ID #%" PRIu32 ": invalid Value ID", + output_id); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_softmax; + node->num_inputs = 1; + node->inputs.raw[0] = input_id; + node->num_outputs = 1; + node->outputs.raw[0] = output_id; + node->flags = flags; + + return xnn_status_success; +} + enum xnn_status xnn_delete_subgraph( xnn_subgraph_t subgraph) { diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h index 1de690495..d2bf06b53 100644 --- a/src/xnnpack/igemm.h +++ b/src/xnnpack/igemm.h @@ -88,6 +88,10 @@ DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_co DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53) DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53) +DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64) +DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75) +DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75) + DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__sse_load1) DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__sse_load1) diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index de1a9d3f6..7f9342efe 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -58,9 +58,14 @@ struct xnn_blob { enum xnn_node_type { xnn_node_type_invalid = 0, xnn_node_type_add2, + xnn_node_type_clamp, xnn_node_type_convolution_2d, xnn_node_type_depthwise_convolution_2d, + xnn_node_type_hardswish, xnn_node_type_multiply2, + xnn_node_type_prelu, + xnn_node_type_sigmoid, + xnn_node_type_softmax, }; struct xnn_node { diff --git a/test/f32-igemm.cc b/test/f32-igemm.cc index 597b8fcfa..54eee542e 100644 --- a/test/f32-igemm.cc +++ b/test/f32-igemm.cc @@ -3016,6 +3016,1472 @@ #endif // XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY +#if XNN_ARCH_ARM + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(2) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(2) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(2) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(8) + .k(2) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_eq_2_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(2) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_lt_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 2; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_lt_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 2; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_gt_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 4; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_gt_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 4; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_div_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 4; k <= 20; k += 2) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, k_div_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 4; k <= 20; k += 2) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 10; k += 3) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 10; k += 3) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 10; k += 3) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 10; k += 3) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(43) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 10; k += 3) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(43) + .zero_index(mz) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(2) + .qmin(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(2) + .qmax(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(2) + .cm_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64); + } +#endif // XNN_ARCH_ARM + + +#if XNN_ARCH_ARM + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(8) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(8) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_div_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 12; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 12; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .zero_index(mz) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .qmin(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .qmax(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_CORTEX_A75, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .cm_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75); + } +#endif // XNN_ARCH_ARM + + +#if XNN_ARCH_ARM + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(8) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(8) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_div_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 12; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 12; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .zero_index(mz) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + } + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .qmin(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .qmax(128) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } + + TEST(F32_IGEMM_4X8__AARCH32_NEON_PLD_CORTEX_A75, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(1) + .sr(1) + .m(4) + .n(8) + .k(4) + .cm_stride(11) + .Test(xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75); + } +#endif // XNN_ARCH_ARM + + #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY TEST(F32_IGEMM_5X8__AARCH64_NEONFMA_CORTEX_A57, k_eq_8) { TEST_REQUIRES_ARM_NEON_FMA; diff --git a/test/f32-igemm.yaml b/test/f32-igemm.yaml index 5563dee19..a8d5f1fd0 100644 --- a/test/f32-igemm.yaml +++ b/test/f32-igemm.yaml @@ -26,6 +26,15 @@ k-block: 8 pipelined: true assembly: true +- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64 + k-block: 2 + pipelined: false +- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75 + k-block: 4 + pipelined: true +- name: xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75 + k-block: 4 + pipelined: true - name: xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a57 k-block: 8 pipelined: true diff --git a/third_party/cpuinfo.BUILD b/third_party/cpuinfo.BUILD index ad8a07000..58ab81717 100644 --- a/third_party/cpuinfo.BUILD +++ b/third_party/cpuinfo.BUILD @@ -107,6 +107,17 @@ cc_library( ":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS, ":android_x86": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, ":android_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, + ":ios_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":ios_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":ios_armv7": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":ios_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":ios_arm64e": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":watchos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, ":emscripten_wasm": COMMON_SRCS + EMSCRIPTEN_SRCS, }), copts = C99OPTS + [ @@ -201,6 +212,94 @@ config_setting( ) config_setting( + name = "ios_armv7", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "ios_arm64e", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_arm64e", + }, +) + +config_setting( + name = "ios_x86", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_i386", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "watchos_armv7k", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_armv7k", + }, +) + +config_setting( + name = "watchos_arm64_32", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_arm64_32", + }, +) + +config_setting( + name = "watchos_x86", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_i386", + }, +) + +config_setting( + name = "watchos_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "watchos_x86_64", + }, +) + +config_setting( + name = "tvos_arm64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "tvos_arm64", + }, +) + +config_setting( + name = "tvos_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "tvos_x86_64", + }, +) + +config_setting( name = "emscripten_wasm", values = { "cpu": "wasm", |