diff options
author | Ian Hua <ianhua@google.com> | 2021-08-16 17:20:47 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-08-16 17:20:47 +0000 |
commit | 0a3dd72c00d27c1b37f35a638a97e1cfa506dd61 (patch) | |
tree | 420dd314b49b6a7f5772ae3cb6e5f738520572b8 | |
parent | 033350c4252004f65e85e8d547b473ae28ebd158 (diff) | |
parent | d4ddc68d70a2a5d9006ec93f1927819baea347be (diff) | |
download | ruy-0a3dd72c00d27c1b37f35a638a97e1cfa506dd61.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master' to 'aosp/master' for external/ruy. am: b635c099fe am: 6cb29cf3e2 am: f058a5f8c9 am: d4ddc68d70
Original change: https://android-review.googlesource.com/c/platform/external/ruy/+/1793687
Change-Id: If3277202dc5c34bef0f9fa1da383eb7621ef837f
-rw-r--r-- | Android.bp | 1 | ||||
-rwxr-xr-x | cmake/bazel_to_cmake.py | 297 | ||||
-rw-r--r-- | ruy/BUILD | 26 | ||||
-rw-r--r-- | ruy/CMakeLists.txt | 34 | ||||
-rw-r--r-- | ruy/apply_multiplier.cc | 1 | ||||
-rw-r--r-- | ruy/apply_multiplier_test.cc | 5 | ||||
-rw-r--r-- | ruy/block_map.cc | 3 | ||||
-rw-r--r-- | ruy/context.cc | 5 | ||||
-rw-r--r-- | ruy/context.h | 3 | ||||
-rw-r--r-- | ruy/cpuinfo.cc | 12 | ||||
-rw-r--r-- | ruy/cpuinfo.h | 1 | ||||
-rw-r--r-- | ruy/denormal.cc | 121 | ||||
-rw-r--r-- | ruy/denormal.h | 53 | ||||
-rw-r--r-- | ruy/kernel_arm.h | 14 | ||||
-rw-r--r-- | ruy/kernel_arm32.cc | 8 | ||||
-rw-r--r-- | ruy/kernel_arm64.cc | 1809 | ||||
-rw-r--r-- | ruy/kernel_avx512.cc | 428 | ||||
-rw-r--r-- | ruy/kernel_common.h | 7 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 16 | ||||
-rw-r--r-- | ruy/mat.h | 8 | ||||
-rw-r--r-- | ruy/mul_params.h | 62 | ||||
-rw-r--r-- | ruy/mul_params_test.cc | 2 | ||||
-rw-r--r-- | ruy/prepacked_cache.cc | 10 | ||||
-rw-r--r-- | ruy/prepacked_cache.h | 8 | ||||
-rw-r--r-- | ruy/ruy.h | 8 | ||||
-rw-r--r-- | ruy/test.h | 3 | ||||
-rw-r--r-- | ruy/test_overflow_dst_zero_point.cc | 133 | ||||
-rw-r--r-- | ruy/thread_pool.cc | 4 | ||||
-rw-r--r-- | ruy/trmul.cc | 7 | ||||
-rw-r--r-- | ruy/trmul_params.h | 4 | ||||
-rw-r--r-- | ruy/tune.cc | 8 | ||||
-rw-r--r-- | ruy/tune.h | 8 |
32 files changed, 2504 insertions, 605 deletions
@@ -50,6 +50,7 @@ cc_defaults { "ruy/context_get_ctx.cc", "ruy/cpuinfo.cc", "ruy/ctx.cc", + "ruy/denormal.cc", "ruy/frontend.cc", "ruy/have_built_path_for_avx.cc", "ruy/have_built_path_for_avx2_fma.cc", diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py index ba1a38b..8f972ba 100755 --- a/cmake/bazel_to_cmake.py +++ b/cmake/bazel_to_cmake.py @@ -49,88 +49,92 @@ replacements = [ ['selects.config_setting_group', 'config_setting_group'], ['@com_google_googletest//:gtest', 'gtest'], ['@com_google_googletest//:gtest_main', 'gtest_main'], - ['@cpuinfo//:cpuinfo_with_unstripped_include_path', 'cpuinfo'], + ['@cpuinfo', 'cpuinfo'], ] def preprocess_input_text(text): - result = text - for replacement in replacements: - result = result.replace(replacement[0], replacement[1]) - return result + result = text + for replacement in replacements: + result = result.replace(replacement[0], replacement[1]) + return result def set_cmake_list(list_name, values, indent): - semicolon_separated = ";".join(values) - print(f'{indent}set({list_name} "{semicolon_separated}")') + semicolon_separated = ';'.join(values) + print(f'{indent}set({list_name} "{semicolon_separated}")') def generate_cmake_select(select_name, dict): - new_if_branch_keyword = 'if' - default_value = [] - for key in dict: - condition = '' - if key == '//conditions:default': - default_value = dict[key] - continue - elif re.search(r':windows$', key): - condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows' - elif re.search(r':ppc$', key): - condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le' - elif re.search(r':s390x$', key): - condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x' - elif re.search(r':fuchsia$', key): - condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia' - elif re.search(r':arm32_assuming_neon$', key): - condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm' - elif re.search(r':do_not_want_O3$', key): - # Ruy is a specialist library: we always want code to be compiled - # with -O3 unless the build type is Debug or the compiler does not - # support that flag syntax. - condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC' - elif re.search(r':x86_64_and_not_msvc$', key): - condition = '(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC' - elif re.search(r':windows_msvc$', key): - condition = 'MSVC' - elif re.search(r':ruy_profiler$', key): - condition = '${RUY_PROFILER}' - else: - raise ValueError(f'Unhandled key in select: {key}') + new_if_branch_keyword = 'if' + default_value = [] + for key in dict: + condition = '' + if key == '//conditions:default': + default_value = dict[key] + continue + elif re.search(r':windows$', key): + condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows' + elif re.search(r':ppc$', key): + condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR ' + 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le') + elif re.search(r':s390x$', key): + condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR ' + 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390x') + elif re.search(r':fuchsia$', key): + condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia' + elif re.search(r':arm32_assuming_neon$', key): + condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm' + elif re.search(r':do_not_want_O3$', key): + # Ruy is a specialist library: we always want code to be compiled + # with -O3 unless the build type is Debug or the compiler does not + # support that flag syntax. + condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC' + elif re.search(r':x86_64_and_not_msvc$', key): + condition = ('(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR ' + 'CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC') + elif re.search(r':windows_msvc$', key): + condition = 'MSVC' + elif re.search(r':ruy_profiler$', key): + condition = '${RUY_PROFILER}' + else: + raise ValueError(f'Unhandled key in select: {key}') - print(f'{new_if_branch_keyword}({condition})') - set_cmake_list(select_name, dict[key], ' ') - new_if_branch_keyword = 'elseif' + print(f'{new_if_branch_keyword}({condition})') + set_cmake_list(select_name, dict[key], ' ') + new_if_branch_keyword = 'elseif' - print('else()') - set_cmake_list(select_name, default_value, ' ') + print('else()') + set_cmake_list(select_name, default_value, ' ') - print('endif()\n') + print('endif()\n') def trim_multiple_ruy_prefixes(name): - return re.sub(r'(ruy_)+ruy', 'ruy', name) + return re.sub(r'(ruy_)+ruy', 'ruy', name) + def get_cmake_local_target_name(name): - global package_prefix - return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}') + global package_prefix + return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}') def get_cmake_dep_target_name(name): - if name in external_targets: - return name - if name.startswith('$'): - # Happens for deps that are the result of expanding a select() that we - # have compiled to expanding a variable. - return name - if name.startswith('//'): - after_last_slash = name.split('/')[-1] - if not ':' in after_last_slash: - name = f'{name}:{after_last_slash}' - raw=name[2:].replace('/', '_').replace(':', '_') - return trim_multiple_ruy_prefixes(raw) - if name.startswith(':'): - name = name[1:] - return get_cmake_local_target_name(name) + if name in external_targets: + return name + if name.startswith('$'): + # Happens for deps that are the result of expanding a select() that we + # have compiled to expanding a variable. + return name + if name.startswith('//'): + after_last_slash = name.split('/')[-1] + if ':' not in after_last_slash: + name = f'{name}:{after_last_slash}' + raw = name[2:].replace('/', '_').replace(':', '_') + return trim_multiple_ruy_prefixes(raw) + if name.startswith(':'): + name = name[1:] + return get_cmake_local_target_name(name) # @@ -139,45 +143,45 @@ def get_cmake_dep_target_name(name): def package(**kwargs): - pass + pass def exports_files(*args): - pass + pass def load(filename, *args): - if filename.startswith('@'): - return - elif filename.startswith(':'): - filename = os.path.join(bazel_package_dir, filename[1:]) - elif filename.startswith('//'): - split = filename[2:].split(':') - filename = os.path.join(bazel_workspace_dir, split[0], split[1]) + if filename.startswith('@'): + return + elif filename.startswith(':'): + filename = os.path.join(bazel_package_dir, filename[1:]) + elif filename.startswith('//'): + split = filename[2:].split(':') + filename = os.path.join(bazel_workspace_dir, split[0], split[1]) - src_file_content = open(filename).read() - processed_file_content = preprocess_input_text(src_file_content) - exec(processed_file_content, globals(), globals()) + src_file_content = open(filename).read() + processed_file_content = preprocess_input_text(src_file_content) + exec(processed_file_content, globals(), globals()) def config_setting(**kwargs): - # Nothing to do since our implementation of select() is based on parsing - # the names of config_settings, not looking deep into their actual - # implementation. - pass + # Nothing to do since our implementation of select() is based on parsing + # the names of config_settings, not looking deep into their actual + # implementation. + pass def filegroup(**kwargs): - pass + pass def config_setting_group(**kwargs): - # See config_setting. - pass + # See config_setting. + pass def bzl_library(**kwargs): - pass + pass select_index = 0 @@ -185,95 +189,96 @@ select_cache = {} def select(select_dict): - global select_index - global select_cache - global package_prefix - key = pickle.dumps(sorted(select_dict.items())) - if key in select_cache: - select_name = select_cache[key] - else: - unique_values = sorted(set(itertools.chain.from_iterable(select_dict.values()))) # sorting ensures determinism, no spurious diffs - description = '_'.join(unique_values) - select_name = f'{package_prefix}_{select_index}_{description}' - select_name = select_name.replace('c++', 'cxx') - select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name) - select_index = select_index + 1 - select_cache[key] = select_name - generate_cmake_select(select_name, select_dict) - - return [f'${{{select_name}}}'] + global select_index + global select_cache + global package_prefix + key = pickle.dumps(sorted(select_dict.items())) + if key in select_cache: + select_name = select_cache[key] + else: + unique_values = sorted( + set(itertools.chain.from_iterable(select_dict.values())) + ) # sorting ensures determinism, no spurious diffs + description = '_'.join(unique_values) + select_name = f'{package_prefix}_{select_index}_{description}' + select_name = select_name.replace('c++', 'cxx') + select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name) + select_index = select_index + 1 + select_cache[key] = select_name + generate_cmake_select(select_name, select_dict) + + return [f'${{{select_name}}}'] def generic_rule(rule_name, **kwargs): - print(f'{rule_name}(') - for key in kwargs.keys(): - values = kwargs[key] - if type(values) is bool: - if values: - print(f' {key.upper()}') - continue - else: - raise ValueError( - 'Cannot specify FALSE boolean args in CMake') - if key == 'visibility': - if values == ['//visibility:public']: - print(f' PUBLIC') - continue - if key == 'tags': - values = list(filter(lambda x : not x.startswith('req_dep'), values)) - if not values: - continue + print(f'{rule_name}(') + for key in kwargs.keys(): + values = kwargs[key] + if type(values) is bool: + if values: print(f' {key.upper()}') - if type(values) is list: - for value in values: - if key == 'deps': - target_name = get_cmake_dep_target_name(value) - print(f' {target_name}') - else: - print(f' {value}') + continue + else: + raise ValueError('Cannot specify FALSE boolean args in CMake') + if key == 'visibility': + if values == ['//visibility:public']: + print(f' PUBLIC') + continue + if key == 'tags': + values = list(filter(lambda x: not x.startswith('req_dep'), values)) + if not values: + continue + print(f' {key.upper()}') + if type(values) is list: + for value in values: + if key == 'deps': + target_name = get_cmake_dep_target_name(value) + print(f' {target_name}') else: - if key == 'name': - target_name = get_cmake_local_target_name(values) - print(f' {target_name}') - else: - print(f' {values}') - print(')\n') + print(f' {value}') + else: + if key == 'name': + target_name = get_cmake_local_target_name(values) + print(f' {target_name}') + else: + print(f' {values}') + print(')\n') def cc_library(**kwargs): - generic_rule('ruy_cc_library', **kwargs) + generic_rule('ruy_cc_library', **kwargs) def cc_test(**kwargs): - generic_rule('ruy_cc_test', **kwargs) + generic_rule('ruy_cc_test', **kwargs) def cc_binary(**kwargs): - generic_rule('ruy_cc_binary', **kwargs) + generic_rule('ruy_cc_binary', **kwargs) # # Program entry point. # if __name__ == "__main__": - if len(sys.argv) != 3: - print("Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir") - sys.exit(1) + if len(sys.argv) != 3: + print('Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir') + sys.exit(1) - bazel_workspace_dir = sys.argv[1] - bazel_package_dir = sys.argv[2] - bazel_package_relative_dir = os.path.relpath( - bazel_package_dir, bazel_workspace_dir) - package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_') + bazel_workspace_dir = sys.argv[1] + bazel_package_dir = sys.argv[2] + bazel_package_relative_dir = os.path.relpath(bazel_package_dir, + bazel_workspace_dir) + package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_') - print("""# This file is generated (whence no license header). Do not edit! + print("""# This file is generated (whence no license header). Do not edit! # To regenerate, run: # cmake/bazel_to_cmake.sh """) - src_build_file = os.path.join(bazel_package_dir, "BUILD") - src_build_content = open(src_build_file).read() - processed_build_content = preprocess_input_text(src_build_content) - exec(processed_build_content) + src_build_file = os.path.join(bazel_package_dir, 'BUILD') + src_build_content = open(src_build_file).read() + processed_build_content = preprocess_input_text(src_build_content) + exec(processed_build_content) - print("ruy_add_all_subdirs()") + print('ruy_add_all_subdirs()') @@ -357,6 +357,7 @@ cc_library( deps = [ ":blocking_counter", ":check_macros", + ":denormal", ":time", ":trace", ":wait", @@ -420,6 +421,14 @@ cc_library( ) cc_library( + name = "denormal", + srcs = ["denormal.cc"], + hdrs = ["denormal.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], +) + +cc_library( name = "performance_advisory", hdrs = ["performance_advisory.h"], copts = ruy_copts(), @@ -956,6 +965,7 @@ cc_library( ":cpu_cache_params", ":cpuinfo", ":ctx", + ":denormal", ":mat", ":matrix", ":mul_params", @@ -1195,6 +1205,22 @@ ruy_test( ], ) +cc_test( + name = "test_overflow_dst_zero_point", + srcs = [ + "test_overflow_dst_zero_point.cc", + ], + copts = ruy_copts(), + deps = [ + ":gtest_wrapper", + ":matrix", + ":path", + ":ruy", + ":test_lib", + ":tune", + ], +) + bzl_library( name = "ruy_test_ext.oss_bzl", srcs = ["ruy_test_ext.oss.bzl"], diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt index 4c3e394..502ad8a 100644 --- a/ruy/CMakeLists.txt +++ b/ruy/CMakeLists.txt @@ -376,6 +376,7 @@ ruy_cc_library( DEPS ruy_blocking_counter ruy_check_macros + ruy_denormal ruy_time ruy_trace ruy_wait @@ -455,6 +456,20 @@ ruy_cc_library( ruy_cc_library( NAME + ruy_denormal + SRCS + denormal.cc + HDRS + denormal.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC +) + +ruy_cc_library( + NAME ruy_performance_advisory HDRS performance_advisory.h @@ -1102,6 +1117,7 @@ ruy_cc_library( ruy_cpu_cache_params ruy_cpuinfo ruy_ctx + ruy_denormal ruy_mat ruy_matrix ruy_mul_params @@ -1693,4 +1709,22 @@ ruy_cc_test( slow ) +ruy_cc_test( + NAME + ruy_test_overflow_dst_zero_point + SRCS + test_overflow_dst_zero_point.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_gtest_wrapper + ruy_matrix + ruy + ruy_path + ruy_test_lib + ruy_tune +) + ruy_add_all_subdirs() diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc index 19bfd88..b28c3b0 100644 --- a/ruy/apply_multiplier.cc +++ b/ruy/apply_multiplier.cc @@ -49,7 +49,6 @@ std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x, std::int32_t quantized_multiplier, int shift) { RUY_CHECK_GE(shift, -31); - RUY_CHECK_LE(shift, 7); int total_shift = 31 - shift; diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc index 2df80d7..ff4cb2c 100644 --- a/ruy/apply_multiplier_test.cc +++ b/ruy/apply_multiplier_test.cc @@ -104,14 +104,9 @@ void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, TEST(ApplyMultiplierTest, ApplyMultiplierUniform) { MulParams<std::int32_t, std::int8_t> mul_params; - // Test that default values give a multiplication by 1. - TestApplyMultiplier(mul_params, 0, 1000, 1000); mul_params.set_multiplier_fixedpoint(1 << 30); mul_params.set_multiplier_exponent(-1); TestApplyMultiplier(mul_params, 0, 1000, 250); - mul_params.set_multiplier_fixedpoint(1 << 25); - mul_params.set_multiplier_exponent(3); - TestApplyMultiplier(mul_params, 0, 1000, 125); } TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) { diff --git a/ruy/block_map.cc b/ruy/block_map.cc index 8240de2..e04e7af 100644 --- a/ruy/block_map.cc +++ b/ruy/block_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include <algorithm> #include <cstdint> +#include <limits> #ifdef RUY_MAKEBLOCKMAP_DEBUG #include <cstdio> @@ -330,7 +331,7 @@ bool IsObviouslyLinearTraversal(int rows, int cols, int depth, // as that requires knowing the kernel block layout. Since we just want // a coarse estimate with only the guarantee that if we return `true` then // linear traversal will be used, it is OK here to over-estimate `rows` and - // `cols`, by omitting to divide them by the rectangularness factors.ß + // `cols`, by omitting to divide them by the rectangularness factors. return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params) == BlockMapTraversalOrder::kLinear; } diff --git a/ruy/context.cc b/ruy/context.cc index 4661738..342ce52 100644 --- a/ruy/context.cc +++ b/ruy/context.cc @@ -55,4 +55,9 @@ void Context::set_runtime_enabled_paths(Path paths) { mutable_ctx()->SetRuntimeEnabledPaths(paths); } +Path Context::get_runtime_enabled_paths() { + // The `& kAllPaths` hides internal test-only paths. + return mutable_ctx()->GetRuntimeEnabledPaths() & ruy::kAllPaths; +} + } // namespace ruy diff --git a/ruy/context.h b/ruy/context.h index 79a4b5c..f148f0f 100644 --- a/ruy/context.h +++ b/ruy/context.h @@ -90,6 +90,9 @@ class Context final { // Paths in kNonArchPaths are always implicitly supported. void set_runtime_enabled_paths(Path paths); + // Returns the set of Path's that are available. + Path get_runtime_enabled_paths(); + private: CtxImpl* const impl_; diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc index b1f54bc..a3e75d7 100644 --- a/ruy/cpuinfo.cc +++ b/ruy/cpuinfo.cc @@ -133,6 +133,17 @@ bool CpuInfo::CurrentCpuIsA55ish() { } } +bool CpuInfo::CurrentCpuIsX1() { + if (!EnsureInitialized()) { + return false; + } + if (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch == + cpuinfo_uarch_cortex_x1) { + return true; + } + return false; +} + #else // not defined RUY_HAVE_CPUINFO CpuInfo::~CpuInfo() {} @@ -151,6 +162,7 @@ bool CpuInfo::Avx2Fma() { return false; } bool CpuInfo::Avx512() { return false; } bool CpuInfo::AvxVnni() { return false; } bool CpuInfo::CurrentCpuIsA55ish() { return false; } +bool CpuInfo::CurrentCpuIsX1() { return false; } #endif diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h index e45fa51..2c7bc6a 100644 --- a/ruy/cpuinfo.h +++ b/ruy/cpuinfo.h @@ -39,6 +39,7 @@ class CpuInfo final { // Common features const CpuCacheParams& CacheParams(); bool CurrentCpuIsA55ish(); + bool CurrentCpuIsX1(); private: enum class InitStatus { diff --git a/ruy/denormal.cc b/ruy/denormal.cc new file mode 100644 index 0000000..35bb739 --- /dev/null +++ b/ruy/denormal.cc @@ -0,0 +1,121 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/denormal.h" + +// NOTE: this is simply a copy of pthreadpool/src/threadpool-utils.h that's not +// exposed by the pthreadpool library +// (https://github.com/Maratyszcza/pthreadpool), but with an additional C++ +// helper class to suppress floating-point denormal values. + +/* SSE-specific headers */ +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) +#include <xmmintrin.h> +#endif + +/* MSVC-specific headers */ +#if defined(_MSC_VER) +#include <intrin.h> +#endif + +namespace ruy { +namespace { +inline struct fpu_state get_fpu_state() { + struct fpu_state state = {}; +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + state.mxcsr = static_cast<std::uint32_t>(_mm_getcsr()); +#elif defined(_MSC_VER) && defined(_M_ARM) + state.fpscr = + static_cast<std::uint32_t>(_MoveFromCoprocessor(10, 7, 1, 0, 0)); +#elif defined(_MSC_VER) && defined(_M_ARM64) + state.fpcr = static_cast<std::uint64_t>(_ReadStatusReg(0x5A20)); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + __asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r"(state.fpscr)); +#elif defined(__GNUC__) && defined(__aarch64__) + __asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r"(state.fpcr)); +#endif + return state; +} + +inline void set_fpu_state(const struct fpu_state state) { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + _mm_setcsr(static_cast<unsigned int>(state.mxcsr)); +#elif defined(_MSC_VER) && defined(_M_ARM) + _MoveToCoprocessor(static_cast<int>(state.fpscr), 10, 7, 1, 0, 0); +#elif defined(_MSC_VER) && defined(_M_ARM64) + _WriteStatusReg(0x5A20, static_cast<__int64>(state.fpcr)); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + __asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r"(state.fpscr)); +#elif defined(__GNUC__) && defined(__aarch64__) + __asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r"(state.fpcr)); +#else + (void)state; +#endif +} + +inline void disable_fpu_denormals() { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + _mm_setcsr(_mm_getcsr() | 0x8040); +#elif defined(_MSC_VER) && defined(_M_ARM) + int fpscr = _MoveFromCoprocessor(10, 7, 1, 0, 0); + fpscr |= 0x1000000; + _MoveToCoprocessor(fpscr, 10, 7, 1, 0, 0); +#elif defined(_MSC_VER) && defined(_M_ARM64) + __int64 fpcr = _ReadStatusReg(0x5A20); + fpcr |= 0x1080000; + _WriteStatusReg(0x5A20, fpcr); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + std::uint32_t fpscr; +#if defined(__thumb__) && !defined(__thumb2__) + __asm__ __volatile__( + "VMRS %[fpscr], fpscr\n" + "ORRS %[fpscr], %[bitmask]\n" + "VMSR fpscr, %[fpscr]\n" + : [fpscr] "=l"(fpscr) + : [bitmask] "l"(0x1000000) + : "cc"); +#else + __asm__ __volatile__( + "VMRS %[fpscr], fpscr\n" + "ORR %[fpscr], #0x1000000\n" + "VMSR fpscr, %[fpscr]\n" + : [fpscr] "=r"(fpscr)); +#endif +#elif defined(__GNUC__) && defined(__aarch64__) + std::uint64_t fpcr; + __asm__ __volatile__( + "MRS %[fpcr], fpcr\n" + "ORR %w[fpcr], %w[fpcr], 0x1000000\n" + "ORR %w[fpcr], %w[fpcr], 0x80000\n" + "MSR fpcr, %[fpcr]\n" + : [fpcr] "=r"(fpcr)); +#endif +} +} // namespace + +ScopedSuppressDenormals::ScopedSuppressDenormals() { + restore_ = get_fpu_state(); + disable_fpu_denormals(); +} + +ScopedSuppressDenormals::~ScopedSuppressDenormals() { set_fpu_state(restore_); } +} // namespace ruy diff --git a/ruy/denormal.h b/ruy/denormal.h new file mode 100644 index 0000000..e5b836c --- /dev/null +++ b/ruy/denormal.h @@ -0,0 +1,53 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef RUY_RUY_DENORMAL_H_ +#define RUY_RUY_DENORMAL_H_ + +#include <cstdint> + +namespace ruy { +// NOTE: the following 'fpu_state' struct is copied from +// pthreadpool/src/threadpool-utils.h that's not exposed by the pthreadpool +// library (https://github.com/Maratyszcza/pthreadpool). +struct fpu_state { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + std::uint32_t mxcsr; +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) || \ + defined(_MSC_VER) && defined(_M_ARM) + std::uint32_t fpscr; +#elif defined(__GNUC__) && defined(__aarch64__) || \ + defined(_MSC_VER) && defined(_M_ARM64) + std::uint64_t fpcr; +#endif +}; + +// While this class is active, denormal floating point numbers are suppressed. +// The destructor restores the original flags. +class ScopedSuppressDenormals { + public: + ScopedSuppressDenormals(); + ~ScopedSuppressDenormals(); + + private: + fpu_state restore_; + + ScopedSuppressDenormals(const ScopedSuppressDenormals&) = delete; + void operator=(const ScopedSuppressDenormals&) = delete; +}; +} // namespace ruy + +#endif // RUY_RUY_DENORMAL_H_ diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h index 76cfc82..15a5a89 100644 --- a/ruy/kernel_arm.h +++ b/ruy/kernel_arm.h @@ -49,6 +49,7 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params); void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params); void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params); void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params); #if RUY_PLATFORM_NEON_64 template <typename DstScalar> @@ -104,7 +105,8 @@ struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { #if RUY_PLATFORM_NEON_64 template <typename DstScalar> -struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstScalar> { +struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, + DstScalar> { static constexpr Path kPath = Path::kNeonDotprod; Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; @@ -121,6 +123,8 @@ struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstSca Kernel8bitNeonDotprod1Col(params); } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) { Kernel8bitNeonDotprodA55ish(params); + } else if (tuning == Tuning::kX1) { + Kernel8bitNeonDotprodX1(params); } else { Kernel8bitNeonDotprod(params); } @@ -129,6 +133,7 @@ struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstSca #endif void KernelFloatNeon(const KernelParamsFloat<8, 8>& params); +void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params); void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params); void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params); void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params); @@ -150,6 +155,8 @@ struct Kernel<Path::kNeon, float, float, float, float> { end_col, dst, ¶ms); if (__builtin_expect(tuning == Tuning::kA55ish, true)) { KernelFloatNeonA55ish(params); + } else if (tuning == Tuning::kX1) { + KernelFloatNeonX1(params); } else { KernelFloatNeon(params); } @@ -188,8 +195,7 @@ struct Kernel<Path::kNeonDotprod, float, float, float, float> { Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; - using Base = - Kernel<Path::kNeon, float, float, float, float>; + using Base = Kernel<Path::kNeon, float, float, float, float>; explicit Kernel(Tuning tuning_) : tuning(tuning_) {} void Run(const PMat<float>& lhs, const PMat<float>& rhs, const MulParams<float, float>& mul_params, int start_row, @@ -199,6 +205,8 @@ struct Kernel<Path::kNeonDotprod, float, float, float, float> { end_col, dst, ¶ms); if (__builtin_expect(tuning == Tuning::kA55ish, true)) { KernelFloatNeonDotprodA55ish(params); + } else if (tuning == Tuning::kX1) { + KernelFloatNeonX1(params); } else { KernelFloatNeon(params); } diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc index b20f668..8782dce 100644 --- a/ruy/kernel_arm32.cc +++ b/ruy/kernel_arm32.cc @@ -1102,7 +1102,7 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { "vdup.16 q13, r4\n" // dst_zero_point // Add the destination zero point - "vadd.i16 q14, q14, q13\n" + "vqadd.s16 q14, q14, q13\n" // Cast-and-saturate from int16 to uint8 // Now all 8 1-byte values are in d30. @@ -1226,7 +1226,7 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { "vdup.16 q13, r4\n" // dst_zero_point // Add the destination zero point - "vadd.i16 q14, q14, q13\n" + "vqadd.s16 q14, q14, q13\n" // Cast-and-saturate from int16 to int8 // Now all 8 1-byte values are in d30. @@ -2014,7 +2014,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { "vdup.16 q13, r4\n" // dst_zero_point // Add the destination zero point - "vadd.i16 q14, q14, q13\n" + "vqadd.s16 q14, q14, q13\n" // Cast-and-saturate from int16 to uint8 "vqmovun.s16 d30, q14\n" @@ -2126,7 +2126,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { "vdup.16 q13, r4\n" // dst_zero_point // Add the destination zero point - "vadd.i16 q14, q14, q13\n" + "vqadd.s16 q14, q14, q13\n" // Cast-and-saturate from int16 to int8 "vqmovn.s16 d30, q14\n" diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc index fe65d9c..5424107 100644 --- a/ruy/kernel_arm64.cc +++ b/ruy/kernel_arm64.cc @@ -623,8 +623,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8 "sqxtun v16.8b, v16.8h\n" @@ -750,8 +750,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" // Cast-and-saturate from int16 to int8 "sqxtn v16.8b, v16.8h\n" @@ -1472,7 +1472,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8 // Now all data is in the first 32-bits of v16 @@ -1553,7 +1553,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" // Cast-and-saturate from int16 to int8 "sqxtn v16.8b, v16.8h\n" @@ -2394,9 +2394,9 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) { "dup v14.8h, v13.h[4]\n" RUY_MAKE_ZERO(v20) "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" RUY_MAKE_ZERO(v22) // Cast-and-saturate from int16 to uint8 @@ -2526,9 +2526,9 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) { "dup v14.8h, v13.h[4]\n" RUY_MAKE_ZERO(v20) "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" RUY_MAKE_ZERO(v22) // Cast-and-saturate from int16 to uint8 @@ -3713,14 +3713,14 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8 "sqxtun v16.8b, v16.8h\n" @@ -3888,14 +3888,14 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8 "sqxtn v16.8b, v16.8h\n" @@ -4402,6 +4402,1261 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { "v26", "v27", "v28", "v29", "v30", "v31"); } +// A fork of the above 8bitNeonDotprod kernel but removes the max streaming +// manual unrolling. Manually unrolling the inner loops benefits some GEMM +// shapes on the Cortex-A76 but destroys performance on the X1 by increasing +// backend stalls. Therefore, we remove the MAX_STREAMING option in this +// kernel. The target CPU for this kernel is currently only the Cortex-X1. +void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeonDotprod)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x8 block + // /-----------------------------------------| + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------| + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Kernel inner loop (over depth). + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + // Loop termination condition. + "cmp w1, w12\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v10.4s, v14.s[0]\n" + "dup v11.4s, v14.s[1]\n" + "dup v12.4s, v14.s[2]\n" + "dup v13.4s, v14.s[3]\n" + "add v16.4s, v16.4s, v10.4s\n" + "add v17.4s, v17.4s, v10.4s\n" + "add v18.4s, v18.4s, v11.4s\n" + "add v19.4s, v19.4s, v11.4s\n" + "add v20.4s, v20.4s, v12.4s\n" + "add v21.4s, v21.4s, v12.4s\n" + "add v22.4s, v22.4s, v13.4s\n" + "add v23.4s, v23.4s, v13.4s\n" + "dup v10.4s, v15.s[0]\n" + "dup v11.4s, v15.s[1]\n" + "dup v12.4s, v15.s[2]\n" + "dup v13.4s, v15.s[3]\n" + "add v24.4s, v24.4s, v10.4s\n" + "add v25.4s, v25.4s, v10.4s\n" + "add v26.4s, v26.4s, v11.4s\n" + "add v27.4s, v27.4s, v11.4s\n" + "add v28.4s, v28.4s, v12.4s\n" + "add v29.4s, v29.4s, v12.4s\n" + "add v30.4s, v30.4s, v13.4s\n" + "add v31.4s, v31.4s, v13.4s\n" + "7:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + // Compute the multiplier_exponent pointer + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, x3, lsl #2\n" + "csel x1, x1, x5, eq\n" + // Load multiplier_exponent + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + // Separate positive and negative exponents + "smin v11.4s, v8.4s, v9.4s\n" + "smin v12.4s, v8.4s, v10.4s\n" + "sub v9.4s, v9.4s, v11.4s\n" + "sub v10.4s, v10.4s, v12.4s\n" + + // Compute the multiplier_fixedpoint pointer + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "add x5, x4, x3, lsl #2\n" + "csel x4, x4, x5, eq\n" + // Load multiplier_fixedpoint + "ldr q14, [x4]\n" + "ldr q15, [x4, #16]\n" + + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 8f\n" + // Case where channels are rows + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v9.4s\n" + "sshl v17.4s, v17.4s, v10.4s\n" + "sshl v18.4s, v18.4s, v9.4s\n" + "sshl v19.4s, v19.4s, v10.4s\n" + "sshl v20.4s, v20.4s, v9.4s\n" + "sshl v21.4s, v21.4s, v10.4s\n" + "sshl v22.4s, v22.4s, v9.4s\n" + "sshl v23.4s, v23.4s, v10.4s\n" + "sshl v24.4s, v24.4s, v9.4s\n" + "sshl v25.4s, v25.4s, v10.4s\n" + "sshl v26.4s, v26.4s, v9.4s\n" + "sshl v27.4s, v27.4s, v10.4s\n" + "sshl v28.4s, v28.4s, v9.4s\n" + "sshl v29.4s, v29.4s, v10.4s\n" + "sshl v30.4s, v30.4s, v9.4s\n" + "sshl v31.4s, v31.4s, v10.4s\n" + "10:\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v14.4s\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + "sqdmulh v18.4s, v18.4s, v14.4s\n" + "sqdmulh v19.4s, v19.4s, v15.4s\n" + "sqdmulh v20.4s, v20.4s, v14.4s\n" + "sqdmulh v21.4s, v21.4s, v15.4s\n" + "sqdmulh v22.4s, v22.4s, v14.4s\n" + "sqdmulh v23.4s, v23.4s, v15.4s\n" + "sqdmulh v24.4s, v24.4s, v14.4s\n" + "sqdmulh v25.4s, v25.4s, v15.4s\n" + "sqdmulh v26.4s, v26.4s, v14.4s\n" + "sqdmulh v27.4s, v27.4s, v15.4s\n" + "sqdmulh v28.4s, v28.4s, v14.4s\n" + "sqdmulh v29.4s, v29.4s, v15.4s\n" + "sqdmulh v30.4s, v30.4s, v14.4s\n" + "sqdmulh v31.4s, v31.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "srshl v31.4s, v31.4s, v12.4s\n" + "b 9f\n" + + "8:\n" + // Case where channels are columns + + // Apply the positive exponent part of the multiplier. + "dup v4.4s, v9.s[0]\n" + "dup v5.4s, v9.s[1]\n" + "dup v6.4s, v9.s[2]\n" + "dup v7.4s, v9.s[3]\n" + "sshl v16.4s, v16.4s, v4.4s\n" + "sshl v17.4s, v17.4s, v4.4s\n" + "sshl v18.4s, v18.4s, v5.4s\n" + "sshl v19.4s, v19.4s, v5.4s\n" + "sshl v20.4s, v20.4s, v6.4s\n" + "sshl v21.4s, v21.4s, v6.4s\n" + "sshl v22.4s, v22.4s, v7.4s\n" + "sshl v23.4s, v23.4s, v7.4s\n" + "dup v4.4s, v10.s[0]\n" + "dup v5.4s, v10.s[1]\n" + "dup v6.4s, v10.s[2]\n" + "dup v7.4s, v10.s[3]\n" + "sshl v24.4s, v24.4s, v4.4s\n" + "sshl v25.4s, v25.4s, v4.4s\n" + "sshl v26.4s, v26.4s, v5.4s\n" + "sshl v27.4s, v27.4s, v5.4s\n" + "sshl v28.4s, v28.4s, v6.4s\n" + "sshl v29.4s, v29.4s, v6.4s\n" + "sshl v30.4s, v30.4s, v7.4s\n" + "sshl v31.4s, v31.4s, v7.4s\n" + "11:\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v14.s[0]\n" + "sqdmulh v17.4s, v17.4s, v14.s[0]\n" + "sqdmulh v18.4s, v18.4s, v14.s[1]\n" + "sqdmulh v19.4s, v19.4s, v14.s[1]\n" + "sqdmulh v20.4s, v20.4s, v14.s[2]\n" + "sqdmulh v21.4s, v21.4s, v14.s[2]\n" + "sqdmulh v22.4s, v22.4s, v14.s[3]\n" + "sqdmulh v23.4s, v23.4s, v14.s[3]\n" + "sqdmulh v24.4s, v24.4s, v15.s[0]\n" + "sqdmulh v25.4s, v25.4s, v15.s[0]\n" + "sqdmulh v26.4s, v26.4s, v15.s[1]\n" + "sqdmulh v27.4s, v27.4s, v15.s[1]\n" + "sqdmulh v28.4s, v28.4s, v15.s[2]\n" + "sqdmulh v29.4s, v29.4s, v15.s[2]\n" + "sqdmulh v30.4s, v30.4s, v15.s[3]\n" + "sqdmulh v31.4s, v31.4s, v15.s[3]\n" + + // Apply the negative exponent part of the multiplier. + "dup v4.4s, v11.s[0]\n" + "dup v5.4s, v11.s[1]\n" + "dup v6.4s, v11.s[2]\n" + "dup v7.4s, v11.s[3]\n" + "srshl v16.4s, v16.4s, v4.4s\n" + "srshl v17.4s, v17.4s, v4.4s\n" + "srshl v18.4s, v18.4s, v5.4s\n" + "srshl v19.4s, v19.4s, v5.4s\n" + "srshl v20.4s, v20.4s, v6.4s\n" + "srshl v21.4s, v21.4s, v6.4s\n" + "srshl v22.4s, v22.4s, v7.4s\n" + "srshl v23.4s, v23.4s, v7.4s\n" + "dup v4.4s, v12.s[0]\n" + "dup v5.4s, v12.s[1]\n" + "dup v6.4s, v12.s[2]\n" + "dup v7.4s, v12.s[3]\n" + "srshl v24.4s, v24.4s, v4.4s\n" + "srshl v25.4s, v25.4s, v4.4s\n" + "srshl v26.4s, v26.4s, v5.4s\n" + "srshl v27.4s, v27.4s, v5.4s\n" + "srshl v28.4s, v28.4s, v6.4s\n" + "srshl v29.4s, v29.4s, v6.4s\n" + "srshl v30.4s, v30.4s, v7.4s\n" + "srshl v31.4s, v31.4s, v7.4s\n" + "9:\n" + + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "umax v17.16b, v17.16b, v14.16b\n" + "umax v18.16b, v18.16b, v14.16b\n" + "umax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "umin v17.16b, v17.16b, v15.16b\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "smax v17.16b, v17.16b, v14.16b\n" + "smax v18.16b, v18.16b, v14.16b\n" + "smax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "smin v17.16b, v17.16b, v15.16b\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.4s, v19.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v20.4s, v21.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v22.4s, v23.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v24.4s, v25.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v26.4s, v27.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v28.4s, v29.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v30.4s, v31.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + + // Similar to the above 8-bit dotprod kernel, but specialized for the case of // RHS cols == 1. // Relevant target CPUs for this kernel include ARM Cortex-A76, @@ -4692,7 +5947,7 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8, leaving all data in the // lower half of v16. @@ -4788,7 +6043,7 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) { // Add the destination zero point "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" // Cast-and-saturate from int16 to uint8 "sqxtn v16.8b, v16.8h\n" @@ -5691,14 +6946,14 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) { RUY_MAKE_ZERO(v31) // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" // Load the clamp_min, clamp_max bounds "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" @@ -5865,14 +7120,14 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) { RUY_MAKE_ZERO(v31) // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" + "sqadd v16.8h, v16.8h, v14.8h\n" + "sqadd v17.8h, v17.8h, v14.8h\n" + "sqadd v18.8h, v18.8h, v14.8h\n" + "sqadd v19.8h, v19.8h, v14.8h\n" + "sqadd v20.8h, v20.8h, v14.8h\n" + "sqadd v21.8h, v21.8h, v14.8h\n" + "sqadd v22.8h, v22.8h, v14.8h\n" + "sqadd v23.8h, v23.8h, v14.8h\n" // Load the clamp_min, clamp_max bounds "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" @@ -7069,6 +8324,472 @@ void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) { "v26", "v27", "v28", "v29", "v30", "v31"); } +// A fork of the standard float kernel where we omit the manual loop unrolling +// to recover performance on the X1. For now, the X1 core is the only CPU that +// uses this kernel. +void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params) { + CheckOffsetsInKernelParamsFloat(params); + profiler::ScopeLabel label("Kernel (kNeon) X1"); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v15 are used to load data from LHS and RHS. + // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and + // v3 are used to load a 1x8 block of RHS, like this: + // + // RHS 1x8 block + // /-----------------------------------------| + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------| + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for floading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "add w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "cmp w1, w12\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "mov v2.16b, v4.16b\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition. + // Jump based on channel dimension. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v8.4s, v14.s[0]\n" + "dup v9.4s, v14.s[1]\n" + "dup v10.4s, v14.s[2]\n" + "dup v11.4s, v14.s[3]\n" + "dup v12.4s, v15.s[0]\n" + "dup v13.4s, v15.s[1]\n" + "dup v14.4s, v15.s[2]\n" + "dup v15.4s, v15.s[3]\n" + "fadd v16.4s, v16.4s, v8.4s\n" + "fadd v17.4s, v17.4s, v8.4s\n" + "fadd v18.4s, v18.4s, v9.4s\n" + "fadd v19.4s, v19.4s, v9.4s\n" + "fadd v20.4s, v20.4s, v10.4s\n" + "fadd v21.4s, v21.4s, v10.4s\n" + "fadd v22.4s, v22.4s, v11.4s\n" + "fadd v23.4s, v23.4s, v11.4s\n" + "fadd v24.4s, v24.4s, v12.4s\n" + "fadd v25.4s, v25.4s, v12.4s\n" + "fadd v26.4s, v26.4s, v13.4s\n" + "fadd v27.4s, v27.4s, v13.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v14.4s\n" + "fadd v30.4s, v30.4s, v15.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "7:\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + // Variant of KernelFloatNeon tuned for in-order CPUs that do not // support dotprod (while dotprod by itself is not relevant to floating-point, // this additional bit of information that we have about the target happens to diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index fddb482..84b9380 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -52,45 +52,6 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) -namespace { -namespace intrin_utils { - -__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b, - const __m256i& mask) { - __m256d result = - _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b), - _mm256_castsi256_pd(mask)); - return _mm256_castpd_si256(result); -} - -__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b, - const __m512i& mask) { - __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); - __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); - __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); - __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); - __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0); - __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1); - __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo); - __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi); - __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); - return _mm512_inserti64x4(result, hi, 1); -} - -__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) { - __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); - __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); - __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); - __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); - __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo); - __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi); - __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); - return _mm512_inserti64x4(result, hi, 1); -} - -} // namespace intrin_utils -} // namespace - void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { profiler::ScopeLabel label("Kernel kAvx512 8-bit"); @@ -391,13 +352,13 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { // Construct the "nudge" value for each lane if the exponent is // greater than 0. Otherwise, the nudge is 0. const __m512i zeros = _mm512_setzero_si512(); - const __m512i mask_rightshift_gtz = - intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const auto mask_rightshift_gtz = + _mm512_cmpgt_epi64_mask(exponent, zeros); const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); - __m512i nudge = intrin_utils::mm512_blendv_epi64( - zeros, one_shift_exp_minus1, mask_rightshift_gtz); + __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, + one_shift_exp_minus1); // Calculate the shifted sum (results + nudge) >> exp. const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); @@ -406,14 +367,12 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); - const __m512i mask_num_plus_nudge_overflow = - intrin_utils::mm512_cmpgt_epi64( - results, - _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( + results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); // Fill results with either (results + nudge) >> exponent or // 1 << (31 - exp) in the case of overflow. - results = intrin_utils::mm512_blendv_epi64( - shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + results = _mm512_mask_mov_epi64( + shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); }; if (per_column_multiplier) { @@ -424,8 +383,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift); __m512i m_64bit_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); - __m512i offset_vector_val = _mm512_permutexvar_epi64( - perm_64bit_vals, offset_vector); + __m512i offset_vector_val = + _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector); __m512i final_right_shift_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? final_right_shift_low : final_right_shift_high); @@ -802,13 +761,13 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { // Construct the "nudge" value for each lane if the exponent is // greater than 0. Otherwise, the nudge is 0. const __m512i zeros = _mm512_setzero_si512(); - const __m512i mask_rightshift_gtz = - intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const auto mask_rightshift_gtz = + _mm512_cmpgt_epi64_mask(exponent, zeros); const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(_mm512_set1_epi64(1), _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); - __m512i nudge = intrin_utils::mm512_blendv_epi64( - zeros, one_shift_exp_minus1, mask_rightshift_gtz); + __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, + one_shift_exp_minus1); // Calculate the shifted sum (results + nudge) >> exp. const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); @@ -817,14 +776,12 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); - const __m512i mask_num_plus_nudge_overflow = - intrin_utils::mm512_cmpgt_epi64( - results, - _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( + results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); // Fill results with either (results + nudge) >> exponent or // 1 << (31 - exp) in the case of overflow. - results = intrin_utils::mm512_blendv_epi64( - shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + results = _mm512_mask_mov_epi64( + shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); }; // Shift and round column 0. @@ -930,9 +887,8 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { float* dst_ptr = dst_col_ptr + row; // Process block in two halves, split by columns. - { - constexpr int mmm = 0; - +#pragma unroll(1) + for (int mmm = 0; mmm < 2; ++mmm) { __m512 accum_data_v0; __m512 accum_data_v1; __m512 accum_data_v2; @@ -972,81 +928,49 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* rhs_ptr = rhs_col_ptr + 8 * mmm; for (int d = 0; d < (params.depth - 1); ++d) { const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than - // first loading together and then extract with broadcasting. This is - // because AVX flavours and instrinsics and compilers in combination - // do not handle this pattern of extraction very well. const float* rhs_data = rhs_ptr; lhs_ptr += 16; rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: + // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do + // so if given an rvalue. + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); } - { + { // nested extra blocks lead to measurable speed gains const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); { float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); @@ -1075,147 +999,7 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); } } - } // Inner half-block loop, unrolled, first iteration. - { - constexpr int mmm = 1; - - __m512 accum_data_v0; - __m512 accum_data_v1; - __m512 accum_data_v2; - __m512 accum_data_v3; - __m512 accum_data_v4; - __m512 accum_data_v5; - __m512 accum_data_v6; - __m512 accum_data_v7; - - // Initialize with bias. - if (channel_dimension_is_col) { - const float* bias_elem_ptr = - bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; - accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); - accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); - accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); - accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); - accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); - accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); - accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); - accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); - } else { - const __m512 initial_accum_data = - _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); - - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - } - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, second iteration. + } } // End row-block loop. // The unrolling within this conditional may be somewhat pointless. It @@ -1273,73 +1057,45 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* rhs_data = rhs_ptr; lhs_ptr += 16; rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: + // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do + // so if given an rvalue. + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); } { const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); { float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h index 9509b8f..cff243b 100644 --- a/ruy/kernel_common.h +++ b/ruy/kernel_common.h @@ -177,6 +177,8 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; if (mul_params.multiplier_fixedpoint_perchannel()) { + // Temporary release-assert to debug some crashes in an application. + RUY_CHECK(mul_params.multiplier_exponent_perchannel()); params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; params->multiplier_fixedpoint = mul_params.multiplier_fixedpoint_perchannel(); @@ -200,6 +202,11 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, params->dst_type_id = DstTypeId<DstScalar>::kValue; params->dst_base_ptr = dst->data.get() + start_col * dst->layout.stride + start_row; + + // Temporary release-asserts to debug some crashes in an application. + RUY_CHECK(params->multiplier_fixedpoint); + RUY_CHECK(params->multiplier_exponent); + RUY_CHECK(params->bias); } template <int LhsCols, int RhsCols> diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 2f8fe19..b716502 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -607,14 +607,12 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - // Load 8 RHS values, then use permute instructions to - // broadcast each value to a register. - __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + // Load 8 RHS values, then use permute instructions to broadcast each + // value to a register. _mm256_permute2f128_ps is slow on AMD. __m256 rhs0_3 = - _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); __m256 rhs4_7 = - _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); accum_data_v[0] = intrin_utils::MulAdd<path>( @@ -707,13 +705,11 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] __m256 rhs0_3 = - _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); __m256 rhs4_7 = - _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); accum_data_v[0] = intrin_utils::MulAdd<path>( @@ -327,7 +327,7 @@ inline bool IsColMajor(const MatLayout& layout) { return layout.order == Order::kColMajor; } -inline int FlatSize(const MatLayout& layout) { +inline std::ptrdiff_t FlatSize(const MatLayout& layout) { const int outerdim = layout.order == Order::kColMajor ? layout.cols : layout.rows; return layout.stride * outerdim; @@ -349,7 +349,7 @@ inline bool IsColMajor(const PMatLayout& layout) { return layout.order == Order::kColMajor; } -inline int FlatSize(const PMatLayout& layout) { +inline std::ptrdiff_t FlatSize(const PMatLayout& layout) { const int outerdim = layout.order == Order::kColMajor ? layout.cols : layout.rows; return layout.stride * outerdim; @@ -429,11 +429,11 @@ Scalar Element(const PMat<Scalar>& mat, int row, int col) { // Helpers for PEMat. -inline int DataBytes(const PEMat& packed) { +inline std::ptrdiff_t DataBytes(const PEMat& packed) { return FlatSize(packed.layout) * packed.data_type.size; } -inline int SumsBytes(const PEMat& packed) { +inline std::ptrdiff_t SumsBytes(const PEMat& packed) { // Packed matrices are only relevant for Ruy's TrMul implementations. For // TrMul, the number of sums is always equal to the number of columns. return packed.layout.cols * packed.sums_type.size; diff --git a/ruy/mul_params.h b/ruy/mul_params.h index d5aa27b..42a5700 100644 --- a/ruy/mul_params.h +++ b/ruy/mul_params.h @@ -103,14 +103,9 @@ class MulParams final { // The bias vector data, if not null. const AccumScalar* bias() const { return storage_.bias; } void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } - // Only for non-floating-point cases. The fixed-point part of the multiplier - // by which accumulators are multiplied before being casted to the destination - // type. This is a fixed-point quantity with 0 integer bits. Since - // (as explained in the class comment) AccumScalar must be std::int32_t, - // that means that the fixed-point format is Q0.31. For example, - // a multiplier_fixedpoint value of 2^30 has the effect of multiplying - // by one half (1/2). More generally, the effect is to multiply by - // (multiplier_fixedpoint / (2^31)). + // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) + // of the multiplier by which accumulators are multiplied before being casted + // to the destination type. AccumScalar multiplier_fixedpoint() const { return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; } @@ -132,10 +127,9 @@ class MulParams final { // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` // and `multiplier_exponent_perchannel` are used instead. // - // This must point to a buffer of as many values as there are rows or columns - // in the destination matrix, whichever is the channels dimension. Each - // channel of the destination matrix will use the corresponding buffer element - // instead of multiplier_fixedpoint. + // This must point to a buffer of as many values as there are rows in the + // destination matrix. Each row of the destination matrix will use the + // corresponding buffer element instead of multiplier_fixedpoint. const AccumScalar* multiplier_fixedpoint_perchannel() const { return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel : nullptr; @@ -205,6 +199,16 @@ class MulParams final { detail::MulParamsStorage<AccumScalar, DstScalar> storage_; void set_perchannel(bool perchannel) { + if (storage_.perchannel == perchannel) { + return; + } + if (perchannel) { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); + } else { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); + RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); + } storage_.perchannel = perchannel; } }; @@ -240,25 +244,25 @@ template <typename DstScalar> struct MulParamsStorage<std::int32_t, DstScalar> final { using AccumScalar = std::int32_t; static_assert(std::is_integral<DstScalar>::value, ""); - static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); + static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); const AccumScalar* bias = nullptr; - union { - const AccumScalar* multiplier_fixedpoint_perchannel; - // Let the default multiplier be effecively a multiplication by 1, so that - // the matmul behaves as a (saturating) plain integer matmul. Unfortunately - // 1 is not exactly representable in fixedpoint with 0 integer bits, but - // using the highest representable value is a sufficiently good - // approximation: since this specialization of MulParams is for the case - // where DstScalar is at least 2x narrower than MulScalar, the values - // for which there would be a difference will get saturated anyway. - AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max(); - }; - union { - const int* multiplier_exponent_perchannel; - // See the above comment about the default value of multiplier_fixedpoint. - int multiplier_exponent = 0; - }; + // union { // This used to be a union, temporarily flattened to debug a crash + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + // Let the default multiplier be effecively a multiplication by 1, so that + // the matmul behaves as a (saturating) plain integer matmul. Unfortunately + // 1 is not exactly representable in fixedpoint with 0 integer bits, but + // using the highest representable value is a sufficiently good + // approximation: since this specialization of MulParams is for the case + // where DstScalar is at least 2x narrower than MulScalar, the values + // for which there would be a difference will get saturated anyway. + AccumScalar multiplier_fixedpoint = 0; + //}; + // union { // This used to be a union, temporarily flattened to debug a crash + const int* multiplier_exponent_perchannel = nullptr; + // See the above comment about the default value of multiplier_fixedpoint. + int multiplier_exponent = 0; + // }; DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); ChannelDimension channel_dimension = ChannelDimension::kRow; diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc index 4bc9f87..feb7dbb 100644 --- a/ruy/mul_params_test.cc +++ b/ruy/mul_params_test.cc @@ -31,7 +31,7 @@ TEST(MulParamsTest, SpecClassSanity) { MulParamsType mul_params; EXPECT_EQ(mul_params.bias(), nullptr); - EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max()); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0); EXPECT_EQ(mul_params.multiplier_exponent(), 0); EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr); EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr); diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc index ee891cb..5080ca9 100644 --- a/ruy/prepacked_cache.cc +++ b/ruy/prepacked_cache.cc @@ -26,10 +26,10 @@ namespace { // Allocates the `data` and `sums` buffers, and sets the corresponding // pointer fields, in a PEMat whose other fields, particularly `layout` // and the runtime data types, are already populated. -int AllocateBuffers(PEMat* packed_matrix) { - const int data_bytes = DataBytes(*packed_matrix); +std::ptrdiff_t AllocateBuffers(PEMat* packed_matrix) { + const std::ptrdiff_t data_bytes = DataBytes(*packed_matrix); packed_matrix->data = detail::SystemAlignedAlloc(data_bytes); - int sums_bytes = 0; + std::ptrdiff_t sums_bytes = 0; if (!packed_matrix->sums_type.is_floating_point) { // Integer quantized matrices also need the `sums` buffer. sums_bytes = SumsBytes(*packed_matrix); @@ -93,7 +93,7 @@ PrepackedCache::Action PrepackedCache::Get(const void* src_data, } // No existing entry found. Allocate new buffers now and insert in the cache. - const int new_bytes = AllocateBuffers(packed_matrix); + const std::ptrdiff_t new_bytes = AllocateBuffers(packed_matrix); EjectUntilRoomFor(new_bytes); Entry entry{*packed_matrix, timestamp_++}; cache_.emplace(key, entry); @@ -101,7 +101,7 @@ PrepackedCache::Action PrepackedCache::Get(const void* src_data, return Action::kInsertedNewEntry; } -void PrepackedCache::EjectUntilRoomFor(int new_bytes) { +void PrepackedCache::EjectUntilRoomFor(std::ptrdiff_t new_bytes) { profiler::ScopeLabel label("PrepackedCacheEjection"); // While we are above the threshold of ejection, eject the LRU entry. while (!cache_.empty() && buffers_bytes_ + new_bytes > max_buffers_bytes_) { diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h index cb3a113..c58593e 100644 --- a/ruy/prepacked_cache.h +++ b/ruy/prepacked_cache.h @@ -101,7 +101,7 @@ class PrepackedCache final { ~PrepackedCache(); // Returns the total size in bytes of buffers held in this cache. - int BuffersBytes() const { return buffers_bytes_; } + std::ptrdiff_t BuffersBytes() const { return buffers_bytes_; } // Returns the number of packed matrices held in this cache. int MatrixCount() const { return cache_.size(); } @@ -128,11 +128,11 @@ class PrepackedCache final { private: void EjectOne(); - void EjectUntilRoomFor(int new_bytes); + void EjectUntilRoomFor(std::ptrdiff_t new_bytes); std::unordered_map<Key, Entry, KeyHash> cache_; - const int max_buffers_bytes_; - int buffers_bytes_ = 0; + const std::ptrdiff_t max_buffers_bytes_; + std::ptrdiff_t buffers_bytes_ = 0; Timestamp timestamp_ = 0; }; @@ -93,14 +93,6 @@ void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, // (e.g. the number of CPU cores in typical scenarios). At least ruy forces // each invocation to make an explicit decision here, there is no automatic // detection of the best number of threads to use in ruy. -// -// Constraints on the template parameters: -// * If DstScalar is floating-point then AccumScalar must also be. -// * If DstScalar is integral then AccumScalar must be std::int32_t. -// Please refer to MulParams' class comment for more information. When -// DstScalar is integral and is narrower than AccumScalar, additional -// MulParams fields must be set to control the scaling of internal accumulators -// before the final saturating cast to the DstScalar type. template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar> void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, @@ -122,6 +122,7 @@ inline const char* TuningName(Tuning tuning) { return #NAME; switch (tuning) { RUY_SUBPATHNAME_CASE(kA55ish) + RUY_SUBPATHNAME_CASE(kX1) RUY_SUBPATHNAME_CASE(kGeneric) default: RUY_CHECK(false); @@ -1825,7 +1826,7 @@ inline std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) { } #if RUY_PLATFORM_ARM if (path == Path::kNeon || path == Path::kNeonDotprod) { - return {Tuning::kA55ish, Tuning::kGeneric, Tuning::kAuto}; + return {Tuning::kA55ish, Tuning::kX1, Tuning::kGeneric, Tuning::kAuto}; } #endif (void)path; diff --git a/ruy/test_overflow_dst_zero_point.cc b/ruy/test_overflow_dst_zero_point.cc new file mode 100644 index 0000000..db1f08d --- /dev/null +++ b/ruy/test_overflow_dst_zero_point.cc @@ -0,0 +1,133 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test covers destination zero_points that cause internal int16 overflow. + +// Kernels tend to perform the addition of the destination zero_point in int16. +// Although this happens after the rescaling to the destination scale, it is +// still possible for this int16 addition to overflow. This should be handled +// by saturating, which ensures correct results as the subsequent cast to +// the destination 8-bit type is saturating anyway, so this second saturation +// eats any effect of the previous saturation in the int16 addition of the +// destination zero_point. +// When this is not correctly saturating, a typical effect is wrapping around +// to the opposite end of the range of int16, which causes the latter saturation +// to the int8/uint8 range to saturate to the opposite end of that, resulting +// in a large numerical difference in the output values. + +#include <limits> +#include <type_traits> +#include <vector> + +#include "ruy/context.h" +#include "ruy/gtest_wrapper.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" +#include "ruy/path.h" +#include "ruy/ruy.h" +#include "ruy/test.h" +#include "ruy/tune.h" + +namespace ruy { +namespace { + +template <typename DstScalar> +void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context, + int cols, + DstScalar dst_zero_point) { + // Set the bias value so that the int16 addition of the zero_point will + // overflow. + const int bias_value = dst_zero_point > 0 + ? std::numeric_limits<std::int16_t>::max() + : std::numeric_limits<std::int16_t>::min(); + // This is the end of the DstScalar range that we expect values will be + // clamped to. + const int expected_dst_value = dst_zero_point > 0 + ? std::numeric_limits<DstScalar>::max() + : std::numeric_limits<DstScalar>::min(); + + const std::vector<const std::int8_t> lhs_data(1, 0); + const std::vector<std::int8_t> rhs_data(cols, 0); + std::vector<DstScalar> dst_data(cols, 0); + + ruy::MulParams<std::int32_t, DstScalar> mul_params; + std::int32_t bias_data[1] = {bias_value}; + mul_params.set_bias(bias_data); + // Set the quantized multiplier to essentially 1 so we get unscaled + // accumulators in the output, only clamped. + mul_params.set_multiplier_fixedpoint( + std::numeric_limits<std::int32_t>::max()); + + ruy::Matrix<std::int8_t> lhs; + ruy::MakeSimpleLayout(1, 1, ruy::Order::kColMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data.data()); + + ruy::Matrix<std::int8_t> rhs; + ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data.data()); + + ruy::Matrix<DstScalar> dst; + ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data.data()); + dst.set_zero_point(dst_zero_point); + + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + // Check that the DstScalar overflow was clamped, not wrapped around. + for (auto d : dst_data) { + EXPECT_EQ(d, expected_dst_value); + } +} + +template <typename DstScalar> +void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context) { + // Test both a matrix*vector and a general matrix*matrix (in the sense that + // cols>1) as these may exercise different kernels. + TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, 1); + TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, 1); + if (std::is_signed<DstScalar>::value) { + TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, -1); + TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, -1); + } +} + +TEST(RuyTest, OverflowingAdditionOfDestinationZeroPoint) { + ruy::Context context; + ruy::Path runtime_enabled_paths = context.get_runtime_enabled_paths(); + for (unsigned bit = 0; bit < 8 * sizeof(ruy::Path); bit++) { + ruy::Path path = static_cast<ruy::Path>(1 << bit); + if ((path & runtime_enabled_paths) == ruy::Path::kNone) { + continue; + } + context.set_runtime_enabled_paths(path); + for (ruy::Tuning tuning : + {ruy::Tuning::kGeneric, ruy::Tuning::kA55ish, ruy::Tuning::kX1}) { + fprintf(stderr, "Testing path %s, tuning %s\n", PathName(path), + TuningName(tuning)); + context.set_explicit_tuning(tuning); + TestOverflowingAdditionOfDestinationZeroPoint<std::int8_t>(&context); + TestOverflowingAdditionOfDestinationZeroPoint<std::uint8_t>(&context); + TestOverflowingAdditionOfDestinationZeroPoint<std::int16_t>(&context); + } + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc index 100cfe3..5f22a13 100644 --- a/ruy/thread_pool.cc +++ b/ruy/thread_pool.cc @@ -25,6 +25,7 @@ limitations under the License. #include <thread> // NOLINT(build/c++11) #include "ruy/check_macros.h" +#include "ruy/denormal.h" #include "ruy/trace.h" #include "ruy/wait.h" @@ -113,6 +114,9 @@ class Thread { RUY_TRACE_SCOPE_NAME("Ruy worker thread function"); ChangeState(State::Ready); + // Suppress denormals to avoid computation inefficiency. + ScopedSuppressDenormals suppress_denormals; + // Thread main loop while (true) { RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration"); diff --git a/ruy/trmul.cc b/ruy/trmul.cc index 9345f0c..602660b 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -30,6 +30,7 @@ limitations under the License. #include "ruy/cpu_cache_params.h" #include "ruy/cpuinfo.h" #include "ruy/ctx.h" +#include "ruy/denormal.h" #include "ruy/mat.h" #include "ruy/matrix.h" #include "ruy/mul_params.h" @@ -307,6 +308,12 @@ void TrMul(Ctx* ctx, TrMulParams* params) { GetTentativeThreadCount(ctx, rows, cols, depth); const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams(); + // Suppress denormals to avoid computation inefficiency. + // Note this only handles the denormal suppression on the main thread. As for + // worker threads, the suppression is handled in each thread's main loop. See + // the corresponding code in thread_pool.cc for details. + ScopedSuppressDenormals suppress_denormals; + // Case of running this TrMul as a simple loop. // This is a good place to start reading this function: all the rest // of this function is just an optimized, but functionally equivalent, diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h index e68d909..486a6c6 100644 --- a/ruy/trmul_params.h +++ b/ruy/trmul_params.h @@ -53,7 +53,9 @@ constexpr int kMaxMulParamsSize = kMaxMulParamsSizeQuantizedIntegerCase)); // OK to adjust as needed, but we want to avoid unnecessarily inflating that. -static_assert(kMaxMulParamsSize <= 32, ""); +// Temporarily bumped from 32 to 48 as part of temporarily not using unions +// in MulParams. +static_assert(kMaxMulParamsSize <= 48, ""); // Type-erased data needed for implementing TrMul. struct TrMulParams { diff --git a/ruy/tune.cc b/ruy/tune.cc index 1f615bf..004bd5a 100644 --- a/ruy/tune.cc +++ b/ruy/tune.cc @@ -23,7 +23,13 @@ limitations under the License. namespace ruy { Tuning TuningResolver::ResolveNow(CpuInfo* cpuinfo) { - return cpuinfo->CurrentCpuIsA55ish() ? Tuning::kA55ish : Tuning::kGeneric; + if (cpuinfo->CurrentCpuIsA55ish()) { + return Tuning::kA55ish; + } + if (cpuinfo->CurrentCpuIsX1()) { + return Tuning::kX1; + } + return Tuning::kGeneric; } TuningResolver::TuningResolver() @@ -69,7 +69,13 @@ enum class Tuning { // A55r1 supports dotprod unlike A55r0 and A53, they are not using the same // kernels in practice anyway, so there was no need to distinguish them with // separate Tuning values. - kA55ish + kA55ish, + // Use code tuned for Cortex-X1 CPUs. Currently, the driver to distinguish + // this CPU is the get maximum performance on the dotprod kernels, where we + // attain high performance simply by avoiding any manual loop unrolling. As a + // purely performance oriented microarchitecture, there will likely be + // additional reasons to distinguish the X1 from other CPUs. + kX1 }; // Why a TuningResolver class? |