aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2021-06-16 12:08:00 -0700
committerCopybara-Service <copybara-worker@google.com>2021-06-16 12:08:32 -0700
commit22f89228893b3abae601004a6e48b64e089fc626 (patch)
treefc236951db7127ed496e1365ff79b36a18ea44d5
parent23633b37099b614a2f836ef012cafc8087fdb98c (diff)
downloadruy-22f89228893b3abae601004a6e48b64e089fc626.tar.gz
Fix integer overflow causing incorrect results.
Kernels perform the addition of the destination zero_point in int16. This addition needed to be saturating to avoid wrapping around. Thanks to Marat Dukhan for reporting and debugging this issue. Additionally, this commit: - makes the new Cortex-X1 tuned kernels tested. - adds Context::get_runtime_enabled_paths() to query the runtime CPU detection from the public Context interface. - updates the Bazel-to-CMake converter to support some minor recent BUILD changes. PiperOrigin-RevId: 379778779
-rwxr-xr-xcmake/bazel_to_cmake.py299
-rw-r--r--ruy/BUILD16
-rw-r--r--ruy/CMakeLists.txt20
-rw-r--r--ruy/context.cc5
-rw-r--r--ruy/context.h3
-rw-r--r--ruy/kernel_arm32.cc8
-rw-r--r--ruy/kernel_arm64.cc120
-rw-r--r--ruy/profiler/CMakeLists.txt2
-rw-r--r--ruy/test.h3
-rw-r--r--ruy/test_overflow_dst_zero_point.cc133
10 files changed, 398 insertions, 211 deletions
diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py
index ba1a38b..fc01c2a 100755
--- a/cmake/bazel_to_cmake.py
+++ b/cmake/bazel_to_cmake.py
@@ -49,88 +49,94 @@ replacements = [
['selects.config_setting_group', 'config_setting_group'],
['@com_google_googletest//:gtest', 'gtest'],
['@com_google_googletest//:gtest_main', 'gtest_main'],
- ['@cpuinfo//:cpuinfo_with_unstripped_include_path', 'cpuinfo'],
+ ['@cpuinfo', 'cpuinfo'],
]
def preprocess_input_text(text):
- result = text
- for replacement in replacements:
- result = result.replace(replacement[0], replacement[1])
- return result
+ result = text
+ for replacement in replacements:
+ result = result.replace(replacement[0], replacement[1])
+ return result
def set_cmake_list(list_name, values, indent):
- semicolon_separated = ";".join(values)
- print(f'{indent}set({list_name} "{semicolon_separated}")')
+ semicolon_separated = ';'.join(values)
+ print(f'{indent}set({list_name} "{semicolon_separated}")')
def generate_cmake_select(select_name, dict):
- new_if_branch_keyword = 'if'
- default_value = []
- for key in dict:
- condition = ''
- if key == '//conditions:default':
- default_value = dict[key]
- continue
- elif re.search(r':windows$', key):
- condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows'
- elif re.search(r':ppc$', key):
- condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le'
- elif re.search(r':s390x$', key):
- condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x'
- elif re.search(r':fuchsia$', key):
- condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia'
- elif re.search(r':arm32_assuming_neon$', key):
- condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm'
- elif re.search(r':do_not_want_O3$', key):
- # Ruy is a specialist library: we always want code to be compiled
- # with -O3 unless the build type is Debug or the compiler does not
- # support that flag syntax.
- condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC'
- elif re.search(r':x86_64_and_not_msvc$', key):
- condition = '(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC'
- elif re.search(r':windows_msvc$', key):
- condition = 'MSVC'
- elif re.search(r':ruy_profiler$', key):
- condition = '${RUY_PROFILER}'
- else:
- raise ValueError(f'Unhandled key in select: {key}')
+ new_if_branch_keyword = 'if'
+ default_value = []
+ for key in dict:
+ condition = ''
+ if key == '//conditions:default':
+ default_value = dict[key]
+ continue
+ elif re.search(r':windows$', key):
+ condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows'
+ elif re.search(r':ppc$', key):
+ condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR '
+ 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le')
+ elif re.search(r':s390x$', key):
+ condition = ('CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR '
+ 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390x')
+ elif re.search(r':fuchsia$', key):
+ condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia'
+ elif re.search(r':arm32_assuming_neon$', key):
+ condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm'
+ elif re.search(r':do_not_want_O3$', key):
+ # Ruy is a specialist library: we always want code to be compiled
+ # with -O3 unless the build type is Debug or the compiler does not
+ # support that flag syntax.
+ condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC'
+ elif re.search(r':x86_64_and_not_msvc$', key):
+ condition = ('(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR '
+ 'CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC')
+ elif re.search(r':windows_msvc$', key):
+ condition = 'MSVC'
+ elif re.search(r':ruy_profiler$', key):
+ condition = '${RUY_PROFILER}'
+ elif re.search(r'//tools/cc_target_os:xtensa$', key):
+ condition = 'CMAKE_SYSTEM_NAME STREQUAL xtensa-esp32-elf'
+ else:
+ raise ValueError(f'Unhandled key in select: {key}')
- print(f'{new_if_branch_keyword}({condition})')
- set_cmake_list(select_name, dict[key], ' ')
- new_if_branch_keyword = 'elseif'
+ print(f'{new_if_branch_keyword}({condition})')
+ set_cmake_list(select_name, dict[key], ' ')
+ new_if_branch_keyword = 'elseif'
- print('else()')
- set_cmake_list(select_name, default_value, ' ')
+ print('else()')
+ set_cmake_list(select_name, default_value, ' ')
- print('endif()\n')
+ print('endif()\n')
def trim_multiple_ruy_prefixes(name):
- return re.sub(r'(ruy_)+ruy', 'ruy', name)
+ return re.sub(r'(ruy_)+ruy', 'ruy', name)
+
def get_cmake_local_target_name(name):
- global package_prefix
- return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}')
+ global package_prefix
+ return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}')
def get_cmake_dep_target_name(name):
- if name in external_targets:
- return name
- if name.startswith('$'):
- # Happens for deps that are the result of expanding a select() that we
- # have compiled to expanding a variable.
- return name
- if name.startswith('//'):
- after_last_slash = name.split('/')[-1]
- if not ':' in after_last_slash:
- name = f'{name}:{after_last_slash}'
- raw=name[2:].replace('/', '_').replace(':', '_')
- return trim_multiple_ruy_prefixes(raw)
- if name.startswith(':'):
- name = name[1:]
- return get_cmake_local_target_name(name)
+ if name in external_targets:
+ return name
+ if name.startswith('$'):
+ # Happens for deps that are the result of expanding a select() that we
+ # have compiled to expanding a variable.
+ return name
+ if name.startswith('//'):
+ after_last_slash = name.split('/')[-1]
+ if ':' not in after_last_slash:
+ name = f'{name}:{after_last_slash}'
+ raw = name[2:].replace('/', '_').replace(':', '_')
+ return trim_multiple_ruy_prefixes(raw)
+ if name.startswith(':'):
+ name = name[1:]
+ return get_cmake_local_target_name(name)
#
@@ -139,45 +145,45 @@ def get_cmake_dep_target_name(name):
def package(**kwargs):
- pass
+ pass
def exports_files(*args):
- pass
+ pass
def load(filename, *args):
- if filename.startswith('@'):
- return
- elif filename.startswith(':'):
- filename = os.path.join(bazel_package_dir, filename[1:])
- elif filename.startswith('//'):
- split = filename[2:].split(':')
- filename = os.path.join(bazel_workspace_dir, split[0], split[1])
+ if filename.startswith('@'):
+ return
+ elif filename.startswith(':'):
+ filename = os.path.join(bazel_package_dir, filename[1:])
+ elif filename.startswith('//'):
+ split = filename[2:].split(':')
+ filename = os.path.join(bazel_workspace_dir, split[0], split[1])
- src_file_content = open(filename).read()
- processed_file_content = preprocess_input_text(src_file_content)
- exec(processed_file_content, globals(), globals())
+ src_file_content = open(filename).read()
+ processed_file_content = preprocess_input_text(src_file_content)
+ exec(processed_file_content, globals(), globals())
def config_setting(**kwargs):
- # Nothing to do since our implementation of select() is based on parsing
- # the names of config_settings, not looking deep into their actual
- # implementation.
- pass
+ # Nothing to do since our implementation of select() is based on parsing
+ # the names of config_settings, not looking deep into their actual
+ # implementation.
+ pass
def filegroup(**kwargs):
- pass
+ pass
def config_setting_group(**kwargs):
- # See config_setting.
- pass
+ # See config_setting.
+ pass
def bzl_library(**kwargs):
- pass
+ pass
select_index = 0
@@ -185,95 +191,96 @@ select_cache = {}
def select(select_dict):
- global select_index
- global select_cache
- global package_prefix
- key = pickle.dumps(sorted(select_dict.items()))
- if key in select_cache:
- select_name = select_cache[key]
- else:
- unique_values = sorted(set(itertools.chain.from_iterable(select_dict.values()))) # sorting ensures determinism, no spurious diffs
- description = '_'.join(unique_values)
- select_name = f'{package_prefix}_{select_index}_{description}'
- select_name = select_name.replace('c++', 'cxx')
- select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name)
- select_index = select_index + 1
- select_cache[key] = select_name
- generate_cmake_select(select_name, select_dict)
-
- return [f'${{{select_name}}}']
+ global select_index
+ global select_cache
+ global package_prefix
+ key = pickle.dumps(sorted(select_dict.items()))
+ if key in select_cache:
+ select_name = select_cache[key]
+ else:
+ unique_values = sorted(
+ set(itertools.chain.from_iterable(select_dict.values()))
+ ) # sorting ensures determinism, no spurious diffs
+ description = '_'.join(unique_values)
+ select_name = f'{package_prefix}_{select_index}_{description}'
+ select_name = select_name.replace('c++', 'cxx')
+ select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name)
+ select_index = select_index + 1
+ select_cache[key] = select_name
+ generate_cmake_select(select_name, select_dict)
+
+ return [f'${{{select_name}}}']
def generic_rule(rule_name, **kwargs):
- print(f'{rule_name}(')
- for key in kwargs.keys():
- values = kwargs[key]
- if type(values) is bool:
- if values:
- print(f' {key.upper()}')
- continue
- else:
- raise ValueError(
- 'Cannot specify FALSE boolean args in CMake')
- if key == 'visibility':
- if values == ['//visibility:public']:
- print(f' PUBLIC')
- continue
- if key == 'tags':
- values = list(filter(lambda x : not x.startswith('req_dep'), values))
- if not values:
- continue
+ print(f'{rule_name}(')
+ for key in kwargs.keys():
+ values = kwargs[key]
+ if type(values) is bool:
+ if values:
print(f' {key.upper()}')
- if type(values) is list:
- for value in values:
- if key == 'deps':
- target_name = get_cmake_dep_target_name(value)
- print(f' {target_name}')
- else:
- print(f' {value}')
+ continue
+ else:
+ raise ValueError('Cannot specify FALSE boolean args in CMake')
+ if key == 'visibility':
+ if values == ['//visibility:public']:
+ print(f' PUBLIC')
+ continue
+ if key == 'tags':
+ values = list(filter(lambda x: not x.startswith('req_dep'), values))
+ if not values:
+ continue
+ print(f' {key.upper()}')
+ if type(values) is list:
+ for value in values:
+ if key == 'deps':
+ target_name = get_cmake_dep_target_name(value)
+ print(f' {target_name}')
else:
- if key == 'name':
- target_name = get_cmake_local_target_name(values)
- print(f' {target_name}')
- else:
- print(f' {values}')
- print(')\n')
+ print(f' {value}')
+ else:
+ if key == 'name':
+ target_name = get_cmake_local_target_name(values)
+ print(f' {target_name}')
+ else:
+ print(f' {values}')
+ print(')\n')
def cc_library(**kwargs):
- generic_rule('ruy_cc_library', **kwargs)
+ generic_rule('ruy_cc_library', **kwargs)
def cc_test(**kwargs):
- generic_rule('ruy_cc_test', **kwargs)
+ generic_rule('ruy_cc_test', **kwargs)
def cc_binary(**kwargs):
- generic_rule('ruy_cc_binary', **kwargs)
+ generic_rule('ruy_cc_binary', **kwargs)
#
# Program entry point.
#
if __name__ == "__main__":
- if len(sys.argv) != 3:
- print("Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir")
- sys.exit(1)
+ if len(sys.argv) != 3:
+ print('Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir')
+ sys.exit(1)
- bazel_workspace_dir = sys.argv[1]
- bazel_package_dir = sys.argv[2]
- bazel_package_relative_dir = os.path.relpath(
- bazel_package_dir, bazel_workspace_dir)
- package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_')
+ bazel_workspace_dir = sys.argv[1]
+ bazel_package_dir = sys.argv[2]
+ bazel_package_relative_dir = os.path.relpath(bazel_package_dir,
+ bazel_workspace_dir)
+ package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_')
- print("""# This file is generated (whence no license header). Do not edit!
+ print("""# This file is generated (whence no license header). Do not edit!
# To regenerate, run:
# cmake/bazel_to_cmake.sh
""")
- src_build_file = os.path.join(bazel_package_dir, "BUILD")
- src_build_content = open(src_build_file).read()
- processed_build_content = preprocess_input_text(src_build_content)
- exec(processed_build_content)
+ src_build_file = os.path.join(bazel_package_dir, 'BUILD')
+ src_build_content = open(src_build_file).read()
+ processed_build_content = preprocess_input_text(src_build_content)
+ exec(processed_build_content)
- print("ruy_add_all_subdirs()")
+ print('ruy_add_all_subdirs()')
diff --git a/ruy/BUILD b/ruy/BUILD
index b16f161..d04a45d 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -1205,6 +1205,22 @@ ruy_test(
],
)
+cc_test(
+ name = "test_overflow_dst_zero_point",
+ srcs = [
+ "test_overflow_dst_zero_point.cc",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":gtest_wrapper",
+ ":matrix",
+ ":path",
+ ":ruy",
+ ":test_lib",
+ ":tune",
+ ],
+)
+
bzl_library(
name = "ruy_test_ext.oss_bzl",
srcs = ["ruy_test_ext.oss.bzl"],
diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt
index b83bc8c..0337768 100644
--- a/ruy/CMakeLists.txt
+++ b/ruy/CMakeLists.txt
@@ -107,6 +107,8 @@ ruy_cc_library(
if(CMAKE_SYSTEM_NAME STREQUAL Windows)
set(ruy_3_pthread "")
+elseif(CMAKE_SYSTEM_NAME STREQUAL xtensa-esp32-elf)
+ set(ruy_3_pthread "")
else()
set(ruy_3_pthread "-pthread")
endif()
@@ -1709,4 +1711,22 @@ ruy_cc_test(
slow
)
+ruy_cc_test(
+ NAME
+ ruy_test_overflow_dst_zero_point
+ SRCS
+ test_overflow_dst_zero_point.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_gtest_wrapper
+ ruy_matrix
+ ruy
+ ruy_path
+ ruy_test_lib
+ ruy_tune
+)
+
ruy_add_all_subdirs()
diff --git a/ruy/context.cc b/ruy/context.cc
index 4661738..342ce52 100644
--- a/ruy/context.cc
+++ b/ruy/context.cc
@@ -55,4 +55,9 @@ void Context::set_runtime_enabled_paths(Path paths) {
mutable_ctx()->SetRuntimeEnabledPaths(paths);
}
+Path Context::get_runtime_enabled_paths() {
+ // The `& kAllPaths` hides internal test-only paths.
+ return mutable_ctx()->GetRuntimeEnabledPaths() & ruy::kAllPaths;
+}
+
} // namespace ruy
diff --git a/ruy/context.h b/ruy/context.h
index 79a4b5c..f148f0f 100644
--- a/ruy/context.h
+++ b/ruy/context.h
@@ -90,6 +90,9 @@ class Context final {
// Paths in kNonArchPaths are always implicitly supported.
void set_runtime_enabled_paths(Path paths);
+ // Returns the set of Path's that are available.
+ Path get_runtime_enabled_paths();
+
private:
CtxImpl* const impl_;
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index b20f668..8782dce 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -1102,7 +1102,7 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
- "vadd.i16 q14, q14, q13\n"
+ "vqadd.s16 q14, q14, q13\n"
// Cast-and-saturate from int16 to uint8
// Now all 8 1-byte values are in d30.
@@ -1226,7 +1226,7 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
- "vadd.i16 q14, q14, q13\n"
+ "vqadd.s16 q14, q14, q13\n"
// Cast-and-saturate from int16 to int8
// Now all 8 1-byte values are in d30.
@@ -2014,7 +2014,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
- "vadd.i16 q14, q14, q13\n"
+ "vqadd.s16 q14, q14, q13\n"
// Cast-and-saturate from int16 to uint8
"vqmovun.s16 d30, q14\n"
@@ -2126,7 +2126,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
- "vadd.i16 q14, q14, q13\n"
+ "vqadd.s16 q14, q14, q13\n"
// Cast-and-saturate from int16 to int8
"vqmovn.s16 d30, q14\n"
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index b06a06e..5424107 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -623,8 +623,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
@@ -750,8 +750,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
@@ -1472,7 +1472,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
// Now all data is in the first 32-bits of v16
@@ -1553,7 +1553,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
@@ -2394,9 +2394,9 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
- "add v17.8h, v17.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
@@ -2526,9 +2526,9 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
- "add v17.8h, v17.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
@@ -3713,14 +3713,14 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
@@ -3888,14 +3888,14 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
@@ -4967,14 +4967,14 @@ void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
@@ -5142,14 +5142,14 @@ void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
@@ -5947,7 +5947,7 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8, leaving all data in the
// lower half of v16.
@@ -6043,7 +6043,7 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
- "add v16.8h, v16.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
@@ -6946,14 +6946,14 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
RUY_MAKE_ZERO(v31)
// Add the destination zero point
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
@@ -7120,14 +7120,14 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
RUY_MAKE_ZERO(v31)
// Add the destination zero point
- "add v16.8h, v16.8h, v14.8h\n"
- "add v17.8h, v17.8h, v14.8h\n"
- "add v18.8h, v18.8h, v14.8h\n"
- "add v19.8h, v19.8h, v14.8h\n"
- "add v20.8h, v20.8h, v14.8h\n"
- "add v21.8h, v21.8h, v14.8h\n"
- "add v22.8h, v22.8h, v14.8h\n"
- "add v23.8h, v23.8h, v14.8h\n"
+ "sqadd v16.8h, v16.8h, v14.8h\n"
+ "sqadd v17.8h, v17.8h, v14.8h\n"
+ "sqadd v18.8h, v18.8h, v14.8h\n"
+ "sqadd v19.8h, v19.8h, v14.8h\n"
+ "sqadd v20.8h, v20.8h, v14.8h\n"
+ "sqadd v21.8h, v21.8h, v14.8h\n"
+ "sqadd v22.8h, v22.8h, v14.8h\n"
+ "sqadd v23.8h, v23.8h, v14.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
diff --git a/ruy/profiler/CMakeLists.txt b/ruy/profiler/CMakeLists.txt
index df4b30a..35644fa 100644
--- a/ruy/profiler/CMakeLists.txt
+++ b/ruy/profiler/CMakeLists.txt
@@ -10,6 +10,8 @@ endif()
if(CMAKE_SYSTEM_NAME STREQUAL Windows)
set(ruy_profiler_1_pthread "")
+elseif(CMAKE_SYSTEM_NAME STREQUAL xtensa-esp32-elf)
+ set(ruy_profiler_1_pthread "")
else()
set(ruy_profiler_1_pthread "-pthread")
endif()
diff --git a/ruy/test.h b/ruy/test.h
index 5aa4c41..0b05399 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -122,6 +122,7 @@ inline const char* TuningName(Tuning tuning) {
return #NAME;
switch (tuning) {
RUY_SUBPATHNAME_CASE(kA55ish)
+ RUY_SUBPATHNAME_CASE(kX1)
RUY_SUBPATHNAME_CASE(kGeneric)
default:
RUY_CHECK(false);
@@ -1825,7 +1826,7 @@ inline std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) {
}
#if RUY_PLATFORM_ARM
if (path == Path::kNeon || path == Path::kNeonDotprod) {
- return {Tuning::kA55ish, Tuning::kGeneric, Tuning::kAuto};
+ return {Tuning::kA55ish, Tuning::kX1, Tuning::kGeneric, Tuning::kAuto};
}
#endif
(void)path;
diff --git a/ruy/test_overflow_dst_zero_point.cc b/ruy/test_overflow_dst_zero_point.cc
new file mode 100644
index 0000000..db1f08d
--- /dev/null
+++ b/ruy/test_overflow_dst_zero_point.cc
@@ -0,0 +1,133 @@
+/* Copyright 2021 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test covers destination zero_points that cause internal int16 overflow.
+
+// Kernels tend to perform the addition of the destination zero_point in int16.
+// Although this happens after the rescaling to the destination scale, it is
+// still possible for this int16 addition to overflow. This should be handled
+// by saturating, which ensures correct results as the subsequent cast to
+// the destination 8-bit type is saturating anyway, so this second saturation
+// eats any effect of the previous saturation in the int16 addition of the
+// destination zero_point.
+// When this is not correctly saturating, a typical effect is wrapping around
+// to the opposite end of the range of int16, which causes the latter saturation
+// to the int8/uint8 range to saturate to the opposite end of that, resulting
+// in a large numerical difference in the output values.
+
+#include <limits>
+#include <type_traits>
+#include <vector>
+
+#include "ruy/context.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/path.h"
+#include "ruy/ruy.h"
+#include "ruy/test.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+namespace {
+
+template <typename DstScalar>
+void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context,
+ int cols,
+ DstScalar dst_zero_point) {
+ // Set the bias value so that the int16 addition of the zero_point will
+ // overflow.
+ const int bias_value = dst_zero_point > 0
+ ? std::numeric_limits<std::int16_t>::max()
+ : std::numeric_limits<std::int16_t>::min();
+ // This is the end of the DstScalar range that we expect values will be
+ // clamped to.
+ const int expected_dst_value = dst_zero_point > 0
+ ? std::numeric_limits<DstScalar>::max()
+ : std::numeric_limits<DstScalar>::min();
+
+ const std::vector<const std::int8_t> lhs_data(1, 0);
+ const std::vector<std::int8_t> rhs_data(cols, 0);
+ std::vector<DstScalar> dst_data(cols, 0);
+
+ ruy::MulParams<std::int32_t, DstScalar> mul_params;
+ std::int32_t bias_data[1] = {bias_value};
+ mul_params.set_bias(bias_data);
+ // Set the quantized multiplier to essentially 1 so we get unscaled
+ // accumulators in the output, only clamped.
+ mul_params.set_multiplier_fixedpoint(
+ std::numeric_limits<std::int32_t>::max());
+
+ ruy::Matrix<std::int8_t> lhs;
+ ruy::MakeSimpleLayout(1, 1, ruy::Order::kColMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data.data());
+
+ ruy::Matrix<std::int8_t> rhs;
+ ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data.data());
+
+ ruy::Matrix<DstScalar> dst;
+ ruy::MakeSimpleLayout(1, cols, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data.data());
+ dst.set_zero_point(dst_zero_point);
+
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ // Check that the DstScalar overflow was clamped, not wrapped around.
+ for (auto d : dst_data) {
+ EXPECT_EQ(d, expected_dst_value);
+ }
+}
+
+template <typename DstScalar>
+void TestOverflowingAdditionOfDestinationZeroPoint(ruy::Context* context) {
+ // Test both a matrix*vector and a general matrix*matrix (in the sense that
+ // cols>1) as these may exercise different kernels.
+ TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, 1);
+ TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, 1);
+ if (std::is_signed<DstScalar>::value) {
+ TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 1, -1);
+ TestOverflowingAdditionOfDestinationZeroPoint<DstScalar>(context, 8, -1);
+ }
+}
+
+TEST(RuyTest, OverflowingAdditionOfDestinationZeroPoint) {
+ ruy::Context context;
+ ruy::Path runtime_enabled_paths = context.get_runtime_enabled_paths();
+ for (unsigned bit = 0; bit < 8 * sizeof(ruy::Path); bit++) {
+ ruy::Path path = static_cast<ruy::Path>(1 << bit);
+ if ((path & runtime_enabled_paths) == ruy::Path::kNone) {
+ continue;
+ }
+ context.set_runtime_enabled_paths(path);
+ for (ruy::Tuning tuning :
+ {ruy::Tuning::kGeneric, ruy::Tuning::kA55ish, ruy::Tuning::kX1}) {
+ fprintf(stderr, "Testing path %s, tuning %s\n", PathName(path),
+ TuningName(tuning));
+ context.set_explicit_tuning(tuning);
+ TestOverflowingAdditionOfDestinationZeroPoint<std::int8_t>(&context);
+ TestOverflowingAdditionOfDestinationZeroPoint<std::uint8_t>(&context);
+ TestOverflowingAdditionOfDestinationZeroPoint<std::int16_t>(&context);
+ }
+ }
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}