aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <bsgcomp@arm.com>2024-04-29 10:53:06 +0000
committerJenkins <bsgcomp@arm.com>2024-04-29 10:53:06 +0000
commit4fda7a803eaadf00ba36bd532481a33c18952089 (patch)
treeae280dfbfbfa3e870dfb5fd7e4df80ae50c7e4a8
parentf2eda6665c12d568e179f5b0e7a24ccdc0ac824d (diff)
downloadComputeLibrary-upstream-main.tar.gz
Compute Library v24.04upstream-main
-rw-r--r--Android.bp27
-rw-r--r--CMakeLists.txt4
-rw-r--r--README.md24
-rw-r--r--SConscript25
-rw-r--r--SConstruct3
-rw-r--r--arm_compute/core/QuantizationInfo.h72
-rw-r--r--arm_compute/core/Types.h2
-rw-r--r--arm_compute/core/utils/DataTypeUtils.h26
-rw-r--r--arm_compute/function_info/GEMMInfo.h31
-rw-r--r--arm_compute/function_info/ScatterInfo.h (renamed from compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.cpp)48
-rw-r--r--arm_compute/runtime/CL/CLFunctions.h3
-rw-r--r--arm_compute/runtime/CL/functions/CLScatter.h109
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h15
-rw-r--r--arm_compute/runtime/NEON/functions/NEMatMul.h23
-rw-r--r--cmake/Options.cmake4
-rw-r--r--compute_kernel_writer/prototype/CMakeLists.txt78
-rw-r--r--compute_kernel_writer/prototype/examples/add_exp_store.cpp206
-rw-r--r--compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.cpp98
-rw-r--r--compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.h112
-rw-r--r--compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.h56
-rw-r--r--compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.cpp59
-rw-r--r--compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.h62
-rw-r--r--compute_kernel_writer/prototype/examples/writer_helper.cpp113
-rw-r--r--compute_kernel_writer/prototype/include/ckw/Error.h78
-rw-r--r--compute_kernel_writer/prototype/include/ckw/Kernel.h102
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelArgument.h107
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriter.h338
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h1286
-rw-r--r--compute_kernel_writer/prototype/include/ckw/OperandBase.h78
-rw-r--r--compute_kernel_writer/prototype/include/ckw/ScalarValue.h137
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorInfo.h153
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorOperand.h196
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h169
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TileInfo.h92
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TileOperand.h127
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h41
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/DataType.h50
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Functions.h62
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h41
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Operators.h78
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h82
-rw-r--r--compute_kernel_writer/prototype/src/Kernel.cpp163
-rw-r--r--compute_kernel_writer/prototype/src/KernelArgument.cpp66
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp371
-rw-r--r--compute_kernel_writer/prototype/src/OperandBase.cpp49
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h4189
-rw-r--r--compute_kernel_writer/prototype/src/TensorInfo.cpp77
-rw-r--r--compute_kernel_writer/prototype/src/TensorOperand.cpp272
-rw-r--r--compute_kernel_writer/prototype/src/TensorTileSampler.cpp191
-rw-r--r--compute_kernel_writer/prototype/src/TileInfo.cpp73
-rw-r--r--compute_kernel_writer/prototype/src/TileOperand.cpp135
-rw-r--r--docs/Doxyfile2
-rw-r--r--docs/user_guide/operator_list.dox4
-rw-r--r--docs/user_guide/release_version_and_change_log.dox11
-rw-r--r--examples/CMakeLists.txt8
-rw-r--r--examples/SConscript17
-rw-r--r--examples/neon_gemm_s8_f32.cpp239
-rw-r--r--filelist.json38
-rwxr-xr-xscripts/clang_tidy_rules.py6
-rwxr-xr-xscripts/format_code.py8
-rwxr-xr-xscripts/generate_android_bp.py5
-rw-r--r--scripts/generate_build_files.py4
-rw-r--r--src/BUILD.bazel11
-rw-r--r--src/CMakeLists.txt11
-rw-r--r--src/common/cpuinfo/CpuIsaInfo.cpp4
-rw-r--r--src/core/CL/cl_kernels/common/elementwise_operation.cl6
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp4
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp79
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp15
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp24
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp137
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp142
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp19
-rw-r--r--src/core/NEON/kernels/arm_gemm/interleave-8way.cpp267
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp111
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp3240
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_gemv_fp16fp32fp16_dot_16VL/generic.cpp530
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp417
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp448
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp513
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp3
-rw-r--r--src/core/NEON/kernels/arm_gemm/mergeresults.cpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp32_bf16_8x12.hpp2809
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/list-sve.hpp3
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/list.hpp3
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/sve_merge_fp32_bf16_8x3VL.hpp2137
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp60
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp87
-rw-r--r--src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/transform.cpp11
-rw-r--r--src/core/common/Registrars.h16
-rw-r--r--src/core/utils/helpers/tensor_transform.cpp7
-rw-r--r--src/core/utils/quantization/AsymmHelpers.cpp16
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp324
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h41
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp21
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h26
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h3
-rw-r--r--src/cpu/kernels/CpuQuantizeKernel.cpp170
-rw-r--r--src/cpu/kernels/CpuQuantizeKernel.h25
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp72
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h12
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp24
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp47
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/neon/impl.h26
-rw-r--r--src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp6
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp16.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp32.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/impl.cpp353
-rw-r--r--src/cpu/kernels/softmax/generic/neon/impl.h197
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp17
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp16.cpp774
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp32.cpp578
-rw-r--r--src/cpu/kernels/softmax/list.h14
-rw-r--r--src/cpu/operators/CpuConv2d.cpp22
-rw-r--r--src/cpu/operators/CpuGemm.cpp13
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp13
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp94
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h9
-rw-r--r--src/cpu/operators/CpuMatMul.cpp28
-rw-r--r--src/cpu/operators/CpuQuantize.cpp5
-rw-r--r--src/cpu/operators/CpuSoftmax.cpp90
-rw-r--r--src/cpu/operators/CpuSoftmax.h9
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp149
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h3
-rw-r--r--src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.cpp139
-rw-r--r--src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.h24
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h99
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h16
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.cpp22
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.h10
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h71
-rw-r--r--src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h17
-rw-r--r--src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h3
-rw-r--r--src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h10
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h12
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.cpp14
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h22
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.cpp16
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h22
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp14
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h19
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp15
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h22
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp14
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h23
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp93
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h130
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp95
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h127
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp5
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp12
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h16
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp10
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h13
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp19
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h20
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp17
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h22
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuClamp.cpp3
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuMatMul.cpp4
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.cpp11
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp3
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp114
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h135
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h140
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp181
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h120
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp212
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.h103
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp364
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.h112
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp393
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.h116
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp274
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h115
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp267
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h107
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp171
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h106
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp470
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.h132
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp161
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.h107
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp279
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.h120
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp89
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.h86
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp325
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.h92
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp78
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.h79
-rw-r--r--src/gpu/cl/operators/ClScatter.cpp93
-rw-r--r--src/gpu/cl/operators/ClScatter.h96
-rw-r--r--src/runtime/CL/functions/CLScatter.cpp93
-rw-r--r--support/Bfloat16.h18
-rw-r--r--tests/CMakeLists.txt3
-rw-r--r--tests/SConscript5
-rw-r--r--tests/datasets/LargeConvolutionLayerDataset.h12
-rw-r--r--tests/datasets/LargeGEMMDataset.h21
-rw-r--r--tests/datasets/ScatterDataset.h128
-rw-r--r--tests/datasets/SmallGEMMDataset.h19
-rw-r--r--tests/validation/CL/DepthwiseConvolutionLayer.cpp248
-rw-r--r--tests/validation/CL/GEMMLowp.cpp13
-rw-r--r--tests/validation/CL/ScatterLayer.cpp116
-rw-r--r--tests/validation/CPP/DFT.cpp4
-rw-r--r--tests/validation/Helpers.h45
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp96
-rw-r--r--tests/validation/NEON/DepthwiseConvolutionLayer.cpp479
-rw-r--r--tests/validation/NEON/DilatedConvolutionLayer.cpp4
-rw-r--r--tests/validation/NEON/ElementwiseDivision.cpp6
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp6
-rw-r--r--tests/validation/NEON/GEMM.cpp145
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp125
-rw-r--r--tests/validation/NEON/LSTMLayerQuantized.cpp6
-rw-r--r--tests/validation/NEON/MatMul.cpp408
-rw-r--r--tests/validation/NEON/PoolingLayer.cpp16
-rw-r--r--tests/validation/NEON/QuantizationLayer.cpp30
-rw-r--r--tests/validation/NEON/RNNLayer.cpp4
-rw-r--r--tests/validation/NEON/SoftmaxLayer.cpp37
-rw-r--r--tests/validation/dynamic_fusion/gpu/Integration.cpp10
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp10
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp6
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp3
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp3
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp16
-rw-r--r--tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp184
-rw-r--r--tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h159
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h4
-rw-r--r--tests/validation/fixtures/GEMMFixture.h60
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h224
-rw-r--r--tests/validation/fixtures/MatMulFixture.h383
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h174
-rw-r--r--tests/validation/reference/ActivationLayer.cpp27
-rw-r--r--tests/validation/reference/ActivationLayer.h23
-rw-r--r--tests/validation/reference/DepthConvertLayer.cpp4
-rw-r--r--tests/validation/reference/ElementwiseOperations.cpp11
-rw-r--r--tests/validation/reference/GEMM.cpp79
-rw-r--r--tests/validation/reference/GEMM.h11
-rw-r--r--tests/validation/reference/GEMMLowp.cpp12
-rw-r--r--tests/validation/reference/GEMMLowp.h11
-rw-r--r--tests/validation/reference/Permute.cpp18
-rw-r--r--tests/validation/reference/QuantizationLayer.cpp12
-rw-r--r--tests/validation/reference/ReshapeLayer.cpp15
-rw-r--r--tests/validation/reference/ScatterLayer.cpp113
-rw-r--r--tests/validation/reference/ScatterLayer.h (renamed from src/dynamic_fusion/sketch/gpu/GpuKernelArgument.cpp)29
-rw-r--r--utils/GraphUtils.cpp6
-rw-r--r--utils/TypePrinter.h74
272 files changed, 18076 insertions, 17274 deletions
diff --git a/Android.bp b/Android.bp
index 2983e2e21..6cc85f192 100644
--- a/Android.bp
+++ b/Android.bp
@@ -172,6 +172,7 @@ cc_library_static {
proprietary: true,
local_include_dirs: ["build/android-arm64v8a/src/core",
"build/android-arm64v8a/src/core/CL",
+ "compute_kernel_writer/include",
"src/core/common",
"src/core/helpers",
"src/core/NEON/kernels/arm_gemm",
@@ -323,14 +324,17 @@ cc_library_static {
"src/core/NEON/kernels/arm_conv/pooling/pooling_u8.cpp",
"src/core/NEON/kernels/arm_conv/pooling/pooling_u8q.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp",
+ "src/core/NEON/kernels/arm_gemm/interleave-8way.cpp",
"src/core/NEON/kernels/arm_gemm/interleave_indirect-sve.cpp",
"src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp",
"src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp",
@@ -621,7 +625,6 @@ cc_library_static {
"src/dynamic_fusion/sketch/attributes/ReshapeAttributes.cpp",
"src/dynamic_fusion/sketch/attributes/ResizeAttributes.cpp",
"src/dynamic_fusion/sketch/attributes/SoftmaxAttributes.cpp",
- "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentGraph.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentStream.cpp",
@@ -634,8 +637,6 @@ cc_library_static {
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp",
- "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp",
- "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp",
@@ -657,19 +658,6 @@ cc_library_static {
"src/dynamic_fusion/sketch/gpu/operators/GpuSub.cpp",
"src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp",
"src/dynamic_fusion/sketch/gpu/operators/internal/GpuElementwiseBinaryCommon.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp",
"src/gpu/cl/ClContext.cpp",
"src/gpu/cl/ClKernelLibrary.cpp",
"src/gpu/cl/ClQueue.cpp",
@@ -720,6 +708,7 @@ cc_library_static {
"src/gpu/cl/kernels/ClQuantizeKernel.cpp",
"src/gpu/cl/kernels/ClReshapeKernel.cpp",
"src/gpu/cl/kernels/ClScaleKernel.cpp",
+ "src/gpu/cl/kernels/ClScatterKernel.cpp",
"src/gpu/cl/kernels/ClSoftmaxKernel.cpp",
"src/gpu/cl/kernels/ClTransposeKernel.cpp",
"src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp",
@@ -771,6 +760,7 @@ cc_library_static {
"src/gpu/cl/operators/ClQuantize.cpp",
"src/gpu/cl/operators/ClReshape.cpp",
"src/gpu/cl/operators/ClScale.cpp",
+ "src/gpu/cl/operators/ClScatter.cpp",
"src/gpu/cl/operators/ClSoftmax.cpp",
"src/gpu/cl/operators/ClSub.cpp",
"src/gpu/cl/operators/ClTranspose.cpp",
@@ -869,6 +859,7 @@ cc_library_static {
"src/runtime/CL/functions/CLReshapeLayer.cpp",
"src/runtime/CL/functions/CLReverse.cpp",
"src/runtime/CL/functions/CLScale.cpp",
+ "src/runtime/CL/functions/CLScatter.cpp",
"src/runtime/CL/functions/CLSelect.cpp",
"src/runtime/CL/functions/CLSlice.cpp",
"src/runtime/CL/functions/CLSoftmaxLayer.cpp",
@@ -1224,6 +1215,7 @@ cc_library_static {
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
@@ -1311,6 +1303,9 @@ cc_library_static {
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0462c2f08..c67479ce4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -28,7 +28,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
list(APPEND CMAKE_MESSAGE_CONTEXT ArmCompute)
project(
ArmCompute
- VERSION 35.0.1
+ VERSION 36.0.0
DESCRIPTION
"The Arm Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A CPU and Arm® Mali™ GPU architectures"
LANGUAGES C CXX ASM)
@@ -57,7 +57,7 @@ endif()
# ---------------------------------------------------------------------
# Configuration
-set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -gdwarf-2 -DARM_COMPUTE_ASSERTS_ENABLED")
+set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -gdwarf-2 -DARM_COMPUTE_ASSERTS_ENABLED -DARM_COMPUTE_DEBUG_ENABLED")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
# Default to Release Build
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
diff --git a/README.md b/README.md
index 51d4dfbe6..112f40225 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
<img src="https://raw.githubusercontent.com/ARM-software/ComputeLibrary/gh-pages/ACL_logo.png"/><br><br>
</div>
-# Compute Library ![](https://img.shields.io/badge/latest_release-24.02.1-green)
+# Compute Library ![](https://img.shields.io/badge/latest_release-24.04-green)
The Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A, Arm® Neoverse® and Arm® Mali™ GPUs architectures.<br>
@@ -37,7 +37,7 @@ Key Features:
<br>
## Documentation
-[![Documentation](https://img.shields.io/badge/documentation-24.02.1-green)](https://arm-software.github.io/ComputeLibrary/latest)
+[![Documentation](https://img.shields.io/badge/documentation-24.04-green)](https://arm-software.github.io/ComputeLibrary/latest)
> Note: The documentation includes the reference API, changelogs, build guide, contribution guide, errata, etc.
@@ -50,24 +50,24 @@ All the binaries can be downloaded from [here](https://github.com/ARM-software/C
| Platform | Operating System | Release archive (Download) |
| -------------- | ---------------- | -------------------------- |
-| Raspberry Pi 4 | Linux® 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon.tar.gz) |
-| Raspberry Pi 4 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) |
-| Odroid N2 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
-| HiKey960 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| Raspberry Pi 4 | Linux® 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon.tar.gz) |
+| Raspberry Pi 4 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) |
+| Odroid N2 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| HiKey960 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
<br>
| Architecture | Operating System | Release archive (Download) |
| ------------ | ---------------- | -------------------------- |
-| armv7 | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon-cl.tar.gz) |
-| arm64-v8a | Androidâ„¢ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-neon-cl.tar.gz) |
-| arm64-v8a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
-| arm64-v8.2-a | Androidâ„¢ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-neon-cl.tar.gz) |
-| arm64-v8.2-a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) |
+| armv7 | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon-cl.tar.gz) |
+| arm64-v8a | Androidâ„¢ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-neon-cl.tar.gz) |
+| arm64-v8a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| arm64-v8.2-a | Androidâ„¢ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-neon-cl.tar.gz) |
+| arm64-v8.2-a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) |
<br>
-Please refer to the following link for more pre-built binaries: [![](https://img.shields.io/badge/v24.02.1-bins-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/tag/v24.02.1)
+Please refer to the following link for more pre-built binaries: [![](https://img.shields.io/badge/v24.04-bins-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/tag/v24.04)
Pre-build binaries are generated with the following security / good coding practices related flags:
> -Wall, -Wextra, -Wformat=2, -Winit-self, -Wstrict-overflow=2, -Wswitch-default, -Woverloaded-virtual, -Wformat-security, -Wctor-dtor-privacy, -Wsign-promo, -Weffc++, -pedantic, -fstack-protector-strong
diff --git a/SConscript b/SConscript
index cf4ab39e8..3430ecf4e 100644
--- a/SConscript
+++ b/SConscript
@@ -30,17 +30,25 @@ import subprocess
import zlib
import json
import codecs
+import platform
-VERSION = "v24.02.1"
-LIBRARY_VERSION_MAJOR = 35
+VERSION = "v24.04"
+LIBRARY_VERSION_MAJOR = 36
LIBRARY_VERSION_MINOR = 0
-LIBRARY_VERSION_PATCH = 1
+LIBRARY_VERSION_PATCH = 0
SONAME_VERSION = str(LIBRARY_VERSION_MAJOR) + "." + str(LIBRARY_VERSION_MINOR) + "." + str(LIBRARY_VERSION_PATCH)
Import('env')
Import('vars')
Import('install_lib')
+# Workaround to enable cross-compiling from macOS® to Android™ using the Android NDK.
+if platform.system() == 'Darwin' and env['os'] == 'android':
+ # SCons incorrectly assumes that we always want to build a dynamic library on a macOS host.
+ # When targeting Android, we overwrite the following construction variables to build a shared library instead.
+ env.Replace(SHLIBSUFFIX = '.so') # overwrites .dylib
+ env.Replace(SHLINKFLAGS = ['$LINKFLAGS', '-shared']) # overwrites -dynamiclib
+
def build_bootcode_objs(sources):
arm_compute_env.Append(ASFLAGS = "-I bootcode/")
obj = arm_compute_env.Object(sources)
@@ -564,12 +572,6 @@ if env['fixed_format_kernels']:
# Dynamic fusion
if env['experimental_dynamic_fusion']:
lib_files += filelist['experimental']['dynamic_fusion']['common']
- lib_files += filelist['experimental']['dynamic_fusion']['template_writer']
-
-if "ACL_INTERNAL_TEST_CKW_IN_DF" in env["extra_cxx_flags"]:
- if not env["experimental_dynamic_fusion"]:
- print("To use ACL_INTERNAL_TEST_CKW_IN_DF experimental_dynamic_fusion must be set to 1")
- Exit(1)
lib_files += filelist['experimental']['dynamic_fusion']['ckw_driver']
# Logging files
@@ -722,10 +724,7 @@ Export('bootcode_o')
if (env['multi_isa']):
lib_static_objs, lib_shared_objs = build_multiisa_lib_objects()
-
-
-# STATIC library build.
-if (env['multi_isa']):
+ # STATIC library build.
arm_compute_a = build_library('arm_compute-static', arm_compute_env, lib_static_objs, static=True)
else:
if 'sve2' in env['arch']:
diff --git a/SConstruct b/SConstruct
index 6f498b51c..bad85e503 100644
--- a/SConstruct
+++ b/SConstruct
@@ -227,9 +227,6 @@ if env['experimental_dynamic_fusion']:
# Dynamic Fusion on GPU has a direct dependency on OpenCL and Compute Kernel Writer
env['opencl'] = 1
- # Build CKW by default
- env["extra_cxx_flags"] += ' -DACL_INTERNAL_TEST_CKW_IN_DF'
-
if env['opencl'] and env['embed_kernels'] and env['compress_kernels'] and env['os'] not in ['android']:
print("Compressed kernels are supported only for android builds")
Exit(1)
diff --git a/arm_compute/core/QuantizationInfo.h b/arm_compute/core/QuantizationInfo.h
index 471b8c57a..aecba3712 100644
--- a/arm_compute/core/QuantizationInfo.h
+++ b/arm_compute/core/QuantizationInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023 Arm Limited.
+ * Copyright (c) 2019-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_QUANTIZATION_INFO_H
-#define ARM_COMPUTE_QUANTIZATION_INFO_H
+#ifndef ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
+#define ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
#include "arm_compute/core/Rounding.h"
#include "arm_compute/core/utils/misc/Utility.h"
@@ -84,10 +84,12 @@ public:
*
* @note Used for asymmetric quantization
*
- * @param[in] scale Scale.
- * @param[in] offset Offset.
+ * @param[in] scale Scale.
+ * @param[in] offset Offset.
+ * @param[in] is_dynamic Whether this QuantizationInfo is dynamic, i.e. the scale and offset may change.
*/
- QuantizationInfo(float scale, int offset) : _scale(1, scale), _offset(1, offset)
+ QuantizationInfo(float scale, int offset, bool is_dynamic = false)
+ : _scale(1, scale), _offset(1, offset), _is_dynamic(is_dynamic)
{
}
/** Construct quantization info.
@@ -103,10 +105,12 @@ public:
*
* @note Used for asymmetric per channel quantization
*
- * @param[in] scale Scale.
- * @param[in] offset Offset.
+ * @param[in] scale Scale.
+ * @param[in] offset Offset.
+ * @param[in] is_dynamic Whether this QuantizationInfo is dynamic, i.e. the scale and offset may change.
*/
- QuantizationInfo(std::vector<float> scale, std::vector<int32_t> offset) : _scale(scale), _offset(offset)
+ QuantizationInfo(std::vector<float> scale, std::vector<int32_t> offset, bool is_dynamic = false)
+ : _scale(scale), _offset(offset), _is_dynamic(is_dynamic)
{
}
/** Scale vector accessor
@@ -125,6 +129,14 @@ public:
{
return _offset;
}
+ /** is_dynamic accessor
+ *
+ * @return If true, the scale and offset may change, so operators will need to read on every run
+ */
+ bool is_dynamic() const
+ {
+ return _is_dynamic;
+ }
/** Indicates whether this QuantizationInfo has valid settings or not
*
* @return True if the this has invalid settings.
@@ -149,6 +161,8 @@ public:
private:
std::vector<float> _scale; /**< Vector containing scaling factors */
std::vector<int32_t> _offset; /**< Vector containing zero offsets */
+ bool _is_dynamic =
+ false; /**< If true, the scale and offset may change, so operators will need to read on every run */
};
/** Check whether two quantization info are equal.
@@ -430,6 +444,19 @@ inline float dequantize(uint16_t value, float scale, int32_t offset)
return (static_cast<int>(value) - offset) * scale;
}
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] scale Scale to use for dequantization
+ * @param[in] offset Zero-offset to use for dequantization
+ *
+ * @return Dequantized value
+ */
+inline float dequantize(int32_t value, float scale, int32_t offset)
+{
+ return (static_cast<int>(value) - offset) * scale;
+}
+
/** Quantize a value given a 16-bit symmetric quantization scheme
*
* @param[in] value Value to quantize
@@ -536,6 +563,31 @@ inline float dequantize_qasymm16(uint16_t value, const QuantizationInfo &qinfo)
return dequantize_qasymm16(value, qinfo.uniform());
}
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] qinfo Quantization information to use for dequantizing
+ *
+ * @return Dequantized value
+ */
+inline float dequantize_s32(int32_t value, const UniformQuantizationInfo &qinfo)
+{
+ return (static_cast<int>(value) - qinfo.offset) * qinfo.scale;
+}
+
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] qinfo Quantization information to use for dequantizing
+ *
+ * @return Dequantized value
+ */
+
+inline float dequantize_s32(int32_t value, const QuantizationInfo &qinfo)
+{
+ return dequantize_s32(value, qinfo.uniform());
+}
+
/*
* In case of requantization of a quantized input tensor to an output tensor with another quantization
* instead of applying dequantization and then a quantization functions, we just compute new scale and
@@ -581,4 +633,4 @@ inline UniformQuantizationInfo compute_requantization_scale_offset(const Uniform
}
} // namespace arm_compute
-#endif /* ARM_COMPUTE_QUANTIZATION_INFO_H */
+#endif // ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 6b51af17d..f2f60c150 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/arm_compute/core/utils/DataTypeUtils.h b/arm_compute/core/utils/DataTypeUtils.h
index 7ea5eb767..6fabb19b6 100644
--- a/arm_compute/core/utils/DataTypeUtils.h
+++ b/arm_compute/core/utils/DataTypeUtils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
-#define ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
+#ifndef ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
+#define ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Types.h"
@@ -373,6 +373,24 @@ inline bool is_data_type_quantized_asymmetric_signed(DataType dt)
}
}
+/** Check if a given data type is of 8-bit asymmetric quantized signed type
+ *
+ * @param[in] dt Input data type.
+ *
+ * @return True if data type is of 8-bit asymmetric quantized signed type, else false.
+ */
+inline bool is_data_type_quantized_asymmetric_char(DataType dt)
+{
+ switch (dt)
+ {
+ case DataType::QASYMM8_SIGNED:
+ case DataType::QASYMM8:
+ return true;
+ default:
+ return false;
+ }
+}
+
/** Check if a given data type is of symmetric quantized type
*
* @param[in] dt Input data type.
@@ -528,4 +546,4 @@ inline std::string cpu_impl_dt(const DataType &data_type)
}
} // namespace arm_compute
-#endif /*ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H */
+#endif // ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index a827c79fd..74fe30454 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,8 @@ public:
_pretranspose_B(false),
_activation_info(),
_fixed_format(false),
- _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+ _accumulate(false)
{
}
/** Constructor
@@ -106,6 +107,7 @@ public:
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
* @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
+ * @param[in] accumulate (Optional) Whether to accumulate in destination or not
*/
GEMMInfo(bool is_a_reshaped,
bool is_b_reshaped,
@@ -120,7 +122,8 @@ public:
const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
bool fixed_format = false,
arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
- bool pretranspose_B = false) noexcept
+ bool pretranspose_B = false,
+ bool accumulate = false) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -135,7 +138,8 @@ public:
_pretranspose_B(pretranspose_B),
_activation_info(activation_info),
_fixed_format(fixed_format),
- _weight_format(weight_format)
+ _weight_format(weight_format),
+ _accumulate(accumulate)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -294,7 +298,14 @@ public:
{
return _fixed_format;
}
-
+ /** Flag which specifies if GEMM should accumulate the result in destination or not.
+ *
+ * @return True if GEMM is accumulating the result.
+ */
+ bool accumulate() const
+ {
+ return _accumulate;
+ }
/** Set fixed-format flag
*
* @param[in] fixed_format sets whether or not to use fixed-format kernels
@@ -303,12 +314,19 @@ public:
{
_fixed_format = fixed_format;
}
+ /** Set accumulate flag
+ *
+ * @param[in] accumulate sets whether or not to use accumulation
+ */
+ void set_accumulate(bool accumulate)
+ {
+ _accumulate = accumulate;
+ }
arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
-
/** Set weight format to be used
*
* @param[in] weight_format arm_compute::WeightFormat enumeration
@@ -334,6 +352,7 @@ private:
ActivationLayerInfo _activation_info;
bool _fixed_format;
arm_compute::WeightFormat _weight_format;
+ bool _accumulate;
};
} //namespace arm_compute
#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.cpp b/arm_compute/function_info/ScatterInfo.h
index 1734ce882..176a863ac 100644
--- a/compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.cpp
+++ b/arm_compute/function_info/ScatterInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,31 +22,33 @@
* SOFTWARE.
*/
-#include "ExampleKernelWriter.h"
+#ifndef ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H
+#define ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H
-#include "ckw/Error.h"
-#include "ckw/TileInfo.h"
+#include "arm_compute/core/Error.h"
-#include "ExampleComponentArgument.h"
-
-ExampleKernelWriter::ExampleKernelWriter(ckw::Kernel &kernel) : KernelWriter(kernel)
+namespace arm_compute
{
-}
-
-void ExampleKernelWriter::op_load_once(ExampleComponentArgument *tensor_or_tile, const ckw::TensorTileSampler &sampler)
+/** Scatter Function */
+enum class ScatterFunction
{
- if (!tensor_or_tile->has_tile())
+ Update = 0,
+ Add = 1,
+ Sub = 2,
+ Max = 3,
+ Min = 4
+};
+/** Scatter operator information */
+struct ScatterInfo
+{
+ ScatterInfo(ScatterFunction f, bool zero) : func(f), zero_initialization(zero)
{
- CKW_ASSERT(tensor_or_tile->has_tensor());
-
- auto &tensor = tensor_or_tile->tensor();
-
- const auto tile_name = tensor.name() + "_tile";
- auto &tile =
- declare_tile(tile_name.c_str(), ckw::TileInfo(tensor.data_type(), sampler.height(), sampler.width()));
-
- op_load(tile, tensor, sampler);
-
- tensor_or_tile->init_virtual_tensor(tile, sampler);
+ ARM_COMPUTE_ERROR_ON_MSG(f != ScatterFunction::Add && zero,
+ "Zero initialisation is only supported with Add Scatter Function.");
}
-}
+ ScatterFunction func; /**< Type of scatter function to use with scatter operator*/
+ bool zero_initialization{false}; /**< Fill output tensors with 0. Only available with add scatter function. */
+};
+} // namespace arm_compute
+
+#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_SCATTERINFO_H
diff --git a/arm_compute/runtime/CL/CLFunctions.h b/arm_compute/runtime/CL/CLFunctions.h
index cf757239c..a09ca551d 100644
--- a/arm_compute/runtime/CL/CLFunctions.h
+++ b/arm_compute/runtime/CL/CLFunctions.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -101,6 +101,7 @@
#include "arm_compute/runtime/CL/functions/CLROIAlignLayer.h"
#include "arm_compute/runtime/CL/functions/CLROIPoolingLayer.h"
#include "arm_compute/runtime/CL/functions/CLScale.h"
+#include "arm_compute/runtime/CL/functions/CLScatter.h"
#include "arm_compute/runtime/CL/functions/CLSelect.h"
#include "arm_compute/runtime/CL/functions/CLSlice.h"
#include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h"
diff --git a/arm_compute/runtime/CL/functions/CLScatter.h b/arm_compute/runtime/CL/functions/CLScatter.h
new file mode 100644
index 000000000..1c90d208b
--- /dev/null
+++ b/arm_compute/runtime/CL/functions/CLScatter.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
+#define ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/runtime/IFunction.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+class ICLTensor;
+class ITensorInfo;
+struct ScatterInfo;
+class CLCompileContext;
+
+/** Function to compute ScatterND Layer */
+class CLScatter : public IFunction
+{
+public:
+ /** Default Constructor */
+ CLScatter();
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLScatter(const CLScatter &) = delete;
+ /** Default move constructor */
+ CLScatter(CLScatter &&);
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLScatter &operator=(const CLScatter &) = delete;
+ /** Default move assignment operator */
+ CLScatter &operator=(CLScatter &&);
+ /** Default destructor */
+ ~CLScatter();
+ /** Initialise the kernel's inputs and outputs
+ *
+ * Valid data layouts:
+ * - All
+ *
+ *
+ * @param[in] compile_context The compile context to be used.
+ * @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true.
+ * @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[out] output Destination tensor. Data types supported: same as @p src.
+ * @param[in] info Scatter info object.
+ */
+ void configure(const CLCompileContext &compile_context,
+ const ICLTensor *src,
+ const ICLTensor *updates,
+ const ICLTensor *indices,
+ ICLTensor *output,
+ const ScatterInfo &info);
+ /** Initialise the kernel's inputs and output
+ *
+ * Similar to @ref CLScatter::configure()
+ */
+ void configure(const ICLTensor *src,
+ const ICLTensor *updates,
+ const ICLTensor *indices,
+ ICLTensor *output,
+ const ScatterInfo &info);
+ /** Static function to check if given info will lead to a valid configuration of @ref CLScatter
+ *
+ * @param[in] src Source tensor.
+ * @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[in] output Destination tensor. Data types supported: same as @p src.
+ * @param[in] info Scatter info containing type of scatter.
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *output,
+ const ScatterInfo &info);
+
+ // Inherited methods overridden:
+ void run() override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> _impl;
+};
+} // namespace arm_compute
+
+#endif // ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
index 824c4443a..6d07675d3 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
-#define ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
+#ifndef ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
+#define ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
#include "arm_compute/core/Types.h"
#include "arm_compute/function_info/GEMMInfo.h"
@@ -80,6 +80,7 @@ public:
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
+ * |QASYMM8_SIGNED |QASYMM8_SIGNED |F32 |F32 |
*
* @note GEMM_LOWP: low precision GEMM kernel
* This kernel performs the following computations:
@@ -88,12 +89,12 @@ public:
* -# Convert b values from QASYMM8 to int32 add b_offset to each of them.
* -# Compute the matrix product of the resulting a * b in int32.
*
- * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise
+ * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED/F32 otherwise
*
* @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
* @param[in] b Second input tensor (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL.
- * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32
- * @param[out] output Output tensor. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED
+ * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32/F32
+ * @param[out] output Output tensor. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED/F32
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
* if the reshape of matrix B should be executed only for the first run
*/
@@ -120,4 +121,4 @@ private:
std::unique_ptr<Impl> _impl;
};
} // namespace arm_compute
-#endif /*ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H */
+#endif // ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
diff --git a/arm_compute/runtime/NEON/functions/NEMatMul.h b/arm_compute/runtime/NEON/functions/NEMatMul.h
index 414fc2f3f..58dd7a6f6 100644
--- a/arm_compute/runtime/NEON/functions/NEMatMul.h
+++ b/arm_compute/runtime/NEON/functions/NEMatMul.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL
-#define ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL
+#ifndef ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL_H
+#define ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL_H
#include "arm_compute/core/Types.h"
#include "arm_compute/function_info/ActivationLayerInfo.h"
@@ -41,15 +41,27 @@ public:
{
return _fast_math;
}
+ // get fixed format flag
+ bool fixed_format() const
+ {
+ return _fixed_format;
+ }
// Set fast math flag
CpuMatMulSettings &fast_math(bool fmath)
{
_fast_math = fmath;
return *this;
- };
+ }
+ // Set fixed format flag
+ CpuMatMulSettings &fixed_format(bool fixed_format)
+ {
+ _fixed_format = fixed_format;
+ return *this;
+ }
private:
bool _fast_math{false};
+ bool _fixed_format{false};
};
// Forward declarations
@@ -87,6 +99,7 @@ public:
* |:--------------|:------------------|:--------------|
* |F32 |F32 |F32 |
* |F16 |F16 |F16 |
+ * |BFLOAT16 |BFLOAT16 |BFLOAT16 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |QASYMM8_SIGNED |
* |QASYMM8 |QASYMM8 |QASYMM8 |
*
@@ -129,4 +142,4 @@ private:
std::unique_ptr<Impl> _impl;
};
} // namespace arm_compute
-#endif /* ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL */
+#endif // ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEMATMUL_H
diff --git a/cmake/Options.cmake b/cmake/Options.cmake
index e5c8cb8ef..2e351fde7 100644
--- a/cmake/Options.cmake
+++ b/cmake/Options.cmake
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -48,7 +48,7 @@ set(ARM_COMPUTE_ARCH armv8-a CACHE STRING "Architecture to use")
# ---------------------------------------------------------------------
# Backends
-option(ARM_COMPUTE_ENABLE_BF16_VALIDATION "" OFF)
+option(ARM_COMPUTE_ENABLE_BF16_VALIDATION "" ON)
option(ARM_COMPUTE_ENABLE_SVE_VALIDATION "" OFF)
option(ENABLE_NEON "Enable Arm® Neon™ support" ON)
diff --git a/compute_kernel_writer/prototype/CMakeLists.txt b/compute_kernel_writer/prototype/CMakeLists.txt
deleted file mode 100644
index 439dcd3b7..000000000
--- a/compute_kernel_writer/prototype/CMakeLists.txt
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) 2023 Arm Limited.
-#
-# SPDX-License-Identifier: MIT
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to
-# deal in the Software without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-# sell copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
-
-#---------------------------------------------------------------------
-# Prototype
-
-add_library(ckw_prototype
- src/TileInfo.cpp
- src/TensorInfo.cpp
- src/Kernel.cpp
- src/KernelWriter.cpp
- src/OperandBase.cpp
- src/TileOperand.cpp
- src/TensorOperand.cpp
- src/TensorTileSampler.cpp
- src/KernelArgument.cpp
-)
-
-target_compile_options(ckw_prototype
- PUBLIC
- ${CKW_CXX_FLAGS}
- "$<$<CXX_COMPILER_ID:GNU>:${GNU_WARNINGS}>"
- "$<$<CONFIG:Debug>:${CKW_ASSERTS_OPTS}>"
- "$<$<BOOL:${CKW_ENABLE_ASSERTS}>:${CKW_ASSERTS_OPTS}>"
- ${CMAKE_CXX_FLAGS}
- PRIVATE
- $<$<CONFIG:Release>:-Os>
-)
-
-target_compile_definitions(ckw_prototype PUBLIC
- $<$<CONFIG:Debug>:COMPUTE_KERNEL_WRITER_DEBUG_ENABLED>
- $<$<CONFIG:Debug>:COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED>
- $<$<BOOL:${CKW_ENABLE_ASSERTS}>:COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED>
- $<$<BOOL:${CKW_ENABLE_OPENCL}>:COMPUTE_KERNEL_WRITER_OPENCL_ENABLED>
-)
-
-target_include_directories(ckw_prototype
- PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include
- PRIVATE ${CMAKE_CURRENT_LIST_DIR}
-)
-
-#---------------------------------------------------------------------
-# Examples
-
-add_library(ckw_prototype_examples_common
- examples/common/ExampleKernelWriter.cpp
- examples/common/ExampleScopedKernelWriter.cpp
- examples/common/ExampleComponentArgument.cpp
-)
-
-target_link_libraries(ckw_prototype_examples_common PUBLIC ckw_prototype)
-
-add_executable(ckw_prototype_examples_add_exp_store examples/add_exp_store.cpp)
-target_link_libraries(ckw_prototype_examples_add_exp_store PUBLIC ckw_prototype_examples_common)
-
-add_executable(writer_helper examples/writer_helper.cpp)
-target_link_libraries(writer_helper PUBLIC ckw_prototype)
diff --git a/compute_kernel_writer/prototype/examples/add_exp_store.cpp b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
deleted file mode 100644
index 2b640ca01..000000000
--- a/compute_kernel_writer/prototype/examples/add_exp_store.cpp
+++ /dev/null
@@ -1,206 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/Error.h"
-#include "ckw/KernelArgument.h"
-#include "ckw/KernelWriter.h"
-#include "ckw/TensorOperand.h"
-#include "ckw/TensorTileSampler.h"
-#include "ckw/TileOperand.h"
-
-#include "common/ExampleComponentArgument.h"
-#include "common/ExampleKernelWriter.h"
-#include "common/ExampleScopedKernelWriter.h"
-#include <iostream>
-#include <vector>
-
-using namespace ckw;
-
-TensorTileSampler create_simple_sampler(ExampleScopedKernelWriter writer)
-{
- TensorTileSampler sampler;
-
- constexpr int32_t m0 = 4;
- constexpr int32_t n0 = 4;
-
- auto &gid_0 = writer->declare_tile("gid_0", DataType::Int32);
- auto &gid_1 = writer->declare_tile("gid_1", DataType::Int32);
- auto &gid_2 = writer->declare_tile("gid_2", DataType::Int32);
-
- auto &const_0 = writer->declare_tile("0", 0);
-
- writer->op_get_global_id(gid_0, 0);
- writer->op_get_global_id(gid_1, 1);
- writer->op_get_global_id(gid_2, 2);
-
- sampler.x(gid_0);
- sampler.y(gid_1);
- sampler.z(const_0);
- sampler.b(gid_2);
-
- sampler.width(n0);
- sampler.height(m0);
-
- sampler.format(TensorSamplerFormat::C_WH_1);
- sampler.address_mode_x(TensorSamplerAddressModeX::None);
- sampler.address_mode_y(TensorSamplerAddressModeY::ClampToBorder);
- sampler.address_mode_z(TensorSamplerAddressModeZ::Skip);
-
- return sampler;
-}
-
-void op_binary_elementwise(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
-{
- auto lhs = operands.at(0);
- auto rhs = operands.at(1);
- auto dst = operands.at(2);
-
- // Load the LHS and RHS tile and prepare the tensor sampler.
- if (!lhs->has_tile() && !rhs->has_tile())
- {
- const auto sampler = create_simple_sampler(writer);
-
- writer->op_load_once(lhs, sampler);
- writer->op_load_once(rhs, sampler);
- }
- else if (lhs->has_tile())
- {
- const auto &sampler = lhs->tile_sampler();
- writer->op_load_once(rhs, sampler);
- }
- else
- {
- const auto &sampler = rhs->tile_sampler();
- writer->op_load_once(lhs, sampler);
- }
-
- auto &lhs_tile = lhs->tile();
- auto &rhs_tile = rhs->tile();
- const auto &sampler = lhs->tile_sampler();
-
- // Prepare the output tile.
- if (!dst->has_tile())
- {
- auto &tile = writer->declare_tile("dst_tile", lhs_tile.tile_info());
- dst->init_virtual_tensor(tile, sampler);
- }
-
- auto &dst_tile = dst->tile();
-
- // Perform the operation.
- writer->op_binary_expression(dst_tile, lhs_tile, BinaryOp::Add, rhs_tile);
-}
-
-void op_exp(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
-{
- auto src = operands.at(0);
- auto dst = operands.at(1);
-
- // Load the source tile and prepare the sampler.
- if (!src->has_tile())
- {
- const auto sampler = create_simple_sampler(writer);
- writer->op_load_once(src, sampler);
- }
-
- auto &src_tile = src->tile();
- const auto &sampler = src->tile_sampler();
-
- // Prepare the output tile.
- if (!dst->has_tile())
- {
- auto &tile = writer->declare_tile("dst_tile", src_tile.tile_info());
- dst->init_virtual_tensor(tile, sampler);
- }
-
- auto &dst_tile = dst->tile();
-
- // Perform the operation.
- writer->op_unary_elementwise_function(dst_tile, UnaryFunction::Exp, src_tile);
-}
-
-void op_store(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
-{
- auto src = operands.at(0);
- auto dst = operands.at(1);
-
- auto &src_tile = src->tile();
- const auto &sampler = src->tile_sampler();
- auto &dst_tensor = dst->tensor();
-
- writer->op_store(dst_tensor, src_tile, sampler);
-}
-
-int main()
-{
- Kernel kernel("example", GpuTargetLanguage::OpenCL);
- ExampleKernelWriter root_writer(kernel);
-
- ExampleScopedKernelWriter writer(&root_writer);
-
- const TensorInfo src0_info(DataType::Fp32, TensorShape({3, 10, 20, 1, 1}), TensorDataLayout::Nhwc, 0);
- const TensorInfo src1_info(DataType::Fp32, TensorShape({3, 10, 20, 1, 1}), TensorDataLayout::Nhwc, 1);
- const TensorInfo dst_info(DataType::Fp32, TensorShape({3, 10, 20, 1, 1}), TensorDataLayout::Nhwc, 2);
-
- ExampleComponentArgument src0(
- writer->declare_tensor_argument("src0", src0_info, TensorStorageType::BufferUint8Ptr));
- ExampleComponentArgument src1(
- writer->declare_tensor_argument("src1", src1_info, TensorStorageType::BufferUint8Ptr));
- ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info, TensorStorageType::BufferUint8Ptr));
-
- ExampleComponentArgument ans;
-
- op_binary_elementwise(writer, {&src0, &src1, &ans});
- op_exp(writer, {&ans, &ans});
- op_store(writer, {&ans, &dst});
-
- const auto arguments = kernel.arguments();
-
- std::cout << "\n====================\nArguments:\n====================\n";
-
- for (auto &arg : arguments)
- {
- switch (arg.type())
- {
- case ckw::KernelArgument::Type::TensorStorage:
- std::cout << "* Tensor storage: ID = " << arg.id() << ", type = " << std::hex << "0x"
- << static_cast<uint32_t>(arg.tensor_storage_type()) << std::dec << "\n";
- break;
-
- case ckw::KernelArgument::Type::TensorComponent:
- std::cout << "* Tensor component: ID = " << arg.id() << ", type = " << std::hex << "0x"
- << static_cast<uint32_t>(arg.tensor_component_type()) << std::dec << "\n";
- break;
-
- default:
- CKW_ASSERT(false);
- }
- }
-
- std::cout << "\n====================\nCode:\n====================\n";
- const auto code = root_writer.generate_code();
- std::cout << code;
-
- return 0;
-}
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.cpp b/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.cpp
deleted file mode 100644
index 55223dae0..000000000
--- a/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.cpp
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ExampleComponentArgument.h"
-
-#include "ckw/Error.h"
-
-ExampleComponentArgument::ExampleComponentArgument()
-{
-}
-
-ExampleComponentArgument::ExampleComponentArgument(ckw::TensorOperand &tensor) : _tensor(&tensor)
-{
-}
-
-ExampleComponentArgument &ExampleComponentArgument::init_virtual_tensor(ckw::TileOperand &tile,
- const ckw::TensorTileSampler &tile_sampler)
-{
- CKW_ASSERT(_tile == nullptr);
-
- _tile = &tile;
- _tile_sampler = tile_sampler;
-
- return *this;
-}
-
-bool ExampleComponentArgument::has_tensor() const
-{
- return _tensor != nullptr;
-}
-
-ckw::TensorOperand &ExampleComponentArgument::tensor()
-{
- CKW_ASSERT(_tensor != nullptr);
-
- return *_tensor;
-}
-
-const ckw::TensorOperand &ExampleComponentArgument::tensor() const
-{
- CKW_ASSERT(_tensor != nullptr);
-
- return *_tensor;
-}
-
-bool ExampleComponentArgument::has_tile() const
-{
- return _tile != nullptr;
-}
-
-ckw::TileOperand &ExampleComponentArgument::tile()
-{
- CKW_ASSERT(_tile != nullptr);
-
- return *_tile;
-}
-
-const ckw::TileOperand &ExampleComponentArgument::tile() const
-{
- CKW_ASSERT(_tile != nullptr);
-
- return *_tile;
-}
-
-ckw::TensorTileSampler &ExampleComponentArgument::tile_sampler()
-{
- CKW_ASSERT(_tile != nullptr);
-
- return _tile_sampler;
-}
-
-const ckw::TensorTileSampler &ExampleComponentArgument::tile_sampler() const
-{
- CKW_ASSERT(_tile != nullptr);
-
- return _tile_sampler;
-}
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.h b/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.h
deleted file mode 100644
index 0e029b115..000000000
--- a/compute_kernel_writer/prototype/examples/common/ExampleComponentArgument.h
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLECOMPONENTARGUMENT_H
-#define CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLECOMPONENTARGUMENT_H
-
-#include "ckw/TensorTileSampler.h"
-
-namespace ckw
-{
-class TensorOperand;
-
-class TileOperand;
-} // namespace ckw
-
-/** The argument of a dynamic fusion component which can be either user tensor or virtual tensor. */
-class ExampleComponentArgument
-{
-public:
- /** Initialize a new instance of @ref ExampleComponentArgument class for empty virtual tensor. */
- ExampleComponentArgument();
-
- /** Initialize a new instance of @ref ExampleComponentArgument class for user tensor.
- *
- * @param[in] tensor The user tensor.
- */
- explicit ExampleComponentArgument(ckw::TensorOperand &tensor);
-
- /** Set virtual tensor information (tile, sampler) for the argument.
- *
- * If the component is a user tensor, it can be treated as virtual tensor as well
- * and won't be loaded again using @ref ExampleKernelWriter::op_load_once method.
- *
- * @param[in] tile The tile that has been loaded.
- * @param[in] sampler The tensor sampling information that has been used to load the tile.
- */
- ExampleComponentArgument &init_virtual_tensor(ckw::TileOperand &tile, const ckw::TensorTileSampler &sampler);
-
- /** Get whether the argument is a user tensor. */
- bool has_tensor() const;
-
- /** Get the tensor operand.
- *
- * If the tensor is not available, throw an error.
- */
- ckw::TensorOperand &tensor();
-
- /** Get the tensor operand.
- *
- * If the tensor is not available, throw an error.
- */
- const ckw::TensorOperand &tensor() const;
-
- /** Get whether the argument contains a tile.
- *
- * The argument can be either a user tensor that has been loaded,
- * or a virtual tensor (i.e. a tile with tensor sampling information).
- */
- bool has_tile() const;
-
- /** Get the tile operand.
- *
- * If the tile is not available, throw an error.
- */
- ckw::TileOperand &tile();
-
- /** Get the tile operand.
- *
- * If the tile is not available, throw an error.
- */
- const ckw::TileOperand &tile() const;
-
- /** Get the tensor sampling information for the tile.
- *
- * If the tile is not available, throw an error.
- */
- ckw::TensorTileSampler &tile_sampler();
-
- /** Get the tensor sampling information for the tile.
- *
- * If the tile is not available, throw an error.
- */
- const ckw::TensorTileSampler &tile_sampler() const;
-
-private:
- ckw::TensorOperand *_tensor{nullptr};
- ckw::TileOperand *_tile{nullptr};
- ckw::TensorTileSampler _tile_sampler{};
-};
-
-#endif // CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLECOMPONENTARGUMENT_H
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.h b/compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.h
deleted file mode 100644
index 1528c3d93..000000000
--- a/compute_kernel_writer/prototype/examples/common/ExampleKernelWriter.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLEKERNELWRITER_H
-#define CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLEKERNELWRITER_H
-
-#include "ckw/KernelWriter.h"
-#include "ckw/TensorTileSampler.h"
-
-class ExampleComponentArgument;
-
-namespace ckw
-{
-class Kernel;
-} // namespace ckw
-
-/** Extended implementation of kernel writer for dynamic fusion. */
-class ExampleKernelWriter : public ckw::KernelWriter
-{
-public:
- /** Initialize a new instance of @ref ExampleKernelWriter class.
- *
- * @param[in] kernel The kernel to be generated.
- */
- explicit ExampleKernelWriter(ckw::Kernel &kernel);
-
- /** Load the user tensor to the tile in the same component argument if it hasn't been loaded.
- *
- * @param[in] tensor_or_tile The component argument that is either a user tensor or a virtual tensor.
- * @param[in] sampler The tensor sampling information to load the tile.
- */
- void op_load_once(ExampleComponentArgument *tensor_or_tile, const ckw::TensorTileSampler &sampler);
-};
-
-#endif // CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLEKERNELWRITER_H
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.cpp b/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.cpp
deleted file mode 100644
index 784d5ffb9..000000000
--- a/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.cpp
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ExampleScopedKernelWriter.h"
-
-#include "ExampleKernelWriter.h"
-
-ExampleScopedKernelWriter::ExampleScopedKernelWriter(ExampleKernelWriter *writer)
- : _writer(writer), _parent_id_space(writer->id_space())
-{
- _writer->next_id_space();
-}
-
-ExampleScopedKernelWriter::ExampleScopedKernelWriter(const ExampleScopedKernelWriter &other)
- : _writer(other._writer), _parent_id_space(other._writer->id_space())
-{
- _writer->next_id_space();
-}
-
-ExampleKernelWriter *ExampleScopedKernelWriter::operator->()
-{
- return _writer;
-}
-
-const ExampleKernelWriter *ExampleScopedKernelWriter::operator->() const
-{
- return _writer;
-}
-
-ExampleKernelWriter *ExampleScopedKernelWriter::writer()
-{
- return _writer;
-}
-
-const ExampleKernelWriter *ExampleScopedKernelWriter::writer() const
-{
- return _writer;
-}
diff --git a/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.h b/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.h
deleted file mode 100644
index 4655b1897..000000000
--- a/compute_kernel_writer/prototype/examples/common/ExampleScopedKernelWriter.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLESCOPEDKERNELWRITER_H
-#define CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLESCOPEDKERNELWRITER_H
-
-#include <cstdint>
-
-class ExampleKernelWriter;
-
-/** Helper to automatically manage kernel writer ID space. */
-class ExampleScopedKernelWriter
-{
-public:
- /** Initialize a new instance of @ref ExampleScopedKernelWriter class. */
- explicit ExampleScopedKernelWriter(ExampleKernelWriter *writer);
-
- /** Create a new scope from the specified scoped kernel writer. */
- ExampleScopedKernelWriter(const ExampleScopedKernelWriter &other);
-
- /** Assignment is disallowed. */
- ExampleScopedKernelWriter &operator=(const ExampleScopedKernelWriter &) = delete;
-
- /** Access the underlying kernel writer. */
- ExampleKernelWriter *operator->();
-
- /** Access the underlying kernel writer. */
- const ExampleKernelWriter *operator->() const;
-
- /** Get the kernel writer. */
- ExampleKernelWriter *writer();
-
- /** Get the kernel writer. */
- const ExampleKernelWriter *writer() const;
-
-private:
- ExampleKernelWriter *_writer;
- int32_t _parent_id_space;
-};
-
-#endif // CKW_PROTOTYPE_EXAMPLES_COMMON_EXAMPLESCOPEDKERNELWRITER_H
diff --git a/compute_kernel_writer/prototype/examples/writer_helper.cpp b/compute_kernel_writer/prototype/examples/writer_helper.cpp
deleted file mode 100644
index 8623afbf5..000000000
--- a/compute_kernel_writer/prototype/examples/writer_helper.cpp
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
-* Copyright (c) 2023 Arm Limited.
-*
-* SPDX-License-Identifier: MIT
-*
-* Permission is hereby granted, free of charge, to any person obtaining a copy
-* of this software and associated documentation files (the "Software"), to
-* deal in the Software without restriction, including without limitation the
-* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-* sell copies of the Software, and to permit persons to whom the Software is
-* furnished to do so, subject to the following conditions:
-*
-* The above copyright notice and this permission notice shall be included in all
-* copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-* SOFTWARE.
-*/
-
-#include "ckw/KernelWriter.h"
-#include "ckw/TensorTileSampler.h"
-
-#include "../include/ckw/KernelWriterHelper.h"
-#include <iostream>
-
-using namespace ckw;
-
-TensorTileSampler create_simple_sampler(KernelWriter &writer)
-{
- TensorTileSampler sampler;
-
- constexpr int32_t m0 = 1;
- constexpr int32_t n0 = 1;
-
- auto &gid_0 = writer.declare_tile("gid_0", DataType::Int32);
- auto &gid_1 = writer.declare_tile("gid_1", DataType::Int32);
- auto &gid_2 = writer.declare_tile("gid_2", DataType::Int32);
-
- auto &const_0 = writer.declare_tile("0", 0);
-
- writer.op_get_global_id(gid_0, 0);
- writer.op_get_global_id(gid_1, 1);
- writer.op_get_global_id(gid_2, 2);
-
- sampler.x(gid_0);
- sampler.y(gid_1);
- sampler.z(gid_2);
- sampler.b(const_0);
-
- sampler.width(n0);
- sampler.height(m0);
-
- sampler.format(TensorSamplerFormat::C_WH_1);
- sampler.address_mode_x(TensorSamplerAddressModeX::None);
- sampler.address_mode_y(TensorSamplerAddressModeY::ClampToBorder);
- sampler.address_mode_z(TensorSamplerAddressModeZ::Skip);
-
- return sampler;
-}
-
-int main()
-{
- Kernel kernel("test", GpuTargetLanguage::OpenCL);
- KernelWriterHelper<KernelWriter> writer(kernel);
-
- const TensorInfo src_info(DataType::Fp32, TensorShape({1, 1, 1, 1, 1}), TensorDataLayout::Nhwc, 0);
- const TensorInfo dst_info(DataType::Fp32, TensorShape({1, 1, 1, 1, 1}), TensorDataLayout::Nhwc, 1);
-
- auto &src_tensor = writer.declare_tensor_argument("src", src_info);
- auto &dst_tensor = writer.declare_tensor_argument("dst", dst_info);
-
- const auto sampler = create_simple_sampler(writer);
-
- auto &src = writer.declare_tile("src_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
- auto &other =
- writer.declare_tile("other_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
- auto &dst = writer.declare_tile("dst_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
-
- writer.op_load(src, src_tensor, sampler);
- writer.op_load(other, src_tensor, sampler);
- writer.op_load(dst, dst_tensor, sampler);
-
- auto test = dst ^ src ^ other;
- auto other_test = logical_and(dst, src, other);
- writer.op_assign(dst, logical_and(dst, src, other));
- writer.op_assign(dst, test);
- writer.op_assign(dst, other_test);
- writer.op_assign(dst, operator^(operator^(dst, src), other));
-
- writer.op_if(exp(src) == dst, [&] { writer.op_binary_expression(dst, src, BinaryOp::Add, src); })
- .op_else_if(exp(src) > dst, [&] { writer.op_binary_expression(dst, src, BinaryOp::Add, src); })
- .op_else([&] { writer.op_assign(dst, src); });
-
- writer.op_assign(dst, src + src * src);
- writer.op_assign(dst, src * max(src, dst) + src);
- writer.op_assign(dst, src * select(src, dst, src) + src);
-
- writer.op_assign(dst, src ^ dst);
- writer.op_assign(dst, ~src);
-
- writer.op_for_loop(dst < src, dst += src, [&] { writer.op_assign(dst, src + dst); });
-
- writer.op_assign(dst += src);
- writer.op_assign(dst += exp(src));
-
- std::cout << "======== KERNEL ========" << std::endl;
- std::cout << writer.generate_code() << std::endl;
-}
diff --git a/compute_kernel_writer/prototype/include/ckw/Error.h b/compute_kernel_writer/prototype/include/ckw/Error.h
deleted file mode 100644
index aab713c81..000000000
--- a/compute_kernel_writer/prototype/include/ckw/Error.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_ERROR_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_ERROR_H
-
-#include <stdexcept>
-#include <string>
-
-namespace ckw
-{
-
-/** If the condition is not met, throw an std::runtime_error with the specified message.
- *
- * @param[in] cond The condition that is expected to be true.
- * @param[in] msg The error message when the condition is not met.
- */
-#define CKW_ASSERT_MSG(cond, msg) \
- do \
- { \
- if (!(cond)) \
- { \
- throw ::std::runtime_error(msg); \
- } \
- } while (false)
-
-/** If the condition is not met, throw an std::runtime_error.
- *
- * @param[in] cond The condition that is expected to be true.
- */
-#define CKW_ASSERT(cond) CKW_ASSERT_MSG(cond, #cond)
-
-/** If the precondition is met but the consequence is not met, throw an std::runtime_error.
- *
- * @param[in] precond The condition if is met requires the consequence must also be met.
- * @param[in] cond The condition that is expected to be true if the precondition is true.
- */
-#define CKW_ASSERT_IF(precond, cond) CKW_ASSERT_MSG(!(precond) || ((precond) && (cond)), #precond " |-> " #cond)
-
-/** Mark the variables as unused.
- *
- * @param[in] ... Variables which are unused.
- */
-#define CKW_UNUSED(...) ::ckw::ignore_unused(__VA_ARGS__) // NOLINT
-
-/** Mark the variables as unused.
- *
- * @param[in] ... Variables which are unused.
- */
-template <typename... T>
-inline void ignore_unused(T &&...)
-{
-}
-
-} // namespace ckw
-
-#endif // CKW_INCLUDE_CKW_ERROR_H
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
deleted file mode 100644
index ba31a29ba..000000000
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
-
-#include "ckw/KernelArgument.h"
-#include "ckw/OperandBase.h"
-#include "ckw/types/GpuTargetLanguage.h"
-
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-namespace ckw
-{
-
-class TileOperand;
-
-namespace prototype
-{
-class GpuKernelWriterDataHolder;
-} // namespace prototype
-
-/** The target for kernel writer to write into. */
-class Kernel
-{
-public:
- /** Constructor
- *
- * @param[in] language The programming language to write the kernel.
- */
- Kernel(GpuTargetLanguage language);
- /** Constructor
- *
- * @param[in] name The name of the kernel function.
- * @param[in] language The programming language to write the kernel.
- */
- Kernel(const char *name, GpuTargetLanguage language);
-
- /** Destructor */
- ~Kernel();
-
- /** Get the name of the kernel function. */
- const std::string &name() const;
-
- /** Set the name of the kernel function.
- *
- * @param[in] name The name of the kernel function.
- */
- void name(const std::string &name);
-
- /** Get the list of kernel arguments. */
- ::std::vector<KernelArgument> arguments() const;
-
- /** (Internal use only) Register the tile operand.
- *
- * @param operand The tile operand to be registered.
- */
- TileOperand &register_operand(::std::unique_ptr<TileOperand> operand);
-
- /** (Internal use only) Register the tensor operand.
- *
- * @param operand The tensor operand to be registered.
- */
- TensorOperand &register_operand(::std::unique_ptr<TensorOperand> operand);
-
- /** (Internal use only) Get the implementation data. */
- prototype::GpuKernelWriterDataHolder *impl();
-
-private:
- ::std::string _name;
- ::std::unique_ptr<prototype::GpuKernelWriterDataHolder> _kernel;
- ::std::map<::std::string, ::std::unique_ptr<OperandBase>> _operands;
- ::std::map<int32_t, TensorOperand *> _tensor_id_operands;
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelArgument.h b/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
deleted file mode 100644
index 3384a20ae..000000000
--- a/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
-
-#include "ckw/TensorInfo.h"
-
-#include <cstdint>
-
-namespace ckw
-{
-
-class TensorOperand;
-class TensorComponentOperand;
-
-/** A kernel argument which can be either a tensor storage or a tensor component. */
-class KernelArgument
-{
-public:
- /** The type of kernel argument. */
- enum class Type : int32_t
- {
- /** The argument that provides the read and/or write access to the tensor data.
- *
- * See @ref ckw::TensorStorage to see the list of supported storage type.
- */
- TensorStorage,
-
- /** The argument that provides extra information about the tensor.
- *
- * See @ref ckw::TensorComponent to see the list of supported component.
- */
- TensorComponent,
- };
-
- /** Initialize a new instance of kernel argument class for a tensor storage argument.
- *
- * @param[in] tensor The tensor whose storage is exposed to kernel arguments.
- */
- KernelArgument(TensorOperand &tensor);
-
- /** Initialize a new instance of kernel argument class for a tensor component argument.
- *
- * @param[in] tensor_component The tensor component to be exposed to kernel arguments.
- */
- KernelArgument(TensorComponentOperand &tensor_component);
-
- /** Get the type of kernel argument. */
- Type type() const;
-
- /** Get the argument ID.
- *
- * This method can be used to get the tensor info ID of both tensor storage and tensor component arguments.
- */
- int32_t id() const;
-
- /** Get the type of tensor storage.
- *
- * This method can only be used for tensor storage argument.
- */
- TensorStorageType tensor_storage_type() const;
-
- /** Get the tensor component type.
- *
- * This method can only be used for tensor component argument.
- */
- TensorComponentType tensor_component_type() const;
-
-private:
- Type _type;
- int32_t _id;
-
- union SubId
- {
- int32_t unknown;
- TensorStorageType tensor_storage_type;
- TensorComponentType tensor_component_type;
- };
-
- SubId _sub_id{0};
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
deleted file mode 100644
index f9e0066f9..000000000
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ /dev/null
@@ -1,338 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
-
-#include "ckw/Kernel.h"
-#include "ckw/TensorInfo.h"
-#include "ckw/TensorOperand.h"
-#include "ckw/TileInfo.h"
-#include "ckw/TileOperand.h"
-#include "ckw/types/ConvertPolicy.h"
-#include "ckw/types/Functions.h"
-#include "ckw/types/Operators.h"
-
-#include <memory>
-
-namespace ckw
-{
-
-namespace prototype
-{
-struct GpuKernelWriterAttribute;
-
-class IGpuKernelWriter;
-} // namespace prototype
-
-/** Kernel writer. */
-class KernelWriter
-{
-public:
- // =============================================================================================
- // Constructors and destructor
- // =============================================================================================
-
- /** Initialize a new instance of kernel writer.
- *
- * @param[in] kernel The kernel to be written to.
- */
- explicit KernelWriter(Kernel &kernel);
-
- /** Destructor */
- ~KernelWriter();
-
- /** No copy constructor. */
- KernelWriter(const KernelWriter &) = delete;
-
- /** No copy assignment. */
- KernelWriter &operator=(const KernelWriter &) = delete;
-
- // =============================================================================================
- // Scope management
- // =============================================================================================
-
- /** Get the current ID space. */
- int32_t id_space() const;
-
- /** Set the current ID space. */
- KernelWriter &id_space(int32_t id_space);
-
- /** Switch to and return a new ID space. */
- int32_t next_id_space();
-
- // =============================================================================================
- // Tensor and tile declaration
- // =============================================================================================
-
- /** Declare a tensor argument.
- *
- * @param[in] name The name of the tensor.
- * @param[in] info The tensor info.
- * @param[in] storage_type The tensor storage type.
- *
- * @return The @ref TensorOperand object.
- */
- TensorOperand &declare_tensor_argument(const std::string &name,
- const TensorInfo &info,
- TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
-
- /** Declare a compile-time constant scalar argument.
- *
- * @param[in] name The name of the tile.
- * @param[in] value The value of the tile.
- *
- * @return The @ref TileOperand object.
- */
- TileOperand &declare_tile_argument(const std::string &name, int32_t value);
-
- /** Declare a new tile.
- *
- * The name of the tile must be unique in the current ID space.
- *
- * @param[in] name The name of the tile.
- * @param[in] ... The necessary arguments to create a new @ref TileOperand.
- *
- * @return The @ref TileOperand object.
- */
- template <typename... TArgs>
- TileOperand &declare_tile(const std::string &name, TArgs &&...args)
- {
- const auto var_name = generate_variable_name(name);
- auto operand = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
-
- return declare_tile_operand(std::move(operand));
- }
-
- // =============================================================================================
- // Load and store
- // =============================================================================================
-
- /** Load the data from the tensor memory to the tile using the sampling information.
- *
- * @param[out] tile The tile to be loaded.
- * @param[in] tensor The tensor to be read.
- * @param[in] sampler The tensor sampling information.
- * @param[in] dilation_y Dilation in the Y dimension.
- */
- void op_load(TileOperand &tile,
- const TensorOperand &tensor,
- const TensorTileSampler &sampler,
- const TileOperand &dilation_y = TileOperand("dil_y", 1));
-
- /** Load the data from the tensor memory to the tile using the indirect buffer approach and respective of the sampling information.
- *
- * @param[out] tile The tile to be loaded.
- * @param[in] tensor The tensor to be read.
- * @param[in] sampler The tensor sampling information.
- */
- void op_load_indirect(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler);
-
- /** Construct an indirection buffer in @p tile containing the precalculated addresses of elements in the source tensor.
- *
- * @param[out] tile The tile to be loaded.
- * @param[in] tensor The tensor the be read.
- * @param[in] sampler The tensor sampling information.
- * @param[in] x The X coordinate.
- * @param[in] y The Y coordinate.
- * @param[in] x_off Offset in the X dimension.
- * @param[in] y_off Offset in the Y dimension.
- */
- void util_get_indirect_buffer(TileOperand &tile,
- const TensorOperand &tensor,
- const TensorTileSampler &sampler,
- const TileOperand &x,
- const TileOperand &y,
- const TileOperand &x_off,
- const TileOperand &y_off);
-
- /** Store the tile to the tensor using the specified sampling information.
- *
- * @param[out] dst The tensor that the tile is written to.
- * @param[in] src The tile to be stored.
- * @param[in] sampler The tensor sampling information.
- */
- void op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler);
-
- // =============================================================================================
- // Data processing
- // =============================================================================================
-
- /** Write assignment: `<dst> = <src>;`.
- *
- * @param[out] dst The destination tile.
- * @param[in] src The source tile.
- */
- void op_assign(const TileOperand &dst, const TileOperand &src);
-
- /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
- *
- * @param[out] dst The destination tile.
- * @param[in] src The source tile.
- * @param[in] policy The policy governing the behavior of the cast.
- */
- void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
-
- /** Write the unary expression: `<dst> = <op> <src>`.
- *
- * @param[out] dst The destination tile.
- * @param[in] op The unary operator.
- * @param[in] src The source tile.
- */
- void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
-
- /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
- *
- * @param[out] dst The destination tile.
- * @param[in] lhs The LHS tile.
- * @param[in] op The binary operator.
- * @param[in] rhs The RHS tile.
- */
- void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
-
- /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
- *
- * @param[out] dst The destination tile.
- * @param[in] func The function to be applied to the source tile.
- * @param[in] src The source tile.
- */
- void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
-
- /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
- *
- * @param[out] dst The destination tile.
- * @param[in] func The function to be applied to the source tiles.
- * @param[in] first The first argument tile.
- * @param[in] second The second argument tile.
- */
- void op_binary_elementwise_function(const TileOperand &dst,
- BinaryFunction func,
- const TileOperand &first,
- const TileOperand &second);
-
- /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
- *
- * @param[out] dst The destination tile.
- * @param[in] func The function to be applied to the source tiles.
- * @param[in] first The first argument tile.
- * @param[in] second The second argument tile.
- * @param[in] third The third argument tile.
- */
- void op_ternary_elementwise_function(const TileOperand &dst,
- TernaryFunction func,
- const TileOperand &first,
- const TileOperand &second,
- const TileOperand &third);
-
- /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
- *
- * @param[in] lhs The LHS tile of the condition.
- * @param[in] op The relational binary operator.
- * @param[in] rhs The RHS tile of the condition.
- * @param[in] body The body of the if-statement.
- */
- void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
-
- /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
- *
- * @param[in] lhs The LHS tile of the condition.
- * @param[in] op The relational binary operator.
- * @param[in] rhs The RHS tile of the condition.
- * @param[in] body The body of the else-if-statement.
- */
- void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
-
- /** Write an else-statement: `else { <body> }`.
- *
- * @param[in] body The body of the else-statement.
- */
- void op_else(const std::function<void()> &body);
-
- /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <var> <update_op> <update_value>) { body }`.
- *
- * @param[in] var_name The name of the variable used in condition.
- * @param[in] cond_op The relational binary operator used in condition.
- * @param[in] cond_value_name The value which the variable is compared against.
- * @param[in] update_var_name The name of the variable which is updated.
- * @param[in] update_op The assignment operator used for updating the update value.
- * @param[in, out] update_value The value which is updated at every iteration.
- * @param[in] body The body of the for-loop.
- */
- void op_for_loop(const TileOperand &var_name,
- BinaryOp cond_op,
- const TileOperand &cond_value_name,
- const TileOperand &update_var_name,
- AssignmentOp update_op,
- const TileOperand &update_value_name,
- const std::function<void()> &body);
-
- /** Write the return statement: `return;`
- */
- void op_return();
-
- // =============================================================================================
- // Misc
- // =============================================================================================
-
- /** Set `dst` the global ID of dimension `dim`.
- *
- * @param[out] dst The tile to be written to.
- * @param[in] dim The global ID dimension.
- */
- void op_get_global_id(const TileOperand &dst, int32_t dim);
-
- // =============================================================================================
- // Code generation
- // =============================================================================================
-
- /** Generate the source code of the kernel. */
- ::std::string generate_code();
-
-private:
- /** Generate the full variable name based on the original name and the ID space.
- *
- * @param[in] name The name of the variable.
- *
- * @return The full variable name.
- */
- ::std::string generate_variable_name(const std::string &name) const;
-
- /** Declare the tile operand.
- *
- * @param[in] operand The tile operand to be declared.
- */
- TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
-
-private:
- Kernel *_kernel;
- ::std::unique_ptr<prototype::GpuKernelWriterAttribute> _impl_attr;
- ::std::unique_ptr<prototype::IGpuKernelWriter> _impl;
-
- int32_t _id_space{0};
- int32_t _max_id_space{0};
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h
deleted file mode 100644
index 3ba079bbc..000000000
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h
+++ /dev/null
@@ -1,1286 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
-#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
-
-#include "ckw/KernelWriter.h"
-#include "ckw/TensorOperand.h"
-#include "ckw/TileOperand.h"
-
-#include <iostream>
-#include <type_traits>
-
-/*
- * By including this header file you will be able to supplement the default
- * Compute Kernel Writer API with additional syntax to help ease the use of CKW.
- *
- * To use the KernelWriterHelper you need to wrap your instance of KernelWriter
- * (or any class deriving from KernelWriter):
- * KernelWriterHelper<KernelWriter> writer;
- * The resulting writer object comprises the original KernelWriter
- * functionality (drop-in replacement), but extends the syntax as follows.
- *
- * Common functions/operators have natural syntax:
- * 1. Unary expressions:
- * writer.op_assign(dst, !src); // Logical NOT
- * writer.op_assign(dst, ~src); // Bitwise NOT
- *
- * 2. Binary expressions:
- * writer.op_assign(dst, lhs + rhs); // Addition
- * writer.op_assign(dst, lhs - rhs); // Subtraction
- * writer.op_assign(dst, lhs * rhs); // Multiplication
- * writer.op_assign(dst, lhs / rhs); // Division
- * writer.op_assign(dst, lhs % rhs); // Modulo
- * writer.op_assign(dst, lhs == rhs); // Equality
- * writer.op_assign(dst, lhs < rhs); // Less-than
- * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal
- * writer.op_assign(dst, lhs > rhs); // Greater-than
- * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal
- * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR
- * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND
- * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR
- *
- * 3. Unary elementwise functions:
- * writer.op_assign(dst, exp(src)); // Exponent
- * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent
- * writer.op_assign(dst, sqrt(src)); // Square root
- * writer.op_assign(dst, erf(src)); // Error function
- * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number
- * writer.op_assign(dst, log(src)); // Natural logarithm
- * writer.op_assign(dst, round(src)); // Round
- * writer.op_assign(dst, sizeOf(src)); // sizeof
- *
- * 4. Binary elementwise functions:
- * writer.op_assign(dst, max(first, second)); // Max
- * writer.op_assign(dst, min(first, second)); // Min
- *
- * 5. Ternary elementwise functions:
- * writer.op_assign(dst, select(first, second, third)); // Select
- *
- * NOTE: All the above examples support nesting, so you could write
- * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg)));
- *
- *
- * 6. If-statements. The preceding syntax also allows easier writing of if-statements:
- * writer.op_if(<cond>, <body>);
- *
- * For example:
- * writer.op_if(exp(first_arg) == dst, [&]{
- * //...
- * }).op_else_if(exp(first_arg) > dst, [&]{
- * //...
- * }).op_else([&] {
- * //...
- * });
- *
- * 7. For-loops. A similar syntax exists for for-loops:
- * writer.op_for_loop(<cond>, <updater>, <body>);
- *
- * For example:
- * writer.op_for_loop(index < limit, index += step, [&]{
- * //...
- * });
- *
- * NOTE: There are limitations on the for-loop <cond> and <updater> parameters.
- * In neither the <cond> (Binary expression) or <updater> (Increment/Decrement)
- * is it allowed to use nesting. For example, `(index + other) < limit` and
- * `index < round(limit)` are invalid <cond> parameters. This is because the
- * semantics of for-loops rely on the condition being evaluated at every iteration,
- * but as temporary variables might be defined for nested expressions the semantics
- * cannot be guaranteed.
- */
-
-namespace ckw
-{
-
-// ==================================================
-// Type traits
-// ==================================================
-
-/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */
-template <typename T>
-struct can_be_operand : ::std::false_type
-{
-};
-
-/** Specifies if the type can be assigned/written to. */
-template <typename T>
-struct can_be_assigned : ::std::false_type
-{
-};
-
-template <>
-struct can_be_operand<TileOperand &> : ::std::true_type
-{
-};
-
-template <>
-struct can_be_assigned<TileOperand &> : ::std::true_type
-{
-};
-
-// ==================================================
-// Assignment
-// ==================================================
-
-/** AST node for assignments.
- *
- * Note that \p TRight must be an operand, and \p TLeft must be assignable.
- *
- * @tparam TLeft The type of the destination of the assignment.
- * @tparam TRight The type of the source assigned to the destination.
- */
-template <typename TLeft,
- typename TRight,
- typename = ::std::enable_if<can_be_operand<TRight>::value && can_be_assigned<TLeft>::value>>
-struct Assignment
-{
- TLeft lhs;
- TRight rhs;
- AssignmentOp opcode;
-};
-
-/** Represents the expression: `\p lhs += \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the assignment.
- * @tparam TRight The type of the RHS of the assignment.
- * @param[in] lhs The LHS of the assignment.
- * @param[in] rhs The RHS of the assignment.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline Assignment<TLeft, TRight> operator+=(TLeft &&lhs, TRight &&rhs)
-{
- return Assignment<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Increment};
-}
-
-/** Represents the expression: `\p lhs -= \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the assignment.
- * @tparam TRight The type of the RHS of the assignment.
- * @param[in] lhs The LHS of the assignment.
- * @param[in] rhs The RHS of the assignment.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline Assignment<TLeft, TRight> operator-=(TLeft &&lhs, TRight &&rhs)
-{
- return Assignment<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Decrement};
-}
-
-// ==================================================
-// Unary expression
-// ==================================================
-
-/** AST node for unary expressions.
- *
- * Note that \p TSrc must be an operand.
- *
- * @tparam TSrc The type of the argument to the expression.
- */
-template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
-struct UnaryExpression
-{
- TSrc src;
- UnaryOp opcode;
-};
-
-template <typename TLeft>
-struct can_be_operand<UnaryExpression<TLeft>> : ::std::true_type
-{
-};
-
-/** Represents the expression: `!\p src`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-inline UnaryExpression<TSrc> operator!(TSrc &&src)
-{
- return UnaryExpression<TSrc>{std::forward<TSrc>(src), UnaryOp::LogicalNot};
-}
-
-/** Represents the expression: `~\p src`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-inline UnaryExpression<TSrc> operator~(TSrc &&src)
-{
- return UnaryExpression<TSrc>{std::forward<TSrc>(src), UnaryOp::BitwiseNot};
-}
-
-// ==================================================
-// Binary expressions
-// ==================================================
-
-/** AST node for binary expressions.
- *
- * Note that both \p TLeft and \p TRight must be operands.
- *
- * @tparam TLeft The type of the left argument of the expression.
- * @tparam TRight The type of the right argument of the expression.
- */
-template <typename TLeft,
- typename TRight,
- typename = ::std::enable_if_t<can_be_operand<TLeft>::value && can_be_operand<TRight>::value>>
-struct BinaryExpression
-{
- TLeft lhs;
- TRight rhs;
- BinaryOp opcode;
-};
-
-template <typename TLeft, typename TRight>
-struct can_be_operand<BinaryExpression<TLeft, TRight>> : ::std::true_type
-{
-};
-
-/** Represents the expression: `\p lhs + \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator+(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Add};
-}
-
-/** Represents the expression: `\p lhs - \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator-(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Sub};
-}
-
-/** Represents the expression: `\p lhs * \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator*(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mul};
-}
-
-/** Represents the expression: `\p lhs / \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator/(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Div};
-}
-
-/** Represents the expression: `\p lhs % \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator%(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mod};
-}
-
-/** Represents the expression: `\p lhs == \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator==(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Equal};
-}
-
-/** Represents the expression: `\p lhs < \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator<(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Less};
-}
-
-/** Represents the expression: `\p lhs <= \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator<=(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LessEqual};
-}
-
-/** Represents the expression: `\p lhs > \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator>(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Greater};
-}
-
-/** Represents the expression: `\p lhs >= \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator>=(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::GreaterEqual};
-}
-
-/** Represents the expression: `\p lhs ^ \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> operator^(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::BitwiseXOR};
-}
-
-/** Represents the expression: `\p lhs && \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> logical_and(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd};
-}
-
-/** Represents the expression: `\p lhs && \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight, typename... TOps>
-inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
-{
- return logical_and(
- BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd},
- std::forward<TOps>(ops)...);
-}
-
-/** Represents the expression: `\p lhs || \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight>
-inline BinaryExpression<TLeft, TRight> logical_or(TLeft &&lhs, TRight &&rhs)
-{
- return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr};
-}
-
-/** Represents the expression: `\p lhs || \p rhs`.
- *
- * @tparam TLeft The type of the LHS of the expression.
- * @tparam TRight The type of the RHS of the expression.
- * @param[in] lhs The LHS of the expression.
- * @param[in] rhs The RHS of the expression.
- * @return The resulting AST node.
- */
-template <typename TLeft, typename TRight, typename... TOps>
-inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
-{
- return logical_or(
- BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr},
- std::forward<TOps>(ops)...);
-}
-
-// ==================================================
-// Unary elementwise functions
-// ==================================================
-
-/** AST node for unary elementwise functions.
- *
- * Note that \p TSrc must be an operand.
- *
- * @tparam TSrc The type of the argument to the function.
- */
-template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
-struct UnaryElementwiseFunction
-{
- TSrc src;
- UnaryFunction opcode;
-};
-
-template <typename TLeft>
-struct can_be_operand<UnaryElementwiseFunction<TLeft>> : ::std::true_type
-{
-};
-
-/** Represents the expression: `exp(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> exp(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Exp};
-}
-
-/** Represents the expression: `tanh(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> tanh(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Tanh};
-}
-
-/** Represents the expression: `sqrt(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> sqrt(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Sqrt};
-}
-
-/** Represents the expression: `erf(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> erf(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Erf};
-}
-
-/** Represents the expression: `fabs(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> fabs(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Fabs};
-}
-
-/** Represents the expression: `log(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> log(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Log};
-}
-
-/** Represents the expression: `round(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> round(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Round};
-}
-
-/** Represents the expression: `sizeof(\p src)`.
- *
- * @tparam TSrc The type of the argument.
- * @param[in] src The argument.
- * @return The resulting AST node.
- */
-template <typename TSrc>
-UnaryElementwiseFunction<TSrc> sizeOf(TSrc &&src)
-{
- return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::SizeOf};
-}
-
-// ==================================================
-// Binary elementwise functions
-// ==================================================
-
-/** AST node for binary elementwise functions.
- *
- * Note that both \p TFirst and \p TSecond must be operands.
- *
- * @tparam TFirst The type of the left argument of the function.
- * @tparam TSecond The type of the right argument of the function.
- */
-template <typename TFirst,
- typename TSecond,
- typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value>>
-struct BinaryElementwiseFunction
-{
- TFirst first;
- TSecond second;
- BinaryFunction opcode;
-};
-
-template <typename TFirst, typename TSecond>
-struct can_be_operand<BinaryElementwiseFunction<TFirst, TSecond>> : ::std::true_type
-{
-};
-
-/** Represents the function call: `max(\p first, \p second)`.
- *
- * @tparam TFirst The type of the first argument.
- * @tparam TSecond The type of the second argument.
- * @param[in] first The first argument.
- * @param[in] second The second argument.
- * @return The resulting AST node.
- */
-template <typename TFirst, typename TSecond>
-BinaryElementwiseFunction<TFirst, TSecond> max(TFirst &&first, TSecond &&second)
-{
- return BinaryElementwiseFunction<TFirst, TSecond>{std::forward<TFirst>(first), std::forward<TSecond>(second),
- BinaryFunction::Max};
-}
-
-/** Represents the function call: `min(\p first, \p second)`.
- *
- * @tparam TFirst The type of the first argument.
- * @tparam TSecond The type of the second argument.
- * @param[in] first The first argument.
- * @param[in] second The second argument.
- * @return The resulting AST node.
- */
-template <typename TFirst, typename TSecond>
-BinaryElementwiseFunction<TFirst, TSecond> min(TFirst &&first, TSecond &&second)
-{
- return BinaryElementwiseFunction<TFirst, TSecond>{std::forward<TFirst>(first), std::forward<TSecond>(second),
- BinaryFunction::Min};
-}
-
-// ==================================================
-// Ternary elementwise functions
-// ==================================================
-
-/** AST node for ternary elementwise functions.
- *
- * Note that \p TFirst, \p TSecond, and \p TThird all must be operands.
- *
- * @tparam TFirst The type of the first argument to the function.
- * @tparam TSecond The type of the second argument to the function.
- * @tparam TThird The type of the third argument to the function.
- */
-template <typename TFirst,
- typename TSecond,
- typename TThird,
- typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value &&
- can_be_operand<TThird>::value>>
-struct TernaryElementwiseFunction
-{
- TFirst first;
- TSecond second;
- TThird third;
- TernaryFunction opcode;
-};
-
-template <typename TFirst, typename TSecond, typename TThird>
-struct can_be_operand<TernaryElementwiseFunction<TFirst, TSecond, TThird>> : ::std::true_type
-{
-};
-
-/** Represents the function call: `select(\p first, \p second, \p third)`.
- *
- * @tparam TFirst The type of the first argument.
- * @tparam TSecond The type of the second argument.
- * @tparam TThird The type of the third argument.
- * @param[in] first The first argument.
- * @param[in] second The second argument.
- * @param[in] third The third argument.
- * @return The resulting AST node.
- */
-template <typename TFirst, typename TSecond, typename TThird>
-TernaryElementwiseFunction<TFirst, TSecond, TThird> select(TFirst &&first, TSecond &&second, TThird &&third)
-{
- return TernaryElementwiseFunction<TFirst, TSecond, TThird>{std::forward<TFirst>(first),
- std::forward<TSecond>(second),
- std::forward<TThird>(third), TernaryFunction::Select};
-}
-
-/** Helper class used to extend a KernelWriter with additional functionality
- * in order to make writing easier.
- *
- * This extension automatically handles creation of temporary variables, and
- * allows nested function calls and operations.
- *
- * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter.
- */
-template <class TWriter, typename = std::enable_if<std::is_base_of<KernelWriter, TWriter>::value>>
-class KernelWriterHelper : public TWriter
-{
-public:
- using TWriter::TWriter;
-
- // ==================================================
- // If-statements
- // ==================================================
-
- // Un-hide original implementation, in case the original implementation is required.
- using TWriter::op_if;
-
- /** Represents the if-statement: `if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the if-statement.
- */
- KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TileOperand &> &cond,
- const std::function<void()> &body)
- {
- TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body);
- return *this;
- }
-
- /** Represents the if-statement: `if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the if-statement.
- */
- template <typename TRight>
- KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TRight> &cond,
- const std::function<void()> &body)
- {
- auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
- op_assign(tmp1, cond.rhs);
- TWriter::op_if(cond.lhs, cond.opcode, tmp1, body);
- return *this;
- }
-
- /** Represents the if-statement: `if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the if-statement.
- */
- template <typename TLeft>
- KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TLeft, TileOperand &> &cond,
- const std::function<void()> &body)
- {
- auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
- op_assign(tmp1, cond.lhs);
- TWriter::op_if(tmp1, cond.opcode, cond.rhs, body);
- return *this;
- }
-
- // Un-hide original implementation, in case the original implementation is required.
- using TWriter::op_else_if;
-
- /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the else-if-statement.
- */
- KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TileOperand &> &cond,
- const std::function<void()> &body)
- {
- TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body);
- return *this;
- }
-
- /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the else-if-statement.
- */
- template <typename TRight>
- KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TRight> &cond,
- const std::function<void()> &body)
- {
- auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
- op_assign(tmp1, cond.rhs);
- TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body);
- return *this;
- }
-
- /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] body The body of the else-if-statement.
- */
- template <typename TLeft>
- KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TLeft, TileOperand &> &cond,
- const std::function<void()> &body)
- {
- auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
- op_assign(tmp1, cond.lhs);
- TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body);
- return *this;
- }
-
- // ==================================================
- // For-loops
- // ==================================================
-
- // Un-hide original implementation, in case the original implementation is required.
- using TWriter::op_for_loop;
-
- /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`.
- *
- * The BinaryExpression for the condition and the Assignment
- * for the updater are unpacked and their components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] cond The BinaryExpression representing the condition.
- * @param[in] updater The Assignment representing the updater.
- * @param[in] body The body of the for-loop.
- */
- void op_for_loop(const BinaryExpression<TileOperand &, TileOperand &> &cond,
- const Assignment<TileOperand &, TileOperand &> &updater,
- const std::function<void()> &body)
- {
- TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body);
- }
-
- // ==================================================
- // Unary expressions
- // ==================================================
-
- // Un-hide original implementation, in case the original implementation is required.
- using TWriter::op_assign;
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The UnaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned.
- */
- void op_assign(const TileOperand &dst, const UnaryExpression<TileOperand &> &exp)
- {
- TWriter::op_unary_expression(dst, exp.opcode, exp.src);
- }
-
- // ==================================================
- // Binary expressions
- // ==================================================
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
- */
- void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TileOperand &> &exp)
- {
- TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
- */
- template <typename TRight>
- void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TRight> &exp)
- {
- std::cout << "Beginning assignment!" << std::endl;
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.rhs);
- TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
- */
- template <typename TLeft>
- void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TileOperand &> &exp)
- {
- std::cout << "Beginning assignment!" << std::endl;
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.lhs);
- TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryExpression is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
- */
- template <typename TLeft, typename TRight>
- void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TRight> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.lhs);
- op_assign(tmp2, exp.rhs);
- TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2);
- }
-
- // ==================================================
- // Unary elementwise functions
- // ==================================================
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The UnaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TileOperand &> &exp)
- {
- TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The UnaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TArg>
- void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TArg> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.lhs);
- TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1);
- }
-
- // ==================================================
- // Binary elementwise functions
- // ==================================================
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TileOperand &> &exp)
- {
- TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TRight>
- void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TRight> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.second);
- TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TLeft>
- void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TileOperand &> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The BinaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TLeft, typename TRight>
- void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TRight> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- op_assign(tmp2, exp.second);
- TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2);
- }
-
- // ==================================================
- // Ternary elementwise functions
- // ==================================================
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- void op_assign(const TileOperand &dst,
- const TernaryElementwiseFunction<TileOperand &, TileOperand &, TileOperand &> &exp)
- {
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TFirst>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TileOperand &> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TSecond>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TileOperand &> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.second);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TThird>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TThird> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.third);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TFirst, typename TSecond>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TileOperand &> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- op_assign(tmp2, exp.second);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TFirst, typename TThird>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TThird> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- op_assign(tmp2, exp.third);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TSecond, typename TThird>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TThird> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.second);
- op_assign(tmp2, exp.third);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2);
- }
-
- /** Represents the assignment: `\p dst = \p exp`.
- *
- * The TernaryElementwiseFunction is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] dst The tile which is assigned to.
- * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
- */
- template <typename TFirst, typename TSecond, typename TThird>
- void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TThird> &exp)
- {
- auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info());
- auto &tmp2 = declare_temp_tile(dst.tile_info());
- auto &tmp3 = declare_temp_tile(dst.tile_info());
- op_assign(tmp1, exp.first);
- op_assign(tmp2, exp.second);
- op_assign(tmp3, exp.third);
- TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3);
- }
-
- // ==================================================
- // Assignments
- // ==================================================
-
- /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
- *
- * The Assignment is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @param[in] exp The Assignment representing the expression to be evaluated.
- */
- void op_assign(const Assignment<TileOperand &, TileOperand &> &exp)
- {
- if (exp.opcode == AssignmentOp::Increment)
- {
- TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs);
- }
- else if (exp.opcode == AssignmentOp::Decrement)
- {
- TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs);
- }
- }
-
- /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
- *
- * The Assignment is unpacked and its components are forwarded to
- * the underlying KernelWriter's implementation.
- *
- * @tparam TRight The type of the RHS of the assignment.
- * @param[in] exp The Assignment representing the expression to be evaluated.
- */
- template <typename TRight>
- void op_assign(const Assignment<TileOperand &, TRight> &exp)
- {
- auto &tmp1 = declare_temp_tile(exp.lhs.tile_info());
- op_assign(tmp1, exp.rhs);
- op_assign(Assignment<TileOperand &, TileOperand &>{exp.lhs, tmp1, exp.opcode});
- }
-
-private:
- unsigned int temp_var_counter = 0;
-
- /** Return the current counter value, then increment it.
- *
- * @return The current counter value.
- */
- int next_ctr()
- {
- return temp_var_counter++;
- }
-
- /** Gets the next temporary variable counter value,
- * and returns a suitable temporary variable name.
- *
- * @return A temporary variable name.
- */
- std::string next_tmp_var_name()
- {
- return "tmp_" + std::to_string(next_ctr());
- }
-
- /** Returns the argument.
- *
- * Used for recursion with the variadic function version of this function.
- *
- * @param[in] arg The TileInfo to return.
- * @return The \p arg.
- */
- TileInfo get_largest_size(const TileInfo &arg)
- {
- return arg;
- }
-
- /** Returns a TileInfo object where the size in each dimension (width, height) is the largest
- * of either TileInfo argument in the corresponding dimension.
- *
- * @tparam TOps Must be of TileInfo type.
- * @param[in] first A TileInfo object.
- * @param[in] second A TileInfo object.
- * @param[in] ops A number of TileInfo objects.
- * @return A TileInfo object which represents the largest shape in each dimension across the arguments.
- */
- template <typename... TOps, typename = ::std::enable_if_t<std::is_same<TOps..., TileInfo>::value>>
- TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops)
- {
- TileInfo largest = {first.data_type(), std::max(first.width(), second.width()),
- std::max(first.height(), second.height())};
- return get_largest_size(largest, ops...);
- }
-
- /** Helper function to define a suitable TileOperand with appropriate TileInfo
- * such that broadcasting is taken into account, based on the arguments provided.
- *
- * @tparam TArgs Must be of TileInfo type.
- * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare.
- * @return A newly created TileOperand.
- */
- template <typename... TArgs, typename = ::std::enable_if_t<std::is_same<TArgs..., TileInfo>::value>>
- TileOperand &declare_temp_tile(const TArgs &...args)
- {
- return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...));
- }
-};
-
-} // namespace ckw
-
-#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
diff --git a/compute_kernel_writer/prototype/include/ckw/OperandBase.h b/compute_kernel_writer/prototype/include/ckw/OperandBase.h
deleted file mode 100644
index 984212733..000000000
--- a/compute_kernel_writer/prototype/include/ckw/OperandBase.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
-
-#include "ckw/types/DataType.h"
-
-#include <string>
-
-namespace ckw
-{
-namespace prototype
-{
-class IGpuKernelWriter;
-
-class Operand;
-} // namespace prototype
-
-/** The base class for all operands. */
-class OperandBase
-{
-public:
- /** Constructor
- *
- * @param[in] name The name of the operand.
- */
- explicit OperandBase(const ::std::string &name);
-
- /** Destructor */
- virtual ~OperandBase();
-
- /** (Internal use only) Create the implementation operand.
- *
- * @param[in] writer The implementation kernel writer.
- */
- virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const = 0;
-
- /** Get the name of the operand. */
- const ::std::string &name() const;
-
- /** Set the name of the operand. */
- OperandBase &name(const ::std::string &name);
-
- /** Get the data type of the operand. */
- virtual DataType data_type() const = 0;
-
- /** Get whether the operand is compile-time constant. */
- virtual bool is_constant() const = 0;
-
-private:
- ::std::string _name;
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/ScalarValue.h b/compute_kernel_writer/prototype/include/ckw/ScalarValue.h
deleted file mode 100644
index 2a9c42acc..000000000
--- a/compute_kernel_writer/prototype/include/ckw/ScalarValue.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_SCALARVALUE_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_SCALARVALUE_H
-
-#include "ckw/Error.h"
-
-#include <cstdint>
-
-namespace ckw
-{
-
-/** The scalar value known at compile-time. */
-class ScalarValue
-{
-public:
- /** Initialize a new instance of @ref ScalarValue class with integer value 0. */
- ScalarValue()
- {
- _type = Type::INT;
- _value.i64 = 0;
- }
-
- /** Initialize a new instance of @ref ScalarValue class with the specified value. */
- template <typename T>
- ScalarValue(T value)
- {
- set(value);
- }
-
- /** Set the value. */
- template <typename T>
- void set(T value)
- {
- CKW_ASSERT(::std::is_integral<T>::value || ::std::is_floating_point<T>::value);
- CKW_ASSERT(sizeof(T) <= 8);
-
- _size = sizeof(T);
-
- if (::std::is_integral<T>::value)
- {
- if (::std::is_signed<T>::value)
- {
- _type = Type::INT;
- _value.i64 = value;
- }
- else
- {
- _type = Type::UINT;
- _value.u64 = value;
- }
- }
- else
- {
- _type = Type::FLOAT;
- _value.f64 = value;
- }
- }
-
- /** Get the value.
- *
- * The caller must make sure that what has been stored in the object must fit
- * the output data type without data corruption or loss of accuracy.
- */
- template <typename T>
- T get() const
- {
- CKW_ASSERT(::std::is_integral<T>::value || ::std::is_floating_point<T>::value);
- CKW_ASSERT(sizeof(T) >= _size);
-
- if (::std::is_integral<T>::value)
- {
- if (::std::is_signed<T>::value)
- {
- CKW_ASSERT(_type == Type::INT || _type == Type::UINT);
- CKW_ASSERT_IF(_type == Type::UINT, sizeof(T) > _size);
-
- return _value.i64;
- }
- else
- {
- CKW_ASSERT(_type == Type::INT);
-
- return _value.u64;
- }
- }
- else
- {
- return _value.f64;
- }
- }
-
-private:
- union Value
- {
- int64_t i64;
- uint64_t u64;
- double f64;
- };
-
- enum class Type : int32_t
- {
- UINT,
- INT,
- FLOAT,
- };
-
- Value _value{};
- Type _type{};
- uint32_t _size{};
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_SCALARVALUE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
deleted file mode 100644
index 24da7dc8a..000000000
--- a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
-
-#include "ckw/types/DataType.h"
-
-#include <array>
-#include <cstdint>
-
-namespace ckw
-{
-/** Compute Kernel Writer tensor data layout (or memory format) */
-enum class TensorDataLayout
-{
- Unknown,
- Nhwc,
- Ndhwc
-};
-
-/** Compute Kernel Writer tensor data layout component */
-enum class TensorDataLayoutComponent
-{
- Unknown,
- N,
- D,
- H,
- W,
- C,
-};
-
-/** Compute Kernel Writer tensor component bitmask. The bitmask can be used to retrieve
- * the info from @ref TensorComponent.
- */
-enum class TensorComponentBitmask : uint32_t
-{
- OffsetFirstElement = 0x01000000, // For example, OffsetFirstElement in @ref TensorComponent
- Stride = 0x02000000, // For example, stride0 in @ref TensorComponent
- Dimension = 0x04000000, // For example, Dim0 in @ref TensorComponent
- FoldedDimensions = 0x08000000, // For example, Dim0xDim1 in @ref TensorComponent
-};
-
-/** Compute Kernel Writer tensor component. The tensor components are used to access specific backend-agnostic tensor arguments,
- * such as the tensor dimensions and tensor strides.
- * The data type is represented as an integer. The value of the integer value
- * is assigned to retrieve the information through the @ref TensorComponentBitmask.
- */
-enum class TensorComponentType : uint32_t
-{
- Unknown = 0x00000000,
- OffsetFirstElement = 0x01000000,
- Stride0 = 0x02000001,
- Stride1 = 0x02000010,
- Stride2 = 0x02000100,
- Stride3 = 0x02001000,
- Stride4 = 0x02010000,
- Dim0 = 0x04000001,
- Dim1 = 0x04000010,
- Dim2 = 0x04000100,
- Dim3 = 0x04001000,
- Dim4 = 0x04010000,
- Dim1xDim2 = 0x08000110,
- Dim2xDim3 = 0x08001100,
- Dim1xDim2xDim3 = 0x08001110
-};
-
-/** Compute Kernel Writer tensor storage. The tensor storage represents the type of tensor memory object.
- */
-enum class TensorStorageType : uint32_t
-{
- Unknown = 0x00000000,
- BufferUint8Ptr = 0x01000000,
- Texture2dReadOnly = 0x02000001,
- Texture2dWriteOnly = 0x02000010,
-};
-
-/** Compute Kernel Writer tensor shape
- * Negative dimensions can be interpreted as dynamic dimensions by the Compute Kernel Writer
- */
-using TensorShape = std::array<int32_t, 5>;
-
-/** Compute Kernel Writer tensor info */
-class TensorInfo
-{
-public:
- /** Constructor
- *
- * @param[in] dt Tensor data type
- * @param[in] shape Tensor shape
- * @param[in] dl Tensor data layout
- * @param[in] id Tensor id. The id is used to keep track of the bound user tensor. Through the id,
- * the user can know what tensor has been used by the Compute Kernel Writer.
- * Possible id values:
- * - greater than or equal to 0: bind a user specific tensors
- * - less than 0: bind a virtual tensor (tile)
- */
- TensorInfo(DataType dt, const TensorShape &shape, TensorDataLayout dl, int32_t id);
-
- /** Set shape */
- TensorInfo &shape(const TensorShape &shape);
-
- /** Get shape */
- TensorShape shape() const;
-
- /** Set data type */
- TensorInfo &data_type(DataType dt);
-
- /** Get data type */
- DataType data_type() const;
-
- /** Set data layout */
- TensorInfo &data_layout(TensorDataLayout dl);
-
- /** Get data layout */
- TensorDataLayout data_layout() const;
-
- /** Set id */
- TensorInfo &id(int32_t id);
-
- /** Get layout */
- int32_t id() const;
-
-private:
- TensorShape _shape{{0}};
- DataType _dt{DataType::Unknown};
- TensorDataLayout _dl{TensorDataLayout::Unknown};
- int32_t _id{-1};
-};
-} // namespace ckw
-
-#endif /* CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H */
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
deleted file mode 100644
index c221b449f..000000000
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ /dev/null
@@ -1,196 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSOROPERAND_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TENSOROPERAND_H
-
-#include "ckw/OperandBase.h"
-#include "ckw/TensorInfo.h"
-#include "ckw/TensorTileSampler.h"
-#include "ckw/TileOperand.h"
-#include "ckw/types/DataType.h"
-
-#include <memory>
-
-namespace ckw
-{
-
-class TensorComponentOperand;
-
-// =================================================================================================
-// TensorOperand
-// =================================================================================================
-
-/** Tensor operand */
-class TensorOperand : public OperandBase
-{
-public:
- /** Initialize a new instance of @ref TensorOperand class.
- *
- * @param[in] name The name of the tensor.
- * @param[in] info The tensor info.
- * @param[in] storage_type The tensor storage type.
- */
- TensorOperand(const ::std::string &name, const TensorInfo &info, TensorStorageType storage_type);
-
- /** No copy constructor. */
- TensorOperand(const TensorOperand &other) = delete;
-
- /** No copy assignment. */
- TensorOperand &operator=(const TensorOperand &other) = delete;
-
- /** (Internal use only) Create the implementation operand.
- *
- * @param[in] writer The implementation kernel writer.
- */
- virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
-
- /** Get the tensor info. */
- const TensorInfo &info() const;
-
- /** Get the tensor info. */
- TensorInfo &info();
-
- /** Get the tensor storage type. */
- TensorStorageType storage_type() const;
-
- /** Get the data type. */
- virtual DataType data_type() const override;
-
- /** Get whether the tensor is compile-time constant. */
- virtual bool is_constant() const override;
-
- /** Get the default tile attached to the tensor. */
- const TileOperand &tile() const;
-
- /** Get the default tile attached to the tensor. */
- TileOperand &tile();
-
- /** Set the default tile attached to the tensor. */
- TensorOperand &tile(TileOperand &tile);
-
- /** Get the tensor sampler of the default tile. */
- const TensorTileSampler &tile_sampler() const;
-
- /** Get the tensor sampler of the default tile. */
- TensorTileSampler &tile_sampler();
-
- /** Set the tensor sampler of the default tile. */
- TensorOperand &tile_sampler(const TensorTileSampler &value);
-
- /** Get the operand that contains the stride in y dimension of the tensor. */
- TensorComponentOperand &stride1();
-
- /** Get the operand that contains the stride in z dimension of the tensor. */
- TensorComponentOperand &stride2();
-
- /** Get the operand that contains the stride in w dimension of the tensor. */
- TensorComponentOperand &stride3();
-
- /** Get the operand that contains the stride in w dimension of the tensor. */
- TensorComponentOperand &stride4();
-
- /** Get the operand that contains the size of dimension 0 of the tensor. */
- TensorComponentOperand &dim0();
-
- /** Get the operand that contains the size of dimension 1 of the tensor. */
- TensorComponentOperand &dim1();
-
- /** Get the operand that contains the size of dimension 2 of the tensor. */
- TensorComponentOperand &dim2();
-
- /** Get the operand that contains the size of dimension 3 of the tensor. */
- TensorComponentOperand &dim3();
-
- /** Get the operand that contains the size of dimension 4 of the tensor. */
- TensorComponentOperand &dim4();
-
- /** Get the operand that contains the size of dimensions 1 and 2 collapsed. */
- TensorComponentOperand &dim1_dim2();
-
- /** Get the operand that contains the size of dimensions 1, 2 and 3 collapsed. */
- TensorComponentOperand &dim1_dim2_dim3();
-
- /** Get the operand that contains the offset in bytes to the first element. */
- TensorComponentOperand &offset_first_element_in_bytes();
-
-private:
- TensorInfo _info;
- TensorStorageType _storage_type;
-
- TileOperand *_tile{nullptr};
- TensorTileSampler _tile_sampler{};
-
- ::std::unique_ptr<TensorComponentOperand> _stride1{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _stride2{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _stride3{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _stride4{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim0{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim1{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim2{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim3{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim4{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim1_dim2{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _dim1_dim2_dim3{nullptr};
- ::std::unique_ptr<TensorComponentOperand> _offset_first_element_in_bytes{nullptr};
-};
-
-// =================================================================================================
-// TensorComponentOperand
-// =================================================================================================
-
-/** Tile operand that contains tensor information. */
-class TensorComponentOperand : public TileOperand
-{
-public:
- /** Initialize a new instance of @ref TensorComponentOperand class.
- *
- * @param[in] tensor The tensor operand.
- * @param[in] component The tensor info component.
- */
- TensorComponentOperand(TensorOperand &tensor, TensorComponentType component);
-
- /** Get the tensor operand. */
- TensorOperand &tensor();
-
- /** Get the tensor operand. */
- const TensorOperand &tensor() const;
-
- /** Get the tensor component. */
- TensorComponentType component_type() const;
-
- /** (Internal use only) Create the implementation operand.
- *
- * @param[in] writer The implementation kernel writer.
- */
- virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
-
-private:
- TensorOperand &_tensor;
- TensorComponentType _component;
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_TENSOROPERAND_H
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h b/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
deleted file mode 100644
index 606dec353..000000000
--- a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
-
-#include "ckw/types/TensorSamplerTypes.h"
-
-#include <functional>
-
-namespace ckw
-{
-
-class TileOperand;
-
-/** Tensor sampler
- *
- * It contains information about how the result tile should be stored to tensor memory.
- * It can also be used to dictate how the subsequent operators fetch the input tensor.
- */
-class TensorTileSampler
-{
-public:
- /** Initialize a new instance of @ref TensorSampler class. */
- TensorTileSampler();
-
- /** Initialize a new instance of @ref TensorSampler class.
- *
- * @param[in] x The coordinate in the x dimension.
- * @param[in] y The coordinate in the y dimension.
- * @param[in] z The coordinate in the z dimension.
- * @param[in] b The coordinate in the batch dimension.
- * @param[in] format The tensor data format.
- * @param[in] address_mode_x The address mode of the x dimension.
- * @param[in] address_mode_y The address mode of the y dimension.
- * @param[in] address_mode_z The address mode of the z dimension.
- */
- TensorTileSampler(TileOperand &x,
- TileOperand &y,
- TileOperand &z,
- TileOperand &b,
- TensorSamplerFormat format,
- TensorSamplerAddressModeX address_mode_x,
- TensorSamplerAddressModeY address_mode_y,
- TensorSamplerAddressModeZ address_mode_z);
-
- /** Initialize a new instance of @ref TensorSampler class.
- *
- * @param[in] x The coordinate in the x dimension.
- * @param[in] y The coordinate in the y dimension.
- * @param[in] z The coordinate in the z dimension.
- * @param[in] b The coordinate in the batch dimension.
- * @param[in] height The height of the tile.
- * @param[in] width The width of the tile.
- * @param[in] format The tensor data format.
- * @param[in] address_mode_x The address mode of the x dimension.
- * @param[in] address_mode_y The address mode of the y dimension.
- * @param[in] address_mode_z The address mode of the z dimension.
- */
- TensorTileSampler(TileOperand &x,
- TileOperand &y,
- TileOperand &z,
- TileOperand &b,
- int32_t height,
- int32_t width,
- TensorSamplerFormat format,
- TensorSamplerAddressModeX address_mode_x,
- TensorSamplerAddressModeY address_mode_y,
- TensorSamplerAddressModeZ address_mode_z);
-
- /** Get the coordinate in the x dimension. */
- const TileOperand &x() const;
-
- /** Set the coordinate in the x dimension. */
- TensorTileSampler &x(TileOperand &x);
-
- /** Get the coordinate in the y dimension. */
- const TileOperand &y() const;
-
- /** Set the coordinate in the y dimension. */
- TensorTileSampler &y(TileOperand &y);
-
- /** Get the coordinate in the z dimension. */
- const TileOperand &z() const;
-
- /** Set the coordinate in the z dimension. */
- TensorTileSampler &z(TileOperand &z);
-
- /** Get the coordinate in the batch dimension. */
- const TileOperand &b() const;
-
- /** Set the coordinate in the batch dimension. */
- TensorTileSampler &b(TileOperand &b);
-
- /** Get the width of the tile. */
- int32_t width() const;
-
- /** Set the width of the tile. */
- TensorTileSampler &width(int32_t width);
-
- /** Get the height of the tile. */
- int32_t height() const;
-
- /** Set the height of the tile. */
- TensorTileSampler &height(int32_t height);
-
- /** Get the format of the tensor. */
- TensorSamplerFormat format() const;
-
- /** Set the format of the tensor. */
- TensorTileSampler &format(TensorSamplerFormat format);
-
- /** Get the address mode of the x dimension. */
- TensorSamplerAddressModeX address_mode_x() const;
-
- /** Set the address mode of the x-dimension. */
- TensorTileSampler &address_mode_x(TensorSamplerAddressModeX address_mode_x);
-
- /** Get the address mode of the y dimension. */
- TensorSamplerAddressModeY address_mode_y() const;
-
- /** Set the address mode of the y dimension. */
- TensorTileSampler &address_mode_y(TensorSamplerAddressModeY address_mode_y);
-
- /** Get the address mode of the z dimension. */
- TensorSamplerAddressModeZ address_mode_z() const;
-
- /** Set the address mode of the z dimension. */
- TensorTileSampler &address_mode_z(TensorSamplerAddressModeZ address_mode_z);
-
-private:
- TileOperand *_x{nullptr};
- TileOperand *_y{nullptr};
- TileOperand *_z{nullptr};
- TileOperand *_b{nullptr};
-
- int32_t _height{0};
- int32_t _width{0};
-
- TensorSamplerFormat _format{TensorSamplerFormat::Unknown};
- TensorSamplerAddressModeX _address_mode_x{TensorSamplerAddressModeX::Unknown};
- TensorSamplerAddressModeY _address_mode_y{TensorSamplerAddressModeY::Unknown};
- TensorSamplerAddressModeZ _address_mode_z{TensorSamplerAddressModeZ::Unknown};
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
diff --git a/compute_kernel_writer/prototype/include/ckw/TileInfo.h b/compute_kernel_writer/prototype/include/ckw/TileInfo.h
deleted file mode 100644
index e0d064169..000000000
--- a/compute_kernel_writer/prototype/include/ckw/TileInfo.h
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
-
-#include "ckw/types/DataType.h"
-
-#include <array>
-#include <cstdint>
-
-namespace ckw
-{
-// Constants to access the tile width and height in the TileShape
-constexpr int32_t kTileWidthIdx = 0;
-constexpr int32_t kTileHeightIdx = 1;
-
-/** Compute Kernel Writer tile shape. It is used to define the shape of the tile */
-using TileShape = std::array<int32_t, 2>;
-
-/** Compute Kernel Writer tile info */
-class TileInfo
-{
-public:
- /** Constructor used to initialize a scalar variable with a given data type
- *
- * @param[in] dt Tile data type
- */
- TileInfo(DataType dt);
-
- /** Constructor used to initialize a vector with a given data type and vector length.
- *
- * @param[in] dt Tile data type
- * @param[in] w Tile width (or vector length)
- */
- TileInfo(DataType dt, int32_t w);
-
- /** Constructor used to initialize a tile with a given data type and tile sizes.
- *
- * @param[in] dt Tile data type
- * @param[in] h Tile height
- * @param[in] w Tile width
- */
- TileInfo(DataType dt, int32_t h, int32_t w);
-
- /** Set width */
- TileInfo &width(int32_t w);
-
- /** Get width */
- int32_t width() const;
-
- /** Set height */
- TileInfo &height(int32_t h);
-
- /** Get height */
- int32_t height() const;
-
- /** Set data type */
- TileInfo &data_type(DataType dt);
-
- /** Get data type */
- DataType data_type() const;
-
-private:
- DataType _dt{DataType::Unknown};
- TileShape _shape{};
-};
-
-} // namespace ckw
-
-#endif /* COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TILEINFO_H */
diff --git a/compute_kernel_writer/prototype/include/ckw/TileOperand.h b/compute_kernel_writer/prototype/include/ckw/TileOperand.h
deleted file mode 100644
index 24ee373a2..000000000
--- a/compute_kernel_writer/prototype/include/ckw/TileOperand.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TILEOPERAND_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TILEOPERAND_H
-
-#include "ckw/Error.h"
-#include "ckw/OperandBase.h"
-#include "ckw/ScalarValue.h"
-#include "ckw/TileInfo.h"
-
-#include <vector>
-
-namespace ckw
-{
-
-class Kernel;
-
-using TileContainer = std::vector<std::vector<std::string>>;
-
-/** Tile operand which can be either scalar, vector or 2D tile. */
-class TileOperand : public OperandBase
-{
-public:
- /** Initialize a new instance of @ref TileOperand class with the tile information.
- *
- * @param[in] name The name of the tile.
- * @param[in] tile_info The tile info.
- */
- TileOperand(const ::std::string &name, const TileInfo &tile_info);
-
- /** Initialize a new instance of @ref TileOperand for scalar variable.
- *
- * @param[in] name The name of the tile.
- * @param[in] data_type The data type of the tile.
- */
- TileOperand(const ::std::string &name, DataType data_type);
-
- /** Initialize a new instance of @ref TileOperand for compile-time constant scalar variable.
- *
- * @param[in] name The name of the tile.
- * @param[in] value The value of the tile.
- */
- TileOperand(const ::std::string &name, int32_t value);
-
- /** Initialize a new instance of @ref TileOperand for compile-time constant scalar variable.
- *
- * @param[in] name The name of the tile.
- * @param[in] value The value of the tile.
- */
- TileOperand(const ::std::string &name, float value);
-
- /** Initialize a new instance of @ref TileOperand for compile-time constant variable.
- *
- * @param[in] name The name of the tile.
- * @param[in] value The value of the tile.
- */
- TileOperand(const ::std::string &name, const ::std::vector<std::vector<std::string>> &value, DataType dt);
-
- /** Prohibit copy of tile operand. */
- TileOperand(const TileOperand &) = delete;
-
- /** Prohibit copy of tile operand. */
- TileOperand &operator=(const TileOperand &) = delete;
-
- /** (Internal use only) Create the implementation operand.
- *
- * @param[in] writer The implementation kernel writer.
- */
- virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
-
- /** Get the tile info. */
- const TileInfo &tile_info() const;
-
- /** Get the data type of the tile. */
- virtual DataType data_type() const override;
-
- /** Get whether the tile is compile-time constant. */
- virtual bool is_constant() const override;
-
- /** Get whether the tile is a scalar value. */
- bool is_scalar() const;
-
- /** Get the scalar value of the tile.
- *
- * The tile must have the shape of 1, 1 (i.e. scalar).
- *
- * @return Scalar value as a string.
- */
- std::string scalar_value() const;
-
- /** Get the values of the tile.
- *
- * @return 2D container of values.
- */
- const TileContainer &value() const;
-
-private:
- TileInfo _info;
- TileContainer _value{};
- bool _constant;
-};
-
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_TILEOPERAND_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h b/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
deleted file mode 100644
index 2a198507e..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_INCLUDE_CKW_CONVERTPOLICY_H
-#define CKW_INCLUDE_CKW_CONVERTPOLICY_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-enum class ConvertPolicy : int32_t
-{
- None = 0, // No policy specified.
- Saturate = 1, // Saturated.
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_CONVERTPOLICY_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/DataType.h b/compute_kernel_writer/prototype/include/ckw/types/DataType.h
deleted file mode 100644
index 3447dd61d..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/DataType.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
-* Copyright (c) 2023 Arm Limited.
-*
-* SPDX-License-Identifier: MIT
-*
-* Permission is hereby granted, free of charge, to any person obtaining a copy
-* of this software and associated documentation files (the "Software"), to
-* deal in the Software without restriction, including without limitation the
-* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-* sell copies of the Software, and to permit persons to whom the Software is
-* furnished to do so, subject to the following conditions:
-*
-* The above copyright notice and this permission notice shall be included in all
-* copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-* SOFTWARE.
-*/
-
-#ifndef CKW_INCLUDE_CKW_DATATYPE_H
-#define CKW_INCLUDE_CKW_DATATYPE_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
-enum class DataType : int32_t
-{
- Unknown = 0x00,
- Fp32 = 0x11,
- Fp16 = 0x12,
- Int32 = 0x21,
- Int16 = 0x22,
- Int8 = 0x24,
- Uint32 = 0x31,
- Uint16 = 0x32,
- Uint8 = 0x34,
- Bool = 0x41
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_DATATYPE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
deleted file mode 100644
index c6afaa0ac..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/Functions.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
-* Copyright (c) 2023 Arm Limited.
-*
-* SPDX-License-Identifier: MIT
-*
-* Permission is hereby granted, free of charge, to any person obtaining a copy
-* of this software and associated documentation files (the "Software"), to
-* deal in the Software without restriction, including without limitation the
-* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-* sell copies of the Software, and to permit persons to whom the Software is
-* furnished to do so, subject to the following conditions:
-*
-* The above copyright notice and this permission notice shall be included in all
-* copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-* SOFTWARE.
-*/
-
-#ifndef CKW_INCLUDE_CKW_FUNCTIONS_H
-#define CKW_INCLUDE_CKW_FUNCTIONS_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-enum class UnaryFunction : int32_t
-{
- Exp = 0x0000,
- Tanh = 0x0001,
- Sqrt = 0x0002,
- Erf = 0x0003,
- Fabs = 0x0004,
- Log = 0x0006,
- Round = 0x0007,
- Floor = 0x0008,
-
- // Misc
- SizeOf = 0x0009,
-};
-
-enum class BinaryFunction : int32_t
-{
- Min = 0x0000,
- Max = 0x0001,
-};
-
-enum class TernaryFunction : int32_t
-{
- Select = 0x0000,
- Clamp = 0x0001,
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_FUNCTIONS_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h b/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
deleted file mode 100644
index 6c0861794..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
-#define CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-enum class GpuTargetLanguage : int32_t
-{
- Unknown,
- OpenCL
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
deleted file mode 100644
index b56099683..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/Operators.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
-* Copyright (c) 2023 Arm Limited.
-*
-* SPDX-License-Identifier: MIT
-*
-* Permission is hereby granted, free of charge, to any person obtaining a copy
-* of this software and associated documentation files (the "Software"), to
-* deal in the Software without restriction, including without limitation the
-* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-* sell copies of the Software, and to permit persons to whom the Software is
-* furnished to do so, subject to the following conditions:
-*
-* The above copyright notice and this permission notice shall be included in all
-* copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-* SOFTWARE.
-*/
-
-#ifndef CKW_INCLUDE_CKW_OPERATORS_H
-#define CKW_INCLUDE_CKW_OPERATORS_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-enum class UnaryOp : int32_t
-{
- LogicalNot = 0x0000, // !
- BitwiseNot = 0x0001, // ~
- Negate = 0x0002, // -
-};
-
-/* Binary operations
-*/
-enum class BinaryOp : int32_t
-{
- // Elementwise
- Add = 0x0000, // +
- Sub = 0x0001, // -
- Mul = 0x0002, // *
- Div = 0x0003, // /
- Mod = 0x0004, // %
- // Relational
- Equal = 0x1000, // ==
- Less = 0x1001, // <
- LessEqual = 0x1002, // <=
- Greater = 0x1003, // >
- GreaterEqual = 0x1004, // >=
- // Algebra
- MatMul_Nt_Nt = 0x2000, // X
- MatMul_Nt_T = 0x2001, // X
- MatMul_T_Nt = 0x2002, // X
- MatMul_T_T = 0x2003, // X
- Dot = 0x2004, // .
- // Logical
- LogicalAnd = 0x3000, // &&
- LogicalOr = 0x3001, // ||
- // Bitwise
- BitwiseXOR = 0x4000, // ^
-};
-
-enum class AssignmentOp : int32_t
-{
- // Unary
- Increment = 0x0000, // +=
- Decrement = 0x0001, // -=
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_OPERATORS_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h b/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
deleted file mode 100644
index 63405a076..000000000
--- a/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
-#define CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
-
-#include <cstdint>
-
-namespace ckw
-{
-
-enum class TensorSamplerFormat : int32_t
-{
- Unknown = 0,
- C_WH_1 = 1,
- C_W_H = 2
-};
-
-enum class TensorSamplerAddressModeX : int32_t
-{
- Unknown = 0,
- None = 1, // The user guarantees that the X coordinate is always in-bound
- OverlappingMin =
- 2 // (FIXED shapes only) Reduce the load/store length when x == 0 (MIN). The load length will be width % original length
- // Leftover elements can be handled using overlapping. This involves processing some of the elements in the array twice.
-};
-
-enum class TensorSamplerAddressModeY : int32_t
-{
- Unknown = 0,
- None = 1, // The user guarantees that the Y coordinate is always in-bound
- OverlappingMin =
- 2, // (FIXED shapes only) Reduce the load/store length when x == 0 (MIN). The load length will be width % original length
- Skip = 3, // Skip the read/write
- SkipMinEdgeOnly =
- 4, // Skip greater than or equal to max only. The user guarantees that the Y coordinate is always >= 0
- SkipMaxEdgeOnly = 5, // Skip less than 0 only
- ClampToNearest = 6, // Clamp the coordinate to nearest edge (0 or max value allowed on Y)
- ClampToMinEdgeOnly = 7, // Clamp the negative coordinate to 0 only. Therefore, we expect Y to be always < MAX
- ClampToMaxEdgeOnly = 8, // Clamp the coordinate to the max value allowed on Y only. We expect Y to be always >= 0
- ClampToBorder = 9, // Clamp to border which always has 0 value
- ClampToBorderMinEdgeOnly = 10,
- ClampToBorderMaxEdgeOnly = 11
-};
-
-enum class TensorSamplerAddressModeZ : int32_t
-{
- Unknown = 0,
- None = 1, // The user guarantees that the Y coordinate is always in-bound
- Skip = 3, // Skip the read/write
- SkipMinEdgeOnly =
- 4, // Skip greater than or equal to max only. The user guarantees that the Y coordinate is always >= 0
- SkipMaxEdgeOnly = 5, // Skip less than 0 only
- ClampToNearest = 6, // Clamp the coordinate to nearest edge (0 or max value allowed on Y)
- ClampToMinEdgeOnly = 7, // Clamp the negative coordinate to 0 only. Therefore, we expect Y to be always < MAX
- ClampToMaxEdgeOnly = 8, // Clamp the coordinate to the max value allowed on Y only. We expect Y to be always >= 0
-};
-
-} // namespace ckw
-
-#endif //CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
deleted file mode 100644
index 6228ed17d..000000000
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ /dev/null
@@ -1,163 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/Kernel.h"
-
-#include "ckw/TensorOperand.h"
-#include "ckw/types/GpuTargetLanguage.h"
-
-#include "src/Prototype.h"
-
-namespace ckw
-{
-
-Kernel::Kernel(GpuTargetLanguage language) : Kernel{"unnamed", language}
-{
-}
-
-Kernel::Kernel(const char *name, GpuTargetLanguage language)
- : _name(name),
- _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)),
- _operands{},
- _tensor_id_operands{}
-{
-}
-
-Kernel::~Kernel()
-{
-}
-
-const std::string &Kernel::name() const
-{
- return _name;
-}
-
-void Kernel::name(const std::string &name)
-{
- _name = name;
-}
-std::vector<KernelArgument> Kernel::arguments() const
-{
- std::vector<KernelArgument> arguments;
-
- const auto impl_args = _kernel->arguments.tensor_argument_declarations();
-
- for (auto tensor_arg : impl_args)
- {
- auto tensor = _tensor_id_operands.at(tensor_arg->format().id);
- arguments.push_back(*tensor);
-
- for (auto component_arg : tensor_arg->component_declarations())
- {
- switch (component_arg)
- {
- case TensorComponentType::OffsetFirstElement:
- arguments.push_back(tensor->offset_first_element_in_bytes());
- break;
-
- case TensorComponentType::Stride1:
- arguments.push_back(tensor->stride1());
- break;
-
- case TensorComponentType::Stride2:
- arguments.push_back(tensor->stride2());
- break;
-
- case TensorComponentType::Stride3:
- arguments.push_back(tensor->stride3());
- break;
-
- case TensorComponentType::Stride4:
- arguments.push_back(tensor->stride4());
- break;
-
- case TensorComponentType::Dim0:
- arguments.push_back(tensor->dim0());
- break;
-
- case TensorComponentType::Dim1:
- arguments.push_back(tensor->dim1());
- break;
-
- case TensorComponentType::Dim2:
- arguments.push_back(tensor->dim2());
- break;
-
- case TensorComponentType::Dim3:
- arguments.push_back(tensor->dim3());
- break;
-
- case TensorComponentType::Dim4:
- arguments.push_back(tensor->dim4());
- break;
-
- case TensorComponentType::Dim1xDim2:
- arguments.push_back(tensor->dim1_dim2());
- break;
-
- case TensorComponentType::Dim1xDim2xDim3:
- arguments.push_back(tensor->dim1_dim2_dim3());
- break;
-
- default:
- CKW_ASSERT(false);
- }
- }
- }
-
- return arguments;
-}
-
-TileOperand &Kernel::register_operand(std::unique_ptr<TileOperand> operand)
-{
- const auto &name = operand->name();
- auto ptr = operand.get();
-
- CKW_ASSERT(_operands.find(name) == _operands.end());
- _operands[name] = std::move(operand);
-
- return *ptr;
-}
-
-TensorOperand &Kernel::register_operand(std::unique_ptr<TensorOperand> operand)
-{
- const auto id = operand->info().id();
- const auto &name = operand->name();
- auto ptr = operand.get();
-
- CKW_ASSERT(_tensor_id_operands.find(id) == _tensor_id_operands.end());
- CKW_ASSERT(_operands.find(name) == _operands.end());
-
- _tensor_id_operands[id] = operand.get();
- _operands[name] = std::move(operand);
-
- return *ptr;
-}
-
-prototype::GpuKernelWriterDataHolder *Kernel::impl()
-{
- return _kernel.get();
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelArgument.cpp b/compute_kernel_writer/prototype/src/KernelArgument.cpp
deleted file mode 100644
index 24ace28eb..000000000
--- a/compute_kernel_writer/prototype/src/KernelArgument.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/KernelArgument.h"
-
-#include "ckw/Error.h"
-#include "ckw/TensorOperand.h"
-
-namespace ckw
-{
-
-KernelArgument::KernelArgument(TensorOperand &tensor) : _type(Type::TensorStorage), _id(tensor.info().id())
-{
- _sub_id.tensor_storage_type = tensor.storage_type();
-}
-
-KernelArgument::KernelArgument(TensorComponentOperand &tensor_component)
- : _type(Type::TensorComponent), _id(tensor_component.tensor().info().id())
-{
- _sub_id.tensor_component_type = tensor_component.component_type();
-}
-
-KernelArgument::Type KernelArgument::type() const
-{
- return _type;
-}
-
-int32_t KernelArgument::id() const
-{
- return _id;
-}
-
-TensorStorageType KernelArgument::tensor_storage_type() const
-{
- CKW_ASSERT(_type == Type::TensorStorage);
- return _sub_id.tensor_storage_type;
-}
-
-TensorComponentType KernelArgument::tensor_component_type() const
-{
- CKW_ASSERT(_type == Type::TensorComponent);
- return _sub_id.tensor_component_type;
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
deleted file mode 100644
index 9f58d9fef..000000000
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ /dev/null
@@ -1,371 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/KernelWriter.h"
-
-#include "ckw/Error.h"
-#include "ckw/TensorInfo.h"
-#include "ckw/TensorOperand.h"
-
-#include "src/Prototype.h"
-
-#include <sstream>
-
-namespace ckw
-{
-
-namespace
-{
-
-inline prototype::TensorInfo create_impl_tensor_info(const TensorInfo &info)
-{
- return prototype::TensorInfo{info.shape(), info.data_type(), info.data_layout(), info.id()};
-}
-
-} // namespace
-
-// =================================================================================================
-// Constructors and destructor
-// =================================================================================================
-
-KernelWriter::KernelWriter(Kernel &kernel)
- : _kernel(&kernel),
- _impl_attr(std::make_unique<prototype::GpuKernelWriterAttribute>()),
- _impl(prototype::GpuKernelWriterFactory::create(_impl_attr.get(), kernel.impl()))
-{
- _impl->set_IdSpace(1);
-}
-
-KernelWriter::~KernelWriter()
-{
-}
-
-// =================================================================================================
-// Scope management
-// =================================================================================================
-
-int32_t KernelWriter::id_space() const
-{
- return _id_space;
-}
-
-KernelWriter &KernelWriter::id_space(int32_t id_space)
-{
- CKW_ASSERT(id_space <= _max_id_space);
-
- _id_space = id_space;
- return *this;
-}
-
-int32_t KernelWriter::next_id_space()
-{
- id_space(++_max_id_space);
- return _id_space;
-}
-
-// =================================================================================================
-// Tensor and tile declaration
-// =================================================================================================
-
-TensorOperand &
-KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
-{
- const auto var_name = generate_variable_name(name);
-
- _impl->declare_argument(var_name, create_impl_tensor_info(info));
-
- auto &operand = _kernel->register_operand(std::make_unique<TensorOperand>(var_name, info, storage_type));
-
- return operand;
-}
-
-TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
-{
- const auto var_name = generate_variable_name(name);
-
- auto &operand = _kernel->register_operand(std::make_unique<TileOperand>(var_name, value));
-
- return operand;
-}
-
-std::string KernelWriter::generate_variable_name(const std::string &name) const
-{
- std::stringstream var_name;
-
- var_name << "_" << _id_space << "_" << name;
-
- return var_name.str();
-}
-
-TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> operand_ptr)
-{
- auto &operand = _kernel->register_operand(std::move(operand_ptr));
- const auto &name = operand.name();
-
- if (!operand.is_constant())
- {
- const auto &info = operand.tile_info();
-
- _impl->declare_tile(name, prototype::TileInfo(info.data_type(), info.width(), info.height()));
- }
- else
- {
- _impl->declare_const_tile(name, operand.value(), operand.data_type());
- }
-
- return operand;
-}
-
-// =================================================================================================
-// Load and store
-// =================================================================================================
-
-void KernelWriter::op_load(TileOperand &tile,
- const TensorOperand &tensor,
- const TensorTileSampler &sampler,
- const TileOperand &dilation_y)
-{
- prototype::TensorOperand impl_tensor(
- tensor.name(),
- prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
-
- auto impl_x = sampler.x().create_impl_operand(_impl.get());
- auto impl_y = sampler.y().create_impl_operand(_impl.get());
- auto impl_z = sampler.z().create_impl_operand(_impl.get());
- auto impl_b = sampler.b().create_impl_operand(_impl.get());
-
- auto impl_dilation_y = dilation_y.create_impl_operand(_impl.get());
-
- auto impl_dst = tile.create_impl_operand(_impl.get());
-
- _impl->op_load_immediate(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b, impl_dilation_y);
-}
-
-void KernelWriter::op_load_indirect(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler)
-{
- prototype::TensorOperand impl_tensor(
- tensor.name(),
- prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
-
- auto impl_x = sampler.x().create_impl_operand(_impl.get());
- auto impl_y = sampler.y().create_impl_operand(_impl.get());
- auto impl_z = sampler.z().create_impl_operand(_impl.get());
- auto impl_b = sampler.b().create_impl_operand(_impl.get());
-
- auto impl_dst = tile.create_impl_operand(_impl.get());
-
- _impl->op_load_indirect(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b);
-}
-
-void KernelWriter::util_get_indirect_buffer(TileOperand &tile,
- const TensorOperand &tensor,
- const TensorTileSampler &sampler,
- const TileOperand &x,
- const TileOperand &y,
- const TileOperand &x_off,
- const TileOperand &y_off)
-{
- prototype::TensorOperand impl_tensor(
- tensor.name(),
- prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
-
- auto impl_x = x.create_impl_operand(_impl.get());
- auto impl_y = y.create_impl_operand(_impl.get());
- auto impl_x_off = x_off.create_impl_operand(_impl.get());
- auto impl_y_off = y_off.create_impl_operand(_impl.get());
-
- auto impl_dst = tile.create_impl_operand(_impl.get());
-
- _impl->util_get_indirect_buffer(impl_dst, impl_tensor, impl_x, impl_y, impl_x_off, impl_y_off);
-}
-
-void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler)
-{
- prototype::TensorOperand impl_tensor(
- tensor.name(),
- prototype::GpuSampler{sampler.format(), prototype::to_gpu_tensor_storage(tensor.storage_type()),
- sampler.address_mode_x(), sampler.address_mode_y(), sampler.address_mode_z()});
- auto impl_src = tile.create_impl_operand(_impl.get());
- auto impl_x = sampler.x().create_impl_operand(_impl.get());
- auto impl_y = sampler.y().create_impl_operand(_impl.get());
- auto impl_z = sampler.z().create_impl_operand(_impl.get());
- auto impl_b = sampler.b().create_impl_operand(_impl.get());
-
- _impl->op_store_immediate(impl_tensor, impl_src, impl_x, impl_y, impl_z, impl_b);
-}
-
-// =================================================================================================
-// Data processing
-// =================================================================================================
-
-void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_src = src.create_impl_operand(_impl.get());
-
- _impl->op_assign(impl_dst, impl_src);
-}
-
-void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_src = src.create_impl_operand(_impl.get());
-
- _impl->op_cast_expression(impl_dst, impl_src, policy);
-}
-
-void KernelWriter::op_binary_expression(const TileOperand &dst,
- const TileOperand &lhs,
- BinaryOp op,
- const TileOperand &rhs)
-{
- auto impl_lhs = lhs.create_impl_operand(_impl.get());
- auto impl_rhs = rhs.create_impl_operand(_impl.get());
- auto impl_dst = dst.create_impl_operand(_impl.get());
-
- _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
-}
-
-void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_src = src.create_impl_operand(_impl.get());
-
- _impl->op_unary_expression(impl_dst, op, impl_src);
-}
-
-void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_src = src.create_impl_operand(_impl.get());
-
- _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
-}
-
-void KernelWriter::op_binary_elementwise_function(const TileOperand &dst,
- BinaryFunction opcode,
- const TileOperand &first,
- const TileOperand &second)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_first = first.create_impl_operand(_impl.get());
- auto impl_second = second.create_impl_operand(_impl.get());
-
- _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
-}
-
-void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst,
- TernaryFunction opcode,
- const TileOperand &first,
- const TileOperand &second,
- const TileOperand &third)
-{
- auto impl_dst = dst.create_impl_operand(_impl.get());
- auto impl_first = first.create_impl_operand(_impl.get());
- auto impl_second = second.create_impl_operand(_impl.get());
- auto impl_third = third.create_impl_operand(_impl.get());
-
- _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
-}
-
-void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
-{
- auto impl_lhs = lhs.create_impl_operand(_impl.get());
- auto impl_rhs = rhs.create_impl_operand(_impl.get());
-
- _impl->op_if_header(impl_lhs, op, impl_rhs);
- _impl->compound_statement_begin();
- body();
- _impl->compound_statement_end();
-}
-
-void KernelWriter::op_else_if(const TileOperand &lhs,
- BinaryOp op,
- const TileOperand &rhs,
- const std::function<void()> &body)
-{
- auto impl_lhs = lhs.create_impl_operand(_impl.get());
- auto impl_rhs = rhs.create_impl_operand(_impl.get());
-
- _impl->op_else_if_header(impl_lhs, op, impl_rhs);
- _impl->compound_statement_begin();
- body();
- _impl->compound_statement_end();
-}
-
-void KernelWriter::op_else(const std::function<void()> &body)
-{
- _impl->op_else_header();
- _impl->compound_statement_begin();
- body();
- _impl->compound_statement_end();
-}
-
-void KernelWriter::op_for_loop(const TileOperand &var_name,
- BinaryOp cond_op,
- const TileOperand &cond_value_name,
- const TileOperand &update_var_name,
- AssignmentOp update_op,
- const TileOperand &update_value_name,
- const std::function<void()> &body)
-{
- auto impl_var_name = var_name.create_impl_operand(_impl.get());
- auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
- auto impl_update_var_name = update_var_name.create_impl_operand(_impl.get());
- auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
-
- _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op,
- impl_update_value_name);
- _impl->compound_statement_begin();
- body();
- _impl->compound_statement_end();
-}
-
-// =================================================================================================
-// Misc
-// =================================================================================================
-
-void KernelWriter::op_get_global_id(const TileOperand &dst, int32_t dim)
-{
- _impl->op_get_global_id(prototype::Operand(dst.name()), dim);
-}
-
-void KernelWriter::op_return()
-{
- _impl->op_return();
-}
-
-// =================================================================================================
-// Code generation
-// =================================================================================================
-
-std::string KernelWriter::generate_code()
-{
- return prototype::generate_code(*_kernel->impl(), _kernel->name());
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/OperandBase.cpp b/compute_kernel_writer/prototype/src/OperandBase.cpp
deleted file mode 100644
index e0617fdc0..000000000
--- a/compute_kernel_writer/prototype/src/OperandBase.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/OperandBase.h"
-
-namespace ckw
-{
-
-OperandBase::OperandBase(const std::string &name) : _name(name)
-{
-}
-
-OperandBase::~OperandBase()
-{
-}
-
-const std::string &OperandBase::name() const
-{
- return _name;
-}
-
-OperandBase &OperandBase::name(const std::string &name)
-{
- _name = name;
- return *this;
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
deleted file mode 100644
index b392fe265..000000000
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ /dev/null
@@ -1,4189 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef CKW_PROTOTYPE_SRC_PROTOTYPE_H
-#define CKW_PROTOTYPE_SRC_PROTOTYPE_H
-
-#include "ckw/Error.h"
-#include "ckw/TensorInfo.h"
-#include "ckw/types/ConvertPolicy.h"
-#include "ckw/types/DataType.h"
-#include "ckw/types/Functions.h"
-#include "ckw/types/GpuTargetLanguage.h"
-#include "ckw/types/Operators.h"
-#include "ckw/types/TensorSamplerTypes.h"
-
-#include <algorithm>
-#include <array>
-#include <cassert> // assert (to be removed)
-#include <chrono>
-#include <cmath>
-#include <cstdint> // int32_t
-#include <functional>
-#include <iostream> // cout (to be removed)
-#include <map>
-#include <memory>
-#include <stdexcept>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-namespace ckw
-{
-namespace prototype
-{
-
-// Dummy data structure for Size2D
-using Size2D = std::vector<int32_t>;
-
-// Dummy Status
-using Status = void;
-
-enum class ComponentType : int32_t
-{
- Complex = 0,
- Simple = 1,
- Unfusable = 2
-};
-
-enum class GpuCompilationSpeed
-{
- Fast = 0x00, // fast compilation may increase the latency of the network
- Slow = 0x01 // slow compilation may decrease the latency of the network
-};
-
-enum class GpuExtensions
-{
- Fp16,
- Dot8,
- Mmul,
- FastMath
-};
-
-struct TensorInfo
-{
- TensorShape shape{{0}};
- DataType data_type{DataType::Unknown};
- TensorDataLayout data_layout{TensorDataLayout::Nhwc};
- int32_t id{-1};
-};
-
-struct ComponentAttribute
-{
- GpuCompilationSpeed compilation_speed{GpuCompilationSpeed::Fast};
- bool overwrite_tile{true};
-};
-
-inline std::string data_type_to_cl_type(DataType dt)
-{
- switch (dt)
- {
- case DataType::Fp32:
- return "float";
- case DataType::Fp16:
- return "half";
- case DataType::Int8:
- return "char";
- case DataType::Uint8:
- return "uchar";
- case DataType::Uint16:
- return "ushort";
- case DataType::Int16:
- return "short";
- case DataType::Uint32:
- return "uint";
- case DataType::Int32:
- return "int";
- case DataType::Bool:
- return "bool";
- default:
- assert(false);
- return "";
- }
-}
-
-inline int32_t width_to_cl_vector_size(int32_t width)
-{
- switch (width)
- {
- case 1:
- return 1;
- case 2:
- return 2;
- case 3:
- return 3;
- case 4:
- return 4;
- case 5:
- case 6:
- case 7:
- case 8:
- return 8;
- case 9:
- case 10:
- case 11:
- case 12:
- case 13:
- case 14:
- case 15:
- case 16:
- return 16;
- default:
- assert(false);
- return 0;
- }
-}
-
-inline std::string get_cl_data_type(DataType dt, int32_t width)
-{
- std::string data_type;
- int32_t w = width_to_cl_vector_size(width);
- data_type += data_type_to_cl_type(dt);
- if (w != 1)
- {
- data_type += std::to_string(w);
- }
- return data_type;
-}
-
-inline std::string to_opencl_store(int32_t vector_length)
-{
- if (vector_length != 1)
- {
- return "vstore" + std::to_string(vector_length) + "(";
- }
- else
- {
- return "*(";
- }
-}
-
-struct TileInfo
-{
- TileInfo()
- {
- }
-
- TileInfo(DataType dt) : dt(dt), w(1), h(1)
- {
- }
-
- TileInfo(DataType dt, int32_t width) : dt(dt), w(width), h(1)
- {
- }
-
- TileInfo(DataType dt, int32_t width, int32_t height) : dt(dt), w(width), h(height)
- {
- }
-
- DataType dt{DataType::Unknown}; // Data type of the tile
- int32_t w{0}; // Width (i.e. c0 - portion of the channels)
- int32_t h{0}; // Height (i.e. s0 - portion of the spatial dimensions)
-};
-
-inline std::ostream &operator<<(std::ostream &o, const TileInfo &a)
-{
- o << a.w << " x " << a.h;
- return o;
-}
-
-struct DataTypeAsString
-{
- std::string str{""};
- DataType dt{DataType::Unknown};
- int32_t size{1};
-};
-
-struct ValueAsString
-{
- std::string str{""};
- DataTypeAsString type{};
-};
-
-// https://stackoverflow.com/questions/51515378/storing-and-accessing-tile-properties-in-c
-// A Tile is a collection of variables used to express a 2D data.
-class IScalarTile
-{
-public:
- virtual ~IScalarTile() = default;
-
- /** Method to get the scalar variable from a tile
- * @param[in] x X coordinate on the width of the tile. If out-of-bound, the coordinate is clamped to the nearest valid edge
- * @param[in] y Y coordinate on the height of the tile. If out-of-bound, the coordinate is clamped to the nearest valid edge
- *
- * @return the scalar variable as a string
- */
- virtual ValueAsString scalar(int32_t x, int32_t y) const = 0;
-
- /** Method to get the list of underlying variable names used by the tile
- *
- * @return the list of variable names
- */
- virtual std::vector<ValueAsString> underlying_source_variables() const = 0;
-
- /** Method to get the name of the tile.
- *
- * @return the name of the tile
- */
- std::string name() const
- {
- return _basename;
- }
-
- /** Method to get the tile format
- *
- * @return the format
- */
- TileInfo format() const
- {
- return _format;
- }
-
- /** Method to know whether the tile is assignable or not (constant)
- *
- * @return true if the tile is assignable
- */
- virtual bool is_assignable() const = 0;
-
- /** Method to know whether the tile needs to be declared
- *
- * @return true if the tile needs to be declared in the code before being used
- */
- virtual bool need_declaration() const = 0;
-
-protected:
- TileInfo _format{}; // Tile format
- std::string _basename{""}; // Tile name
-};
-
-// A tile is a collection of variables used to express a 2D data. The variables are vectors in the GPU context.
-// The vector size is given by the width of the tile. The number of vectors height by depth defines the number of vectors
-class IVectorTile : public IScalarTile
-{
-public:
- virtual ~IVectorTile() = default;
-
- /** Method to get the vector variable from a tile. A vector is an ordered homogeneous collection of two or more scalars.
- * The user can query the list of supported width for the vectors through preferred_vector_sizes().
- *
- * @param[in] y Y coordinate on the height of the tile. If out-of-bound, the coordinate is clamped to the nearest valid edge
- *
- * @return the vector variable as a string
- */
- virtual ValueAsString vector(int32_t y) const = 0;
-
- /** Method to get a vector variable from a tile. A vector is an ordered homogeneous collection of two or more scalars.
- *
- * @return the vector variable as a string
- */
- virtual ValueAsString vector(int32_t x_start, int32_t width, int32_t y) const = 0;
- /** Method to get the preferred vector sizes.
- *
- * @return a vector with the preferred vector sizes
- */
- //virtual std::vector<int32_t> preferred_vector_sizes() const = 0;
-};
-
-class ClTile : public IVectorTile
-{
-public:
- ClTile(const std::string &name, TileInfo format)
- {
- _format = format;
- _basename = name;
- }
-
- ValueAsString scalar(int32_t x, int32_t y) const override
- {
- x = std::max(std::min(x, _format.w - 1), static_cast<int32_t>(0));
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- ValueAsString t;
- t.str = build_variable_name(y);
- t.type.str = get_cl_data_type(_format.dt, 1);
- t.type.dt = _format.dt;
- t.type.size = 1;
-
- // Check required because if the width has only one element, we cannot use .s0
- if (_format.w != 1)
- {
- // Automatic broadcasting
- t.str += ".s" + std::to_string(x);
- }
-
- return t;
- }
-
- ValueAsString vector(int32_t y) const override
- {
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- ValueAsString t;
- t.str = build_variable_name(y);
- t.type.str = get_cl_data_type(_format.dt, _format.w);
- t.type.dt = _format.dt;
- t.type.size = _format.w;
- return t;
- }
-
- ValueAsString vector(int32_t x_start, int32_t width, int32_t y) const override
- {
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- ValueAsString t;
- t.str = build_variable_name(y);
- t.type.str = get_cl_data_type(_format.dt, width);
- t.type.dt = _format.dt;
- t.type.size = width;
-
- if (_format.w != 1)
- {
- t.str += ".s";
- for (int i = 0; i < width; ++i)
- {
- t.str += to_scalar_hex(x_start + i);
- }
- }
- return t;
- }
-
- std::vector<ValueAsString> underlying_source_variables() const override
- {
- std::vector<ValueAsString> vars;
- for (int32_t y = 0; y < _format.h; ++y)
- {
- ValueAsString t;
- t.str = build_variable_name(y);
- t.type.str = get_cl_data_type(_format.dt, _format.w);
- t.type.dt = _format.dt;
- t.type.size = _format.w;
- vars.push_back(t);
- }
- return vars;
- }
-
- bool is_assignable() const override
- {
- return true;
- }
-
- bool need_declaration() const override
- {
- return true;
- }
-
-private:
- std::string build_variable_name(int32_t y) const
- {
- std::string var_name = _basename;
-
- if (_format.h == 1)
- {
- return var_name;
- }
- else
- {
- var_name += "_";
- var_name += std::to_string(y);
- }
-
- return var_name;
- }
-
- std::string to_scalar_hex(int32_t x) const
- {
- switch (x)
- {
- case 0:
- case 1:
- case 2:
- case 3:
- case 4:
- case 5:
- case 6:
- case 7:
- case 8:
- case 9:
- return std::to_string(x);
- case 10:
- return "A";
- case 11:
- return "B";
- case 12:
- return "C";
- case 13:
- return "D";
- case 14:
- return "E";
- case 15:
- return "F";
- default:
- std::cout << "Unsupported hexadecimal value" << std::endl;
- assert(false);
- return "";
- }
- }
-};
-
-// Unique features: It contains values in the form of string. The name used for this object is misleading since the variables can change the value over time.
-class ClConstantTile : public IVectorTile
-{
-public:
- ClConstantTile(const std::vector<std::vector<std::string>> &in, DataType dt)
- {
- _format.w = in[0].size();
- _format.h = in.size();
- _format.dt = dt;
-
- _data = std::vector<std::vector<std::string>>(_format.h, std::vector<std::string>(_format.w));
-
- for (int32_t y = 0; y < _format.h; ++y)
- {
- for (int32_t x = 0; x < _format.w; ++x)
- {
- _data[y][x] = in[y][x];
- }
- }
- }
-
- ValueAsString scalar(int32_t x, int32_t y) const override
- {
- x = std::max(std::min(x, _format.w - 1), static_cast<int32_t>(0));
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- ValueAsString t;
- t.str = _data[y][x];
- t.type.str = get_cl_data_type(_format.dt, 1);
- t.type.dt = _format.dt;
- t.type.size = 1;
-
- return t;
- }
-
- ValueAsString vector(int32_t y) const override
- {
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- return vector(0, _format.w, y);
- }
-
- ValueAsString vector(int32_t x_start, int32_t width, int32_t y) const override
- {
- y = std::max(std::min(y, _format.h - 1), static_cast<int32_t>(0));
-
- ValueAsString t;
- t.str = "";
- t.type.str = get_cl_data_type(_format.dt, width);
- t.type.dt = _format.dt;
- t.type.size = width;
-
- if (width > 1)
- {
- t.str += "((" + get_cl_data_type(_format.dt, width) + ")(";
- }
-
- int32_t x = x_start;
- for (; x < width - 1; ++x)
- {
- t.str += scalar(x, y).str;
- t.str += ", ";
- }
- t.str += scalar(x, y).str;
-
- if (width > 1)
- {
- t.str += "))";
- }
-
- return t;
- }
-
- std::vector<ValueAsString> underlying_source_variables() const override
- {
- std::vector<ValueAsString> vars;
-
- for (int32_t y = 0; y < _format.h; ++y)
- {
- for (int32_t x = 0; x < _format.w; ++x)
- {
- ValueAsString t;
- t.str = _data[y][x];
- t.type.str = get_cl_data_type(_format.dt, 1);
- t.type.dt = _format.dt;
- t.type.size = 1;
- vars.push_back(t);
- }
- }
-
- return vars;
- }
-
- bool is_assignable() const override
- {
- return false;
- }
-
- bool need_declaration() const override
- {
- return false;
- }
-
-private:
- std::vector<std::vector<std::string>> _data{};
-};
-
-enum class TensorComponentIndex : int32_t
-{
- IndexMask = 0x0000000f,
-};
-
-enum class TensorComponentGroup : int32_t
-{
- OffsetFirstElement = 0x00000100,
- Stride = 0x00001000,
- Dimension = 0x00010000,
- FoldedDimension = 0x00100000,
- Constant = 0x01000000
-};
-
-inline std::string to_string(TensorComponentType x)
-{
- switch (x)
- {
- case TensorComponentType::Unknown:
- return "Unknown";
- case TensorComponentType::OffsetFirstElement:
- return "OffsetFirstElement";
- case TensorComponentType::Stride1:
- return "Stride1";
- case TensorComponentType::Stride2:
- return "Stride2";
- case TensorComponentType::Stride3:
- return "Stride3";
- case TensorComponentType::Stride4:
- return "Stride4";
- case TensorComponentType::Dim0:
- return "Dim0";
- case TensorComponentType::Dim1:
- return "Dim1";
- case TensorComponentType::Dim2:
- return "Dim2";
- case TensorComponentType::Dim3:
- return "Dim3";
- case TensorComponentType::Dim4:
- return "Dim4";
- case TensorComponentType::Dim1xDim2:
- return "Dim1xDim2";
- case TensorComponentType::Dim1xDim2xDim3:
- return "Dim1xDim2xDim3";
- default:
- assert(false);
- return "";
- }
-}
-
-class ITensorArgument
-{
-public:
- virtual ~ITensorArgument() = default;
-
- /** Method to get the tensor component as a string
- *
- * @param[in] x tensor component to query
- *
- * @return the tensor component as a string
- */
- virtual std::string component(TensorComponentType x) = 0;
-
- /** Method to get the tensor component type declaration as a string
- *
- * @return the tensor component type declaration as a string
- */
- virtual std::string component_type_declaration() const = 0;
-
- /** Method to get the tensor component data type
- *
- * @return the tensor component data type
- */
- virtual DataType component_data_type() const = 0;
-
- /** Method to get the tensor component declarations
- *
- * @return a vector containing the tensor component declarations
- */
- virtual std::vector<TensorComponentType> component_declarations() const = 0;
-
- /** Method to get the name of the tensor argument.
- *
- * @return the name of the tensor argument
- */
- std::string name() const
- {
- return _basename;
- }
-
- /** Method to get the tensor format
- *
- * @return the format
- */
- TensorInfo format() const
- {
- return _format;
- }
-
-protected:
- TensorInfo _format{};
- std::string _basename{};
-};
-
-enum class GpuTensorStorage : int32_t
-{
- Unknown = 0x0000,
- BufferUint8Ptr = 0x0012,
- Image2dReadOnly = 0x0020,
- Image2dWriteOnly = 0x0021,
- Image3dReadOnly = 0x0030,
- Image3dWriteOnly = 0x0031
-};
-
-inline GpuTensorStorage to_gpu_tensor_storage(TensorStorageType s)
-{
- switch (s)
- {
- case TensorStorageType::Unknown:
- return GpuTensorStorage::Unknown;
-
- case TensorStorageType::BufferUint8Ptr:
- return GpuTensorStorage::BufferUint8Ptr;
-
- case TensorStorageType::Texture2dReadOnly:
- return GpuTensorStorage::Image2dReadOnly;
-
- case TensorStorageType::Texture2dWriteOnly:
- return GpuTensorStorage::Image2dWriteOnly;
-
- default:
- assert(false);
- return GpuTensorStorage::Unknown;
- }
-}
-
-inline TensorStorageType to_tensor_storage(GpuTensorStorage s)
-{
- switch (s)
- {
- case GpuTensorStorage::Unknown:
- return TensorStorageType::Unknown;
-
- case GpuTensorStorage::BufferUint8Ptr:
- return TensorStorageType::BufferUint8Ptr;
-
- case GpuTensorStorage::Image2dReadOnly:
- return TensorStorageType::Texture2dReadOnly;
-
- case GpuTensorStorage::Image2dWriteOnly:
- return TensorStorageType::Texture2dWriteOnly;
-
- default:
- assert(false);
- return TensorStorageType::Unknown;
- }
-}
-
-class IGpuTensorArgument : public ITensorArgument
-{
-public:
- virtual ~IGpuTensorArgument() = default;
-
- /** Method to get the tensor storage, which is the underlying storage used to keep the data memory
- *
- * @param[in] x tensor storage to query
- *
- * @return the tensor storage as a string
- */
- virtual std::string storage(GpuTensorStorage x) = 0;
-
- /** Method to get the tensor storage type declaration as a string
- *
- * @param[in] x tensor component to query
- *
- * @return the tensor storage type declaration as a string
- */
- virtual std::string storage_type_declaration(GpuTensorStorage x) const = 0;
-
- /** Method to get the tensor storage declarations
- *
- * @return a vector containing the tensor storage declarations
- */
- virtual std::vector<GpuTensorStorage> storage_declarations() const = 0;
-};
-
-class ClTensorArgument : public IGpuTensorArgument
-{
-public:
- ClTensorArgument(const std::string &name, const TensorInfo &x, bool return_by_value_when_possible)
- {
- _basename = name;
- _format = x;
- _return_by_value_when_possible = return_by_value_when_possible;
- }
-
- // Methods to override
- std::string component(TensorComponentType x) override
- {
- if ((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Constant)))
- {
- int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
- return std::to_string(idx - 1);
- }
-
- if (_return_by_value_when_possible)
- {
- if ((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Dimension)))
- {
- int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
- return std::to_string(_format.shape[idx]);
- }
-
- if ((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::FoldedDimension)))
- {
- switch (x)
- {
- case TensorComponentType::Dim1xDim2:
- return std::to_string(_format.shape[1] * _format.shape[2]);
- case TensorComponentType::Dim1xDim2xDim3:
- return std::to_string(_format.shape[1] * _format.shape[2] * _format.shape[2]);
- default:
- std::cout << "Unsupported folded dimension" << std::endl;
- assert(false);
- }
- }
- }
-
- if (std::find(_components_required.begin(), _components_required.end(), x) == _components_required.end())
- {
- _components_required.push_back(x);
- }
-
- return build_component_name(x);
- }
-
- std::string component_type_declaration() const override
- {
- return "int";
- };
-
- DataType component_data_type() const override
- {
- return DataType::Int32;
- }
-
- std::string storage(GpuTensorStorage x) override
- {
- if (std::find(_storage_required.begin(), _storage_required.end(), x) == _storage_required.end())
- {
- _storage_required.push_back(x);
- }
-
- return build_storage_name(x);
- }
-
- std::string storage_type_declaration(GpuTensorStorage x) const override
- {
- switch (x)
- {
- case GpuTensorStorage::BufferUint8Ptr:
- return "__global uchar*";
- case GpuTensorStorage::Image2dReadOnly:
- return "__read_only image2d_t";
- case GpuTensorStorage::Image2dWriteOnly:
- return "__write_only image2d_t";
- case GpuTensorStorage::Image3dReadOnly:
- return "__read_only image3d_t ";
- case GpuTensorStorage::Image3dWriteOnly:
- return "__write_only image3d_t ";
- default:
- std::cout << "Unsupported storage" << std::endl;
- assert(false);
- return "";
- }
- };
-
- std::vector<GpuTensorStorage> storage_declarations() const override
- {
- return _storage_required;
- }
-
- std::vector<TensorComponentType> component_declarations() const override
- {
- return _components_required;
- }
-
-private:
- std::string build_storage_name(GpuTensorStorage x) const
- {
- std::string var_name = _basename;
-
- switch (x)
- {
- case GpuTensorStorage::BufferUint8Ptr:
- return var_name + "_ptr";
- case GpuTensorStorage::Image2dReadOnly:
- case GpuTensorStorage::Image2dWriteOnly:
- return var_name + "_img2d";
- case GpuTensorStorage::Image3dReadOnly:
- case GpuTensorStorage::Image3dWriteOnly:
- return var_name + "_img3d";
- default:
- std::cout << "Unsupported storage" << std::endl;
- assert(false);
- }
-
- return var_name;
- }
-
- std::string build_component_name(TensorComponentType x) const
- {
- std::string var_name = _basename;
-
- switch (x)
- {
- case TensorComponentType::OffsetFirstElement:
- return var_name + "_offset_first_element";
- case TensorComponentType::Stride1:
- return var_name + "_stride1";
- case TensorComponentType::Stride2:
- return var_name + "_stride2";
- case TensorComponentType::Stride3:
- return var_name + "_stride3";
- case TensorComponentType::Dim0:
- return var_name + "_dim0";
- case TensorComponentType::Dim1:
- return var_name + "_dim1";
- case TensorComponentType::Dim2:
- return var_name + "_dim2";
- case TensorComponentType::Dim3:
- return var_name + "_dim3";
- case TensorComponentType::Dim1xDim2:
- return var_name + "_dim1xdim2";
- case TensorComponentType::Dim1xDim2xDim3:
- return var_name + "_dim1xdim2xdim3";
- default:
- std::cout << "Unsupported component" << std::endl;
- assert(false);
- }
-
- return var_name;
- }
-
- bool _return_by_value_when_possible{false};
- std::vector<GpuTensorStorage> _storage_required{};
- std::vector<TensorComponentType> _components_required{};
-};
-
-/**
- * @brief Data structure that contains the declared tiles by the components.
- * The registry is a linear data structure that follows the similar principle of the stack. The user can use the @p increment_registry_level() method to
- * increase the level of the stack (0 when it starts). When the user uses the @p decrement_registry_level() method, the registry decreases the level of the stack
- * and remove (pop) all the tiles from the level above.
- * When a tile is declared on the level 0, it is a global tile. A global tile is visible in all parts of the code.
- * Since different components may use the same name to define a tile, the registry adopts the IdSpace concept, an @p id to prevent name collisions
- * when declaring tiles among different components.
- *
- */
-class GpuTileRegistry
-{
-public:
- enum class RegistryTileType
- {
- Tile,
- Link
- };
-
- using RegistryIdSpace = int32_t;
- using RegistryLevel = int32_t;
- using RegistryTileName = std::string;
-
- struct RegistryTileTableEntry
- {
- RegistryLevel registry_level{0};
- std::unique_ptr<IVectorTile> tile_object{nullptr};
- };
-
- struct RegistryTileTypeTableEntry
- {
- RegistryTileType tile_type{RegistryTileType::Tile};
- RegistryTileName tile_name{};
- RegistryIdSpace registry_idspace{0};
- RegistryLevel registry_level{0};
- };
-
- using RegistryTileTable = std::map<RegistryIdSpace, std::map<RegistryTileName, RegistryTileTableEntry>>;
- using RegistryTileTypeTable = std::map<RegistryIdSpace, std::map<RegistryTileName, RegistryTileTypeTableEntry>>;
-
- /**
- * @brief Construct a new Gpu Tile Registry object
- *
- */
- GpuTileRegistry()
- {
- _language = GpuTargetLanguage::Unknown;
- }
-
- /**
- * @brief Construct a new Gpu Tile Registry object providing the Gpu programming language
- *
- * @param[in] language Gpu programming language to use
- */
- GpuTileRegistry(GpuTargetLanguage language)
- {
- _language = language;
- }
-
- /**
- * @brief Default destructor. Destroy the Gpu Tile Registry object
- *
- */
- ~GpuTileRegistry() = default;
-
- /**
- * @brief Set the working IdSpace for the tile registry. IdSpace is used to prevent name collisions when declaring tiles.
- * Therefore, the IdSpace should be set before declaring any tiles.
- *
- * @param[in] id The IdSpace id
- */
- void set_IdSpace(int32_t id)
- {
- _IdSpace = id;
- }
-
- /**
- * @brief Get the current working IdSpace for the tile registry. IdSpace is used to prevent name collisions when declaring tiles
- *
- * @return The IdSpace id
- */
- int32_t IdSpace() const
- {
- return _IdSpace;
- }
-
- /**
- * @brief Gets all the IdSpace declarations defined in the tile registry.
- *
- * @return all the IdSpace declarations defined in the tile registry as std::vector<int32_t>. It returns an empty vector if there are no IdSpace declarations.
- */
- std::vector<int32_t> IdSpace_declarations() const
- {
- std::vector<int32_t> x;
-
- auto it = _frags.begin();
-
- while (it != _frags.end())
- {
- x.push_back(it->first);
-
- it++;
- }
-
- return x;
- }
-
- /**
- * @brief Declare a tile from a previously created tile
- */
- void insert(const std::string &name, const IVectorTile *frag)
- {
- assert(_language == GpuTargetLanguage::OpenCL);
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = name;
- const std::string var_name = frag->name();
- TileInfo format = frag->format();
-
- // First check whether a tile with the same name exists
- IVectorTile *result = (*this)[key_var_name];
- assert(result == nullptr);
- if (result == nullptr)
- {
- std::unique_ptr<ClTile> tile = std::make_unique<ClTile>(var_name, format);
-
- _frags[key_IdSpace][key_var_name].tile_object = std::move(tile);
- _frags[key_IdSpace][key_var_name].registry_level = _registry_level;
-
- _frag_types[key_IdSpace][key_var_name].tile_type = RegistryTileType::Link;
- _frag_types[key_IdSpace][key_var_name].tile_name = key_var_name;
- _frag_types[key_IdSpace][key_var_name].registry_idspace = _IdSpace;
- _frag_types[key_IdSpace][key_var_name].registry_level = _registry_level;
- }
- }
-
- /**
- * @brief Declare a tile with TileInfo. The tile will be stored in the IdSpace set with @p set_IdSpace()
- *
- * @note The reference name used for declaring the tile should not be previously used in the IdSpace
- *
- * @param[in] name Reference name for the tile. The reference name can be used to retrieve the tile stored in the registry.
- * @param[in] format Tile format use to use
- */
- void insert(const std::string &name, const TileInfo &format)
- {
- assert(_language == GpuTargetLanguage::OpenCL);
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = name;
- const std::string var_name = generate_tile_name(name);
-
- // First check whether a tile with the same name exists
- IVectorTile *result = (*this)[key_var_name];
- assert(result == nullptr);
- if (result == nullptr)
- {
- std::unique_ptr<ClTile> tile = std::make_unique<ClTile>(var_name, format);
- _frags[key_IdSpace][key_var_name].tile_object = std::move(tile);
- _frags[key_IdSpace][key_var_name].registry_level = _registry_level;
-
- _frag_types[key_IdSpace][key_var_name].tile_type = RegistryTileType::Tile;
- _frag_types[key_IdSpace][key_var_name].tile_name = key_var_name;
- _frag_types[key_IdSpace][key_var_name].registry_idspace = _IdSpace;
- _frag_types[key_IdSpace][key_var_name].registry_level = _registry_level;
- }
- }
-
- /**
- * @brief Declare a constant tile. The content of the tile is passed as a vector of std::string
- *
- * @note The reference name used for declaring the tile should not be previously used in the IdSpace
- *
- * @param[in] name Reference name for the tile. The reference name can be used to retrieve the tile stored in the registry.
- * @param[in] in A 3D std::vector of std::string. From the 3D std::vector we can know the dimensions for the tile
- * @param[in] dt The data type for the elements stored in the 3D std::vector as std::string. It is user's responsibilty to ensure
- * that the data type is aligned with the content of the std::string.
- */
- void insert(const std::string &name, const std::vector<std::vector<std::string>> &in, DataType dt)
- {
- assert(_language == GpuTargetLanguage::OpenCL);
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = name;
-
- // First check whether a tile with the same name exists
- IVectorTile *result = (*this)[key_var_name];
- assert(result == nullptr);
- if (result == nullptr)
- {
- std::unique_ptr<ClConstantTile> tile = std::make_unique<ClConstantTile>(in, dt);
- _frags[key_IdSpace][key_var_name].tile_object = std::move(tile);
- _frags[key_IdSpace][key_var_name].registry_level = _registry_level;
-
- _frag_types[key_IdSpace][key_var_name].tile_type = RegistryTileType::Tile;
- _frag_types[key_IdSpace][key_var_name].tile_name = key_var_name;
- _frag_types[key_IdSpace][key_var_name].registry_idspace = _IdSpace;
- _frag_types[key_IdSpace][key_var_name].registry_level = _registry_level;
- }
- }
-
- /**
- * @brief Declare an anonymous constant tile. The content of the tile is passed as a vector of std::string
- *
- * @note This method can be used to declare temporary tiles that need to be accessed only once.
- *
- * @param[in] in A 3D std::vector of std::string. From the 3D std::vector we can know the dimensions for the tile
- * @param[in] dt The data type for the elements stored in the 3D std::vector as std::string. It is user responsibilty to ensure
- * that the data type is aligned with what passed with the std::string.
- *
- * @return IVectorTile* the anonymous constant tile
- */
- IVectorTile *insert(const std::vector<std::vector<std::string>> &in, DataType dt)
- {
- assert(_language == GpuTargetLanguage::OpenCL);
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = "_" + std::to_string(_anonymous_frag_count++);
-
- // First check whether a tile with the same name exists
- IVectorTile *result = (*this)[key_var_name];
- assert(result == nullptr);
- if (result == nullptr)
- {
- std::unique_ptr<ClConstantTile> tile = std::make_unique<ClConstantTile>(in, dt);
- _frags[key_IdSpace][key_var_name].tile_object = std::move(tile);
- _frags[key_IdSpace][key_var_name].registry_level = _registry_level;
-
- _frag_types[key_IdSpace][key_var_name].tile_type = RegistryTileType::Tile;
- _frag_types[key_IdSpace][key_var_name].tile_name = key_var_name;
- _frag_types[key_IdSpace][key_var_name].registry_idspace = _IdSpace;
- _frag_types[key_IdSpace][key_var_name].registry_level = _registry_level;
- }
-
- return (*this)[key_var_name];
- }
-
- /**
- * @brief Get the tile from the registry. This method searches the tile in the IdSpace provided by the user
- *
- * @param[in] name The name of the tile to retrieve
- * @param[in] IdSpace The IdSpace id where to search the tile
- *
- * @return IVectorTile* The tile
- */
- IVectorTile *get(const std::string &name, int32_t IdSpace)
- {
- const int32_t key_IdSpace = IdSpace;
- const std::string key_var_name = name;
-
- IVectorTile *result = nullptr;
- auto search_IdSpace = _frags.find(key_IdSpace);
- if (search_IdSpace != _frags.end())
- {
- auto search_tile = _frags[key_IdSpace].find(key_var_name);
- if (search_tile != _frags[key_IdSpace].end())
- {
- result = search_tile->second.tile_object.get();
- assert(result != nullptr);
- }
- }
-
- return result;
- }
-
- /**
- * @brief Get the tile from the registry. This method searches the tile in the IdSpace set with @p set_IdSpace()
- *
- * @param[in] name The name of the tile to retrieve
- *
- * @return IVectorTile* The tile
- */
- IVectorTile *operator[](const std::string &name)
- {
- return get(name, _IdSpace);
- }
-
- /**
- * @brief Check whether the tile in the in the IdSpace provided by the user exists
- *
- * @param[in] name Name of the tile to search for
- * @param[in] IdSpace The IdSpace id where to search the tile
- *
- * @return true if the tile exists
- * @return false if the tile does not exist
- */
- bool has_tile(const std::string &name, int32_t IdSpace) const
- {
- const int32_t key_IdSpace = IdSpace;
- const std::string key_var_name = name;
-
- // IVectorTile* result = nullptr;
- auto search_IdSpace = _frags.find(key_IdSpace);
-
- return search_IdSpace != _frags.end();
- }
-
- /**
- * @brief Check whether the tile within the current IdSpace exists
- *
- * @param[in] name Name of the tile to search for
- *
- * @return true if the tile exists
- * @return false if the tile does not exist
- */
- bool has_tile(const std::string &name) const
- {
- return has_tile(name, _IdSpace);
- }
-
- /**
- * @brief Get all the tiles declared within the IdSpace provided by the user
- *
- * @param[in] IdSpace IdSpace where to retrieve all the declared tiles
- *
- * @return std::vector<IVectorTile*> A vector with all the declared tiles in the IdSpace provided by the user
- */
- std::vector<IVectorTile *> tile_declarations(int32_t IdSpace)
- {
- std::vector<IVectorTile *> tiles;
-
- std::map<RegistryTileName, RegistryTileTypeTableEntry>::iterator it = _frag_types[IdSpace].begin();
-
- while (it != _frag_types[IdSpace].end())
- {
- // The following line should be enabled. However, we cannot at this stage
- // because it used to retrieve the output tile produced by each component.
- // However, this method should NOT be used to retrieve the output tile
- //if(it->second.tile_type == RegistryTileType::Tile)
- {
- tiles.push_back(get(it->second.tile_name, it->second.registry_idspace));
- }
- it++;
- }
-
- return tiles;
- }
-
- /**
- * @brief Increase the level of stack.
- *
- */
- void increment_registry_level()
- {
- _registry_level++;
- }
-
- /**
- * @brief Remove all the tiles declared at the current stack level and decrease the level of the stack.
- *
- */
- void decrement_registry_level()
- {
- assert(_registry_level >= 0);
-
- // Remove all variables in the local scope
- std::map<RegistryTileName, RegistryTileTableEntry>::iterator it = _frags[_IdSpace].begin();
-
- while (it != _frags[_IdSpace].end())
- {
- if (it->second.registry_level == _registry_level)
- {
- it = _frags[_IdSpace].erase(it);
- }
- else
- {
- it++;
- }
- }
-
- std::map<RegistryTileName, RegistryTileTypeTableEntry>::iterator it_type = _frag_types[_IdSpace].begin();
-
- while (it_type != _frag_types[_IdSpace].end())
- {
- if (it_type->second.registry_level == _registry_level)
- {
- it_type = _frag_types[_IdSpace].erase(it_type);
- }
- else
- {
- it_type++;
- }
- }
-
- _registry_level--;
- }
-
- /**
- * @brief Get the level of the stack
- *
- */
- int32_t level() const
- {
- return _registry_level;
- }
-
-private:
- // This method ensures that the key is unique among different components
- std::string generate_tile_name(const std::string &name)
- {
- assert(_IdSpace >= 0);
- if (_registry_level == 0)
- {
- return "_G" + std::to_string(_IdSpace) + "_" + name;
- }
- else
- {
- return name;
- }
- }
-
- RegistryTileTable _frags{};
- RegistryTileTypeTable _frag_types{};
- RegistryLevel _registry_level{0};
- RegistryIdSpace _IdSpace{-1};
- int32_t _anonymous_frag_count{0}; // Counter used to create the anonymous tiles
- GpuTargetLanguage _language{GpuTargetLanguage::Unknown}; // Gpu programming language
-};
-
-using TensorEntry = std::unique_ptr<IGpuTensorArgument>;
-
-/**
- * @brief Data structure that contains the tensors consumed by the components.
- * Since different components may use the same name as reference for a tensor, the registry adopts the IdSpace concept, an @p id to prevent name collisions
- * when declaring tensors among different components.
- *
- */
-class GpuTensorArgumentRegistry
-{
-public:
- /**
- * @brief Construct a new Gpu Tensor Registry object
- *
- */
- GpuTensorArgumentRegistry()
- {
- _language = GpuTargetLanguage::Unknown;
- }
-
- /**
- * @brief Construct a new Gpu Tensor Registry object
- *
- * @param[in] language Gpu programming language to use
- */
- GpuTensorArgumentRegistry(GpuTargetLanguage language)
- {
- _language = language;
- }
-
- /**
- * @brief Default destructor. Destroy the Gpu Tensor Registry object
- *
- */
- ~GpuTensorArgumentRegistry() = default;
-
- /**
- * @brief Set the working IdSpace for the tensor registry. IdSpace is used to prevent name collisions when declaring tensors.
- * Therefore, the IdSpace should be set before declaring any tensors.
- *
- * @param[in] id The IdSpace id
- */
- void set_IdSpace(int32_t id)
- {
- _IdSpace = id;
- }
-
- /**
- * @brief Get the current working IdSpace for the tensor registry. IdSpace is used to prevent name collisions when declaring tensors
- *
- * @return The IdSpace id
- */
- int32_t IdSpace() const
- {
- return _IdSpace;
- }
-
- /**
- * @brief Gets all the IdSpace declarations defined in the tensor registry.
- *
- * @return all the IdSpace declarations defined in the tensor registry as std::vector<int32_t>. It returns an empty vector if there are no IdSpace declarations.
- */
- std::vector<int32_t> IdSpace_declarations() const
- {
- std::vector<int32_t> x;
-
- auto it = _refs.begin();
-
- while (it != _refs.end())
- {
- x.push_back(it->first);
-
- it++;
- }
-
- return x;
- }
-
- /**
- * @brief Declare a tensor with TensorInfo. The tensor will be stored in the IdSpace set with @p set_IdSpace()
- *
- * @note The reference name used for declaring the tensor should not be previously used in the IdSpace
- *
- * @param[in] name Reference name for the tensor. The reference name can be used to retrieve the tensor stored in the registry.
- * @param[in] x Pair of tensor info and tensor id
- * @param[in] return_by_value_when_possible True if we want the value stored in the tensor components
- */
- void insert(const std::string &name, const TensorInfo &x, bool return_by_value_when_possible)
- {
- assert(_language == GpuTargetLanguage::OpenCL);
- const int32_t key_IdSpace = _IdSpace;
- const int32_t tensor_id = x.id;
- const std::string key_var_name = name;
- const std::string var_name = generate_tensor_name(name, tensor_id);
-
- // First, check whether the tensor has already a reference. If so, trigger an assert
- assert(!has_tensor_argument(name));
-
- // Check whether a tensor with that tensorID exists
- auto result = _tensor_arguments.find(tensor_id);
- if (result == _tensor_arguments.end())
- {
- // It means that we haven't added a tensor with that tensor_id yet. Create a IGpuTensorArgument before creating the reference
- std::unique_ptr<ClTensorArgument> arg =
- std::make_unique<ClTensorArgument>(var_name, x, return_by_value_when_possible);
- _tensor_arguments[tensor_id] = std::move(arg);
- }
-
- _refs[key_IdSpace][key_var_name] = tensor_id;
- }
-
- /**
- * @brief Get the tensor from the registry. This method searches the tensor in the IdSpace set with @p set_IdSpace()
- *
- * @param[in] name The name of the tensor to retrieve
- *
- * @return IGpuTensor* The tensor
- */
- IGpuTensorArgument *operator[](const std::string &name)
- {
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = name;
-
- IGpuTensorArgument *result = nullptr;
- auto search_IdSpace = _refs.find(key_IdSpace);
- if (search_IdSpace != _refs.end())
- {
- auto search_tensor_id = _refs[key_IdSpace].find(key_var_name);
-
- if (search_tensor_id != _refs[key_IdSpace].end())
- {
- const int32_t tensor_id = search_tensor_id->second;
- auto search_tensor_argument = _tensor_arguments.find(tensor_id);
- if (search_tensor_argument != _tensor_arguments.end())
- {
- result = search_tensor_argument->second.get();
- }
- assert(result != nullptr);
- }
- }
-
- return result;
- }
-
- /**
- * @brief Get all the tensors declared in the IdSpace provided by the user
- *
- * @return std::vector<IGpuTensorArgument*> A vector with all the declared tensors
- */
- std::vector<IGpuTensorArgument *> tensor_argument_declarations()
- {
- std::vector<IGpuTensorArgument *> args;
-
- auto it = _tensor_arguments.begin();
-
- while (it != _tensor_arguments.end())
- {
- args.push_back(it->second.get());
- it++;
- }
-
- return args;
- }
-
- /**
- * @brief Check whether the tensor argument in the IdSpace set with @p set_IdSpace() exists
- *
- * @param[in] name Name of the tensor argument to search for
- *
- * @return true if the tensor argument exists
- * @return false if the tensor argument does not exist
- */
- bool has_tensor_argument(const std::string &name)
- {
- const int32_t key_IdSpace = _IdSpace;
- const std::string key_var_name = name;
-
- auto search_IdSpace = _refs.find(key_IdSpace);
-
- if (search_IdSpace != _refs.end())
- {
- auto search_tensor_id = _refs[key_IdSpace].find(key_var_name);
-
- return search_tensor_id != _refs[key_IdSpace].end();
- }
- else
- {
- return false;
- }
- }
-
- /**
- * @brief Check whether the tensor argument is in the the IdSpace provided by the user
- *
- * @param[in] name Name of the tensor argument to search for
- * @param[in] IdSpace The IdSpace id where to search the tensor argument
- *
- * @return true if the tile exists
- * @return false if the tile does not exist
- */
- bool has_tensor_argument(const std::string &name, int32_t IdSpace)
- {
- const int32_t key_IdSpace = IdSpace;
- const std::string key_var_name = name;
-
- auto search_IdSpace = _refs.find(key_IdSpace);
-
- if (search_IdSpace != _refs.end())
- {
- auto search_tensor_id = _refs[key_IdSpace].find(key_var_name);
-
- return search_tensor_id != _refs[key_IdSpace].end();
- }
- else
- {
- return false;
- }
- }
-
-private:
- // This method ensures that the key is unique among different components
- std::string generate_tensor_name(const std::string &name, int32_t tensor_id)
- {
- assert(tensor_id >= 0);
-
- return name + std::to_string(tensor_id);
- }
-
- std::map<int32_t, TensorEntry> _tensor_arguments{};
- std::map<int32_t, std::map<std::string, int32_t>> _refs{};
- int32_t _IdSpace{-1};
- GpuTargetLanguage _language{GpuTargetLanguage::Unknown}; // Gpu programming language
-};
-
-enum class OpType : int32_t
-{
- Elementwise = 0x0000,
- Relational = 0x1000,
- Algebra = 0x2000
-};
-
-inline std::string to_string(AssignmentOp op)
-{
- switch (op)
- {
- case AssignmentOp::Decrement:
- return "-=";
- case AssignmentOp::Increment:
- return "+=";
- default:
- assert(false);
- return "";
- }
-}
-
-inline std::string to_string(UnaryOp op)
-{
- switch (op)
- {
- case UnaryOp::LogicalNot:
- return "!";
- case UnaryOp::BitwiseNot:
- return "~";
- case UnaryOp::Negate:
- return "-";
- default:
- assert(false);
- return "";
- }
-}
-
-inline std::string to_string(BinaryOp op)
-{
- switch (op)
- {
- case BinaryOp::Add:
- return "+";
- case BinaryOp::Sub:
- return "-";
- case BinaryOp::Mul:
- return "*";
- case BinaryOp::Div:
- return "/";
- case BinaryOp::Mod:
- return "%";
- case BinaryOp::Equal:
- return "==";
- case BinaryOp::Less:
- return "<";
- case BinaryOp::LessEqual:
- return "<=";
- case BinaryOp::Greater:
- return ">";
- case BinaryOp::GreaterEqual:
- return ">=";
- case BinaryOp::LogicalAnd:
- return "&&";
- case BinaryOp::LogicalOr:
- return "||";
- case BinaryOp::BitwiseXOR:
- return "^";
- default:
- assert(false);
- return "";
- }
-}
-
-inline std::string binary_op_string(BinaryOp op)
-{
- switch (op)
- {
- case BinaryOp::Add:
- return "add";
- case BinaryOp::Sub:
- return "sub";
- case BinaryOp::Mul:
- return "mul";
- case BinaryOp::Div:
- return "div";
- case BinaryOp::Mod:
- return "mod";
- case BinaryOp::Equal:
- return "eq";
- case BinaryOp::Less:
- return "gt";
- case BinaryOp::LessEqual:
- return "gteq";
- case BinaryOp::Greater:
- return "lt";
- case BinaryOp::GreaterEqual:
- return "lte";
- default:
- assert(false);
- return "";
- }
-}
-
-enum class OperandType : int32_t
-{
- Unknown = 0x00000000,
- ScalarFp32 = 0x00001011, // Immediate scalar tile
- ScalarFp16 = 0x00001012, // Immediate scalar tile
- ScalarInt32 = 0x00001021, // Immediate scalar tile
- ScalarInt16 = 0x00001022, // Immediate scalar tile
- ScalarInt8 = 0x00001024, // Immediate scalar tile
- ScalarUInt32 = 0x00001031, // Immediate scalar tile
- ScalarUInt16 = 0x00001032, // Immediate scalar tile
- ScalarUInt8 = 0x00001034, // Immediate scalar tile
- ScalarBool = 0x00001041, // Immediate scalar tile
- ScalarTile = 0x00001050, // Scalar from a tile
- Tile = 0x00010000, // Tile
- TensorStride1 = 0x00100001, // Tensor component
- TensorStride2 = 0x00100002, // Tensor component
- TensorStride3 = 0x00100003, // Tensor component
- TensorStride4 = 0x00100004, // Tensor component
- TensorDim0 = 0x00100010, // Tensor component
- TensorDim1 = 0x00100020, // Tensor component
- TensorDim2 = 0x00100030, // Tensor component
- TensorDim3 = 0x00100040, // Tensor component
- TensorDim4 = 0x00100050, // Tensor component
- TensorC = 0x00100010, // Tensor component
- TensorW = 0x00100020, // Tensor component
- TensorH = 0x00100030, // Tensor component
- TensorD = 0x00100040, // Tensor component
- TensorN = 0x00100050, // Tensor component
- TensorDim1xDim2 = 0x00100100, // Tensor component
- TensorDim1xDim2xDim3 = 0x00100200, // Tensor component
- TensorWxH = 0x00100300, // Tensor component
- TensorWxHxD = 0x00100400, // Tensor component
- TensorDataOffset = 0x00100500, // Tensor component
-};
-
-struct ScalarTileCoord
-{
- ScalarTileCoord()
- {
- }
-
- ScalarTileCoord(int32_t x0, int32_t y0) : x(x0), y(y0)
- {
- }
-
- int32_t x{-1};
- int32_t y{-1};
-};
-
-/**
- * @brief Operand class. This object is used to pass the operands to the operations performed by the writer.
- * Operand can be of three types:
- * -# Scalar immediate: constant expression
- * -# Tile: A tile
- * -# Tensor component: A component (scalar) of a tensor
- *
- */
-class Operand
-{
-public:
- Operand(const std::string &val)
- {
- _str = val;
- _type = OperandType::Tile;
- }
-
- Operand(const std::string &val, const ScalarTileCoord &coord)
- {
- _str = val;
- _type = OperandType::ScalarTile;
- _coord = coord;
- }
-
- Operand(const std::string &val, OperandType type)
- {
- _str = val;
- _type = type;
- }
-
- Operand(const Operand &t)
- {
- _str = t.value();
- _type = t.type();
- }
-
- Operand &operator=(const Operand &t)
- {
- _str = t.value();
- _type = t.type();
- _coord = t.scalar_tile_coordinate();
- return *this;
- }
-
- std::string value() const
- {
- return _str;
- }
-
- OperandType type() const
- {
- return _type;
- }
-
- ScalarTileCoord scalar_tile_coordinate() const
- {
- return _coord;
- }
-
-private:
- std::string _str{};
- OperandType _type{OperandType::Unknown};
- ScalarTileCoord _coord{};
-};
-
-using GpuSamplerTensorStorage = GpuTensorStorage;
-
-struct GpuSampler
-{
- GpuSampler() = default;
-
- TensorSamplerFormat format{TensorSamplerFormat::Unknown};
- GpuSamplerTensorStorage storage{GpuSamplerTensorStorage::Unknown};
- TensorSamplerAddressModeX address_mode_x{TensorSamplerAddressModeX::Unknown};
- TensorSamplerAddressModeY address_mode_y{TensorSamplerAddressModeY::Unknown};
- TensorSamplerAddressModeZ address_mode_z{TensorSamplerAddressModeZ::Unknown};
-};
-
-inline GpuSampler create_simple_sampler(
- const TensorInfo *tensor_info_id, GpuSampler sampler, int32_t step_x, int32_t step_y, int32_t step_z)
-{
- CKW_UNUSED(step_x, step_y, step_z);
-
- auto tensor = tensor_info_id->shape;
-
- GpuSampler dst_sampler;
- dst_sampler.format = sampler.format;
- dst_sampler.storage = GpuSamplerTensorStorage::BufferUint8Ptr;
- dst_sampler.address_mode_x = sampler.address_mode_x;
- dst_sampler.address_mode_y = sampler.address_mode_y;
- dst_sampler.address_mode_z = sampler.address_mode_z;
-
- int32_t dim_x = 0;
- int32_t dim_y = 0;
- int32_t dim_z = 0;
-
- switch (sampler.format)
- {
- case TensorSamplerFormat::C_W_H:
- dim_x = tensor[0];
- dim_y = tensor[1];
- dim_z = tensor[2];
- break;
- case TensorSamplerFormat::C_WH_1:
- dim_x = tensor[0];
- dim_y = tensor[1] * tensor[2];
- dim_z = 1;
- break;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- break;
- }
-
- if (dim_x == 1)
- {
- assert(step_x == 1);
- dst_sampler.address_mode_x = TensorSamplerAddressModeX::None;
- }
-
- if (dim_y == 1)
- {
- assert(step_y == 1);
- dst_sampler.address_mode_y = TensorSamplerAddressModeY::None;
- }
-
- if (dim_z == 1)
- {
- assert(step_z == 1);
- dst_sampler.address_mode_z = TensorSamplerAddressModeZ::None;
- }
-
- return dst_sampler;
-}
-
-class GpuOutputSampler
-{
-public:
- GpuOutputSampler() = default;
-
- /**
- * @brief Method used to initialize the GpuOutputSampler. The GpuOutputSampler can be initialized only once
- * by the root component. Once initialized, all simpler components will need to used this sampler
- * or a broadcasted version of it
- *
- * @param[in] sampler GpuSampler
- * @param[in] step_x Increment step in the X direction. Not necessarily it is the same of n0 of tile!
- * @param[in] step_y Increment step in the Y direction. Not necessarily it is the same of m0 of tile!
- * @param[in] step_z Increment step in the Z direction. Not necessarily it is the same of d0 of tile!
- */
- void initialize(const TensorInfo *tensor_info_id,
- GpuSamplerTensorStorage tensor_storage,
- TensorSamplerFormat tensor_format,
- int32_t step_x,
- int32_t step_y,
- int32_t step_z)
- {
- assert(_is_initialized == false);
-
- _step_x = step_x;
- _step_y = step_y;
- _step_z = step_z;
- _tensor_info_id = tensor_info_id;
- _sampler = create_sampler(tensor_storage, tensor_format);
- _is_initialized = true;
- };
-
- GpuSampler sampler() const
- {
- return _sampler;
- };
-
- int32_t step_x() const
- {
- return _step_x;
- };
-
- int32_t step_y() const
- {
- return _step_y;
- };
-
- int32_t step_z() const
- {
- return _step_z;
- };
-
-private:
- GpuSampler create_sampler(GpuSamplerTensorStorage tensor_storage, TensorSamplerFormat tensor_format)
- {
- // Output can only be in output mode
- assert(tensor_storage != GpuSamplerTensorStorage::Image2dReadOnly);
- assert(tensor_storage != GpuSamplerTensorStorage::Image3dReadOnly);
-
- auto tensor = _tensor_info_id->shape;
-
- GpuSampler sampler;
- sampler.format = tensor_format;
- sampler.storage = tensor_storage;
- sampler.address_mode_x = TensorSamplerAddressModeX::None;
- sampler.address_mode_y = TensorSamplerAddressModeY::None;
- sampler.address_mode_z = TensorSamplerAddressModeZ::None;
-
- // In the case of texture, we do not need any special checks at the border
- if (tensor_storage == GpuSamplerTensorStorage::BufferUint8Ptr)
- {
- int32_t dim_x = 0;
- int32_t dim_y = 0;
- int32_t dim_z = 0;
-
- switch (tensor_format)
- {
- case TensorSamplerFormat::C_W_H:
- dim_x = tensor[0];
- dim_y = tensor[1];
- dim_z = tensor[2];
- break;
- case TensorSamplerFormat::C_WH_1:
- dim_x = tensor[0];
- dim_y = tensor[1] * tensor[2];
- dim_z = 1;
- break;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- break;
- }
-
- if ((dim_x % _step_x) != 0 && dim_x != 1)
- {
- sampler.address_mode_x = TensorSamplerAddressModeX::OverlappingMin;
- }
-
- if ((dim_y % _step_y) != 0 && dim_y != 1)
- {
- sampler.address_mode_y = TensorSamplerAddressModeY::ClampToMaxEdgeOnly;
- }
-
- if ((dim_z % _step_z) != 0 && dim_z != 1)
- {
- sampler.address_mode_z = TensorSamplerAddressModeZ::ClampToMaxEdgeOnly;
- }
- }
-
- return sampler;
- }
-
- GpuSampler _sampler{}; // GpuSampler
- int32_t _step_x{1};
- int32_t _step_y{1};
- int32_t _step_z{1};
- const TensorInfo *_tensor_info_id{nullptr};
- bool _is_initialized{false};
-};
-
-/**
- * @brief Tensor operand class. This object is used to pass the operands as tensor to the operations performed by the writer.
- */
-class TensorOperand
-{
-public:
- TensorOperand(const std::string &val, GpuSampler sampler) : _str(val), _sampler(sampler)
- {
- }
-
- TensorOperand &operator=(const TensorOperand &t)
- {
- _str = t.value();
- _sampler = t.sampler();
- return *this;
- }
-
- std::string value() const
- {
- return _str;
- }
-
- GpuSampler sampler() const
- {
- return _sampler;
- }
-
-private:
- std::string _str{};
- GpuSampler _sampler{};
-};
-
-/**
- * @brief Data structure that contains all the necessary information to write the Gpu kernel with the Gpu kernel Writer
- * This data structure must be initialized before being passed to the Gpu Kernel Writer
- *
- */
-class GpuKernelWriterDataHolder
-{
-public:
- /**
- * @brief Construct a new Gpu Kernel Data object. In this phase, we should also store
- * the GPU target and target specific capabilities (extensions). For now, we just initialize the
- * programming language
- *
- * @param[in] language Gpu programming language to use
- */
- GpuKernelWriterDataHolder(GpuTargetLanguage language)
- : tiles(language), arguments(language), code(""), _language(language)
- {
- }
-
- /**
- * @brief Get the Gpu programming language used
- *
- * @return GpuTargetLanguage the Gpu programming language
- */
- GpuTargetLanguage programming_language() const
- {
- return _language;
- }
-
- /**
- * @brief @ref GpuTileRegistry
- *
- */
- GpuTileRegistry tiles{};
- /**
- * @brief @ref GpuTensorArgumentRegistry
- *
- */
- GpuTensorArgumentRegistry arguments{};
- /**
- * @brief @ref GpuOutputSampler.
- *
- */
- GpuOutputSampler output_sampler{};
- /**
- * @brief Source code
- *
- */
- std::string code{};
-
- // GpuExtensionRegistry extensions{};
-private:
- GpuTargetLanguage _language;
-};
-
-struct LWS
-{
- int32_t x{1};
- int32_t y{1};
- int32_t z{1};
-};
-
-/**
- * @brief Utility class used to get the tile from the operand. If the operand is not a tile, @ref OperandUnpacker
- * declare an anonymous tile in the tile registry.
- */
-class OperandUnpacker
-{
-public:
- OperandUnpacker(GpuTileRegistry &tiles, GpuTensorArgumentRegistry &arguments) : _tiles(tiles), _arguments(arguments)
- {
- // Increase the level of the stack to allocate possible temporary tiles
- _tiles.increment_registry_level();
- };
-
- ~OperandUnpacker()
- {
- // Decrease the level of the stack to deallocate any temporary tiles
- _tiles.decrement_registry_level();
- }
-
- IVectorTile *unpack(const Operand &src)
- {
- // Get the tile
- if (src.type() == OperandType::Tile)
- {
- assert(_tiles.has_tile(src.value()));
- return _tiles[src.value()];
- }
- // Create an anonymous tile with a constant
- else if (static_cast<int32_t>(src.type()) & 0x00001000)
- {
- if (src.type() == OperandType::ScalarTile)
- {
- ScalarTileCoord coord = src.scalar_tile_coordinate();
- assert(_tiles.has_tile(src.value()));
- assert(coord.x >= 0);
- assert(coord.y >= 0);
- auto val = _tiles[src.value()]->scalar(coord.x, coord.y);
- return _tiles.insert({{{val.str}}}, val.type.dt);
- }
- else
- {
- return _tiles.insert({{{src.value()}}}, to_tile_data_type(src.type()));
- }
- }
- // Create an anonymous tile with the tensor component
- else
- {
- assert(_arguments.has_tensor_argument(src.value()));
- auto x = _arguments[src.value()];
- const std::string val = x->component(to_tensor_component(src.type()));
- const DataType dt = x->component_data_type();
- return _tiles.insert({{{val}}}, dt);
- }
- }
-
-private:
- DataType to_tile_data_type(OperandType x)
- {
- return static_cast<DataType>(static_cast<int32_t>(x) & 0x00ff);
- }
-
- TensorComponentType to_tensor_component(OperandType x)
- {
- switch (x)
- {
- case OperandType::TensorDim0:
- return TensorComponentType::Dim0;
- case OperandType::TensorDim1:
- return TensorComponentType::Dim1;
- case OperandType::TensorDim2:
- return TensorComponentType::Dim2;
- case OperandType::TensorDim3:
- return TensorComponentType::Dim3;
- case OperandType::TensorDim4:
- return TensorComponentType::Dim4;
- case OperandType::TensorStride1:
- return TensorComponentType::Stride1;
- case OperandType::TensorStride2:
- return TensorComponentType::Stride2;
- case OperandType::TensorStride3:
- return TensorComponentType::Stride3;
- case OperandType::TensorStride4:
- return TensorComponentType::Stride4;
- case OperandType::TensorDim1xDim2:
- return TensorComponentType::Dim1xDim2;
- case OperandType::TensorDim1xDim2xDim3:
- return TensorComponentType::Dim1xDim2xDim3;
- case OperandType::TensorDataOffset:
- return TensorComponentType::OffsetFirstElement;
- default:
- assert(false);
- return TensorComponentType::Unknown;
- }
- }
-
- GpuTileRegistry &_tiles;
- GpuTensorArgumentRegistry &_arguments;
-};
-
-/**
- * @brief Utility class used to get the tensor argument from the operand. If the operand is not a tile, @ref OperandUnpacker
- * declare an anonymous tile in the tile registry.
- * Tensor dimension reduction aims for reducing the tensor data dimension while keeping data's tensor structure.
- */
-class TensorOperandUnpacker
-{
-public:
- TensorOperandUnpacker(GpuTensorArgumentRegistry &arguments) : _arguments(arguments){};
-
- IGpuTensorArgument *unpack(const TensorOperand &src)
- {
- assert(_arguments.has_tensor_argument(src.value()));
- return _arguments[src.value()];
- }
-
-private:
- GpuTensorArgumentRegistry &_arguments;
-};
-
-/**
- * @brief The GpuKernel will be used in three occasions (stages):
- * #- Compilation stage
- * #- Tuning stage
- * #- Dispatch stage
- */
-struct GpuKernel
-{
- // Compilation stage
- std::string code{}; // Source code, required for the compilation stage
- std::vector<GpuExtensions> list_extensions{}; // Extensions, required for the compilation stage
- // Tuning stage
- std::string config_id{}; // Unique id, required for the tuning stage
- std::vector<LWS> list_lws{}; // LWS to test, required for the tuning stage
- // Dispatch stage
- GpuOutputSampler output_sampler{}; // GpuOutputSampler, required for the dispatch stage
- std::vector<std::pair<int32_t, GpuTensorStorage>>
- list_tensor_storages; // List of tensor storages, required for the dispatch stage
- std::vector<std::pair<int32_t, TensorComponentType>>
- list_tensor_components; // List of tensor components (width, stride,..), required for the dispatch stage)
-};
-
-// Generate all extension pragmas (hardcoded for now)
-inline std::string generate_extensions()
-{
- std::string ext = R"(
-#if defined(cl_khr_fp16)
-#pragma OPENCL EXTENSION cl_khr_fp16 : enable
-#endif // defined(cl_khr_fp16)
-
-#if defined(cl_arm_integer_dot_product_int8)
-#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
-#endif // defined(cl_arm_integer_dot_product_int8)
-
-#if defined(cl_arm_integer_dot_product_accumulate_int8)
-#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
-#endif // defined(cl_arm_integer_dot_product_accumulate_int8)
-
-#if defined(cl_arm_printf)
-#pragma OPENCL EXTENSION cl_arm_printf : enable
-#endif // defined(cl_arm_printf);
-)";
- return ext;
-}
-
-// This function should produce an object with the source
-inline std::string generate_code(GpuKernelWriterDataHolder &in, const std::string &name)
-{
- std::string code;
- code += generate_extensions();
- code += "__kernel void ";
- code += name;
- code += "(\n";
-
- auto IdSpaces = in.arguments.IdSpace_declarations();
-
- std::vector<std::string> arg_str;
-
- auto tensor_args = in.arguments.tensor_argument_declarations();
-
- for (auto &i : tensor_args)
- {
- // For each tensor used, get the storage and tensor components
- auto storages = i->storage_declarations();
- auto components = i->component_declarations();
-
- for (auto &y : storages)
- {
- std::string str;
- str += i->storage_type_declaration(y);
- str += " ";
- str += i->storage(y);
- arg_str.push_back(str);
- }
-
- for (auto &y : components)
- {
- std::string str;
- str += i->component_type_declaration();
- str += " ";
- str += i->component(y);
- arg_str.push_back(str);
- }
- }
-
- for (size_t i = 0; i < arg_str.size(); ++i)
- {
- code += arg_str[i];
- if (i + 1 < arg_str.size())
- {
- code += ",\n";
- }
- }
-
- code += ")\n";
- code += "{\n";
- code += in.code;
- code += "}\n";
-
- return code;
-}
-
-/**
- * @brief This class is responsible to map a N-Tensor to a 3d tensor. The mapper needs the GpuSampler to know
- * how to reduce the dimensionality of a tensor
- *
- */
-class GpuTensor3dMapper
-{
-public:
- GpuTensor3dMapper(IGpuTensorArgument *tensor, GpuSampler sampler) : _sampler(sampler), _tensor(tensor){};
-
- std::string tensor_component_x() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Dim0);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string tensor_component_y() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- return _tensor->component(TensorComponentType::Dim1xDim2);
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Dim1);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string tensor_component_z() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- return "1";
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Dim2);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string tensor_component_stride_y() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Stride1);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string tensor_component_stride_z() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- return "0";
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Stride2);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string tensor_component_stride_batch() const
- {
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponentType::Stride3);
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- bool is_one_component_x() const
- {
- auto t = _tensor->format();
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- case TensorSamplerFormat::C_W_H:
- return t.shape[0] == 1;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- bool is_one_component_y() const
- {
- auto t = _tensor->format();
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- return (t.shape[1] * t.shape[2]) == 1;
- case TensorSamplerFormat::C_W_H:
- return t.shape[1] == 1;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- bool is_one_component_z() const
- {
- auto t = _tensor->format();
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- return true;
- case TensorSamplerFormat::C_W_H:
- return t.shape[2] == 1;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- bool is_one_component_batch() const
- {
- auto t = _tensor->format();
- const auto format = _sampler.format;
- switch (format)
- {
- case TensorSamplerFormat::C_WH_1:
- case TensorSamplerFormat::C_W_H:
- return t.shape[3] == 1;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- return "";
- }
- }
-
- GpuSampler gpu_sampler() const
- {
- return _sampler;
- }
-
- IGpuTensorArgument *tensor_argument() const
- {
- return _tensor;
- }
-
-private:
- GpuSampler _sampler;
- IGpuTensorArgument *_tensor;
-};
-
-struct GpuKernelWriterAttribute
-{
- bool return_tensor_component_by_value{false};
-};
-
-enum class RoundingMode
-{
- None,
- Rte,
- Rtz,
- Rtp,
- Rtn
-};
-
-// https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl05.html
-class IGpuKernelWriter
-{
-public:
- virtual ~IGpuKernelWriter() = default;
-
- virtual void set_IdSpace(int32_t id) = 0;
-
- virtual void import_tile(const std::string &dst, const IVectorTile *src) = 0;
-
- virtual void declare_argument(const std::string &name, const TensorInfo &tensor) = 0;
-
- virtual void declare_tile(const std::string &name, const TileInfo &info) = 0;
-
- virtual void
- declare_const_tile(const std::string &name, const std::vector<std::vector<std::string>> &in, DataType dt) = 0;
-
- virtual void write_text(const std::string &x) = 0;
-
- virtual void compound_statement_begin() = 0;
-
- virtual void compound_statement_end() = 0;
-
- // Operations
- virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
-
- virtual void
- op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
-
- virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
-
- virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
-
- virtual void op_unary_expression(const Operand &dst, UnaryOp op, const Operand &src) = 0;
-
- virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
-
- virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
-
- virtual void
- op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) = 0;
-
- virtual void op_binary_elementwise_function(const Operand &dst_name,
- BinaryFunction func,
- const Operand &first_name,
- const Operand &second_name) = 0;
-
- virtual void op_ternary_elementwise_function(const Operand &dst_name,
- TernaryFunction func,
- const Operand &first_name,
- const Operand &second_name,
- const Operand &third_name) = 0;
-
- virtual void op_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
-
- virtual void op_else_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
-
- virtual void op_else_header() = 0;
-
- virtual void op_for_loop_header(const Operand &var_name,
- BinaryOp cond_op,
- const Operand &cond_value,
- const Operand &update_var,
- AssignmentOp update_op,
- const Operand &update_value) = 0;
-
- virtual void op_load_indirect(const TensorOperand &tensor,
- const Operand &dst,
- const Operand &x,
- const Operand &y_indirect,
- const Operand &z,
- const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
-
- virtual void op_load_immediate(const TensorOperand &tensor,
- const Operand &dst,
- const Operand &x,
- const Operand &y,
- const Operand &z,
- const Operand &b = Operand("0", OperandType::ScalarInt32),
- const Operand &dilation_y = Operand("1", OperandType::ScalarInt32)) = 0;
-
- virtual void op_store_immediate(const TensorOperand &tensor,
- const Operand &src,
- const Operand &x,
- const Operand &y,
- const Operand &z,
- const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
-
- virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
-
- virtual void op_return() = 0;
-
- // Utils
- // It is the process of converting
- virtual void util_get_indirect_buffer(const Operand &dst,
- const TensorOperand &tensor,
- const Operand &x,
- const Operand &y,
- const Operand &x_off,
- const Operand &y_off) = 0;
-};
-
-enum class GpuLoadStoreType
-{
- Load = 1,
- Store = 2
-};
-
-class IGpuLoadStoreHelperWriter
-{
-public:
- IGpuLoadStoreHelperWriter(IGpuKernelWriter *x, GpuTensor3dMapper mapper, GpuLoadStoreType type)
- : _writer(x), _mapper(mapper), _type(type)
- {
- }
-
- IGpuLoadStoreHelperWriter(const IGpuLoadStoreHelperWriter &) = default;
-
- IGpuLoadStoreHelperWriter &operator=(const IGpuLoadStoreHelperWriter &) = default;
-
- virtual ~IGpuLoadStoreHelperWriter() = default;
-
- virtual void initialize(IVectorTile *dst, IVectorTile *x, IVectorTile *z, IVectorTile *b) = 0;
-
- virtual void write(const std::pair<int32_t, std::string> &y) = 0;
-
- virtual void finalize() = 0;
-
-protected:
- IGpuKernelWriter *_writer;
- GpuTensor3dMapper _mapper;
- GpuLoadStoreType _type;
-};
-
-class ClLoadStoreBufferHelperWriter : public IGpuLoadStoreHelperWriter
-{
-public:
- ClLoadStoreBufferHelperWriter(IGpuKernelWriter *x, const GpuTensor3dMapper &mapper, GpuLoadStoreType type)
- : IGpuLoadStoreHelperWriter(x, mapper, type)
- {
- }
-
- ClLoadStoreBufferHelperWriter(const ClLoadStoreBufferHelperWriter &) = default;
-
- ClLoadStoreBufferHelperWriter &operator=(const ClLoadStoreBufferHelperWriter &) = default;
-
- static bool validate(IGpuKernelWriter *x, GpuTensor3dMapper mapper, GpuLoadStoreType type, IVectorTile *dst)
- {
- CKW_UNUSED(x, type, dst);
-
- if (mapper.gpu_sampler().storage != GpuSamplerTensorStorage::BufferUint8Ptr)
- {
- return false;
- }
- return true;
- }
-
- void initialize(IVectorTile *dst, IVectorTile *x, IVectorTile *z, IVectorTile *b) override
- {
- assert(validate(_writer, _mapper, _type, dst));
-
- _dst = dst;
- _ls_width_full = dst->format().w;
-
- _coord_x = x->scalar(0, 0).str;
- _coord_z = z->scalar(0, 0).str;
- _coord_b = b->scalar(0, 0).str;
- _coord_orig_z = _coord_z;
-
- out_of_bound_initialize_x(_coord_x);
- out_of_bound_initialize_z(_coord_z);
-
- /*
- meaning of else:
- - x: partial load/store
- - y: no load/store operation
- - z: no load/store operation
- if(x)
- {
- if(z)
- {
- if(y)
- {
- // full load/store width
- }
- else
- {
- // no load/store
- }
- }
- else
- {
- // no load/store
- }
- }
- else
- {
- if(z)
- {
- if(y)
- {
- // partial load/store width
- }
- else
- {
- // no load/store
- }
- }
- else
- {
- // no load/store
- }
- }
- */
- }
-
- void write(const std::pair<int32_t, std::string> &y) override
- {
- int32_t idx_y = y.first;
- std::string coord_y = y.second;
-
- // The only check required is on Y.
- out_of_bound_initialize_y(coord_y);
-
- const std::string dst = _dst->vector(idx_y).str;
- const std::string address = to_ls_buffer_address(_coord_x, coord_y, _coord_z, _coord_b);
- const std::string ls_buf = to_ls_buffer(_type, _ls_width_full, dst, address);
-
- _writer->write_text(ls_buf);
- _writer->write_text(";\n");
-
- out_of_bound_finalize_y(dst);
-
- // The left over load/store will be written in the finalize stage
- if (_ls_width_part.size() != 0)
- {
- int32_t w = 0;
- for (auto &p : _ls_width_part)
- {
- const std::string dst0 = _dst->vector(w, p, idx_y).str;
- const std::string coord_x = _coord_x + " + " + std::to_string(w);
- const std::string address = to_ls_buffer_address(coord_x, coord_y, _coord_z, _coord_b);
- const std::string ls_buf0 = to_ls_buffer(_type, p, dst0, address);
- _leftovers_x.push_back(std::make_pair(std::make_pair(dst0, coord_y), ls_buf0));
-
- w += p;
- }
- }
- }
-
- void finalize() override
- {
- out_of_bound_finalize_z();
- out_of_bound_finalize_x();
- }
-
-private:
- IVectorTile *_dst{nullptr};
- int32_t _ls_width_full{0};
- std::vector<int32_t> _ls_width_part{};
- std::vector<std::pair<std::pair<std::string, std::string>, std::string>> _leftovers_x{};
- std::string _coord_x{};
- std::string _coord_z{};
- std::string _coord_orig_z{};
- std::string _coord_b{};
-
- void out_of_bound_initialize_x(std::string &coord)
- {
- if (_mapper.gpu_sampler().address_mode_x == TensorSamplerAddressModeX::OverlappingMin)
- {
- auto tensor_format = _mapper.tensor_argument()->format();
- auto shape = tensor_format.shape;
-
- _ls_width_part = decompose_leftover_ls_vector_width(shape[0] % _ls_width_full);
- if (_ls_width_part.size() != 0)
- {
- _writer->write_text("if(" + coord + " > 0)\n");
- _writer->compound_statement_begin();
- }
- }
- };
-
- void out_of_bound_finalize_x()
- {
- if (_mapper.gpu_sampler().address_mode_x == TensorSamplerAddressModeX::OverlappingMin)
- {
- if (_ls_width_part.size() != 0)
- {
- _writer->compound_statement_end();
- _writer->write_text("else\n");
- _writer->compound_statement_begin();
-
- out_of_bound_initialize_z(_coord_orig_z);
- for (auto &i : _leftovers_x)
- {
- out_of_bound_initialize_y(i.first.second);
- _writer->write_text(i.second);
- _writer->write_text(";\n");
- out_of_bound_finalize_y(i.first.first);
- }
- out_of_bound_finalize_z();
- _writer->compound_statement_end();
- }
- }
- };
-
- void out_of_bound_initialize_y(std::string &coord)
- {
- std::string max = "";
-
- const auto address_mode_y = _mapper.gpu_sampler().address_mode_y;
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::Skip:
- case TensorSamplerAddressModeY::ClampToBorder:
- // NOTE: This line should not be moved outside of the switch statement.
- // The reason for that is because when we query the component, the component is marked as used
- // and added to the list of arguments of the kernel. Since, not in all cases this component is required,
- // we should request the component only when used
- max = _mapper.tensor_component_y();
- _writer->write_text("if((" + coord + " >= 0) && (" + coord + " < " + max + "))\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::SkipMinEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMinEdgeOnly:
- _writer->write_text("if(" + coord + " >= 0)\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::SkipMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMaxEdgeOnly:
- max = _mapper.tensor_component_y();
- _writer->write_text("if(" + coord + " < " + max + ")\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::ClampToNearest:
- max = _mapper.tensor_component_y();
- coord = "clamp(" + coord + ", 0, " + max + " - 1)";
- break;
- case TensorSamplerAddressModeY::ClampToMaxEdgeOnly:
- max = _mapper.tensor_component_y();
- coord = "min(" + coord + ", " + max + " - 1)";
- break;
- case TensorSamplerAddressModeY::ClampToMinEdgeOnly:
- coord = "max(" + coord + ", 0)";
- break;
- case TensorSamplerAddressModeY::None:
- break;
- default:
- std::cout << "Unsupported address mode for write_out_of_bound_check_yz" << std::endl;
- assert(false);
- }
- };
-
- void out_of_bound_finalize_y(const std::string &dst)
- {
- const auto address_mode_y = _mapper.gpu_sampler().address_mode_y;
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::ClampToBorder:
- case TensorSamplerAddressModeY::ClampToBorderMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMinEdgeOnly:
- case TensorSamplerAddressModeY::Skip:
- case TensorSamplerAddressModeY::SkipMaxEdgeOnly:
- case TensorSamplerAddressModeY::SkipMinEdgeOnly:
- _writer->compound_statement_end();
- break;
- case TensorSamplerAddressModeY::None:
- break;
-
- default:
- assert(false);
- }
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::ClampToBorder:
- case TensorSamplerAddressModeY::ClampToBorderMinEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMaxEdgeOnly:
- _writer->write_text("else\n");
- _writer->compound_statement_begin();
- _writer->write_text(dst);
- _writer->write_text(" = 0.0f;\n");
- _writer->compound_statement_end();
- break;
- case TensorSamplerAddressModeY::None:
- break;
-
- default:
- assert(false);
- }
- };
-
- void out_of_bound_initialize_z(std::string &coord)
- {
- std::string max = "";
-
- const auto address_mode_z = _mapper.gpu_sampler().address_mode_z;
-
- switch (address_mode_z)
- {
- case TensorSamplerAddressModeZ::Skip:
- max = _mapper.tensor_component_z();
- _writer->write_text("if((" + coord + " >= 0) && (" + coord + " < " + max + "))\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeZ::SkipMinEdgeOnly:
- _writer->write_text("if(" + coord + " >= 0)\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeZ::SkipMaxEdgeOnly:
- max = _mapper.tensor_component_z();
- _writer->write_text("if(" + coord + " < " + max + ")\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeZ::ClampToNearest:
- max = _mapper.tensor_component_z();
- coord = "clamp(" + coord + ", 0, " + max + " - 1)";
- break;
- case TensorSamplerAddressModeZ::ClampToMaxEdgeOnly:
- max = _mapper.tensor_component_z();
- coord = "min(" + coord + ", " + max + " - 1)";
- break;
- case TensorSamplerAddressModeZ::ClampToMinEdgeOnly:
- coord = "max(" + coord + ", 0)";
- break;
- case TensorSamplerAddressModeZ::None:
- break;
- default:
- std::cout << "Unsupported address mode for write_out_of_bound_check_yz" << std::endl;
- assert(false);
- }
- };
-
- void out_of_bound_finalize_z()
- {
- const auto address_mode_z = _mapper.gpu_sampler().address_mode_z;
-
- switch (address_mode_z)
- {
- case TensorSamplerAddressModeZ::Skip:
- case TensorSamplerAddressModeZ::SkipMinEdgeOnly:
- case TensorSamplerAddressModeZ::SkipMaxEdgeOnly:
- _writer->compound_statement_end();
- break;
- case TensorSamplerAddressModeZ::None:
- break;
-
- default:
- assert(false);
- }
- };
-
- std::vector<int32_t> decompose_leftover_ls_vector_width(int32_t ls_leftover_vector_width) const
- {
- std::vector<int32_t> x;
-
- switch (ls_leftover_vector_width)
- {
- case 0:
- break;
- case 1:
- case 2:
- case 3:
- case 4:
- case 8:
- case 16:
- x.push_back(ls_leftover_vector_width);
- break;
- case 5:
- x.push_back(4);
- x.push_back(1);
- break;
- case 6:
- x.push_back(4);
- x.push_back(2);
- break;
- case 7:
- x.push_back(4);
- x.push_back(3);
- break;
- case 9:
- x.push_back(8);
- x.push_back(1);
- break;
- case 10:
- x.push_back(8);
- x.push_back(2);
- break;
- case 11:
- x.push_back(8);
- x.push_back(3);
- break;
- case 12:
- x.push_back(8);
- x.push_back(4);
- break;
- case 13:
- x.push_back(8);
- x.push_back(4);
- x.push_back(1);
- break;
- case 14:
- x.push_back(8);
- x.push_back(4);
- x.push_back(2);
- break;
- case 15:
- x.push_back(8);
- x.push_back(4);
- x.push_back(3);
- break;
-
- default:
- assert(false);
- }
- return x;
- }
-
- std::string
- to_ls_buffer(GpuLoadStoreType type, int32_t vector_width, const std::string &data, const std::string &address)
- {
- switch (type)
- {
- case GpuLoadStoreType::Load:
- if (vector_width != 1)
- {
- return data + " = vload" + std::to_string(vector_width) + "(0, " + address + ")";
- }
- else
- {
- return data + " = *(" + address + ")";
- }
- break;
- case GpuLoadStoreType::Store:
- if (vector_width != 1)
- {
- return "vstore" + std::to_string(vector_width) + "(" + data + ", 0, " + address + ")";
- }
- else
- {
- return "*(" + address + ") = " + data;
- }
- break;
- default:
- std::cout << "Unsupported GpuLoadStoreType" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string
- to_ls_buffer_address(const std::string &x, const std::string &y, const std::string &z, const std::string &b) const
- {
- auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
- assert(tensor_storage == GpuTensorStorage::BufferUint8Ptr);
- const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage);
- const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
-
- std::string address;
- address += "(__global ";
- address += dst_type;
- address += "*)(";
- address += ptr_buf;
- if (x != "0" && (_mapper.is_one_component_x() != true))
- {
- address += " + (";
- address += x + ") * sizeof(" + dst_type + ")";
- }
- if (y != "0")
- {
- const std::string stride_y = _mapper.tensor_component_stride_y();
- address += " + (";
- address += y + ")";
- address += " * ";
- address += stride_y;
- }
- if (z != "0")
- {
- const std::string stride_z = _mapper.tensor_component_stride_z();
- address += " + (";
- address += z + ")";
- address += " * ";
- address += stride_z;
- }
- if (b != "0" && (_mapper.is_one_component_batch() != true))
- {
- const std::string stride_b = _mapper.tensor_component_stride_batch();
- address += " + (";
- address += b + ")";
- address += " * ";
- address += stride_b;
- }
- address += ")";
- return address;
- }
-};
-
-class ClLoadStoreImage2dHelperWriter : public IGpuLoadStoreHelperWriter
-{
-public:
- static bool validate(IGpuKernelWriter *x, const GpuTensor3dMapper &mapper, GpuLoadStoreType type, IVectorTile *dst)
- {
- CKW_UNUSED(x);
-
- if (dst->format().w != 4)
- {
- return false;
- }
- if (mapper.gpu_sampler().address_mode_x != TensorSamplerAddressModeX::None)
- {
- return false;
- }
- if (mapper.gpu_sampler().address_mode_z != TensorSamplerAddressModeZ::None)
- {
- return false;
- }
- if (mapper.gpu_sampler().storage != GpuSamplerTensorStorage::Image2dReadOnly && type == GpuLoadStoreType::Load)
- {
- return false;
- }
- if (mapper.gpu_sampler().storage != GpuSamplerTensorStorage::Image2dWriteOnly &&
- type == GpuLoadStoreType::Store)
- {
- return false;
- }
- if ((dst->format().dt != DataType::Fp32) && (dst->format().dt != DataType::Fp16))
- {
- return false;
- }
- return true;
- /*
- - x: Only GpuSamplerAddressModeX::None is supported and vector length = 4
- - z: Only GpuSamplerAddressModeZ::None is supported
- */
- }
-
- ClLoadStoreImage2dHelperWriter(IGpuKernelWriter *x, const GpuTensor3dMapper &mapper, GpuLoadStoreType type)
- : IGpuLoadStoreHelperWriter(x, mapper, type)
- {
- }
-
- ClLoadStoreImage2dHelperWriter(const ClLoadStoreImage2dHelperWriter &) = default;
-
- ClLoadStoreImage2dHelperWriter &operator=(const ClLoadStoreImage2dHelperWriter &) = default;
-
- void initialize(IVectorTile *dst, IVectorTile *x, IVectorTile *z, IVectorTile *b) override
- {
- assert(validate(_writer, _mapper, _type, dst));
-
- _dst = dst;
- _ls_width_full = dst->format().w;
- _coord_x = x->scalar(0, 0).str;
- _coord_z = z->scalar(0, 0).str;
- _coord_b = b->scalar(0, 0).str;
-
- /*
- if(y)
- {
- // full load/store width
- }
- else
- {
- // no load/store
- }
- */
- }
-
- void write(const std::pair<int32_t, std::string> &y) override
- {
- int32_t idx_y = y.first;
- std::string coord_y = y.second;
-
- // The only check required is on Y.
- out_of_bound_initialize_y(coord_y);
-
- const std::string dst = _dst->vector(idx_y).str;
- const std::string sampler = to_ls_image2d_sampler();
- const std::string coord = to_ls_image2d_coord(_coord_x, coord_y, _coord_z, _coord_b);
- const std::string ls_buf = to_ls_image2d(_type, _ls_width_full, dst, sampler, coord);
-
- _writer->write_text(ls_buf);
- _writer->write_text(";\n");
-
- out_of_bound_finalize_y(dst);
- }
-
- void finalize() override
- {
- }
-
-private:
- IVectorTile *_dst{nullptr};
- int32_t _ls_width_full{0};
- std::string _coord_x{};
- std::string _coord_z{};
- std::string _coord_b{};
-
- void out_of_bound_initialize_y(std::string &coord)
- {
- std::string max = "";
-
- const auto address_mode_y = _mapper.gpu_sampler().address_mode_y;
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::Skip:
- max = _mapper.tensor_component_y();
- _writer->write_text("if((" + coord + " >= 0) && (" + coord + " < " + max + "))\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::SkipMinEdgeOnly:
- _writer->write_text("if(" + coord + " >= 0)\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::SkipMaxEdgeOnly:
- max = _mapper.tensor_component_y();
- _writer->write_text("if(" + coord + " < " + max + ")\n");
- _writer->compound_statement_begin();
- break;
- case TensorSamplerAddressModeY::ClampToBorder:
- case TensorSamplerAddressModeY::ClampToBorderMinEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToNearest:
- case TensorSamplerAddressModeY::ClampToMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToMinEdgeOnly:
- case TensorSamplerAddressModeY::None:
- break;
- default:
- std::cout << "Unsupported address mode for write_out_of_bound_check_y" << std::endl;
- assert(false);
- }
- };
-
- void out_of_bound_finalize_y(const std::string &dst)
- {
- CKW_UNUSED(dst);
-
- const auto address_mode_y = _mapper.gpu_sampler().address_mode_y;
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::Skip:
- case TensorSamplerAddressModeY::SkipMinEdgeOnly:
- case TensorSamplerAddressModeY::SkipMaxEdgeOnly:
- _writer->compound_statement_end();
- break;
-
- default:
- assert(false);
- }
- };
-
- std::string to_ls_image2d(GpuLoadStoreType type,
- int32_t vector_width,
- const std::string &data,
- const std::string &sampler,
- const std::string &coord)
- {
- CKW_UNUSED(vector_width);
-
- auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
- const std::string image2d_obj = _mapper.tensor_argument()->storage(tensor_storage);
- const std::string post_fix = _dst->format().dt == DataType::Fp32 ? "f" : "h";
-
- switch (type)
- {
- case GpuLoadStoreType::Load:
- return data + " = read_image" + post_fix + "(" + image2d_obj + ", " + sampler + ", " + coord + ")";
- break;
- case GpuLoadStoreType::Store:
- return "write_image" + post_fix + "(" + image2d_obj + ", " + coord + ", " + data + ")";
- default:
- assert(false);
- std::cout << "Unsupported GpuLoadStoreType" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string to_ls_image2d_sampler() const
- {
- const auto address_mode_y = _mapper.gpu_sampler().address_mode_y;
-
- switch (address_mode_y)
- {
- case TensorSamplerAddressModeY::None:
- return "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST";
- case TensorSamplerAddressModeY::Skip:
- case TensorSamplerAddressModeY::SkipMinEdgeOnly:
- case TensorSamplerAddressModeY::SkipMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorder:
- case TensorSamplerAddressModeY::ClampToBorderMinEdgeOnly:
- case TensorSamplerAddressModeY::ClampToBorderMaxEdgeOnly:
- return "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST";
- case TensorSamplerAddressModeY::ClampToNearest:
- case TensorSamplerAddressModeY::ClampToMaxEdgeOnly:
- case TensorSamplerAddressModeY::ClampToMinEdgeOnly:
- return "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST";
- default:
- std::cout << "Unsupported address_mode_coord" << std::endl;
- assert(false);
- return "";
- }
- }
-
- std::string
- to_ls_image2d_coord(const std::string &x, const std::string &y, const std::string &z, const std::string &b) const
- {
- std::string coord_x = "(" + x + ") >> 2";
- std::string coord_y = "(";
-
- if (y != "0")
- {
- coord_y += y;
- }
- if (z != "0" && (_mapper.is_one_component_z() != true))
- {
- const std::string dim = _mapper.tensor_component_y();
- coord_y += " + (";
- coord_y += z + ")";
- coord_y += " * ";
- coord_y += dim;
- }
- if (b != "0" && (_mapper.is_one_component_batch() != true))
- {
- const std::string dim0 = _mapper.tensor_component_y();
- const std::string dim1 = _mapper.tensor_component_z();
- coord_y += " + (";
- coord_y += b + ")";
- coord_y += " * ";
- coord_y += dim0;
- coord_y += " * ";
- coord_y += dim1;
- }
- coord_y += ")";
- return "(int2)(" + coord_x + ", " + coord_y + ")";
- }
-};
-
-/** IGpuLoadStoreHelperWriter factory class */
-class ClLoadStoreHelperWriterFactory final
-{
-public:
- /** Static method to call the IGpuLoadStoreHelperWriter class accordingly with the tensor storage set in the mapper
- *
- *
- * @return IGpuLoadStoreHelperWriter
- */
- static std::unique_ptr<IGpuLoadStoreHelperWriter>
- create(IGpuKernelWriter *x, const GpuTensor3dMapper &mapper, GpuLoadStoreType type)
- {
- const auto tensor_storage = mapper.gpu_sampler().storage;
- switch (tensor_storage)
- {
- case GpuSamplerTensorStorage::BufferUint8Ptr:
- return std::make_unique<ClLoadStoreBufferHelperWriter>(x, mapper, type);
- case GpuSamplerTensorStorage::Image2dReadOnly:
- case GpuSamplerTensorStorage::Image2dWriteOnly:
- return std::make_unique<ClLoadStoreImage2dHelperWriter>(x, mapper, type);
- default:
- std::cout << "Unsupported Gpu tensor storage" << std::endl;
- assert(false);
- return nullptr;
- }
- }
-};
-
-// This utility method needs to go in utils.h
-inline bool is_tile_scalar(const IVectorTile *x)
-{
- return x->format().w == 1 && x->format().h == 1;
-}
-
-class ClKernelWriter : public IGpuKernelWriter
-{
-public:
- ClKernelWriter(GpuKernelWriterAttribute *attr, GpuKernelWriterDataHolder *x)
- {
- _data = x;
- _attr = attr;
- }
-
- ClKernelWriter(const ClKernelWriter &) = default;
-
- ClKernelWriter &operator=(const ClKernelWriter &) = default;
-
- // A IdSpaced ID is a term used to describe a fragment that is registered in ICode to ensure
- // there are no conflicts or ambiguity in the code
- void set_IdSpace(int32_t id) override
- {
- _data->tiles.set_IdSpace(id);
- _data->arguments.set_IdSpace(id);
- }
-
- void import_tile(const std::string &dst_name, const IVectorTile *src) override
- {
- _data->tiles.insert(dst_name, src);
- }
-
- void declare_argument(const std::string &name, const TensorInfo &tensor) override
- {
- assert(_data->arguments[name] == nullptr);
- _data->arguments.insert(name, tensor, _attr->return_tensor_component_by_value);
- }
-
- void declare_tile(const std::string &name, const TileInfo &format) override
- {
- assert(_data->tiles[name] == nullptr);
- _data->tiles.insert(name, format);
-
- IVectorTile *x = _data->tiles[name];
-
- for (auto &t : x->underlying_source_variables())
- {
- _data->code += t.type.str + " " + t.str + ";\n";
- }
- }
-
- void
- declare_const_tile(const std::string &name, const std::vector<std::vector<std::string>> &in, DataType dt) override
- {
- assert(_data->tiles[name] == nullptr);
- _data->tiles.insert(name, in, dt);
- // Note: A constant does not need to be declared in the code
- }
-
- void write_text(const std::string &x) override
- {
- _data->code += x;
- }
-
- void compound_statement_begin() override
- {
- _data->tiles.increment_registry_level();
- _data->code += "{\n";
- }
-
- void compound_statement_end() override
- {
- _data->tiles.decrement_registry_level();
- _data->code += "}\n";
- }
-
- void op_get_global_id(const Operand &dst_var, int32_t dim) override
- {
- assert(dst_var.type() == OperandType::Tile);
- assert(_data->tiles.has_tile(dst_var.value()));
- assert(_data->tiles[dst_var.value()]->format().w == 1 &&
- _data->tiles[dst_var.value()]->format().h == 1); // It must be a scalar variable
-
- auto var = _data->tiles[dst_var.value()];
-
- _data->code += var->scalar(0, 0).str;
- _data->code += " = get_global_id(";
- _data->code += std::to_string(dim);
- _data->code += ");\n";
- };
-
- void op_get_global_coord(const Operand &o_dst,
- const Operand &o_step,
- const TensorOperand &o_tensor,
- int32_t dim) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto step = operands.unpack(o_step);
-
- // Validation: Check that x, y and z are scalar
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
- auto gpu_sampler = o_tensor.sampler();
-
- GpuTensor3dMapper mapper(tensor, gpu_sampler);
-
- switch (dim)
- {
- case 0:
- if (mapper.is_one_component_x())
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = 0;\n";
- }
- else
- {
- if (mapper.gpu_sampler().address_mode_x == TensorSamplerAddressModeX::OverlappingMin)
- {
- // Validation: Check: fixed tensor shape
- // TO BE CHANGED
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = get_global_id(0) * ";
- _data->code += step->scalar(0, 0).str;
- _data->code += ";\n";
- }
- else
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = get_global_id(0) * ";
- _data->code += step->scalar(0, 0).str;
- _data->code += ";\n";
- }
- }
- break;
- case 1:
- if (mapper.is_one_component_y())
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = 0;\n";
- }
- else
- {
- if (mapper.gpu_sampler().address_mode_y == TensorSamplerAddressModeY::OverlappingMin)
- {
- }
- else
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = get_global_id(1) * ";
- _data->code += step->scalar(0, 0).str;
- _data->code += ";\n";
- }
- }
- break;
- case 2:
- if (mapper.is_one_component_z())
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = 0;\n";
- }
- else
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = get_global_id(2) * ";
- _data->code += step->scalar(0, 0).str;
- _data->code += ";\n";
- }
- break;
- default:
- break;
- }
- };
-
- void op_get_global_batch(const Operand &o_dst, const TensorOperand &o_tensor) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *dst = operands.unpack(o_dst);
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
- auto gpu_sampler = o_tensor.sampler();
-
- GpuTensor3dMapper mapper(tensor, gpu_sampler);
-
- if (mapper.is_one_component_batch())
- {
- _data->code += dst->scalar(0, 0).str;
- _data->code += " = 0;\n";
- }
- else
- {
- std::cout << "Unsupported batched computation" << std::endl;
- assert(false);
- }
- };
-
- void op_get_global_size(const Operand &dst_var, int32_t dim) override
- {
- assert(dst_var.type() == OperandType::Tile);
- assert(_data->tiles.has_tile(dst_var.value()));
- assert(_data->tiles[dst_var.value()]->format().w == 1 &&
- _data->tiles[dst_var.value()]->format().h == 1); // It must be a scalar variable
-
- auto var = _data->tiles[dst_var.value()];
-
- _data->code += var->scalar(0, 0).str;
- _data->code += " = get_global_size(";
- _data->code += std::to_string(dim);
- _data->code += ");\n";
- }
-
- void op_unary_expression(const Operand &dst_name, UnaryOp op, const Operand &src_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *src = operands.unpack(src_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- const std::string dt = dst->underlying_source_variables()[0].type.str;
-
- const bool broadcast_src_x = dst_w != 1 && src_w == 1;
-
- const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
- _data->code += to_string(op);
- _data->code += src_prefix + src->vector(y).str;
- _data->code += ";\n";
- }
- }
-
- void op_binary_expression(const Operand &dst_name,
- const Operand &lhs_name,
- BinaryOp op,
- const Operand &rhs_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *lhs = operands.unpack(lhs_name);
- const IVectorTile *rhs = operands.unpack(rhs_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- assert(lhs != nullptr);
- const int32_t lhs_w = lhs->format().w;
- const int32_t rhs_w = rhs->format().w;
-
- if (op == BinaryOp::MatMul_Nt_T)
- {
- assert((dst->format().dt == DataType::Fp32) || (dst->format().dt == DataType::Fp16));
- for (int32_t y = 0; y < dst_h; ++y)
- {
- for (int32_t x = 0; x < dst_w; ++x)
- {
- for (int32_t k = 0; k < lhs_w; ++k)
- {
- _data->code += dst->scalar(x, y).str;
- _data->code += " = fma(";
- _data->code += lhs->scalar(k, y).str;
- _data->code += ", ";
- _data->code += rhs->scalar(k, x).str;
- _data->code += ", ";
- _data->code += dst->scalar(x, y).str;
- _data->code += ");\n";
- }
- }
- }
-
- return;
- }
-
- const bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
- const bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
-
- const std::string lhs_prefix =
- broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- const std::string rhs_prefix =
- broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- const std::string op_str = to_string(op);
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
- _data->code += lhs_prefix + lhs->vector(y).str;
- _data->code += " ";
- _data->code += op_str;
- _data->code += " ";
- _data->code += rhs_prefix + rhs->vector(y).str;
- _data->code += ";\n";
- }
- };
-
- void op_cast_expression(const Operand &o_dst, const Operand &o_src, ConvertPolicy policy) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *src = operands.unpack(o_src);
- const IVectorTile *dst = operands.unpack(o_dst);
- // const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const std::string dt = dst->underlying_source_variables()[0].type.str;
- const bool is_float = (dst->format().dt == DataType::Fp32) || (dst->format().dt == DataType::Fp16);
- const std::string sat = ((policy == ConvertPolicy::Saturate && !is_float) ? "_sat" : "");
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = convert_" + dt + sat + "(";
- _data->code += src->vector(y).str;
- _data->code += ");\n";
- }
- };
-
- void op_assign(const Operand &dst_name, const Operand &src_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *src = operands.unpack(src_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- const std::string dt = dst->underlying_source_variables()[0].type.str;
-
- const bool broadcast_src_x = dst_w != 1 && src_w == 1;
-
- const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
- _data->code += src_prefix + src->vector(y).str;
- _data->code += ";\n";
- }
- }
-
- void op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *src = operands.unpack(src_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_h = dst->format().h;
- const std::string dt = dst->underlying_source_variables()[0].type.str;
-
- // Always perform an explicit cast. This automatically covers at least the 2 scenarios:
- // 1. Widen a scalar into a vector type. This enables scalar-vector broadcasting
- // 2. Ensure non-ambiguity over function overloads.
- // E.g. a constant tile may be accidentally initialized with a double literal. By casting it to single float,
- // it avoids ambiguous function calls
- const std::string src_prefix = "(" + dt + ")";
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
-
- switch (func)
- {
- case UnaryFunction::Exp:
- _data->code += "exp(";
- break;
- case UnaryFunction::Tanh:
- _data->code += "tanh(";
- break;
- case UnaryFunction::Sqrt:
- _data->code += "sqrt(";
- break;
- case UnaryFunction::Erf:
- _data->code += "erf(";
- break;
- case UnaryFunction::Fabs:
- _data->code += "fabs(";
- break;
- case UnaryFunction::Log:
- _data->code += "log(";
- break;
- case UnaryFunction::SizeOf:
- _data->code += "sizeof(";
- break;
- case UnaryFunction::Round:
- _data->code += "round(";
- break;
- case UnaryFunction::Floor:
- _data->code += "floor(";
- break;
- default:
- CKW_ASSERT_MSG(false, "Unexpected UnaryFunction used.");
- }
-
- _data->code += src_prefix + src->vector(y).str;
- _data->code += ");\n";
- }
- }
-
- void op_binary_elementwise_function(const Operand &dst_name,
- BinaryFunction func,
- const Operand &first_name,
- const Operand &second_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *first = operands.unpack(first_name);
- const IVectorTile *second = operands.unpack(second_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_h = dst->format().h;
- const auto datatype = dst->underlying_source_variables()[0].type;
- const std::string datatype_str = datatype.str;
-
- // Always perform an explicit cast. See similar comments in op_unary_elementwise_function
- const std::string first_prefix = "(" + datatype_str + ")";
- const std::string second_prefix = "(" + datatype_str + ")";
-
- const bool is_float = (datatype.dt == DataType::Fp32 || datatype.dt == DataType::Fp16);
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
-
- switch (func)
- {
- case BinaryFunction::Min:
- _data->code += is_float ? "fmin(" : "min(";
- break;
- case BinaryFunction::Max:
- _data->code += is_float ? "fmax(" : "max(";
- break;
- default:
- CKW_ASSERT_MSG(false, "Unexpected BinaryFunction used.");
- }
-
- _data->code += first_prefix + first->vector(y).str;
- _data->code += ", ";
- _data->code += second_prefix + second->vector(y).str;
- _data->code += ");\n";
- }
- }
-
- void op_ternary_elementwise_function(const Operand &dst_name,
- TernaryFunction func,
- const Operand &first_name,
- const Operand &second_name,
- const Operand &third_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *first = operands.unpack(first_name);
- const IVectorTile *second = operands.unpack(second_name);
- const IVectorTile *third = operands.unpack(third_name);
- const IVectorTile *dst = operands.unpack(dst_name);
-
- const int32_t dst_h = dst->format().h;
- const std::string dt = dst->underlying_source_variables()[0].type.str;
-
- // Always perform an explicit cast. See similar comments in op_unary_elementwise_function
- const std::string first_prefix = "(" + dt + ")";
- const std::string second_prefix = "(" + dt + ")";
- const std::string third_prefix = "(" + dt + ")";
-
- // Broadcasting on Y is automatic
- for (int32_t y = 0; y < dst_h; ++y)
- {
- _data->code += dst->vector(y).str;
- _data->code += " = ";
-
- switch (func)
- {
- case TernaryFunction::Select:
- _data->code += "select(";
- break;
- case TernaryFunction::Clamp:
- _data->code += "clamp(";
- break;
- default:
- CKW_ASSERT_MSG(false, "Unexpected TernaryFunction used.");
- }
-
- _data->code += first_prefix + first->vector(y).str;
- _data->code += ", ";
- _data->code += second_prefix + second->vector(y).str;
- _data->code += ", ";
- _data->code += third_prefix + third->vector(y).str;
- _data->code += ");\n";
- }
- }
-
- void op_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *lhs = operands.unpack(o_lhs);
- const IVectorTile *rhs = operands.unpack(o_rhs);
-
- assert(is_tile_scalar(lhs));
- assert(is_tile_scalar(rhs));
-
- _data->code += "if(";
- _data->code += lhs->scalar(0, 0).str;
- _data->code += " ";
- _data->code += to_string(op);
- _data->code += " ";
- _data->code += rhs->scalar(0, 0).str;
- _data->code += ")\n";
- }
-
- void op_else_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
- {
- _data->code += "else ";
- op_if_header(o_lhs, op, o_rhs);
- }
-
- void op_else_header() override
- {
- _data->code += "else\n";
- }
-
- void op_for_loop_header(const Operand &var_name,
- BinaryOp cond_op,
- const Operand &cond_value_name,
- const Operand &update_var_name,
- AssignmentOp update_op,
- const Operand &update_value_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *var = operands.unpack(var_name);
- const IVectorTile *cond_value = operands.unpack(cond_value_name);
- const IVectorTile *update_var = operands.unpack(update_var_name);
- const IVectorTile *update_value = operands.unpack(update_value_name);
-
- const int32_t dst_w = var->format().w;
- const int32_t dst_h = var->format().h;
-
- // It must be a scalar variable
- CKW_UNUSED(dst_w, dst_h);
- assert(dst_w == 1);
- assert(dst_h == 1);
-
- _data->code += "for(; ";
- _data->code += var->scalar(0, 0).str;
- _data->code += " ";
- _data->code += to_string(cond_op);
- _data->code += " " + cond_value->scalar(0, 0).str + "; ";
- _data->code += update_var->scalar(0, 0).str;
- _data->code += " ";
- _data->code += to_string(update_op);
- _data->code += " " + update_value->scalar(0, 0).str + ")";
- _data->code += "\n";
- }
-
- void op_load_immediate(const TensorOperand &o_tensor,
- const Operand &o_dst,
- const Operand &o_x,
- const Operand &o_y,
- const Operand &o_z,
- const Operand &o_batch_idx,
- const Operand &dilation_y) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
-
- // Not const as it requires changes to 'load_writer'.
- IVectorTile *dst = operands.unpack(o_dst);
- IVectorTile *x = operands.unpack(o_x);
- IVectorTile *y = operands.unpack(o_y);
- IVectorTile *z = operands.unpack(o_z);
- IVectorTile *dil_y = operands.unpack(dilation_y);
- IVectorTile *b = operands.unpack(o_batch_idx);
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
- auto gpu_sampler = o_tensor.sampler();
-
- GpuTensor3dMapper mapper(tensor, gpu_sampler);
-
- auto load_writer = ClLoadStoreHelperWriterFactory::create(this, mapper, GpuLoadStoreType::Load);
-
- // Initialize the constant part
- load_writer->initialize(dst, x, z, b);
-
- for (int i = 0; i < dst->format().h; ++i)
- {
- std::string coord_y = y->scalar(0, 0).str + " + " + std::to_string(i);
- if (dil_y->scalar(0, 0).str != "1")
- {
- coord_y += " * " + dil_y->scalar(0, 0).str;
- }
- load_writer->write(std::make_pair(i, coord_y));
- }
-
- load_writer->finalize();
- }
-
- void op_load_indirect(const TensorOperand &o_tensor,
- const Operand &o_dst,
- const Operand &o_x,
- const Operand &o_indirect_h,
- const Operand &o_z,
- const Operand &o_batch_idx) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
-
- // Not const as it requires changes to 'load_writer'.
- IVectorTile *dst = operands.unpack(o_dst);
- IVectorTile *x = operands.unpack(o_x);
- IVectorTile *y_ind = operands.unpack(o_indirect_h);
- IVectorTile *z = operands.unpack(o_z);
- IVectorTile *b = operands.unpack(o_batch_idx);
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
- auto gpu_sampler = o_tensor.sampler();
-
- GpuTensor3dMapper mapper(tensor, gpu_sampler);
-
- auto load_writer = ClLoadStoreHelperWriterFactory::create(this, mapper, GpuLoadStoreType::Load);
-
- // Initialize the constant part
- load_writer->initialize(dst, x, z, b);
-
- for (int i = 0; i < dst->format().h; ++i)
- {
- load_writer->write(std::make_pair(i, y_ind->scalar(0, i).str));
- }
-
- load_writer->finalize();
- }
-
- void op_store_immediate(const TensorOperand &tensor_name,
- const Operand &src_name,
- const Operand &x_name,
- const Operand &y_name,
- const Operand &z_name,
- const Operand &batch_index_name) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
-
- // Not const as it requires changes to 'load_writer'.
- IVectorTile *src = operands.unpack(src_name);
- IVectorTile *x = operands.unpack(x_name);
- IVectorTile *y = operands.unpack(y_name);
- IVectorTile *z = operands.unpack(z_name);
- IVectorTile *b = operands.unpack(batch_index_name);
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- IGpuTensorArgument *tensor = tensor_operands.unpack(tensor_name);
- auto gpu_sampler = tensor_name.sampler();
-
- GpuTensor3dMapper mapper(tensor, gpu_sampler);
-
- auto store_writer = ClLoadStoreHelperWriterFactory::create(this, mapper, GpuLoadStoreType::Store);
-
- // Initialize the constant part
- store_writer->initialize(src, x, z, b);
-
- int32_t tile_h = src->format().h;
-
- for (int m0 = tile_h - 1; m0 >= 0; m0--)
- {
- store_writer->write(std::make_pair(m0, y->scalar(0, 0).str + " + " + std::to_string(m0)));
- }
-
- store_writer->finalize();
- }
-
- void op_return() override
- {
- _data->code += "return;\n";
- }
-
- void util_get_indirect_buffer(const Operand &o_dst,
- const TensorOperand &o_tensor,
- const Operand &o_x,
- const Operand &o_y,
- const Operand &o_x_off,
- const Operand &o_y_off) override
- {
- OperandUnpacker operands(_data->tiles, _data->arguments);
- const IVectorTile *dst = operands.unpack(o_dst);
- const IVectorTile *x = operands.unpack(o_x);
- const IVectorTile *y = operands.unpack(o_y);
- const IVectorTile *x_off = operands.unpack(o_x_off);
- const IVectorTile *y_off = operands.unpack(o_y_off);
-
- TensorOperandUnpacker tensor_operands(_data->arguments);
- IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
-
- assert(dst->format().w == 1);
- assert(x->format().w == 1);
- assert(y->format().w == 1);
- assert(x_off->format().w == 1);
- assert(y_off->format().w == 1);
- assert(dst->format().dt == DataType::Int32);
- assert(x->format().dt == DataType::Int32);
- assert(y->format().dt == DataType::Int32);
- assert(x_off->format().dt == DataType::Int32);
- assert(y_off->format().dt == DataType::Int32);
-
- const std::string width = tensor->component(TensorComponentType::Dim1);
- const std::string height = tensor->component(TensorComponentType::Dim2);
- const std::string wxh = tensor->component(TensorComponentType::Dim1xDim2);
- /*
- int x_s;
- int y_s;
- x_s = (xi_0 + x_k);
- y_s = (yi_0 + y_k);
- mi_0 = x_s + y_s * width + b * widthxheight;
- mi_0 = select(-1, mi_0, x_s >= 0);
- mi_0 = select(-1, mi_0, y_s >= 0);
- mi_0 = select(-1, mi_0, x_s < 128);
- mi_0 = select(-1, mi_0, y_s < 128);
- */
- compound_statement_begin();
- declare_tile("_x_s", TileInfo(DataType::Int32));
- declare_tile("_y_s", TileInfo(DataType::Int32));
- auto x_s = operands.unpack(Operand("_x_s"));
- auto y_s = operands.unpack(Operand("_y_s"));
- for (int i = 0; i < dst->format().h; ++i)
- {
- // x_s = (xi_0 + x_k);
- // y_s = (yi_0 + y_k);
- _data->code += x_s->scalar(0, i).str;
- _data->code += " = (";
- _data->code += x->scalar(0, i).str;
- _data->code += " + ";
- _data->code += x_off->scalar(0, i).str;
- _data->code += ");\n";
- _data->code += y_s->scalar(0, i).str;
- _data->code += " = (";
- _data->code += y->scalar(0, i).str;
- _data->code += " + ";
- _data->code += y_off->scalar(0, i).str;
- _data->code += ");\n";
- // mi_0 = x_s + y_s * width;
- _data->code += dst->scalar(0, i).str;
- _data->code += " = ";
- _data->code += x_s->scalar(0, i).str;
- _data->code += " + ";
- _data->code += y_s->scalar(0, i).str;
- _data->code += " * " + width + ";\n";
- // mi_0 = select(wxh, mi_0, x_s >= 0);
- _data->code += dst->scalar(0, i).str;
- _data->code += " = select(-1, ";
- _data->code += dst->scalar(0, i).str;
- _data->code += ", ";
- _data->code += x_s->scalar(0, i).str;
- _data->code += " >= 0);\n";
- // mi_0 = select(wxh, mi_0, x_s < width);
- _data->code += dst->scalar(0, i).str;
- _data->code += " = select(-1, ";
- _data->code += dst->scalar(0, i).str;
- _data->code += ", ";
- _data->code += x_s->scalar(0, i).str;
- _data->code += " < ";
- _data->code += width + ");\n";
- // mi_0 = select(wxh, mi_0, y_s >= 0);
- _data->code += dst->scalar(0, i).str;
- _data->code += " = select(-1, ";
- _data->code += dst->scalar(0, i).str;
- _data->code += ", ";
- _data->code += y_s->scalar(0, i).str;
- _data->code += " >= 0);\n";
- // mi_0 = select(wxh, mi_0, y_s < height);
- _data->code += dst->scalar(0, i).str;
- _data->code += " = select(-1, ";
- _data->code += dst->scalar(0, i).str;
- _data->code += ", ";
- _data->code += y_s->scalar(0, i).str;
- _data->code += " < ";
- _data->code += height + ");\n";
- }
- compound_statement_end();
- }
-
-private:
- GpuKernelWriterDataHolder *_data{nullptr};
- GpuKernelWriterAttribute *_attr{nullptr};
-};
-
-/** IGpuKernelWriter factory class */
-class GpuKernelWriterFactory final
-{
-public:
- /** Static method to call the IGpuKernelWriter class accordingly with the Gpu programming language
- *
- * @param[in] gpu GPU target
- *
- * @return IGpuKernelWriter
- */
- static std::unique_ptr<IGpuKernelWriter> create(GpuKernelWriterAttribute *attr, GpuKernelWriterDataHolder *x)
- {
- switch (x->programming_language())
- {
- case GpuTargetLanguage::OpenCL:
- return std::make_unique<ClKernelWriter>(attr, x);
- default:
- std::cout << "Unsupported Gpu programming language" << std::endl;
- assert(false);
- return nullptr;
- }
- }
-};
-
-inline int32_t
-adjust_step(TensorSamplerFormat tensor_format, int32_t step, const TensorInfo *tensor_info_id, int32_t idx)
-{
- auto tensor = tensor_info_id->shape;
-
- int32_t dim[3] = {0};
-
- switch (tensor_format)
- {
- case TensorSamplerFormat::C_W_H:
- dim[0] = tensor[0];
- dim[1] = tensor[1];
- dim[2] = tensor[2];
- break;
- case TensorSamplerFormat::C_WH_1:
- dim[0] = tensor[0];
- dim[1] = tensor[1] * tensor[2];
- dim[2] = 1;
- break;
- default:
- std::cout << "Unsupported tensor format" << std::endl;
- assert(false);
- break;
- }
-
- return std::min(step, dim[idx]);
-}
-
-} // namespace prototype
-} // namespace ckw
-
-#endif // CKW_PROTOTYPE_SRC_PROTOTYPE_H
diff --git a/compute_kernel_writer/prototype/src/TensorInfo.cpp b/compute_kernel_writer/prototype/src/TensorInfo.cpp
deleted file mode 100644
index 561c12646..000000000
--- a/compute_kernel_writer/prototype/src/TensorInfo.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/TensorInfo.h"
-
-namespace ckw
-{
-TensorInfo::TensorInfo(DataType dt, const TensorShape &shape, TensorDataLayout dl, int32_t id)
- : _shape(shape), _dt(dt), _dl(dl), _id(id)
-{
-}
-
-TensorInfo &TensorInfo::shape(const TensorShape &shape)
-{
- _shape = shape;
- return *this;
-}
-
-TensorShape TensorInfo::shape() const
-{
- return _shape;
-}
-
-TensorInfo &TensorInfo::data_type(DataType dt)
-{
- _dt = dt;
- return *this;
-}
-
-DataType TensorInfo::data_type() const
-{
- return _dt;
-}
-
-TensorInfo &TensorInfo::data_layout(TensorDataLayout dl)
-{
- _dl = dl;
- return *this;
-}
-
-TensorDataLayout TensorInfo::data_layout() const
-{
- return _dl;
-}
-
-TensorInfo &TensorInfo::id(int32_t id)
-{
- _id = id;
- return *this;
-}
-
-int32_t TensorInfo::id() const
-{
- return _id;
-}
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/TensorOperand.cpp b/compute_kernel_writer/prototype/src/TensorOperand.cpp
deleted file mode 100644
index d1aefbbb7..000000000
--- a/compute_kernel_writer/prototype/src/TensorOperand.cpp
+++ /dev/null
@@ -1,272 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/TensorOperand.h"
-
-#include "ckw/Error.h"
-#include "ckw/Kernel.h"
-#include "ckw/TensorInfo.h"
-#include "ckw/TileOperand.h"
-
-#include "src/Prototype.h"
-
-namespace ckw
-{
-
-namespace
-{
-
-TensorComponentOperand &get_or_create_component(TensorOperand &tensor,
- std::unique_ptr<TensorComponentOperand> &ptr,
- TensorComponentType component)
-{
- if (ptr == nullptr)
- {
- ptr = std::make_unique<TensorComponentOperand>(tensor, component);
- }
-
- return *ptr;
-}
-
-} // namespace
-
-// =================================================================================================
-// TensorOperand
-// =================================================================================================
-
-TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
- : OperandBase(name), _info(info), _storage_type(storage_type)
-{
-}
-
-prototype::Operand TensorOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
-{
- CKW_UNUSED(writer);
- return {name()};
-}
-
-const TensorInfo &TensorOperand::info() const
-{
- return _info;
-}
-
-TensorInfo &TensorOperand::info()
-{
- return _info;
-}
-
-TensorStorageType TensorOperand::storage_type() const
-{
- return _storage_type;
-}
-
-DataType TensorOperand::data_type() const
-{
- return _info.data_type();
-}
-
-bool TensorOperand::is_constant() const
-{
- return false;
-}
-
-const TileOperand &TensorOperand::tile() const
-{
- return *_tile;
-}
-
-TileOperand &TensorOperand::tile()
-{
- return *_tile;
-}
-
-TensorOperand &TensorOperand::tile(TileOperand &tile)
-{
- _tile = &tile;
- return *this;
-}
-
-const TensorTileSampler &TensorOperand::tile_sampler() const
-{
- return _tile_sampler;
-}
-
-TensorTileSampler &TensorOperand::tile_sampler()
-{
- return _tile_sampler;
-}
-
-TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value)
-{
- _tile_sampler = value;
- return *this;
-}
-
-TensorComponentOperand &TensorOperand::stride1()
-{
- return get_or_create_component(*this, _stride1, TensorComponentType::Stride1);
-}
-
-TensorComponentOperand &TensorOperand::stride2()
-{
- return get_or_create_component(*this, _stride2, TensorComponentType::Stride2);
-}
-
-TensorComponentOperand &TensorOperand::stride3()
-{
- return get_or_create_component(*this, _stride3, TensorComponentType::Stride3);
-}
-
-TensorComponentOperand &TensorOperand::stride4()
-{
- return get_or_create_component(*this, _stride4, TensorComponentType::Stride4);
-}
-
-TensorComponentOperand &TensorOperand::dim0()
-{
- return get_or_create_component(*this, _dim0, TensorComponentType::Dim0);
-}
-
-TensorComponentOperand &TensorOperand::dim1()
-{
- return get_or_create_component(*this, _dim1, TensorComponentType::Dim1);
-}
-
-TensorComponentOperand &TensorOperand::dim2()
-{
- return get_or_create_component(*this, _dim2, TensorComponentType::Dim2);
-}
-
-TensorComponentOperand &TensorOperand::dim3()
-{
- return get_or_create_component(*this, _dim3, TensorComponentType::Dim3);
-}
-
-TensorComponentOperand &TensorOperand::dim4()
-{
- return get_or_create_component(*this, _dim4, TensorComponentType::Dim4);
-}
-
-TensorComponentOperand &TensorOperand::dim1_dim2()
-{
- return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2);
-}
-
-TensorComponentOperand &TensorOperand::dim1_dim2_dim3()
-{
- return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3);
-}
-
-TensorComponentOperand &TensorOperand::offset_first_element_in_bytes()
-{
- return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement);
-}
-
-// =================================================================================================
-// TensorComponentOperand
-// =================================================================================================
-
-TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component)
- : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component)
-{
-}
-
-TensorOperand &TensorComponentOperand::tensor()
-{
- return _tensor;
-}
-
-const TensorOperand &TensorComponentOperand::tensor() const
-{
- return _tensor;
-}
-
-TensorComponentType TensorComponentOperand::component_type() const
-{
- return _component;
-}
-
-prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
-{
- CKW_UNUSED(writer);
- prototype::OperandType type{prototype::OperandType::Unknown};
-
- switch (_component)
- {
- case TensorComponentType::OffsetFirstElement:
- type = prototype::OperandType::TensorDataOffset;
- break;
-
- case TensorComponentType::Stride1:
- type = prototype::OperandType::TensorStride1;
- break;
-
- case TensorComponentType::Stride2:
- type = prototype::OperandType::TensorStride2;
- break;
-
- case TensorComponentType::Stride3:
- type = prototype::OperandType::TensorStride3;
- break;
-
- case TensorComponentType::Stride4:
- type = prototype::OperandType::TensorStride4;
- break;
-
- case TensorComponentType::Dim0:
- type = prototype::OperandType::TensorDim0;
- break;
-
- case TensorComponentType::Dim1:
- type = prototype::OperandType::TensorDim1;
- break;
-
- case TensorComponentType::Dim2:
- type = prototype::OperandType::TensorDim2;
- break;
-
- case TensorComponentType::Dim3:
- type = prototype::OperandType::TensorDim3;
- break;
-
- case TensorComponentType::Dim4:
- type = prototype::OperandType::TensorDim4;
- break;
-
- case TensorComponentType::Dim1xDim2:
- type = prototype::OperandType::TensorDim1xDim2;
- break;
-
- case TensorComponentType::Dim1xDim2xDim3:
- type = prototype::OperandType::TensorDim1xDim2xDim3;
- break;
-
- default:
- CKW_ASSERT(false);
- }
-
- return prototype::Operand(name(), type);
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
deleted file mode 100644
index bf9f946ce..000000000
--- a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/TensorTileSampler.h"
-
-#include "ckw/TileOperand.h"
-#include "ckw/types/TensorSamplerTypes.h"
-
-namespace ckw
-{
-
-TensorTileSampler::TensorTileSampler()
-{
-}
-
-TensorTileSampler::TensorTileSampler(TileOperand &x,
- TileOperand &y,
- TileOperand &z,
- TileOperand &b,
- TensorSamplerFormat format,
- TensorSamplerAddressModeX address_mode_x,
- TensorSamplerAddressModeY address_mode_y,
- TensorSamplerAddressModeZ address_mode_z)
- : _x(&x),
- _y(&y),
- _z(&z),
- _b(&b),
- _height(0),
- _width(0),
- _format(format),
- _address_mode_x(address_mode_x),
- _address_mode_y(address_mode_y),
- _address_mode_z(address_mode_z)
-{
-}
-
-TensorTileSampler::TensorTileSampler(TileOperand &x,
- TileOperand &y,
- TileOperand &z,
- TileOperand &b,
- int32_t height,
- int32_t width,
- TensorSamplerFormat format,
- TensorSamplerAddressModeX address_mode_x,
- TensorSamplerAddressModeY address_mode_y,
- TensorSamplerAddressModeZ address_mode_z)
- : _x(&x),
- _y(&y),
- _z(&z),
- _b(&b),
- _height(height),
- _width(width),
- _format(format),
- _address_mode_x(address_mode_x),
- _address_mode_y(address_mode_y),
- _address_mode_z(address_mode_z)
-{
-}
-
-const TileOperand &TensorTileSampler::x() const
-{
- return *_x;
-}
-
-TensorTileSampler &TensorTileSampler::x(TileOperand &x)
-{
- _x = &x;
- return *this;
-}
-
-const TileOperand &TensorTileSampler::y() const
-{
- return *_y;
-}
-
-TensorTileSampler &TensorTileSampler::y(TileOperand &y)
-{
- _y = &y;
- return *this;
-}
-
-const TileOperand &TensorTileSampler::z() const
-{
- return *_z;
-}
-
-TensorTileSampler &TensorTileSampler::z(TileOperand &z)
-{
- _z = &z;
- return *this;
-}
-
-const TileOperand &TensorTileSampler::b() const
-{
- return *_b;
-}
-
-TensorTileSampler &TensorTileSampler::b(TileOperand &b)
-{
- _b = &b;
- return *this;
-}
-
-int32_t TensorTileSampler::width() const
-{
- return _width;
-}
-
-TensorTileSampler &TensorTileSampler::width(int32_t width)
-{
- _width = width;
- return *this;
-}
-
-int32_t TensorTileSampler::height() const
-{
- return _height;
-}
-
-TensorTileSampler &TensorTileSampler::height(int32_t height)
-{
- _height = height;
- return *this;
-}
-
-TensorSamplerFormat TensorTileSampler::format() const
-{
- return _format;
-}
-
-TensorTileSampler &TensorTileSampler::format(TensorSamplerFormat format)
-{
- _format = format;
- return *this;
-}
-
-TensorSamplerAddressModeX TensorTileSampler::address_mode_x() const
-{
- return _address_mode_x;
-}
-
-TensorTileSampler &TensorTileSampler::address_mode_x(TensorSamplerAddressModeX address_mode_x)
-{
- _address_mode_x = address_mode_x;
- return *this;
-}
-
-TensorSamplerAddressModeY TensorTileSampler::address_mode_y() const
-{
- return _address_mode_y;
-}
-
-TensorTileSampler &TensorTileSampler::address_mode_y(TensorSamplerAddressModeY address_mode_y)
-{
- _address_mode_y = address_mode_y;
- return *this;
-}
-
-TensorSamplerAddressModeZ TensorTileSampler::address_mode_z() const
-{
- return _address_mode_z;
-}
-
-TensorTileSampler &TensorTileSampler::address_mode_z(TensorSamplerAddressModeZ address_mode_z)
-{
- _address_mode_z = address_mode_z;
- return *this;
-}
-
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/TileInfo.cpp b/compute_kernel_writer/prototype/src/TileInfo.cpp
deleted file mode 100644
index 273266eed..000000000
--- a/compute_kernel_writer/prototype/src/TileInfo.cpp
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/TileInfo.h"
-
-namespace ckw
-{
-TileInfo::TileInfo(DataType dt) : _dt(dt), _shape({{1, 1}})
-{
-}
-
-TileInfo::TileInfo(DataType dt, int32_t w) : _dt(dt), _shape({{w, 1}})
-{
-}
-
-TileInfo::TileInfo(DataType dt, int32_t h, int32_t w) : _dt(dt), _shape({{w, h}})
-{
-}
-
-TileInfo &TileInfo::width(int32_t w)
-{
- _shape[kTileWidthIdx] = w;
- return *this;
-}
-
-int32_t TileInfo::width() const
-{
- return _shape[kTileWidthIdx];
-}
-
-TileInfo &TileInfo::height(int32_t h)
-{
- _shape[kTileHeightIdx] = h;
- return *this;
-}
-
-int32_t TileInfo::height() const
-{
- return _shape[kTileHeightIdx];
-}
-
-TileInfo &TileInfo::data_type(DataType dt)
-{
- _dt = dt;
- return *this;
-}
-
-DataType TileInfo::data_type() const
-{
- return _dt;
-}
-} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/TileOperand.cpp b/compute_kernel_writer/prototype/src/TileOperand.cpp
deleted file mode 100644
index e09c833d9..000000000
--- a/compute_kernel_writer/prototype/src/TileOperand.cpp
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ckw/TileOperand.h"
-
-#include "ckw/Error.h"
-
-#include "src/Prototype.h"
-
-namespace ckw
-{
-
-TileOperand::TileOperand(const std::string &name, const TileInfo &info)
- : OperandBase(name), _info(info), _value{std::vector<std::string>{"0"}}, _constant(false)
-{
-}
-
-TileOperand::TileOperand(const std::string &name, DataType data_type)
- : OperandBase(name), _info(TileInfo{data_type}), _value{std::vector<std::string>{"0"}}, _constant(false)
-{
-}
-
-TileOperand::TileOperand(const std::string &name, int32_t value)
- : OperandBase(name),
- _info(TileInfo{DataType::Int32}),
- _value{std::vector<std::string>{std::to_string(value)}},
- _constant(true)
-{
-}
-
-TileOperand::TileOperand(const std::string &name, float value)
- : OperandBase(name),
- _info(TileInfo{DataType::Fp32}),
- _value{std::vector<std::string>{std::to_string(value)}},
- _constant(true)
-{
-}
-
-TileOperand::TileOperand(const std::string &name, const TileContainer &vals, DataType dt)
- : OperandBase(name),
- _info(TileInfo{dt, static_cast<int32_t>(vals.size()), static_cast<int32_t>(vals[0].size())}),
- _value(vals),
- _constant(true)
-{
-}
-
-prototype::Operand TileOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
-{
- CKW_UNUSED(writer);
-
- if (_constant)
- {
- if (is_scalar())
- {
- switch (_info.data_type())
- {
- case DataType::Int32:
- return prototype::Operand(_value[0][0], prototype::OperandType::ScalarInt32);
-
- case DataType::Fp32:
- return prototype::Operand(_value[0][0], prototype::OperandType::ScalarFp32);
-
- case DataType::Fp16:
- return prototype::Operand(_value[0][0], prototype::OperandType::ScalarFp16);
-
- default:
- CKW_ASSERT(false);
- }
- }
- else
- {
- return prototype::Operand(name());
- }
- }
- else
- {
- return prototype::Operand(name(), prototype::OperandType::Tile);
- }
-}
-
-const TileInfo &TileOperand::tile_info() const
-{
- return _info;
-}
-
-DataType TileOperand::data_type() const
-{
- return _info.data_type();
-}
-
-bool TileOperand::is_constant() const
-{
- return _constant;
-}
-
-bool TileOperand::is_scalar() const
-{
- return _info.width() == 1 && _info.height() == 1;
-}
-
-std::string TileOperand::scalar_value() const
-{
- CKW_ASSERT(is_scalar());
- CKW_ASSERT(is_constant());
-
- return _value[0][0];
-}
-
-const TileContainer &TileOperand::value() const
-{
- return _value;
-}
-
-} // namespace ckw
diff --git a/docs/Doxyfile b/docs/Doxyfile
index 998421a14..cca32210e 100644
--- a/docs/Doxyfile
+++ b/docs/Doxyfile
@@ -38,7 +38,7 @@ PROJECT_NAME = "Compute Library"
# could be handy for archiving the generated documentation or if some version
# control system is used.
-PROJECT_NUMBER = 24.02.1
+PROJECT_NUMBER = 24.04
# Using the PROJECT_BRIEF tag one can provide an optional one line description
# for a project that appears at the top of each page and should give viewer a
diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox
index 25c856da1..e7f1823f8 100644
--- a/docs/user_guide/operator_list.dox
+++ b/docs/user_guide/operator_list.dox
@@ -1,5 +1,5 @@
///
-/// Copyright (c) 2021-2023 Arm Limited.
+/// Copyright (c) 2021-2024 Arm Limited.
///
/// SPDX-License-Identifier: MIT
///
@@ -1773,6 +1773,7 @@ where N = batches, C = channels, H = height, W = width, D = depth
<tr><td>QASYMM8_SIGNED<td>QASYMM8_SIGNED<td>S32<td>S32
<tr><td>QASYMM8_SIGNED<td>QSYMM8_PER_CHANNEL<td>S32<td>S32
<tr><td>QASYMM8_SIGNED<td>QSYMM8<td>S32<td>S32
+ <tr><td>QASYMM8_SIGNED<td>QASYMM8_SIGNED<td>F32<td>F32
</table>
<tr>
<td>CLGEMMLowpMatrixMultiplyCore
@@ -2091,6 +2092,7 @@ where N = batches, C = channels, H = height, W = width, D = depth
<tr><th>lhs<th>rhs<th>dst
<tr><td>F32<td>F32<td>F32
<tr><td>F16<td>F16<td>F16
+ <tr><td>BFLOAT16<td>BFLOAT16<td>BFLOAT16
<tr><td>QASYMM8_SIGNED<td>QASYMM8_SIGNED<td>QASYMM8_SIGNED
<tr><td>QASYMM8<td>QASYMM8<td>QASYMM8
</table>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index 0d7c5fe37..b29b81580 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -41,6 +41,17 @@ If there is more than one release in a month then an extra sequential number is
@section S2_2_changelog Changelog
+v24.04 Public major release
+ - Add Bfloat16 data type support for @ref NEMatMul.
+ - Add support for SoftMax in SME2 for FP32 and FP16.
+ - Add support for in place accumulation to CPU GEMM kernels.
+ - Add low-precision Int8 * Int8 -> FP32 CPU GEMM which dequantizes after multiplication
+ - Add is_dynamic flag to QuantizationInfo to signal to operators that it may change after configuration
+ - Performance optimizations:
+ - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
+ - Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7
+ - Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3.
+
v24.02.1 Public patch release
- Fix performance regression in fixed-format kernels
- Fix compile and runtime errors in arm_compute_validation for Windows on Arm(WoA)
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 139b968e4..6b7fbded5 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -48,6 +48,10 @@ set(EXAMPLE_GRAPH_NAMES
PARENT_SCOPE)
set(EXAMPLE_NEON_NAMES
- neon_cnn neon_copy_objects neon_gemm_qasymm8 neon_permute neon_scale
+ neon_cnn neon_copy_objects
+ neon_gemm_qasymm8
+ neon_gemm_s8_f32
+ neon_permute
+ neon_scale
neon_sgemm
PARENT_SCOPE)
diff --git a/examples/SConscript b/examples/SConscript
index 16f31d93d..8ece7e60b 100644
--- a/examples/SConscript
+++ b/examples/SConscript
@@ -1,7 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
-# Copyright (c) 2017-2023 Arm Limited.
+# Copyright (c) 2017-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -51,11 +51,18 @@ else:
graph_dependency = [arm_compute_graph_so]
extra_link_flags = []
-if env['os'] != 'bare_metal':
+
+if not env['os'] in ['windows','bare_metal'] :
extra_link_flags += ['-fstack-protector-strong']
-load_whole_archive = '-Wl,--whole-archive'
-noload_whole_archive = '-Wl,--no-whole-archive'
+
+if env['os'] != 'windows' :
+ load_whole_archive = '-Wl,--whole-archive'
+ noload_whole_archive = '-Wl,--no-whole-archive'
+else:
+ load_whole_archive = '/wholearchive'
+ noload_whole_archive = ''
+
if 'macos' in examples_env['os']:
load_whole_archive = '-Wl,-force_load'
noload_whole_archive = ''
@@ -67,7 +74,7 @@ examples_libs = examples_env.get("LIBS",[])
for file in Glob("./graph_*.cpp"):
example = os.path.basename(os.path.splitext(str(file))[0])
prog = None
- if env['os'] in ['android', 'macos', 'bare_metal'] or env['standalone']:
+ if env['os'] in ['android','windows', 'macos', 'bare_metal'] or env['standalone']:
prog = examples_env.Program(example, ["{}.cpp".format(example), utils, graph_utils], LIBS = examples_libs + arm_compute_graph_libs, LINKFLAGS=examples_env["LINKFLAGS"]+[load_whole_archive, graph_dependency, noload_whole_archive] + extra_link_flags)
Depends(prog, graph_dependency)
prog = install_bin(prog)
diff --git a/examples/neon_gemm_s8_f32.cpp b/examples/neon_gemm_s8_f32.cpp
new file mode 100644
index 000000000..7c1497ec4
--- /dev/null
+++ b/examples/neon_gemm_s8_f32.cpp
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) 2020-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "arm_compute/core/WindowIterator.h"
+#include "arm_compute/runtime/NEON/NEFunctions.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+#include "support/ToolchainSupport.h"
+#include "utils/Utils.h"
+
+#include <cstdlib>
+
+using namespace arm_compute;
+using namespace utils;
+
+QuantizationInfo dynamic_qinfo(QuantizationInfo qinfo)
+{
+ return QuantizationInfo(qinfo.scale(), qinfo.offset(), true);
+}
+void set_qinfo_dynamic(Tensor &t)
+{
+ t.info()->set_quantization_info(dynamic_qinfo(t.info()->quantization_info()));
+}
+
+void quantize(Tensor &qt, const Tensor &t, float min, float max)
+{
+ DataType dt = DataType::QASYMM8_SIGNED;
+
+ // Determine the scale
+ const float scale = (max - min) / 256.0f;
+
+ // Determine the zero-point; using affine equation val = (qval-zerop) * scale
+ const float zero_point = -128.0f - min / scale;
+
+ QuantizationInfo qinfo(scale, (int32_t)round(zero_point), true);
+
+ // We now have the quantisation info and can configure the quantised tensor
+ qt.allocator()->init(TensorInfo(t.info()->tensor_shape(), 1, dt, qinfo));
+ qt.allocator()->allocate();
+ NEQuantizationLayer quantization;
+ quantization.configure(&t, &qt);
+ quantization.run();
+}
+
+void invert_qinfo_offset(Tensor &t)
+{
+ QuantizationInfo qinfo = t.info()->quantization_info();
+ t.info()->set_quantization_info(QuantizationInfo(qinfo.scale()[0], -qinfo.offset()[0], qinfo.is_dynamic()));
+}
+
+void print_quantization_info(const Tensor &t, const std::string &name_prefix)
+{
+ QuantizationInfo qinfo = t.info()->quantization_info();
+ std::cout << name_prefix << "_qinfo="
+ << "QuantizationInfo(" << qinfo.scale()[0] << ", " << qinfo.offset()[0] << ")\n";
+}
+
+int main(int argc, char **argv)
+{
+ size_t M = 4;
+ size_t N = 4;
+ size_t K = 4;
+
+ // Parse args
+ if (argc < 3) /* case default matrix sizes */
+ {
+ // Print help
+ std::cout << "Usage: ./build/neon_gemm_qasymm8 M N K\n";
+ std::cout << "Too few or no inputs provided. Using default M=4, N=4, K=4\n\n";
+ }
+ else /* case M N K arguments provided */
+ {
+ M = strtol(argv[1], nullptr, 10);
+ N = strtol(argv[2], nullptr, 10);
+ K = strtol(argv[3], nullptr, 10);
+ }
+
+ /*** Floating point matrix multiplication ***/
+
+ // Initialise input matrices
+ NEGEMM fgemm{};
+
+ Tensor src1;
+ Tensor src2;
+ Tensor dst;
+ src1.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::F32));
+ src2.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::F32));
+ dst.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));
+ fgemm.configure(&src1, &src2, nullptr, &dst, 1, 0);
+
+ // Allocate matrices
+ src1.allocator()->allocate();
+ src2.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ float min1 = 0.0f;
+ float max1 = 1.0f;
+ fill_random_tensor(src1, 0, min1, max1);
+
+ float min2 = -1.0f;
+ float max2 = 2.0f;
+ fill_random_tensor(src2, 1, min2, max2);
+
+ // Run single precision gemm and print result
+ fgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ std::cout << "# F32 GEMM result:\n";
+ std::cout << "src1=[ \n";
+ src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "src2=[ \n";
+ src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ Tensor q_src1;
+ quantize(q_src1, src1, min1, max1);
+ print_quantization_info(q_src1, "src1");
+ q_src1.info()->set_are_values_constant(false);
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src1);
+
+ Tensor q_src2;
+ quantize(q_src2, src2, min2, max2);
+ print_quantization_info(q_src2, "src2");
+ q_src2.info()->set_are_values_constant(false);
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src2);
+
+ // q_dst will be Dequantized to F32 so it doesn't need a QuantizationInfo
+ Tensor q_dst;
+ q_dst.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));
+
+ // Configure low precision gemm and initialise result tensor (pre-output)
+ NEGEMMLowpMatrixMultiplyCore qgemm;
+ qgemm.configure(&q_src1, &q_src2, nullptr, &q_dst);
+
+ q_dst.allocator()->allocate();
+
+ // Run low precision matrix multiply kernel
+ qgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ // Print quantized source matrices
+ std::cout << "q_src1=[ \n";
+ q_src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "q_src2=[ \n";
+ q_src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "# Lowp GEMM output (FP32):\n";
+ std::cout << "q_dst=[ \n";
+ q_dst.print(std::cout);
+ std::cout << "] \n";
+
+ // Expected result
+ std::cout << "# Expected result:\n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ // Rerun to test the ability to modify the Tensor contents and QuantizationInfo (dynamic quantization)
+ min1 = -1.0f;
+ max1 = 1.0f;
+ fill_random_tensor(src1, 2, min1, max1);
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ std::cout << "# Refilled src1\n";
+ std::cout << "src1=[ \n";
+ src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "src2=[ \n";
+ src2.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ fgemm.run();
+
+ quantize(q_src1, src1, min1, max1);
+ set_qinfo_dynamic(q_src1);
+ print_quantization_info(q_src1, "src1");
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src1);
+
+ qgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ // Print quantized source matrices
+ std::cout << "q_src1=[ \n";
+ q_src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "q_src2=[ \n";
+ q_src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "# Lowp GEMM output (FP32):\n";
+ std::cout << "q_dst=[ \n";
+ q_dst.print(std::cout);
+ std::cout << "] \n";
+
+ // Expected result
+ std::cout << "# Expected result:\n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+}
diff --git a/filelist.json b/filelist.json
index 2f33b5cd5..2c3621cd8 100644
--- a/filelist.json
+++ b/filelist.json
@@ -770,6 +770,15 @@
]
}
},
+ "Scatter": {
+ "files": {
+ "common": [
+ "src/gpu/cl/kernels/ClScatterKernel.cpp",
+ "src/gpu/cl/operators/ClScatter.cpp",
+ "src/runtime/CL/functions/CLScatter.cpp"
+ ]
+ }
+ },
"Select": {
"files": {
"common": [
@@ -1586,12 +1595,15 @@
"src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int8.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp",
+ "src/core/NEON/kernels/arm_gemm/interleave-8way.cpp",
"src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp",
"src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp",
"src/core/NEON/kernels/arm_gemm/mergeresults.cpp",
@@ -1684,6 +1696,7 @@
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
@@ -1711,6 +1724,9 @@
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
@@ -2225,7 +2241,9 @@
"common": [ "src/cpu/kernels/softmax/generic/sve/impl.cpp" ]
},
"sve2":{
- "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"]
+ "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"],
+ "fp32" :["src/cpu/kernels/softmax/generic/sme2/fp32.cpp"],
+ "fp16" :["src/cpu/kernels/softmax/generic/sme2/fp16.cpp"]
}
}
},
@@ -2324,7 +2342,6 @@
"src/dynamic_fusion/sketch/attributes/ResizeAttributes.cpp",
"src/dynamic_fusion/sketch/attributes/SoftmaxAttributes.cpp",
"src/dynamic_fusion/sketch/attributes/ReshapeAttributes.cpp",
- "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentGraph.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp",
"src/dynamic_fusion/sketch/gpu/GpuKernelComponentStream.cpp",
@@ -2339,8 +2356,6 @@
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp",
- "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp",
- "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp",
"src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp",
@@ -2361,21 +2376,6 @@
"src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp",
"src/dynamic_fusion/sketch/gpu/operators/internal/GpuElementwiseBinaryCommon.cpp"
],
- "template_writer": [
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp",
- "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp"
- ],
"ckw_driver": [
"src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwComponentArgument.cpp",
"src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.cpp",
diff --git a/scripts/clang_tidy_rules.py b/scripts/clang_tidy_rules.py
index 1e1ab7f54..f244017db 100755
--- a/scripts/clang_tidy_rules.py
+++ b/scripts/clang_tidy_rules.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
-# Copyright (c) 2017-2023 Arm Limited.
+# Copyright (c) 2017-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -28,7 +28,7 @@ import re
import sys
def get_list_includes():
- return "compute_kernel_writer/prototype/include " \
+ return "compute_kernel_writer/include " \
"src/cpu/kernels/assembly " \
"src/core/NEON/kernels/assembly " \
"src/core/NEON/kernels/convolution/winograd " \
@@ -43,8 +43,6 @@ def get_list_flags( filename, arch):
flags.append("-DARM_COMPUTE_OPENCL_ENABLED")
if arch == "aarch64":
flags.append("-DARM_COMPUTE_AARCH64_V8_2")
- if "ckw_driver" in filename:
- flags.append("-DACL_INTERNAL_TEST_CKW_IN_DF")
return flags
diff --git a/scripts/format_code.py b/scripts/format_code.py
index b456bd435..8bfb3f560 100755
--- a/scripts/format_code.py
+++ b/scripts/format_code.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -79,7 +79,7 @@ def check_copyright( filename ):
start = 2
if("SConscript" in filename):
start = 3
- m = re.match("(# Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[start])
+ m = re.match(r"(# Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[start])
line = m.group(1)
if m.group(2): # Is there a year already?
@@ -101,7 +101,7 @@ def check_copyright( filename ):
return
# This only works until year 9999
- m = re.match("(.*Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[1])
+ m = re.match(r"(.*Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[1])
start =len(ref)+2
if content[0] != "/*\n" or not m:
start = 0
@@ -146,7 +146,7 @@ def check_license(filename):
year = datetime.datetime.now().year
# This only works until year 9999
- m = re.match("(.*Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[2])
+ m = re.match(r"(.*Copyright \(c\) )(.*\d{4})( [Arm|ARM].*)", content[2])
if not m:
f.write("Copyright (c) {} Arm Limited\n".format(year))
diff --git a/scripts/generate_android_bp.py b/scripts/generate_android_bp.py
index f7ecbc468..d5b268f52 100755
--- a/scripts/generate_android_bp.py
+++ b/scripts/generate_android_bp.py
@@ -45,7 +45,9 @@ excluded_paths = ["build",
"/sve/",
"/SVE/",
"/sve2/",
- "/SVE2/"
+ "/SVE2/",
+ "/sme/",
+ "/sme2/",
]
excluded_files = ["TracePoint.cpp"]
@@ -108,6 +110,7 @@ cc_library_static {
proprietary: true,
local_include_dirs: ["build/android-arm64v8a/src/core",
"build/android-arm64v8a/src/core/CL",
+ "compute_kernel_writer/include",
"src/core/common",
"src/core/helpers",
"src/core/NEON/kernels/arm_gemm",
diff --git a/scripts/generate_build_files.py b/scripts/generate_build_files.py
index 17cf49c0a..f88cf1af4 100644
--- a/scripts/generate_build_files.py
+++ b/scripts/generate_build_files.py
@@ -1,7 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -93,7 +93,7 @@ def resolve_operator_dependencies(filelist, operators, backend=''):
return resolved_operators
def get_template_header():
- return """# Copyright (c) 2023 Arm Limited.
+ return """# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
diff --git a/src/BUILD.bazel b/src/BUILD.bazel
index 9d5ae6348..e3cac07de 100644
--- a/src/BUILD.bazel
+++ b/src/BUILD.bazel
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -117,6 +117,8 @@ filegroup(
"cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp",
"cpu/kernels/elementwise_unary/generic/sve2/q8.cpp",
"cpu/kernels/lut/generic/sve2/u8.cpp",
+ "cpu/kernels/softmax/generic/sme2/fp16.cpp",
+ "cpu/kernels/softmax/generic/sme2/fp32.cpp",
"cpu/kernels/softmax/generic/sve2/impl.cpp"] +
glob(["**/*.h",
"**/*.hpp",
@@ -261,6 +263,9 @@ filegroup(
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
@@ -509,19 +514,23 @@ filegroup(
"core/NEON/kernels/arm_conv/pooling/pooling_u8.cpp",
"core/NEON/kernels/arm_conv/pooling/pooling_u8q.cpp",
"core/NEON/kernels/arm_gemm/gemm_bf16.cpp",
+ "core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp",
"core/NEON/kernels/arm_gemm/gemm_fp16.cpp",
"core/NEON/kernels/arm_gemm/gemm_fp32.cpp",
"core/NEON/kernels/arm_gemm/gemm_int16.cpp",
"core/NEON/kernels/arm_gemm/gemm_int8.cpp",
"core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
+ "core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
"core/NEON/kernels/arm_gemm/gemm_uint8.cpp",
+ "core/NEON/kernels/arm_gemm/interleave-8way.cpp",
"core/NEON/kernels/arm_gemm/interleave_indirect.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index be7a6ef18..984db79c1 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -238,6 +238,9 @@ target_sources(
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp
@@ -335,6 +338,8 @@ target_sources(
cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp
cpu/kernels/elementwise_unary/generic/sve2/q8.cpp
cpu/kernels/lut/generic/sve2/u8.cpp
+ cpu/kernels/softmax/generic/sme2/fp16.cpp
+ cpu/kernels/softmax/generic/sme2/fp32.cpp
cpu/kernels/softmax/generic/sve2/impl.cpp
)
@@ -500,19 +505,23 @@ target_sources(
core/NEON/kernels/arm_conv/pooling/pooling_u8.cpp
core/NEON/kernels/arm_conv/pooling/pooling_u8q.cpp
core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+ core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp
core/NEON/kernels/arm_gemm/gemm_fp16.cpp
core/NEON/kernels/arm_gemm/gemm_fp32.cpp
core/NEON/kernels/arm_gemm/gemm_int16.cpp
core/NEON/kernels/arm_gemm/gemm_int8.cpp
core/NEON/kernels/arm_gemm/gemm_qint8.cpp
core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+ core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
core/NEON/kernels/arm_gemm/gemm_uint16.cpp
core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+ core/NEON/kernels/arm_gemm/interleave-8way.cpp
core/NEON/kernels/arm_gemm/interleave_indirect.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp
diff --git a/src/common/cpuinfo/CpuIsaInfo.cpp b/src/common/cpuinfo/CpuIsaInfo.cpp
index 597768530..c9e39b9a0 100644
--- a/src/common/cpuinfo/CpuIsaInfo.cpp
+++ b/src/common/cpuinfo/CpuIsaInfo.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,7 +60,7 @@ inline bool is_feature_supported(uint64_t features, uint64_t feature_mask)
void decode_hwcaps(CpuIsaInfo &isa, const uint32_t hwcaps, const uint32_t hwcaps2)
{
ARM_COMPUTE_UNUSED(hwcaps2);
- isa.fp16 = is_feature_supported(hwcaps, ARM_COMPUTE_CPU_FEATURE_HWCAP_HALF);
+ isa.fp16 = false;
isa.neon = is_feature_supported(hwcaps, ARM_COMPUTE_CPU_FEATURE_HWCAP_NEON);
}
#elif defined(__aarch64__)
diff --git a/src/core/CL/cl_kernels/common/elementwise_operation.cl b/src/core/CL/cl_kernels/common/elementwise_operation.cl
index 45dcbfc6e..91e51d9d1 100644
--- a/src/core/CL/cl_kernels/common/elementwise_operation.cl
+++ b/src/core/CL/cl_kernels/common/elementwise_operation.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,11 +46,7 @@
#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))))
#endif // VEC_SIZE_OUT == 1
-#if defined(S32)
-#define DIV(x, y) CONVERT(floor(CONVERT(x, VEC_DATA_TYPE(float, VEC_SIZE_OUT)) / CONVERT(y, VEC_DATA_TYPE(float, VEC_SIZE_OUT))), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT));
-#else /* S32 */
#define DIV(x, y) (x / y)
-#endif /* S32 */
#define AND(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1))
#define OR(x, y) (CONVERT((x || y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1))
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index c807cb3ad..6cb10a7bb 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@ void a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(
{
struct Params
{
- long unsigned int n_channels;
+ uint64_t n_channels;
const void *weights;
const int32_t *bias;
const arm_gemm::Requantize32 *requant;
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index c8fe567e7..931673263 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@ void a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(
{
struct Params
{
- long unsigned int n_channels;
+ uint64_t n_channels;
const void *weights;
const int32_t *bias;
const arm_gemm::Requantize32 *requant;
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp
new file mode 100644
index 000000000..aa761b46e
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "bfloat.hpp"
+#include "gemm_implementation.hpp"
+#include "gemm_interleaved.hpp"
+
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
+#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
+namespace arm_gemm {
+
+static const GemmImplementation<bfloat16, bfloat16> gemm_bf16bf16_methods[] =
+{
+#ifdef __aarch64__
+#ifdef ARM_COMPUTE_ENABLE_BF16
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+GemmImplementation<bfloat16, bfloat16>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_ffinterleaved_bf16fp32_mmla_8x12",
+ KernelWeightFormat::VL256_BL64,
+ [](const GemmArgs &args) { return args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, bfloat16>::estimate_cycles<bfloat16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, bfloat16>(args); }
+),
+GemmImplementation<bfloat16, bfloat16>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
+ KernelWeightFormat::VL2VL_BL64,
+ [](const GemmArgs &args) { return args._ci->has_svebf16(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, bfloat16>::estimate_cycles<bfloat16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, bfloat16>(args); }
+),
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_BF16
+#endif // __aarch64__
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr
+}
+};
+
+template<>
+const GemmImplementation<bfloat16, bfloat16> *gemm_implementation_list<bfloat16, bfloat16>() {
+ return gemm_bf16bf16_methods;
+}
+
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<bfloat16, bfloat16> gemm<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, bfloat16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
+template KernelDescription get_gemm_method<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &);
+template std::vector<KernelDescription> get_compatible_kernels<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &);
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 44a7bb894..290fe8723 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
+#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
@@ -292,14 +293,14 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_8x4",
- [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_6x4",
- [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
},
@@ -350,6 +351,14 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
),
+GemmImplementation<float, float>::with_estimate(
+ GemmMethod::GEMM_HYBRID,
+ "a64_ffhybrid_fp32bf16fp32_mmla_6x16",
+ KernelWeightFormat::VL256_BL64_BF16,
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); },
+ [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>(args); }
+),
#endif // BF16
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
index 436316c0f..a6c967730 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -221,7 +221,9 @@ public:
return roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi);
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
+
Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
_B_transposed = buffer;
strategy strat(_ci);
@@ -237,7 +239,7 @@ public:
const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size;
strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb,
- x0, xmax, k0, kmax);
+ x0, xmax, k0, kmax, false);
buffer += size;
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 1780375c4..0cc4d4f3d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -530,7 +530,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else if (_convolver) {
@@ -563,7 +563,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else {
@@ -579,7 +579,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
}
@@ -631,11 +631,16 @@ public:
}
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
- pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size());
+ bool B_pretranspose_supports_transpose() const override {
+ strategy strat(_args._ci);
+ return strat.transforms.PrepareB_supports_transpose();
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
}
- void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override {
+ void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed, size_t start, size_t end) override {
if (end >= get_B_pretranspose_window_size()) {
requantize_bias(in_buffer, B, ldb, B_multi_stride);
}
@@ -717,7 +722,8 @@ public:
strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb,
x0, xmax,
(k_section_base * _args._Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length.
- (k_section_base * _args._Ksize) + k_offset + k_length); // K end point - starting point plus length computed above.
+ (k_section_base * _args._Ksize) + k_offset + k_length, // K end point - starting point plus length computed above.
+ transposed);
// We need to modify our position based on the ROUNDED version of what we just did.
unsigned int padded_length = roundup(k_length, strategy::k_unroll());
@@ -731,7 +737,7 @@ public:
} else {
// In the single K section case, can process the whole lot in one go.
strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb,
- n_start, n_end, k0, std::min(kmax, _args._Ksize));
+ n_start, n_end, k0, std::min(kmax, _args._Ksize), transposed);
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index efb5bd1bb..f12efe428 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -277,7 +277,9 @@ public:
}
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
+
requantize_bias(in_buffer, B, ldb, B_multi_stride);
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
@@ -296,7 +298,7 @@ public:
const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size;
strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb,
- x0, xmax, k0, kmax);
+ x0, xmax, k0, kmax, false);
buffer += size;
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index fd20e53f6..0dc0d55b2 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -128,14 +128,14 @@ GemmImplementation<int8_t, int32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 362a3e30e..ae344f09b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,7 +29,6 @@
#include "arm_gemm.hpp"
#include "bfloat.hpp"
#include "convolver.hpp"
-#include "kernel_weight_format.hpp"
#include "kernel_traits.hpp"
#include "kernel_weight_format.hpp"
#include "mergeresults.hpp"
@@ -247,6 +246,84 @@ void kernel_and_merge<true, false, Requantize32>::run(
}
}
+// Run a kernel with integrated merge, dequantizing to FP32
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<false, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
+ unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias,
+ Tab *acc_buff)
+{
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
+#endif
+
+ const int32_t *offset_col_bias = nullptr;
+ const Tr *offset_bias = nullptr;
+
+ if (col_bias) {
+ offset_col_bias = col_bias + n_0;
+ }
+
+ if (bias) {
+ offset_bias = bias + n_0;
+ }
+
+ strat.kernel(// A and B pointers are just the packed panels.
+ a_ptr, b_panel,
+ // Provide relevant part of output array and row stride.
+ c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc,
+ // M, N, K sizes
+ m_max-m_0, n_max - n_0, kern_k,
+ // Bias, activation, accumulation. Need to offset the bias as needed.
+ offset_col_bias, dq, offset_bias, act, accumulate, acc_buff);
+}
+
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<true, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
+ unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *,
+ Tab *)
+{
+ const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
+#endif
+ auto out_area = strategy::out_width() * strategy::out_height();
+ for (int i=0; i<bblocks; i++) {
+ const unsigned int n_start = n_0 + (strategy::out_width() * i);
+ const unsigned int n_end = std::min(n_start + strategy::out_width(), n_max);
+
+ dequantize_block_32(qp, (n_end - n_start), (m_max - m_0),
+ c_panel + (i * out_area), strategy::out_width(),
+ c_ptr + m_0 * ldc + n_start, ldc,
+ bias != nullptr ? bias + n_start : nullptr, accumulate, act);
+
+ }
+ }
+}
+
// Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in
// "requantizing" context where the output will be requantized.
//
@@ -280,6 +357,12 @@ public:
typedef int32_t type;
};
+template<typename strategy>
+class accumulate_buffer_type<strategy, DequantizeFloat, false> {
+public:
+ typedef int32_t type;
+};
+
template<typename strategy, typename OutputStage>
class accumulate_buffer_type<strategy, OutputStage, true> {
public:
@@ -350,6 +433,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _thread_columns;
const Activation _act;
+ const bool _accumulate;
const int _maxthreads;
int _nthreads;
@@ -680,7 +764,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os(os) { }
@@ -690,7 +774,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os() { }
@@ -763,6 +847,9 @@ public:
const bool first_pass = (k0==0);
const bool last_pass = (kmax==_Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Figure out how many "K" the kernel will actually process.
unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll());
@@ -821,9 +908,9 @@ public:
// K size, and M/N ranges
kern_k, start_row, end_row, start_x, end_x,
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (multi * _Nsize),
// Accumulation buffer
@@ -948,6 +1035,9 @@ public:
const bool first_pass = (current.k0() == 0);
const bool last_pass = (current.kmax() == _Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Pointer to appropriate part of result array.
Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride);
@@ -969,9 +1059,9 @@ public:
// K size, and M/N ranges
kern_k, y, ymax, current.x0(), current.xmax(),
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (current.multi() * _Nsize),
// Accumulation buffer
@@ -1067,11 +1157,18 @@ public:
}
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
- pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size());
+ // Support for transposed B is a property of the strategy::transpose type
+ bool B_pretranspose_supports_transpose() const override {
+ typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
+
+ return transforms.PrepareB_supports_transpose();
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override {
+ pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
}
- void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override {
+ void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
// Perform column sums etc as part of the last block.
if (end >= get_B_pretranspose_window_size()) {
requantize_bias(in_buffer, B, ldb, B_multi_stride);
@@ -1134,7 +1231,8 @@ public:
strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
x0, xmax,
(k_section_base * _Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length.
- (k_section_base * _Ksize) + k_offset + k_length); // K end point - starting point plus length computed above.
+ (k_section_base * _Ksize) + k_offset + k_length, // K end point - starting point plus length computed above.
+ transposed);
// We need to modify our position based on the ROUNDED version of what we just did.
unsigned int padded_length = roundup(k_length, strategy::k_unroll());
@@ -1149,7 +1247,7 @@ public:
// In the single K section case, can process the whole lot in one go.
// Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize.
strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize));
+ current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize), transposed);
buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
}
@@ -1176,6 +1274,13 @@ public:
}
}
+ void set_dequantize_scale(const float scale) override {
+ if(std::is_same<OutputStage, DequantizeFloat>::value) {
+ DequantizeFloat* df = reinterpret_cast<DequantizeFloat *>(&_os);
+ df->scale = scale;
+ }
+ }
+
void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override {
assert(string_len == _Ksize);
_indirect_buf = ptr;
@@ -1240,4 +1345,10 @@ using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strat
template<typename strategy, typename To, typename Tr>
using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>;
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>;
+
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>;
+
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 7c860a24a..d1c4e49ed 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -71,7 +71,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods
#ifdef ARM_COMPUTE_ENABLE_SVE
#ifdef ARM_COMPUTE_ENABLE_SME2
{
- GemmMethod::GEMM_HYBRID,
+ GemmMethod::GEMV_PRETRANSPOSED,
"sme2_gemv_s8qa_dot_16VL",
[](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && quant_hybrid_asymmetric(qp) && args._Msize == 1 && !args._indirect_input && args._nbatches == 1; },
nullptr,
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 3baf9857d..b85b1c4fc 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -67,7 +67,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
#ifdef ARM_COMPUTE_ENABLE_SME2
// SME kernels
{
- GemmMethod::GEMM_HYBRID,
+ GemmMethod::GEMV_PRETRANSPOSED,
"sme2_gemv_u8qa_dot_16VL",
[](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && quant_hybrid_asymmetric(qp) && args._Msize == 1 && !args._indirect_input && args._nbatches == 1; },
nullptr,
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
new file mode 100644
index 000000000..782399df8
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include "arm_gemm.hpp"
+
+#include "kernels/a64_gemm_s16_8x12.hpp"
+#include "kernels/a64_gemm_s8_8x12.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
+#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
+
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
+#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SVE
+
+#include "gemm_implementation.hpp"
+#include "gemm_interleaved.hpp"
+#include "utils.hpp"
+
+#include <cstdint>
+#include <vector>
+namespace arm_gemm {
+
+static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_methods[] =
+{
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
+ return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, float>(args, dq); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
+ return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, float>(args, dq); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, float>(args, dq); }
+},
+#endif // ARM_COMPUTE_ENABLE_SME2
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_mmla_8x3VL",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_svei8mm(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>(args, qp); }
+),
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_dot_8x3VL",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sve(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>(args, qp); }
+),
+#endif // ARM_COMPUTE_ENABLE_SVE
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_interleaved_s8s32_mmla_8x12",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_i8mm(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>(args, qp); }
+),
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s16_8x12",
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->get_cpu_model() == CPUModel::A53 && ((args._Msize > 28) || ((args._Msize % 8) > 4)); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s16_8x12, int8_t, float>(args, qp); }
+},
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_8x12",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_dotprod(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>(args, qp); }
+),
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_4x4",
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>(args, qp); }
+),
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr
+}
+};
+
+template<>
+const GemmImplementation<int8_t, float, DequantizeFloat> *gemm_implementation_list<int8_t, float, DequantizeFloat>() {
+ return gemm_s8fp32_methods;
+}
+
+template UniqueGemmCommon<int8_t, float> gemm<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+template KernelDescription get_gemm_method<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+template std::vector<KernelDescription> get_compatible_kernels<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+
+} // namespace arm_gemm
+
+#endif // __aarch64__ \ No newline at end of file
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index af5cfbbf2..dfacb687a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -94,14 +94,14 @@ GemmImplementation<uint8_t, uint32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 4fc9b3456..ad504f266 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,8 @@ public:
return _subgemm->get_B_pretransposed_array_size();
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
- _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride);
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride, transposed);
}
void set_pretransposed_B_data(void *buffer) override {
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 86b33d081..dbada3605 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -180,7 +180,7 @@ public:
this->_Cptr + (multi * this->_C_multi_stride) + n,
(nmax - n), (kmax-k0),
this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr,
- _args._act, (k0 != 0),
+ _args._act, (k0 != 0) || _args._accumulate,
_os, col_bias, n + (_args._Nsize * multi));
}
}
@@ -215,7 +215,18 @@ public:
}
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) override {
+ if (std::is_same<OutputStage, Requantize32>::value) {
+ Requantize32 *qp = reinterpret_cast<Requantize32 *>(&_os);
+
+ qp->bias = bias;
+ qp->bias_multi_stride = bias_multi_stride;
+ }
+ }
+
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
+
requantize_bias(buffer, B, ldb, B_multi_stride);
// The actual transposed buffer goes after the column sums (if any)
@@ -225,7 +236,7 @@ public:
strategy strat(_args._ci);
for (unsigned int multi=0; multi<_args._nmulti; multi++) {
- strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize);
+ strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize, false);
}
_B_pretransposed = B_buffer;
diff --git a/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
new file mode 100644
index 000000000..a05d700c5
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
@@ -0,0 +1,267 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include <arm_neon.h>
+
+#if !defined(_WIN64) && !defined(__OpenBSD__)
+#include <alloca.h>
+#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */
+
+#include <cstring>
+
+#include "transform.hpp"
+#include "utils.hpp"
+
+namespace arm_gemm {
+
+namespace {
+
+// Helper function to interleave a single 4x4 block of 32-bin values
+// together.
+
+// _full version doesn't need to worry about any padding.
+static inline void transpose_block_32_full(const uint8_t * __restrict in_ptr0, const uint8_t * __restrict in_ptr1, const uint8_t * __restrict in_ptr2, const uint8_t * __restrict in_ptr3, uint8_t * __restrict out_ptr, long output_stride) {
+ uint32x4_t inputs[4];
+ uint32x4_t inters[4];
+ uint32x4_t outputs[4];
+
+ inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr0));
+ inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr1));
+ inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr2));
+ inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr3));
+
+ inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+ inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+ inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+ inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+ outputs[0] = vzip1q_u32(inters[0], inters[2]);
+ outputs[1] = vzip2q_u32(inters[0], inters[2]);
+ outputs[2] = vzip1q_u32(inters[1], inters[3]);
+ outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+}
+
+// _part version: Only read "bytes_in" bytes, not a full vector. Only write
+// out 4-byte blocks that have some live content (if bytes_in is not a
+// multiple of 4 there will some padding in each 4-block)
+static inline void transpose_block_32_part(const uint8_t *in_ptr0, const uint8_t *in_ptr1, const uint8_t *in_ptr2, const uint8_t *in_ptr3, uint8_t *out_ptr, long bytes_in, long output_stride) {
+ uint32x4_t inputs[4];
+ uint32x4_t inters[4];
+ uint32x4_t outputs[4];
+ uint8_t scratch[16] = {0};
+
+ long num_outs = iceildiv<long>(bytes_in, 4);
+
+ memcpy(scratch, in_ptr0, bytes_in);
+ inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr1, bytes_in);
+ inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr2, bytes_in);
+ inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+ memcpy(scratch, in_ptr3, bytes_in);
+ inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+
+ inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+ inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+ inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+ inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+ outputs[0] = vzip1q_u32(inters[0], inters[2]);
+ outputs[1] = vzip2q_u32(inters[0], inters[2]);
+ outputs[2] = vzip1q_u32(inters[1], inters[3]);
+ outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+ do {
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+ if (num_outs < 2)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+ if (num_outs < 3)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+ if (num_outs < 4)
+ break;
+ vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+ } while (0);
+}
+
+template<unsigned N>
+struct Unroll {
+ template<typename F>
+ static void run(F f) {
+ Unroll<N-1>::run(f);
+ f(N-1);
+ }
+};
+
+template<>
+struct Unroll<0> {
+ template<typename F>
+ static void run(F) {
+ }
+};
+
+// Interleave some multiple of 4 rows together.
+//
+// The template parameter BLOCKS controls the size of the inner loop - each BLOCK is 4 rows.
+// The function parameter interleave_multiple controls the number of times the inner loop is run.
+
+// The total interleave depth for a given run is therefore BLOCKS * interleave_multiple * 4.
+template<unsigned BLOCKS>
+void a64_interleave_1x4(uint8_t *out, const uint8_t *in, long width, long in_stride, long height, long interleave_multiple) {
+ const long total_interleave_depth = BLOCKS * 4 * interleave_multiple;
+ constexpr long loop_interleave_depth = BLOCKS * 4;
+
+ uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width));
+
+ if (height % total_interleave_depth) {
+ memset(pad_row, 0, width);
+ }
+
+ // Outer loop: process blocks of total_interleave_depth rows at a time.
+ for (long y0_base=0; y0_base<height; y0_base+=total_interleave_depth) {
+ // Middle loop: process each "interlave_multiple" block of rows.
+ for (long block=0; block<interleave_multiple; block++) {
+ const long y0 = y0_base + (block * loop_interleave_depth);
+ uint8_t *out_ptr = out + (block * loop_interleave_depth * 4); // 4 is the blocking depth (we interleave 4 bytes at a time from each input)
+
+ // Create and set up input row pointers. The idea is that these
+ // should entirely fit in the register file, so we don't have to
+ // repeatedly load them (or perform the padding check)
+ const uint8_t *in_ptrs[loop_interleave_depth];
+ Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+ in_ptrs[y] = (y+y0 < height) ? in + ((y+y0) * in_stride) : pad_row;
+ });
+
+ long bytes_left = width;
+ // Process full vectors using transpose_block_32_full()
+ while (bytes_left >= 16) { // 16 is the vector length in bytes
+ Unroll<BLOCKS>::run( [&](unsigned u) {
+ transpose_block_32_full(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
+ out_ptr + 16*u, total_interleave_depth * 4); // 4 is the blocking depth
+ });
+
+ Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+ in_ptrs[y] += 16; // 16 is the vector length in bytes
+ });
+
+ out_ptr += total_interleave_depth * 16; // 16 is the vector length in bytes
+ bytes_left -= 16; // 16 is the vector length in bytes
+ }
+
+ // Process any remaining bytes using transpose_block_32_part()
+ if (bytes_left) {
+ Unroll<BLOCKS>::run( [&](unsigned u) {
+ transpose_block_32_part(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
+ out_ptr + 16*u, bytes_left, total_interleave_depth * 4);
+ });
+ }
+ }
+
+ // Update "out" pointer for next set of total_interleave_depth rows
+ out += total_interleave_depth * roundup<long>(width, 4);
+ }
+}
+
+} // anonymous namespace
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+ uint8_t *out, const uint8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0),
+ stride,
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+ int8_t *out, const int8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0),
+ stride,
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<12, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<3>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<16, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<4>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 1
+ );
+}
+
+template<>
+void Transform<24, 1, false, VLType::None>(
+ float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+ a64_interleave_1x4<3>(
+ reinterpret_cast<uint8_t *>(out),
+ reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+ (xmax - x0) * sizeof(float),
+ stride * sizeof(float),
+ (ymax - y0),
+ 2
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
index 923d008bb..ac3cbf943 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
{
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 23.64 };
default:
- return { 28.48 };
+ return { 16.89 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp
new file mode 100644
index 000000000..98f7fc940
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#ifdef __aarch64__
+
+#include "../std_transforms_fixed.hpp"
+#include "../bfloat.hpp"
+#include "../kernel_weight_format.hpp"
+#include "../performance_parameters.hpp"
+
+#define ARGLIST \
+ unsigned int, const unsigned int *, \
+ IndirectInputArg<float>, \
+ size_t, size_t, \
+ const bfloat16 *, \
+ size_t, \
+ IndirectOutputArg<float>, \
+ const float *, Activation, bool
+
+namespace arm_gemm
+{
+// Actual kernel implementations
+void a64_ffhybrid_fp32bf16fp32_mmla_6x16( ARGLIST );
+
+class cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16
+{
+public:
+ typedef float lhs_operand_type;
+ typedef bfloat16 rhs_operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)( ARGLIST );
+
+ /* Kernel blocking parameters */
+ static constexpr unsigned int out_height()
+ {
+ return 6;
+ }
+ static unsigned int stripe_width()
+ {
+ return 4;
+ }
+
+ static KernelWeightFormat kernel_weight_format()
+ {
+ return KernelWeightFormat::VL256_BL64_BF16;
+ }
+
+ static unsigned int out_width()
+ {
+ return 16;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 4> transforms = {};
+ template<typename T>
+ static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
+ {
+ if (std::is_same<T, float>::value) {
+ switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 21.05 };
+ default:
+ return { 15.27 };
+ }
+ }
+
+ return { 1.0 };
+ }
+
+ // Default to the generic kernel
+ kern_type kernel=a64_ffhybrid_fp32bf16fp32_mmla_6x16;
+ cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#undef ARGLIST
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
new file mode 100644
index 000000000..9ab4aa98f
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
@@ -0,0 +1,3240 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include "arm_gemm.hpp"
+#include "../../utils.hpp"
+#include "../../bfloat.hpp"
+
+#include <cassert>
+#include <limits>
+
+namespace arm_gemm {
+
+void a64_ffhybrid_fp32bf16fp32_mmla_6x16 (
+ unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<float> A_arg,
+ size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg<float> output_arg,
+ const float *bias, Activation act, bool accumulate
+)
+{
+ struct KernelArgs {
+ float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
+ float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
+ unsigned int num_strings = {};
+ const unsigned int *string_lengths = {};
+ size_t N = {};
+ const bfloat16 *B_ptr = {};
+ const bfloat16 *cur_B_ptr = {};
+ size_t B_stride = {};
+ size_t output_offset = {};
+ size_t input_initial_col = {};
+ size_t input_offset = {};
+ void *output_ptr = nullptr;
+ const float *bias = nullptr;
+ } ka;
+
+ unsigned long flags=0;
+ void *input_ptr;
+
+ if (output_arg.is_indirect) {
+ ka.output_ptr=(void *)(output_arg.indirect.ptr);
+ ka.output_offset=output_arg.indirect.offset;
+ flags |= 0x4;
+ } else {
+ ka.output_ptr=(void *)(output_arg.direct.base);
+ ka.output_offset=output_arg.direct.stride;
+ }
+
+ if (A_arg.is_indirect) {
+ input_ptr=(void *)(A_arg.indirect.ptr);
+ ka.input_offset=A_arg.indirect.start_row;
+ ka.input_initial_col=A_arg.indirect.start_col;
+ flags |= 0x8;
+ } else {
+ assert(num_strings==1);
+ input_ptr=(void *)(A_arg.direct.base);
+ ka.input_offset=A_arg.direct.stride;
+ }
+ if (accumulate) {
+ flags |= 0x1;
+ }
+ ka.num_strings = num_strings;
+ ka.string_lengths = string_lengths;
+ ka.N = N;
+ ka.B_ptr = B_ptr;
+ ka.bias = bias;
+ ka.B_stride = B_stride;
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ ka.maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ ka.minval = 0;
+ flags |= 0x2;
+ break;
+ }
+ __asm__ __volatile__(
+ "1:" // Row loop
+ "cmp %x[M], #0x6\n"
+ "bge 181f\n"
+ "cmp %x[M], #0x4\n"
+ "bgt 145f\n"
+ "beq 109f\n"
+ "cmp %x[M], #0x2\n"
+ "bgt 73f\n"
+ "beq 37f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "2:" // Height 1: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 3f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 3f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 3f\n"
+ "mov x11, x12\n"
+ "3:" // Height 1: B setup done
+ "cbz x15, 4f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "b 16f\n"
+ "4:" // Height 1: no bias
+ "tbz %x[flags], #0, 15f\n"
+ "cmp x14, #0x10\n"
+ "bge 13f\n"
+ "tbz x14, #3, 8f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "tbz x14, #2, 6f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 5f\n"
+ "ldr d16, [x13], #0x8\n"
+ "mov x20, #0x38\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "b 12f\n"
+ "5:" // Height 1: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "b 12f\n"
+ "6:" // Height 1: Partial accumulate: partial_2_8
+ "tbz x14, #1, 7f\n"
+ "ldr d11, [x13], #0x8\n"
+ "mov x20, #0x28\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "b 12f\n"
+ "7:" // Height 1: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "b 12f\n"
+ "8:" // Height 1: Partial accumulate: partial_4_0
+ "tbz x14, #2, 10f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 9f\n"
+ "ldr d10, [x13], #0x8\n"
+ "mov x20, #0x18\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "b 12f\n"
+ "9:" // Height 1: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "b 12f\n"
+ "10:" // Height 1: Partial accumulate: partial_2_0
+ "tbz x14, #1, 11f\n"
+ "ldr d9, [x13], #0x8\n"
+ "mov x20, #0x8\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "b 12f\n"
+ "11:" // Height 1: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "mov x20, #0x0\n"
+ "12:" // Height 1: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 14f\n"
+ "13:" // Height 1: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "14:" // Height 1: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "b 16f\n"
+ "15:" // Height 1: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "16:" // Height 1: setup done
+ "mov x28, #0x0\n"
+ "17:" // Height 1: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 18f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "cbnz x28, 19f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "b 19f\n"
+ "18:" // Height 1: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "19:" // Height 1: input setup done
+ "cmp x27, #0x4\n"
+ "blt 22f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "cmp x27, #0x8\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 21f\n"
+ "20:" // Height 1: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 20b\n"
+ "21:" // Height 1: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "22:" // Height 1: Multiply loop: Main loop skip
+ "cbz x27, 25f\n"
+ "cbz x27, 25f\n"
+ "tbz x27, #1, 23f\n"
+ "ldr d0, [x26], #0x8\n"
+ "tbz x27, #0, 24f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "b 24f\n"
+ "23:" // Height 1: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "24:" // Height 1: Multiply loop: Ragged operand read: Done
+ "ldr q18, [x12, #0x0]\n"
+ "ldr q17, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "25:" // Height 1: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 17b\n"
+ "uzp1 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v10.2d, v10.2d, v14.2d\n"
+ "uzp1 v11.2d, v11.2d, v15.2d\n"
+ "tbz %x[flags], #1, 26f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v18.4s }, [x21]\n"
+ "ld1r { v17.4s }, [x20]\n"
+ "fmin v8.4s, v8.4s, v18.4s\n"
+ "fmin v9.4s, v9.4s, v18.4s\n"
+ "fmin v10.4s, v10.4s, v18.4s\n"
+ "fmin v11.4s, v11.4s, v18.4s\n"
+ "fmax v8.4s, v8.4s, v17.4s\n"
+ "fmax v9.4s, v9.4s, v17.4s\n"
+ "fmax v10.4s, v10.4s, v17.4s\n"
+ "fmax v11.4s, v11.4s, v17.4s\n"
+ "26:" // Height 1: No activation
+ "cmp x14, #0x10\n"
+ "bge 35f\n"
+ "tbz x14, #3, 30f\n"
+ "st1 { v8.4s }, [x13], #0x10\n"
+ "st1 { v9.4s }, [x13], #0x10\n"
+ "tbz x14, #2, 28f\n"
+ "st1 { v10.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 27f\n"
+ "str d11, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v11.s }[2], [x13]\n"
+ "b 34f\n"
+ "27:" // Height 1: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 34f\n"
+ "str s11, [x13, #0x0]\n"
+ "b 34f\n"
+ "28:" // Height 1: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 29f\n"
+ "str d10, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v10.s }[2], [x13]\n"
+ "b 34f\n"
+ "29:" // Height 1: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 34f\n"
+ "str s10, [x13, #0x0]\n"
+ "b 34f\n"
+ "30:" // Height 1: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 32f\n"
+ "st1 { v8.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 31f\n"
+ "str d9, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v9.s }[2], [x13]\n"
+ "b 34f\n"
+ "31:" // Height 1: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 34f\n"
+ "str s9, [x13, #0x0]\n"
+ "b 34f\n"
+ "32:" // Height 1: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 33f\n"
+ "str d8, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v8.s }[2], [x13]\n"
+ "b 34f\n"
+ "33:" // Height 1: Partial direct writeback: partial_1_0
+ "str s8, [x13, #0x0]\n"
+ "34:" // Height 1: Partial direct writeback: Done
+ "b 36f\n"
+ "35:" // Height 1: Full writeback
+ "str q8, [x13, #0x0]\n"
+ "str q9, [x13, #0x10]\n"
+ "str q10, [x13, #0x20]\n"
+ "str q11, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "36:" // Height 1: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 2b\n"
+ "b 218f\n"
+ "37:" // Height 2
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "38:" // Height 2: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 39f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 39f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 39f\n"
+ "mov x11, x12\n"
+ "39:" // Height 2: B setup done
+ "cbz x15, 40f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "b 52f\n"
+ "40:" // Height 2: no bias
+ "tbz %x[flags], #0, 51f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "bge 49f\n"
+ "tbz x14, #3, 44f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "tbz x14, #2, 42f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 41f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "b 48f\n"
+ "41:" // Height 2: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "b 48f\n"
+ "42:" // Height 2: Partial accumulate: partial_2_8
+ "tbz x14, #1, 43f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "b 48f\n"
+ "43:" // Height 2: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "b 48f\n"
+ "44:" // Height 2: Partial accumulate: partial_4_0
+ "tbz x14, #2, 46f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 45f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "b 48f\n"
+ "45:" // Height 2: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "b 48f\n"
+ "46:" // Height 2: Partial accumulate: partial_2_0
+ "tbz x14, #1, 47f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "b 48f\n"
+ "47:" // Height 2: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "48:" // Height 2: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 50f\n"
+ "49:" // Height 2: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "50:" // Height 2: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "b 52f\n"
+ "51:" // Height 2: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "52:" // Height 2: setup done
+ "mov x28, #0x0\n"
+ "53:" // Height 2: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 54f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "cbnz x28, 55f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "b 55f\n"
+ "54:" // Height 2: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "55:" // Height 2: input setup done
+ "cmp x27, #0x4\n"
+ "blt 58f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 57f\n"
+ "56:" // Height 2: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 56b\n"
+ "57:" // Height 2: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "58:" // Height 2: Multiply loop: Main loop skip
+ "cbz x27, 61f\n"
+ "cbz x27, 61f\n"
+ "tbz x27, #1, 59f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "tbz x27, #0, 60f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "b 60f\n"
+ "59:" // Height 2: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "60:" // Height 2: Multiply loop: Ragged operand read: Done
+ "ldr q18, [x12, #0x0]\n"
+ "ldr q17, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "61:" // Height 2: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 53b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "tbz %x[flags], #1, 62f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v18.4s }, [x21]\n"
+ "ld1r { v17.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v18.4s\n"
+ "fmin v12.4s, v12.4s, v18.4s\n"
+ "fmin v13.4s, v13.4s, v18.4s\n"
+ "fmin v14.4s, v14.4s, v18.4s\n"
+ "fmin v8.4s, v8.4s, v18.4s\n"
+ "fmin v9.4s, v9.4s, v18.4s\n"
+ "fmin v10.4s, v10.4s, v18.4s\n"
+ "fmin v11.4s, v11.4s, v18.4s\n"
+ "fmax v6.4s, v6.4s, v17.4s\n"
+ "fmax v12.4s, v12.4s, v17.4s\n"
+ "fmax v13.4s, v13.4s, v17.4s\n"
+ "fmax v14.4s, v14.4s, v17.4s\n"
+ "fmax v8.4s, v8.4s, v17.4s\n"
+ "fmax v9.4s, v9.4s, v17.4s\n"
+ "fmax v10.4s, v10.4s, v17.4s\n"
+ "fmax v11.4s, v11.4s, v17.4s\n"
+ "62:" // Height 2: No activation
+ "cmp x14, #0x10\n"
+ "bge 71f\n"
+ "tbz x14, #3, 66f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "tbz x14, #2, 64f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 63f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "b 70f\n"
+ "63:" // Height 2: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 70f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "b 70f\n"
+ "64:" // Height 2: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 65f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "b 70f\n"
+ "65:" // Height 2: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 70f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "b 70f\n"
+ "66:" // Height 2: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 68f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 67f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "b 70f\n"
+ "67:" // Height 2: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 70f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "b 70f\n"
+ "68:" // Height 2: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 69f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "b 70f\n"
+ "69:" // Height 2: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "70:" // Height 2: Partial direct writeback: Done
+ "b 72f\n"
+ "71:" // Height 2: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "72:" // Height 2: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 38b\n"
+ "b 218f\n"
+ "73:" // Height 3
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "74:" // Height 3: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 75f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 75f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 75f\n"
+ "mov x11, x12\n"
+ "75:" // Height 3: B setup done
+ "cbz x15, 76f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "b 88f\n"
+ "76:" // Height 3: no bias
+ "tbz %x[flags], #0, 87f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "bge 85f\n"
+ "tbz x14, #3, 80f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "tbz x14, #2, 78f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 77f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "b 84f\n"
+ "77:" // Height 3: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "b 84f\n"
+ "78:" // Height 3: Partial accumulate: partial_2_8
+ "tbz x14, #1, 79f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "b 84f\n"
+ "79:" // Height 3: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "b 84f\n"
+ "80:" // Height 3: Partial accumulate: partial_4_0
+ "tbz x14, #2, 82f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 81f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "b 84f\n"
+ "81:" // Height 3: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "b 84f\n"
+ "82:" // Height 3: Partial accumulate: partial_2_0
+ "tbz x14, #1, 83f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "b 84f\n"
+ "83:" // Height 3: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "84:" // Height 3: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 86f\n"
+ "85:" // Height 3: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "86:" // Height 3: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "b 88f\n"
+ "87:" // Height 3: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "88:" // Height 3: setup done
+ "mov x28, #0x0\n"
+ "89:" // Height 3: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 90f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "cbnz x28, 91f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "b 91f\n"
+ "90:" // Height 3: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "91:" // Height 3: input setup done
+ "cmp x27, #0x4\n"
+ "blt 94f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 93f\n"
+ "92:" // Height 3: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 92b\n"
+ "93:" // Height 3: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "94:" // Height 3: Multiply loop: Main loop skip
+ "cbz x27, 97f\n"
+ "cbz x27, 97f\n"
+ "tbz x27, #1, 95f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "tbz x27, #0, 96f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "b 96f\n"
+ "95:" // Height 3: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "96:" // Height 3: Multiply loop: Ragged operand read: Done
+ "ldr q26, [x12, #0x0]\n"
+ "ldr q25, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n"
+ ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "97:" // Height 3: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 89b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x25, x26, x20, LSL #2\n"
+ "uzp1 v16.2d, v16.2d, v20.2d\n"
+ "uzp1 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v19.2d, v19.2d, v23.2d\n"
+ "tbz %x[flags], #1, 98f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v26.4s }, [x21]\n"
+ "ld1r { v25.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v26.4s\n"
+ "fmin v12.4s, v12.4s, v26.4s\n"
+ "fmin v13.4s, v13.4s, v26.4s\n"
+ "fmin v14.4s, v14.4s, v26.4s\n"
+ "fmin v8.4s, v8.4s, v26.4s\n"
+ "fmin v9.4s, v9.4s, v26.4s\n"
+ "fmin v10.4s, v10.4s, v26.4s\n"
+ "fmin v11.4s, v11.4s, v26.4s\n"
+ "fmin v16.4s, v16.4s, v26.4s\n"
+ "fmin v17.4s, v17.4s, v26.4s\n"
+ "fmin v18.4s, v18.4s, v26.4s\n"
+ "fmin v19.4s, v19.4s, v26.4s\n"
+ "fmax v6.4s, v6.4s, v25.4s\n"
+ "fmax v12.4s, v12.4s, v25.4s\n"
+ "fmax v13.4s, v13.4s, v25.4s\n"
+ "fmax v14.4s, v14.4s, v25.4s\n"
+ "fmax v8.4s, v8.4s, v25.4s\n"
+ "fmax v9.4s, v9.4s, v25.4s\n"
+ "fmax v10.4s, v10.4s, v25.4s\n"
+ "fmax v11.4s, v11.4s, v25.4s\n"
+ "fmax v16.4s, v16.4s, v25.4s\n"
+ "fmax v17.4s, v17.4s, v25.4s\n"
+ "fmax v18.4s, v18.4s, v25.4s\n"
+ "fmax v19.4s, v19.4s, v25.4s\n"
+ "98:" // Height 3: No activation
+ "cmp x14, #0x10\n"
+ "bge 107f\n"
+ "tbz x14, #3, 102f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v16.4s }, [x25], #0x10\n"
+ "st1 { v17.4s }, [x25], #0x10\n"
+ "tbz x14, #2, 100f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v18.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 99f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d19, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v19.s }[2], [x25]\n"
+ "b 106f\n"
+ "99:" // Height 3: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 106f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s19, [x25, #0x0]\n"
+ "b 106f\n"
+ "100:" // Height 3: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 101f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d18, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v18.s }[2], [x25]\n"
+ "b 106f\n"
+ "101:" // Height 3: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 106f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s18, [x25, #0x0]\n"
+ "b 106f\n"
+ "102:" // Height 3: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 104f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v16.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 103f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d17, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v17.s }[2], [x25]\n"
+ "b 106f\n"
+ "103:" // Height 3: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 106f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s17, [x25, #0x0]\n"
+ "b 106f\n"
+ "104:" // Height 3: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 105f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d16, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v16.s }[2], [x25]\n"
+ "b 106f\n"
+ "105:" // Height 3: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s16, [x25, #0x0]\n"
+ "106:" // Height 3: Partial direct writeback: Done
+ "b 108f\n"
+ "107:" // Height 3: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q16, [x25, #0x0]\n"
+ "str q17, [x25, #0x10]\n"
+ "str q18, [x25, #0x20]\n"
+ "str q19, [x25, #0x30]\n"
+ "108:" // Height 3: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 74b\n"
+ "b 218f\n"
+ "109:" // Height 4
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "110:" // Height 4: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 111f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 111f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 111f\n"
+ "mov x11, x12\n"
+ "111:" // Height 4: B setup done
+ "cbz x15, 112f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "b 124f\n"
+ "112:" // Height 4: no bias
+ "tbz %x[flags], #0, 123f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "bge 121f\n"
+ "tbz x14, #3, 116f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "tbz x14, #2, 114f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 113f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "b 120f\n"
+ "113:" // Height 4: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "b 120f\n"
+ "114:" // Height 4: Partial accumulate: partial_2_8
+ "tbz x14, #1, 115f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "b 120f\n"
+ "115:" // Height 4: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "b 120f\n"
+ "116:" // Height 4: Partial accumulate: partial_4_0
+ "tbz x14, #2, 118f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 117f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "b 120f\n"
+ "117:" // Height 4: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "b 120f\n"
+ "118:" // Height 4: Partial accumulate: partial_2_0
+ "tbz x14, #1, 119f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "b 120f\n"
+ "119:" // Height 4: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "120:" // Height 4: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 122f\n"
+ "121:" // Height 4: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "122:" // Height 4: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "b 124f\n"
+ "123:" // Height 4: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "124:" // Height 4: setup done
+ "mov x28, #0x0\n"
+ "125:" // Height 4: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 126f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "cbnz x28, 127f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "b 127f\n"
+ "126:" // Height 4: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "127:" // Height 4: input setup done
+ "cmp x27, #0x4\n"
+ "blt 130f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 129f\n"
+ "128:" // Height 4: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 128b\n"
+ "129:" // Height 4: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "130:" // Height 4: Multiply loop: Main loop skip
+ "cbz x27, 133f\n"
+ "cbz x27, 133f\n"
+ "tbz x27, #1, 131f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "tbz x27, #0, 132f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "b 132f\n"
+ "131:" // Height 4: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "132:" // Height 4: Multiply loop: Ragged operand read: Done
+ "ldr q26, [x12, #0x0]\n"
+ "ldr q25, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "133:" // Height 4: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 125b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "tbz %x[flags], #1, 134f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v26.4s }, [x21]\n"
+ "ld1r { v25.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v26.4s\n"
+ "fmin v12.4s, v12.4s, v26.4s\n"
+ "fmin v13.4s, v13.4s, v26.4s\n"
+ "fmin v14.4s, v14.4s, v26.4s\n"
+ "fmin v8.4s, v8.4s, v26.4s\n"
+ "fmin v9.4s, v9.4s, v26.4s\n"
+ "fmin v10.4s, v10.4s, v26.4s\n"
+ "fmin v11.4s, v11.4s, v26.4s\n"
+ "fmin v15.4s, v15.4s, v26.4s\n"
+ "fmin v20.4s, v20.4s, v26.4s\n"
+ "fmin v21.4s, v21.4s, v26.4s\n"
+ "fmin v22.4s, v22.4s, v26.4s\n"
+ "fmin v16.4s, v16.4s, v26.4s\n"
+ "fmin v17.4s, v17.4s, v26.4s\n"
+ "fmin v18.4s, v18.4s, v26.4s\n"
+ "fmin v19.4s, v19.4s, v26.4s\n"
+ "fmax v6.4s, v6.4s, v25.4s\n"
+ "fmax v12.4s, v12.4s, v25.4s\n"
+ "fmax v13.4s, v13.4s, v25.4s\n"
+ "fmax v14.4s, v14.4s, v25.4s\n"
+ "fmax v8.4s, v8.4s, v25.4s\n"
+ "fmax v9.4s, v9.4s, v25.4s\n"
+ "fmax v10.4s, v10.4s, v25.4s\n"
+ "fmax v11.4s, v11.4s, v25.4s\n"
+ "fmax v15.4s, v15.4s, v25.4s\n"
+ "fmax v20.4s, v20.4s, v25.4s\n"
+ "fmax v21.4s, v21.4s, v25.4s\n"
+ "fmax v22.4s, v22.4s, v25.4s\n"
+ "fmax v16.4s, v16.4s, v25.4s\n"
+ "fmax v17.4s, v17.4s, v25.4s\n"
+ "fmax v18.4s, v18.4s, v25.4s\n"
+ "fmax v19.4s, v19.4s, v25.4s\n"
+ "134:" // Height 4: No activation
+ "cmp x14, #0x10\n"
+ "bge 143f\n"
+ "tbz x14, #3, 138f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "tbz x14, #2, 136f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 135f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "b 142f\n"
+ "135:" // Height 4: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 142f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "b 142f\n"
+ "136:" // Height 4: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 137f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "b 142f\n"
+ "137:" // Height 4: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 142f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "b 142f\n"
+ "138:" // Height 4: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 140f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 139f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "b 142f\n"
+ "139:" // Height 4: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 142f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "b 142f\n"
+ "140:" // Height 4: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 141f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "b 142f\n"
+ "141:" // Height 4: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "142:" // Height 4: Partial direct writeback: Done
+ "b 144f\n"
+ "143:" // Height 4: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "144:" // Height 4: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 110b\n"
+ "b 218f\n"
+ "145:" // Height 5
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "146:" // Height 5: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 147f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 147f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 147f\n"
+ "mov x11, x12\n"
+ "147:" // Height 5: B setup done
+ "cbz x15, 148f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "mov v24.16b, v8.16b\n"
+ "mov v28.16b, v12.16b\n"
+ "mov v25.16b, v9.16b\n"
+ "mov v29.16b, v13.16b\n"
+ "mov v26.16b, v10.16b\n"
+ "mov v30.16b, v14.16b\n"
+ "mov v27.16b, v11.16b\n"
+ "mov v31.16b, v15.16b\n"
+ "b 160f\n"
+ "148:" // Height 5: no bias
+ "tbz %x[flags], #0, 159f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "add x23, x24, x20, LSL #2\n"
+ "bge 157f\n"
+ "tbz x14, #3, 152f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "ld1 { v26.4s }, [x23], #0x10\n"
+ "tbz x14, #2, 150f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "ld1 { v27.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 149f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "ldr d6, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "ld1 { v6.s }[2], [x23]\n"
+ "b 156f\n"
+ "149:" // Height 5: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "ldr s6, [x23, #0x0]\n"
+ "b 156f\n"
+ "150:" // Height 5: Partial accumulate: partial_2_8
+ "tbz x14, #1, 151f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "ldr d27, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "ld1 { v27.s }[2], [x23]\n"
+ "b 156f\n"
+ "151:" // Height 5: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "ldr s27, [x23, #0x0]\n"
+ "b 156f\n"
+ "152:" // Height 5: Partial accumulate: partial_4_0
+ "tbz x14, #2, 154f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 153f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "ldr d26, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "ld1 { v26.s }[2], [x23]\n"
+ "b 156f\n"
+ "153:" // Height 5: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "ldr s26, [x23, #0x0]\n"
+ "b 156f\n"
+ "154:" // Height 5: Partial accumulate: partial_2_0
+ "tbz x14, #1, 155f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "ldr d25, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "ld1 { v25.s }[2], [x23]\n"
+ "b 156f\n"
+ "155:" // Height 5: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "ldr s25, [x23, #0x0]\n"
+ "156:" // Height 5: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 158f\n"
+ "157:" // Height 5: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "ldr q25, [x23, #0x0]\n"
+ "ldr q26, [x23, #0x10]\n"
+ "ldr q27, [x23, #0x20]\n"
+ "ldr q6, [x23, #0x30]\n"
+ "158:" // Height 5: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "zip1 v24.2d, v25.2d, v28.2d\n"
+ "zip2 v28.2d, v25.2d, v28.2d\n"
+ "zip1 v25.2d, v26.2d, v29.2d\n"
+ "zip2 v29.2d, v26.2d, v29.2d\n"
+ "zip1 v26.2d, v27.2d, v30.2d\n"
+ "zip2 v30.2d, v27.2d, v30.2d\n"
+ "zip1 v27.2d, v6.2d, v31.2d\n"
+ "zip2 v31.2d, v6.2d, v31.2d\n"
+ "b 160f\n"
+ "159:" // Height 5: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v26.16b, #0x0\n"
+ "movi v27.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "movi v29.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v31.16b, #0x0\n"
+ "160:" // Height 5: setup done
+ "mov x28, #0x0\n"
+ "161:" // Height 5: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 162f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "ldr x22, [x20, #0x20]\n"
+ "cbnz x28, 163f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "add x22, x22, x20, LSL #2\n"
+ "b 163f\n"
+ "162:" // Height 5: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "add x22, x23, x21, LSL #2\n"
+ "163:" // Height 5: input setup done
+ "cmp x27, #0x4\n"
+ "blt 166f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 165f\n"
+ "164:" // Height 5: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q6, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q5, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x10, #0x0]\n"
+ ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e45ec9d // bfmmla v29.4s, v4.8h, v5.8h\n"
+ "ldr q5, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x9, #0x0]\n"
+ ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e45ec9e // bfmmla v30.4s, v4.8h, v5.8h\n"
+ "ldr q5, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e45ec0f // bfmmla v15.4s, v0.8h, v5.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ ".inst 0x6e45ec9f // bfmmla v31.4s, v4.8h, v5.8h\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 164b\n"
+ "165:" // Height 5: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "166:" // Height 5: Multiply loop: Main loop skip
+ "cbz x27, 169f\n"
+ "cbz x27, 169f\n"
+ "tbz x27, #1, 167f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "ldr d4, [x22], #0x8\n"
+ "tbz x27, #0, 168f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "ld1 { v4.s }[2], [x22]\n"
+ "b 168f\n"
+ "167:" // Height 5: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "ldr s4, [x22, #0x0]\n"
+ "168:" // Height 5: Multiply loop: Ragged operand read: Done
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q5, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec9c // bfmmla v28.4s, v4.8h, v5.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "169:" // Height 5: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 161b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "add x23, x24, x20, LSL #2\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "uzp1 v24.2d, v24.2d, v28.2d\n"
+ "uzp1 v25.2d, v25.2d, v29.2d\n"
+ "uzp1 v26.2d, v26.2d, v30.2d\n"
+ "uzp1 v27.2d, v27.2d, v31.2d\n"
+ "tbz %x[flags], #1, 170f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v1.4s }, [x21]\n"
+ "ld1r { v0.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v1.4s\n"
+ "fmin v12.4s, v12.4s, v1.4s\n"
+ "fmin v13.4s, v13.4s, v1.4s\n"
+ "fmin v14.4s, v14.4s, v1.4s\n"
+ "fmin v8.4s, v8.4s, v1.4s\n"
+ "fmin v9.4s, v9.4s, v1.4s\n"
+ "fmin v10.4s, v10.4s, v1.4s\n"
+ "fmin v11.4s, v11.4s, v1.4s\n"
+ "fmin v15.4s, v15.4s, v1.4s\n"
+ "fmin v20.4s, v20.4s, v1.4s\n"
+ "fmin v21.4s, v21.4s, v1.4s\n"
+ "fmin v22.4s, v22.4s, v1.4s\n"
+ "fmin v16.4s, v16.4s, v1.4s\n"
+ "fmin v17.4s, v17.4s, v1.4s\n"
+ "fmin v18.4s, v18.4s, v1.4s\n"
+ "fmin v19.4s, v19.4s, v1.4s\n"
+ "fmin v24.4s, v24.4s, v1.4s\n"
+ "fmin v25.4s, v25.4s, v1.4s\n"
+ "fmin v26.4s, v26.4s, v1.4s\n"
+ "fmin v27.4s, v27.4s, v1.4s\n"
+ "fmax v6.4s, v6.4s, v0.4s\n"
+ "fmax v12.4s, v12.4s, v0.4s\n"
+ "fmax v13.4s, v13.4s, v0.4s\n"
+ "fmax v14.4s, v14.4s, v0.4s\n"
+ "fmax v8.4s, v8.4s, v0.4s\n"
+ "fmax v9.4s, v9.4s, v0.4s\n"
+ "fmax v10.4s, v10.4s, v0.4s\n"
+ "fmax v11.4s, v11.4s, v0.4s\n"
+ "fmax v15.4s, v15.4s, v0.4s\n"
+ "fmax v20.4s, v20.4s, v0.4s\n"
+ "fmax v21.4s, v21.4s, v0.4s\n"
+ "fmax v22.4s, v22.4s, v0.4s\n"
+ "fmax v16.4s, v16.4s, v0.4s\n"
+ "fmax v17.4s, v17.4s, v0.4s\n"
+ "fmax v18.4s, v18.4s, v0.4s\n"
+ "fmax v19.4s, v19.4s, v0.4s\n"
+ "fmax v24.4s, v24.4s, v0.4s\n"
+ "fmax v25.4s, v25.4s, v0.4s\n"
+ "fmax v26.4s, v26.4s, v0.4s\n"
+ "fmax v27.4s, v27.4s, v0.4s\n"
+ "170:" // Height 5: No activation
+ "cmp x14, #0x10\n"
+ "bge 179f\n"
+ "tbz x14, #3, 174f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "st1 { v24.4s }, [x23], #0x10\n"
+ "st1 { v25.4s }, [x23], #0x10\n"
+ "tbz x14, #2, 172f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "st1 { v26.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 171f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "str d27, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "st1 { v27.s }[2], [x23]\n"
+ "b 178f\n"
+ "171:" // Height 5: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 178f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "str s27, [x23, #0x0]\n"
+ "b 178f\n"
+ "172:" // Height 5: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 173f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "str d26, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "st1 { v26.s }[2], [x23]\n"
+ "b 178f\n"
+ "173:" // Height 5: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 178f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "str s26, [x23, #0x0]\n"
+ "b 178f\n"
+ "174:" // Height 5: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 176f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v24.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 175f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "str d25, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "st1 { v25.s }[2], [x23]\n"
+ "b 178f\n"
+ "175:" // Height 5: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 178f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "str s25, [x23, #0x0]\n"
+ "b 178f\n"
+ "176:" // Height 5: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 177f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "str d24, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "st1 { v24.s }[2], [x23]\n"
+ "b 178f\n"
+ "177:" // Height 5: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "str s24, [x23, #0x0]\n"
+ "178:" // Height 5: Partial direct writeback: Done
+ "b 180f\n"
+ "179:" // Height 5: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "str q24, [x23, #0x0]\n"
+ "str q25, [x23, #0x10]\n"
+ "str q26, [x23, #0x20]\n"
+ "str q27, [x23, #0x30]\n"
+ "180:" // Height 5: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 146b\n"
+ "b 218f\n"
+ "181:" // Height 6
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "mov x21, #0x18\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "madd x21, x20, x21, x13\n"
+ "str x21, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "182:" // Height 6: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 183f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 183f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 183f\n"
+ "mov x11, x12\n"
+ "183:" // Height 6: B setup done
+ "cbz x15, 184f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "mov v24.16b, v8.16b\n"
+ "mov v28.16b, v12.16b\n"
+ "mov v25.16b, v9.16b\n"
+ "mov v29.16b, v13.16b\n"
+ "mov v26.16b, v10.16b\n"
+ "mov v30.16b, v14.16b\n"
+ "mov v27.16b, v11.16b\n"
+ "mov v31.16b, v15.16b\n"
+ "b 196f\n"
+ "184:" // Height 6: no bias
+ "tbz %x[flags], #0, 195f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "add x23, x24, x20, LSL #2\n"
+ "add x22, x23, x20, LSL #2\n"
+ "bge 193f\n"
+ "tbz x14, #3, 188f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v28.4s }, [x22], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "ld1 { v26.4s }, [x23], #0x10\n"
+ "ld1 { v29.4s }, [x22], #0x10\n"
+ "tbz x14, #2, 186f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "ld1 { v27.4s }, [x23], #0x10\n"
+ "ld1 { v30.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 185f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "ldr d6, [x23], #0x8\n"
+ "ldr d31, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "ld1 { v6.s }[2], [x23]\n"
+ "ld1 { v31.s }[2], [x22]\n"
+ "b 192f\n"
+ "185:" // Height 6: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "ldr s6, [x23, #0x0]\n"
+ "ldr s31, [x22, #0x0]\n"
+ "b 192f\n"
+ "186:" // Height 6: Partial accumulate: partial_2_8
+ "tbz x14, #1, 187f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "ldr d27, [x23], #0x8\n"
+ "ldr d30, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "ld1 { v27.s }[2], [x23]\n"
+ "ld1 { v30.s }[2], [x22]\n"
+ "b 192f\n"
+ "187:" // Height 6: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "ldr s27, [x23, #0x0]\n"
+ "ldr s30, [x22, #0x0]\n"
+ "b 192f\n"
+ "188:" // Height 6: Partial accumulate: partial_4_0
+ "tbz x14, #2, 190f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v28.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 189f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "ldr d26, [x23], #0x8\n"
+ "ldr d29, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "ld1 { v26.s }[2], [x23]\n"
+ "ld1 { v29.s }[2], [x22]\n"
+ "b 192f\n"
+ "189:" // Height 6: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "ldr s26, [x23, #0x0]\n"
+ "ldr s29, [x22, #0x0]\n"
+ "b 192f\n"
+ "190:" // Height 6: Partial accumulate: partial_2_0
+ "tbz x14, #1, 191f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "ldr d25, [x23], #0x8\n"
+ "ldr d28, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "ld1 { v25.s }[2], [x23]\n"
+ "ld1 { v28.s }[2], [x22]\n"
+ "b 192f\n"
+ "191:" // Height 6: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "ldr s25, [x23, #0x0]\n"
+ "ldr s28, [x22, #0x0]\n"
+ "192:" // Height 6: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 194f\n"
+ "193:" // Height 6: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "ldr q25, [x23, #0x0]\n"
+ "ldr q26, [x23, #0x10]\n"
+ "ldr q27, [x23, #0x20]\n"
+ "ldr q6, [x23, #0x30]\n"
+ "ldr q28, [x22, #0x0]\n"
+ "ldr q29, [x22, #0x10]\n"
+ "ldr q30, [x22, #0x20]\n"
+ "ldr q31, [x22, #0x30]\n"
+ "194:" // Height 6: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "zip1 v24.2d, v25.2d, v28.2d\n"
+ "zip2 v28.2d, v25.2d, v28.2d\n"
+ "zip1 v25.2d, v26.2d, v29.2d\n"
+ "zip2 v29.2d, v26.2d, v29.2d\n"
+ "zip1 v26.2d, v27.2d, v30.2d\n"
+ "zip2 v30.2d, v27.2d, v30.2d\n"
+ "zip1 v27.2d, v6.2d, v31.2d\n"
+ "zip2 v31.2d, v6.2d, v31.2d\n"
+ "b 196f\n"
+ "195:" // Height 6: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v26.16b, #0x0\n"
+ "movi v27.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "movi v29.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v31.16b, #0x0\n"
+ "196:" // Height 6: setup done
+ "mov x28, #0x0\n"
+ "197:" // Height 6: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 198f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "ldr x22, [x20, #0x20]\n"
+ "ldr x21, [x20, #0x28]\n"
+ "cbnz x28, 199f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "add x22, x22, x20, LSL #2\n"
+ "add x21, x21, x20, LSL #2\n"
+ "b 199f\n"
+ "198:" // Height 6: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "add x22, x23, x21, LSL #2\n"
+ "add x21, x22, x21, LSL #2\n"
+ "199:" // Height 6: input setup done
+ "cmp x27, #0x4\n"
+ "blt 202f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ld1 { v5.4s }, [x21], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 201f\n"
+ "200:" // Height 6: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ "ld1 { v5.4s }, [x21], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x11, #0x0]\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x10, #0x0]\n"
+ ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x9, #0x0]\n"
+ ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 200b\n"
+ "201:" // Height 6: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "202:" // Height 6: Multiply loop: Main loop skip
+ "cbz x27, 205f\n"
+ "cbz x27, 205f\n"
+ "tbz x27, #1, 203f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "ldr d4, [x22], #0x8\n"
+ "ldr d5, [x21], #0x8\n"
+ "tbz x27, #0, 204f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "ld1 { v4.s }[2], [x22]\n"
+ "ld1 { v5.s }[2], [x21]\n"
+ "b 204f\n"
+ "203:" // Height 6: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "ldr s4, [x22, #0x0]\n"
+ "ldr s5, [x21, #0x0]\n"
+ "204:" // Height 6: Multiply loop: Ragged operand read: Done
+ "ldr q7, [x12, #0x0]\n"
+ "ldr q6, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "205:" // Height 6: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 197b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x23, x24, x20, LSL #2\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "add x22, x23, x20, LSL #2\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "uzp1 v23.2d, v24.2d, v28.2d\n"
+ "uzp2 v24.2d, v24.2d, v28.2d\n"
+ "uzp1 v28.2d, v25.2d, v29.2d\n"
+ "uzp2 v25.2d, v25.2d, v29.2d\n"
+ "uzp1 v29.2d, v26.2d, v30.2d\n"
+ "uzp2 v26.2d, v26.2d, v30.2d\n"
+ "uzp1 v30.2d, v27.2d, v31.2d\n"
+ "uzp2 v27.2d, v27.2d, v31.2d\n"
+ "tbz %x[flags], #1, 206f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v1.4s }, [x21]\n"
+ "ld1r { v0.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v1.4s\n"
+ "fmin v12.4s, v12.4s, v1.4s\n"
+ "fmin v13.4s, v13.4s, v1.4s\n"
+ "fmin v14.4s, v14.4s, v1.4s\n"
+ "fmin v8.4s, v8.4s, v1.4s\n"
+ "fmin v9.4s, v9.4s, v1.4s\n"
+ "fmin v10.4s, v10.4s, v1.4s\n"
+ "fmin v11.4s, v11.4s, v1.4s\n"
+ "fmin v15.4s, v15.4s, v1.4s\n"
+ "fmin v20.4s, v20.4s, v1.4s\n"
+ "fmin v21.4s, v21.4s, v1.4s\n"
+ "fmin v22.4s, v22.4s, v1.4s\n"
+ "fmin v16.4s, v16.4s, v1.4s\n"
+ "fmin v17.4s, v17.4s, v1.4s\n"
+ "fmin v18.4s, v18.4s, v1.4s\n"
+ "fmin v19.4s, v19.4s, v1.4s\n"
+ "fmin v23.4s, v23.4s, v1.4s\n"
+ "fmin v28.4s, v28.4s, v1.4s\n"
+ "fmin v29.4s, v29.4s, v1.4s\n"
+ "fmin v30.4s, v30.4s, v1.4s\n"
+ "fmin v24.4s, v24.4s, v1.4s\n"
+ "fmin v25.4s, v25.4s, v1.4s\n"
+ "fmin v26.4s, v26.4s, v1.4s\n"
+ "fmin v27.4s, v27.4s, v1.4s\n"
+ "fmax v6.4s, v6.4s, v0.4s\n"
+ "fmax v12.4s, v12.4s, v0.4s\n"
+ "fmax v13.4s, v13.4s, v0.4s\n"
+ "fmax v14.4s, v14.4s, v0.4s\n"
+ "fmax v8.4s, v8.4s, v0.4s\n"
+ "fmax v9.4s, v9.4s, v0.4s\n"
+ "fmax v10.4s, v10.4s, v0.4s\n"
+ "fmax v11.4s, v11.4s, v0.4s\n"
+ "fmax v15.4s, v15.4s, v0.4s\n"
+ "fmax v20.4s, v20.4s, v0.4s\n"
+ "fmax v21.4s, v21.4s, v0.4s\n"
+ "fmax v22.4s, v22.4s, v0.4s\n"
+ "fmax v16.4s, v16.4s, v0.4s\n"
+ "fmax v17.4s, v17.4s, v0.4s\n"
+ "fmax v18.4s, v18.4s, v0.4s\n"
+ "fmax v19.4s, v19.4s, v0.4s\n"
+ "fmax v23.4s, v23.4s, v0.4s\n"
+ "fmax v28.4s, v28.4s, v0.4s\n"
+ "fmax v29.4s, v29.4s, v0.4s\n"
+ "fmax v30.4s, v30.4s, v0.4s\n"
+ "fmax v24.4s, v24.4s, v0.4s\n"
+ "fmax v25.4s, v25.4s, v0.4s\n"
+ "fmax v26.4s, v26.4s, v0.4s\n"
+ "fmax v27.4s, v27.4s, v0.4s\n"
+ "206:" // Height 6: No activation
+ "cmp x14, #0x10\n"
+ "bge 215f\n"
+ "tbz x14, #3, 210f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "st1 { v23.4s }, [x23], #0x10\n"
+ "st1 { v28.4s }, [x23], #0x10\n"
+ "st1 { v24.4s }, [x22], #0x10\n"
+ "st1 { v25.4s }, [x22], #0x10\n"
+ "tbz x14, #2, 208f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "st1 { v29.4s }, [x23], #0x10\n"
+ "st1 { v26.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 207f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "str d30, [x23], #0x8\n"
+ "str d27, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "st1 { v30.s }[2], [x23]\n"
+ "st1 { v27.s }[2], [x22]\n"
+ "b 214f\n"
+ "207:" // Height 6: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 214f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "str s30, [x23, #0x0]\n"
+ "str s27, [x22, #0x0]\n"
+ "b 214f\n"
+ "208:" // Height 6: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 209f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "str d29, [x23], #0x8\n"
+ "str d26, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "st1 { v29.s }[2], [x23]\n"
+ "st1 { v26.s }[2], [x22]\n"
+ "b 214f\n"
+ "209:" // Height 6: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 214f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "str s29, [x23, #0x0]\n"
+ "str s26, [x22, #0x0]\n"
+ "b 214f\n"
+ "210:" // Height 6: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 212f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v23.4s }, [x23], #0x10\n"
+ "st1 { v24.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 211f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "str d28, [x23], #0x8\n"
+ "str d25, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "st1 { v28.s }[2], [x23]\n"
+ "st1 { v25.s }[2], [x22]\n"
+ "b 214f\n"
+ "211:" // Height 6: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 214f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "str s28, [x23, #0x0]\n"
+ "str s25, [x22, #0x0]\n"
+ "b 214f\n"
+ "212:" // Height 6: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 213f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "str d23, [x23], #0x8\n"
+ "str d24, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "st1 { v23.s }[2], [x23]\n"
+ "st1 { v24.s }[2], [x22]\n"
+ "b 214f\n"
+ "213:" // Height 6: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "str s23, [x23, #0x0]\n"
+ "str s24, [x22, #0x0]\n"
+ "214:" // Height 6: Partial direct writeback: Done
+ "b 216f\n"
+ "215:" // Height 6: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "str q23, [x23, #0x0]\n"
+ "str q28, [x23, #0x10]\n"
+ "str q29, [x23, #0x20]\n"
+ "str q30, [x23, #0x30]\n"
+ "str q24, [x22, #0x0]\n"
+ "str q25, [x22, #0x10]\n"
+ "str q26, [x22, #0x20]\n"
+ "str q27, [x22, #0x30]\n"
+ "216:" // Height 6: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 182b\n"
+ "subs %x[M], %x[M], #0x6\n"
+ "beq 218f\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "tbz %x[flags], #3, 217f\n"
+ "add x21, x21, #0x6\n"
+ "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "b 1b\n"
+ "217:" // Update direct input
+ "mov x20, #0x18\n"
+ "madd %x[input_ptr], x20, x21, %x[input_ptr]\n"
+ "b 1b\n"
+ "218:" // Exit
+ : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr)
+ : [args_ptr] "r" (&ka), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_output_ptr] "I" (offsetof(KernelArgs, output_ptr)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths))
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+}
+
+} // namespace arm_gemm
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
index cf4d74266..1a8b0fd63 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 45.25, 4.29, 4.80 };
default:
- return { 38.10, 5.23, 3.15 };
+ return { 29.85, 2.60, 5.49 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
index 171929e65..bce4de74f 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,7 @@
#pragma once
#ifdef __aarch64__
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
#include "../performance_parameters.hpp"
#define ARGLIST \
@@ -71,7 +71,7 @@ public:
return true;
}
- StdTransformsFixed<rhs_operand_type, result_type, 4, 24, 1> transforms = {};
+ StdTransformsFixedTRB<rhs_operand_type, result_type, 4, 24, 1> transforms = {};
template<typename T>
static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
{
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
index 759729de5..7f85d2dd4 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021, 2023 Arm Limited.
+ * Copyright (c) 2019-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,7 @@
#pragma once
#ifdef __aarch64__
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
#include "../performance_parameters.hpp"
#define ARGLIST \
@@ -71,7 +71,7 @@ public:
return true;
}
- StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 1> transforms = {};
+ StdTransformsFixedTRB<rhs_operand_type, result_type, 6, 16, 1> transforms = {};
template<typename T>
static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
{
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
index 65ef407f7..19acfe8ae 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,7 +25,7 @@
#ifdef __aarch64__
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
#include "../performance_parameters.hpp"
#include "../bfloat.hpp"
@@ -68,7 +68,7 @@ public:
}
// Use the standard fixed size transforms.
- StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
+ StdTransformsFixedTRB<operand_type, result_type, 8, 12> transforms = {};
template<typename T>
static PerformanceParameters get_performance_parameters(const CPUInfo *ci) {
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_gemv_fp16fp32fp16_dot_16VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_gemv_fp16fp32fp16_dot_16VL/generic.cpp
index 1067a8548..97c242761 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sme2_gemv_fp16fp32fp16_dot_16VL/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_gemv_fp16fp32fp16_dot_16VL/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -105,10 +105,18 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"mov x20, %x[N]\n"
"mov x22, %x[K]\n"
".inst 0xf8b54af8 // rprfm pldmany, x21, [x23]\n"
- ".inst 0x257467f0 // whilelt p8.h, XZR, x20, VLx4\n"
+ ".inst 0x257447f0 // whilelt p8.h, XZR, x20, VLx2\n"
"cbz x24, 5f\n"
- ".inst 0xa040c700 // ld1w { z0.s-z3.s }, pn9.b/Z, [x24]\n"
- ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
+ "ld1h { z20.s }, p1/Z, [x24]\n"
+ "addvl x20, x24, #4\n"
+ "ld1h { z21.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "ld1h { z22.s }, p1/Z, [x24, #2, MUL VL]\n"
+ "ld1h { z23.s }, p1/Z, [x24, #3, MUL VL]\n"
+ "fcvt z20.s, p1/m, z20.h\n"
+ "fcvt z21.s, p1/m, z21.h\n"
+ "fcvt z22.s, p1/m, z22.h\n"
+ "fcvt z23.s, p1/m, z23.h\n"
+ ".inst 0xc0042e80 // mova za.d[x9, #0], { z20.d-z23.d }\n"
"b 6f\n"
"5:" // Width 1: no bias
".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
@@ -117,63 +125,63 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"ble 8f\n"
"7:" // Width 1: Multiply loop: Main loop head
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
"addvl x27, x27, #16\n"
- "ld1rqh { z14.h }, p0/Z, [x23]\n"
+ "ld1rqh { z0.h }, p0/Z, [x23]\n"
"sub x22, x22, #0x8\n"
"add x23, x23, #0x10\n"
- ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
"addvl x27, x27, #16\n"
"cmp x22, #0x8\n"
- ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc15eb308 // fdot za.s[x9, 0], { z24.h-z27.h }, z14.h[0]\n"
- ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xc150b208 // fdot za.s[x9, 0], { z16.h-z19.h }, z0.h[0]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc15eb608 // fdot za.s[x9, 0], { z16.h-z19.h }, z14.h[1]\n"
- ".inst 0xc15eb808 // fdot za.s[x9, 0], { z0.h-z3.h }, z14.h[2]\n"
- ".inst 0xc15ebf88 // fdot za.s[x9, 0], { z28.h-z31.h }, z14.h[3]\n"
+ ".inst 0xc150b788 // fdot za.s[x9, 0], { z28.h-z31.h }, z0.h[1]\n"
+ ".inst 0xc150bb08 // fdot za.s[x9, 0], { z24.h-z27.h }, z0.h[2]\n"
+ ".inst 0xc150bc88 // fdot za.s[x9, 0], { z4.h-z7.h }, z0.h[3]\n"
"bgt 7b\n"
"8:" // Width 1: Multiply loop: Single iteration only
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z11.h }, p0/Z, [x23]\n"
"add x23, x23, #0x10\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b388 // fdot za.s[x9, 0], { z28.h-z31.h }, z7.h[0]\n"
+ ".inst 0xc15bb308 // fdot za.s[x9, 0], { z24.h-z27.h }, z11.h[0]\n"
"ble 9f\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b508 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[1]\n"
+ ".inst 0xc15bb708 // fdot za.s[x9, 0], { z24.h-z27.h }, z11.h[1]\n"
"ble 9f\n"
- ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b988 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[2]\n"
+ ".inst 0xc15bbb08 // fdot za.s[x9, 0], { z24.h-z27.h }, z11.h[2]\n"
"ble 9f\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bd08 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[3]\n"
+ ".inst 0xc15bbf88 // fdot za.s[x9, 0], { z28.h-z31.h }, z11.h[3]\n"
"9:" // Width 1: Multiply loop: multiply skip
"tbz %x[flags], #1, 10f\n"
".inst 0xc0062c10 // mova { z16.d-z19.d }, za.d[x9, #0]\n"
"add x21, %x[args_ptr], %[offset_min]\n"
"add x20, %x[args_ptr], %[offset_max]\n"
- "ld1rh { z2.h }, p1/Z, [x21]\n"
- "ld1rh { z1.h }, p1/Z, [x20]\n"
- ".inst 0xc120e210 // fcvt z16.h, { z16.s-z17.s }\n"
- ".inst 0xc120e251 // fcvt z17.h, { z18.s-z19.s }\n"
- ".inst 0xc161c050 // fclamp { z16.h-z17.h }, z2.h, z1.h\n"
- ".inst 0xa0602330 // st1h { z16.h-z17.h }, p8, [x25]\n"
+ "ld1rh { z29.h }, p1/Z, [x21]\n"
+ "ld1rh { z20.h }, p1/Z, [x20]\n"
+ ".inst 0xc120e204 // fcvt z4.h, { z16.s-z17.s }\n"
+ ".inst 0xc120e245 // fcvt z5.h, { z18.s-z19.s }\n"
+ ".inst 0xc174c3a4 // fclamp { z4.h-z5.h }, z29.h, z20.h\n"
+ ".inst 0xa0602324 // st1h { z4.h-z5.h }, p8, [x25]\n"
"addvl x25, x25, #2\n"
"b 11f\n"
"10:" // Width 1: No activation
- ".inst 0xc0062c04 // mova { z4.d-z7.d }, za.d[x9, #0]\n"
- ".inst 0xc120e09e // fcvt z30.h, { z4.s-z5.s }\n"
- ".inst 0xc120e0df // fcvt z31.h, { z6.s-z7.s }\n"
- ".inst 0xa060233e // st1h { z30.h-z31.h }, p8, [x25]\n"
+ ".inst 0xc0062c00 // mova { z0.d-z3.d }, za.d[x9, #0]\n"
+ ".inst 0xc120e012 // fcvt z18.h, { z0.s-z1.s }\n"
+ ".inst 0xc120e05a // fcvt z26.h, { z2.s-z3.s }\n"
+ ".inst 0xa1602332 // st1h { z18.h, z26.h }, p8, [x25]\n"
"addvl x25, x25, #2\n"
"11:" // Width 1: Output done
"b 36f\n"
@@ -183,12 +191,27 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"sub x20, %x[N], x28\n"
"mov x22, %x[K]\n"
".inst 0xf8b54af8 // rprfm pldmany, x21, [x23]\n"
- ".inst 0x257467f0 // whilelt p8.h, XZR, x20, VLx4\n"
+ ".inst 0x257447f0 // whilelt p8.h, XZR, x20, VLx2\n"
"cbz x24, 13f\n"
- ".inst 0xa040c71c // ld1w { z28.s-z31.s }, pn9.b/Z, [x24]\n"
- ".inst 0xa041c700 // ld1w { z0.s-z3.s }, pn9.b/Z, [x24, #0x4, MUL VL]\n"
- ".inst 0xc0042f80 // mova za.d[x9, #0], { z28.d-z31.d }\n"
- ".inst 0xc0042c01 // mova za.d[x9, #1], { z0.d-z3.d }\n"
+ "ld1h { z12.s }, p1/Z, [x24]\n"
+ "addvl x20, x24, #4\n"
+ "ld1h { z13.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "ld1h { z14.s }, p1/Z, [x24, #2, MUL VL]\n"
+ "ld1h { z15.s }, p1/Z, [x24, #3, MUL VL]\n"
+ "fcvt z12.s, p1/m, z12.h\n"
+ "ld1h { z28.s }, p1/Z, [x24, #4, MUL VL]\n"
+ "fcvt z13.s, p1/m, z13.h\n"
+ "ld1h { z29.s }, p1/Z, [x24, #5, MUL VL]\n"
+ "fcvt z14.s, p1/m, z14.h\n"
+ "ld1h { z30.s }, p1/Z, [x24, #6, MUL VL]\n"
+ "fcvt z15.s, p1/m, z15.h\n"
+ "ld1h { z31.s }, p1/Z, [x24, #7, MUL VL]\n"
+ "fcvt z28.s, p1/m, z28.h\n"
+ "fcvt z29.s, p1/m, z29.h\n"
+ "fcvt z30.s, p1/m, z30.h\n"
+ "fcvt z31.s, p1/m, z31.h\n"
+ ".inst 0xc0042d80 // mova za.d[x9, #0], { z12.d-z15.d }\n"
+ ".inst 0xc0042f81 // mova za.d[x9, #1], { z28.d-z31.d }\n"
"b 14f\n"
"13:" // Width 2: no bias
".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
@@ -197,88 +220,88 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"ble 16f\n"
"15:" // Width 2: Multiply loop: Main loop head
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
"sub x22, x22, #0x8\n"
- "ld1rqh { z1.h }, p0/Z, [x23]\n"
+ "ld1rqh { z8.h }, p0/Z, [x23]\n"
"cmp x22, #0x8\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc151b308 // fdot za.s[x9, 0], { z24.h-z27.h }, z1.h[0]\n"
+ ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xc158b088 // fdot za.s[x9, 0], { z4.h-z7.h }, z8.h[0]\n"
".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc151b289 // fdot za.s[x9, 1], { z20.h-z23.h }, z1.h[0]\n"
- ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
- ".inst 0xa041a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc158b009 // fdot za.s[x9, 1], { z0.h-z3.h }, z8.h[0]\n"
+ ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc151b488 // fdot za.s[x9, 0], { z4.h-z7.h }, z1.h[1]\n"
".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc151b589 // fdot za.s[x9, 1], { z12.h-z15.h }, z1.h[1]\n"
- ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc158b608 // fdot za.s[x9, 0], { z16.h-z19.h }, z8.h[1]\n"
+ ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc151bb88 // fdot za.s[x9, 0], { z28.h-z31.h }, z1.h[2]\n"
- ".inst 0xc151b909 // fdot za.s[x9, 1], { z8.h-z11.h }, z1.h[2]\n"
- ".inst 0xc151bc88 // fdot za.s[x9, 0], { z4.h-z7.h }, z1.h[3]\n"
- ".inst 0xc151bd89 // fdot za.s[x9, 1], { z12.h-z15.h }, z1.h[3]\n"
+ ".inst 0xc158b589 // fdot za.s[x9, 1], { z12.h-z15.h }, z8.h[1]\n"
+ ".inst 0xc158bb08 // fdot za.s[x9, 0], { z24.h-z27.h }, z8.h[2]\n"
+ ".inst 0xc158b809 // fdot za.s[x9, 1], { z0.h-z3.h }, z8.h[2]\n"
+ ".inst 0xc158bc88 // fdot za.s[x9, 0], { z4.h-z7.h }, z8.h[3]\n"
+ ".inst 0xc158bf89 // fdot za.s[x9, 1], { z28.h-z31.h }, z8.h[3]\n"
"bgt 15b\n"
"16:" // Width 2: Multiply loop: Single iteration only
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z11.h }, p0/Z, [x23]\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b108 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[0]\n"
- ".inst 0xc157b309 // fdot za.s[x9, 1], { z24.h-z27.h }, z7.h[0]\n"
+ ".inst 0xc15bb088 // fdot za.s[x9, 0], { z4.h-z7.h }, z11.h[0]\n"
+ ".inst 0xc15bb189 // fdot za.s[x9, 1], { z12.h-z15.h }, z11.h[0]\n"
"ble 17f\n"
- ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
".inst 0xa041a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b708 // fdot za.s[x9, 0], { z24.h-z27.h }, z7.h[1]\n"
- ".inst 0xc157b689 // fdot za.s[x9, 1], { z20.h-z23.h }, z7.h[1]\n"
+ ".inst 0xc15bb608 // fdot za.s[x9, 0], { z16.h-z19.h }, z11.h[1]\n"
+ ".inst 0xc15bb689 // fdot za.s[x9, 1], { z20.h-z23.h }, z11.h[1]\n"
"ble 17f\n"
- ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa041a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b808 // fdot za.s[x9, 0], { z0.h-z3.h }, z7.h[2]\n"
- ".inst 0xc157ba09 // fdot za.s[x9, 1], { z16.h-z19.h }, z7.h[2]\n"
+ ".inst 0xc15bb988 // fdot za.s[x9, 0], { z12.h-z15.h }, z11.h[2]\n"
+ ".inst 0xc15bba89 // fdot za.s[x9, 1], { z20.h-z23.h }, z11.h[2]\n"
"ble 17f\n"
- ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
- ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa041a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bd88 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[3]\n"
- ".inst 0xc157bf89 // fdot za.s[x9, 1], { z28.h-z31.h }, z7.h[3]\n"
+ ".inst 0xc15bbc08 // fdot za.s[x9, 0], { z0.h-z3.h }, z11.h[3]\n"
+ ".inst 0xc15bbf09 // fdot za.s[x9, 1], { z24.h-z27.h }, z11.h[3]\n"
"17:" // Width 2: Multiply loop: multiply skip
"tbz %x[flags], #1, 18f\n"
- ".inst 0xc0062c1c // mova { z28.d-z31.d }, za.d[x9, #0]\n"
+ ".inst 0xc0062c0c // mova { z12.d-z15.d }, za.d[x9, #0]\n"
"add x21, %x[args_ptr], %[offset_min]\n"
"add x20, %x[args_ptr], %[offset_max]\n"
- ".inst 0xc0062c30 // mova { z16.d-z19.d }, za.d[x9, #1]\n"
- "ld1rh { z13.h }, p1/Z, [x21]\n"
- "ld1rh { z0.h }, p1/Z, [x20]\n"
- ".inst 0xc120e38e // fcvt z14.h, { z28.s-z29.s }\n"
- ".inst 0xc120e3cf // fcvt z15.h, { z30.s-z31.s }\n"
- ".inst 0xc120e218 // fcvt z24.h, { z16.s-z17.s }\n"
- ".inst 0xc120e259 // fcvt z25.h, { z18.s-z19.s }\n"
- ".inst 0xc160c1ae // fclamp { z14.h-z15.h }, z13.h, z0.h\n"
- ".inst 0xc160c1b8 // fclamp { z24.h-z25.h }, z13.h, z0.h\n"
- ".inst 0xa060272e // st1h { z14.h-z15.h }, pn9.b, [x25]\n"
- ".inst 0xa0612338 // st1h { z24.h-z25.h }, p8, [x25, #0x2, MUL VL]\n"
+ ".inst 0xc0062c3c // mova { z28.d-z31.d }, za.d[x9, #1]\n"
+ "ld1rh { z5.h }, p1/Z, [x21]\n"
+ "ld1rh { z21.h }, p1/Z, [x20]\n"
+ ".inst 0xc120e188 // fcvt z8.h, { z12.s-z13.s }\n"
+ ".inst 0xc120e1c9 // fcvt z9.h, { z14.s-z15.s }\n"
+ ".inst 0xc120e39c // fcvt z28.h, { z28.s-z29.s }\n"
+ ".inst 0xc120e3dd // fcvt z29.h, { z30.s-z31.s }\n"
+ ".inst 0xc175c0a8 // fclamp { z8.h-z9.h }, z5.h, z21.h\n"
+ ".inst 0xc175c0bc // fclamp { z28.h-z29.h }, z5.h, z21.h\n"
+ ".inst 0xa0602728 // st1h { z8.h-z9.h }, pn9.b, [x25]\n"
+ ".inst 0xa061233c // st1h { z28.h-z29.h }, p8, [x25, #0x2, MUL VL]\n"
"addvl x25, x25, #4\n"
"b 19f\n"
"18:" // Width 2: No activation
- ".inst 0xc0062c10 // mova { z16.d-z19.d }, za.d[x9, #0]\n"
- ".inst 0xc0062c38 // mova { z24.d-z27.d }, za.d[x9, #1]\n"
- ".inst 0xc120e205 // fcvt z5.h, { z16.s-z17.s }\n"
- ".inst 0xc120e24d // fcvt z13.h, { z18.s-z19.s }\n"
- ".inst 0xa1602725 // st1h { z5.h, z13.h }, pn9.b, [x25]\n"
- ".inst 0xc120e316 // fcvt z22.h, { z24.s-z25.s }\n"
- ".inst 0xc120e35e // fcvt z30.h, { z26.s-z27.s }\n"
- ".inst 0xa1612336 // st1h { z22.h, z30.h }, p8, [x25, #0x2, MUL VL]\n"
+ ".inst 0xc0062c0c // mova { z12.d-z15.d }, za.d[x9, #0]\n"
+ ".inst 0xc0062c24 // mova { z4.d-z7.d }, za.d[x9, #1]\n"
+ ".inst 0xc120e194 // fcvt z20.h, { z12.s-z13.s }\n"
+ ".inst 0xc120e1dc // fcvt z28.h, { z14.s-z15.s }\n"
+ ".inst 0xa1602734 // st1h { z20.h, z28.h }, pn9.b, [x25]\n"
+ ".inst 0xc120e09a // fcvt z26.h, { z4.s-z5.s }\n"
+ ".inst 0xc120e0db // fcvt z27.h, { z6.s-z7.s }\n"
+ ".inst 0xa061233a // st1h { z26.h-z27.h }, p8, [x25, #0x2, MUL VL]\n"
"addvl x25, x25, #4\n"
"19:" // Width 2: Output done
"b 36f\n"
@@ -289,14 +312,36 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"msub x20, x28, x20, %x[N]\n"
"mov x22, %x[K]\n"
".inst 0xf8b54af8 // rprfm pldmany, x21, [x23]\n"
- ".inst 0x257467f0 // whilelt p8.h, XZR, x20, VLx4\n"
+ ".inst 0x257447f0 // whilelt p8.h, XZR, x20, VLx2\n"
"cbz x24, 21f\n"
- ".inst 0xa040c718 // ld1w { z24.s-z27.s }, pn9.b/Z, [x24]\n"
- ".inst 0xa041c70c // ld1w { z12.s-z15.s }, pn9.b/Z, [x24, #0x4, MUL VL]\n"
- ".inst 0xa042c708 // ld1w { z8.s-z11.s }, pn9.b/Z, [x24, #0x8, MUL VL]\n"
- ".inst 0xc0042f00 // mova za.d[x9, #0], { z24.d-z27.d }\n"
- ".inst 0xc0042d81 // mova za.d[x9, #1], { z12.d-z15.d }\n"
- ".inst 0xc0042d02 // mova za.d[x9, #2], { z8.d-z11.d }\n"
+ "addvl x20, x24, #4\n"
+ "ld1h { z16.s }, p1/Z, [x24]\n"
+ "ld1h { z17.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "ld1h { z18.s }, p1/Z, [x24, #2, MUL VL]\n"
+ "ld1h { z19.s }, p1/Z, [x24, #3, MUL VL]\n"
+ "fcvt z16.s, p1/m, z16.h\n"
+ "ld1h { z8.s }, p1/Z, [x24, #4, MUL VL]\n"
+ "fcvt z17.s, p1/m, z17.h\n"
+ "ld1h { z9.s }, p1/Z, [x24, #5, MUL VL]\n"
+ "fcvt z18.s, p1/m, z18.h\n"
+ "ld1h { z10.s }, p1/Z, [x24, #6, MUL VL]\n"
+ "fcvt z19.s, p1/m, z19.h\n"
+ "ld1h { z11.s }, p1/Z, [x24, #7, MUL VL]\n"
+ "fcvt z8.s, p1/m, z8.h\n"
+ "ld1h { z24.s }, p1/Z, [x20]\n"
+ "fcvt z9.s, p1/m, z9.h\n"
+ "ld1h { z25.s }, p1/Z, [x20, #1, MUL VL]\n"
+ "fcvt z10.s, p1/m, z10.h\n"
+ "ld1h { z26.s }, p1/Z, [x20, #2, MUL VL]\n"
+ "fcvt z11.s, p1/m, z11.h\n"
+ ".inst 0xc0042e00 // mova za.d[x9, #0], { z16.d-z19.d }\n"
+ "ld1h { z27.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "fcvt z24.s, p1/m, z24.h\n"
+ "fcvt z25.s, p1/m, z25.h\n"
+ "fcvt z26.s, p1/m, z26.h\n"
+ "fcvt z27.s, p1/m, z27.h\n"
+ ".inst 0xc0042d01 // mova za.d[x9, #1], { z8.d-z11.d }\n"
+ ".inst 0xc0042f02 // mova za.d[x9, #2], { z24.d-z27.d }\n"
"b 22f\n"
"21:" // Width 3: no bias
".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
@@ -305,112 +350,112 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"ble 24f\n"
"23:" // Width 3: Multiply loop: Main loop head
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27]\n"
"sub x22, x22, #0x8\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z6.h }, p0/Z, [x23]\n"
"cmp x22, #0x8\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- "addvl x27, x27, #16\n"
- ".inst 0xc157b188 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[0]\n"
- ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b309 // fdot za.s[x9, 1], { z24.h-z27.h }, z7.h[0]\n"
- ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157b38a // fdot za.s[x9, 2], { z28.h-z31.h }, z7.h[0]\n"
- ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b408 // fdot za.s[x9, 0], { z0.h-z3.h }, z7.h[1]\n"
+ ".inst 0xc156b288 // fdot za.s[x9, 0], { z20.h-z23.h }, z6.h[0]\n"
".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b609 // fdot za.s[x9, 1], { z16.h-z19.h }, z7.h[1]\n"
- ".inst 0xa041a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157b58a // fdot za.s[x9, 2], { z12.h-z15.h }, z7.h[1]\n"
+ ".inst 0xc156b189 // fdot za.s[x9, 1], { z12.h-z15.h }, z6.h[0]\n"
+ ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc156b00a // fdot za.s[x9, 2], { z0.h-z3.h }, z6.h[0]\n"
".inst 0xa042a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bb88 // fdot za.s[x9, 0], { z28.h-z31.h }, z7.h[2]\n"
- ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b909 // fdot za.s[x9, 1], { z8.h-z11.h }, z7.h[2]\n"
+ ".inst 0xa040a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xc156b788 // fdot za.s[x9, 0], { z28.h-z31.h }, z6.h[1]\n"
+ ".inst 0xa041a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc156b589 // fdot za.s[x9, 1], { z12.h-z15.h }, z6.h[1]\n"
+ ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ "addvl x27, x27, #16\n"
+ ".inst 0xc156b40a // fdot za.s[x9, 2], { z0.h-z3.h }, z6.h[1]\n"
+ ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157b80a // fdot za.s[x9, 2], { z0.h-z3.h }, z7.h[2]\n"
+ ".inst 0xc156ba88 // fdot za.s[x9, 0], { z20.h-z23.h }, z6.h[2]\n"
".inst 0xa042a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bd88 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[3]\n"
- ".inst 0xc157be09 // fdot za.s[x9, 1], { z16.h-z19.h }, z7.h[3]\n"
- ".inst 0xc157be8a // fdot za.s[x9, 2], { z20.h-z23.h }, z7.h[3]\n"
+ ".inst 0xc156b909 // fdot za.s[x9, 1], { z8.h-z11.h }, z6.h[2]\n"
+ ".inst 0xc156b98a // fdot za.s[x9, 2], { z12.h-z15.h }, z6.h[2]\n"
+ ".inst 0xc156bc08 // fdot za.s[x9, 0], { z0.h-z3.h }, z6.h[3]\n"
+ ".inst 0xc156be09 // fdot za.s[x9, 1], { z16.h-z19.h }, z6.h[3]\n"
+ ".inst 0xc156be8a // fdot za.s[x9, 2], { z20.h-z23.h }, z6.h[3]\n"
"bgt 23b\n"
"24:" // Width 3: Multiply loop: Single iteration only
"whilelt p0.h, XZR, x22\n"
".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z11.h }, p0/Z, [x23]\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b188 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[0]\n"
- ".inst 0xc157b209 // fdot za.s[x9, 1], { z16.h-z19.h }, z7.h[0]\n"
- ".inst 0xc157b10a // fdot za.s[x9, 2], { z8.h-z11.h }, z7.h[0]\n"
+ ".inst 0xc15bb188 // fdot za.s[x9, 0], { z12.h-z15.h }, z11.h[0]\n"
+ ".inst 0xc15bb009 // fdot za.s[x9, 1], { z0.h-z3.h }, z11.h[0]\n"
+ ".inst 0xc15bb20a // fdot za.s[x9, 2], { z16.h-z19.h }, z11.h[0]\n"
"ble 25f\n"
- ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
".inst 0xa042a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b608 // fdot za.s[x9, 0], { z16.h-z19.h }, z7.h[1]\n"
- ".inst 0xc157b409 // fdot za.s[x9, 1], { z0.h-z3.h }, z7.h[1]\n"
- ".inst 0xc157b68a // fdot za.s[x9, 2], { z20.h-z23.h }, z7.h[1]\n"
+ ".inst 0xc15bb588 // fdot za.s[x9, 0], { z12.h-z15.h }, z11.h[1]\n"
+ ".inst 0xc15bb609 // fdot za.s[x9, 1], { z16.h-z19.h }, z11.h[1]\n"
+ ".inst 0xc15bb68a // fdot za.s[x9, 2], { z20.h-z23.h }, z11.h[1]\n"
"ble 25f\n"
- ".inst 0xa040a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- ".inst 0xa041a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bb08 // fdot za.s[x9, 0], { z24.h-z27.h }, z7.h[2]\n"
- ".inst 0xc157ba89 // fdot za.s[x9, 1], { z20.h-z23.h }, z7.h[2]\n"
- ".inst 0xc157b90a // fdot za.s[x9, 2], { z8.h-z11.h }, z7.h[2]\n"
+ ".inst 0xc15bb888 // fdot za.s[x9, 0], { z4.h-z7.h }, z11.h[2]\n"
+ ".inst 0xc15bbb89 // fdot za.s[x9, 1], { z28.h-z31.h }, z11.h[2]\n"
+ ".inst 0xc15bba8a // fdot za.s[x9, 2], { z20.h-z23.h }, z11.h[2]\n"
"ble 25f\n"
- ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bc08 // fdot za.s[x9, 0], { z0.h-z3.h }, z7.h[3]\n"
- ".inst 0xc157bf89 // fdot za.s[x9, 1], { z28.h-z31.h }, z7.h[3]\n"
- ".inst 0xc157bd8a // fdot za.s[x9, 2], { z12.h-z15.h }, z7.h[3]\n"
+ ".inst 0xc15bbc88 // fdot za.s[x9, 0], { z4.h-z7.h }, z11.h[3]\n"
+ ".inst 0xc15bbf89 // fdot za.s[x9, 1], { z28.h-z31.h }, z11.h[3]\n"
+ ".inst 0xc15bbd8a // fdot za.s[x9, 2], { z12.h-z15.h }, z11.h[3]\n"
"25:" // Width 3: Multiply loop: multiply skip
"tbz %x[flags], #1, 26f\n"
".inst 0xc0062c0c // mova { z12.d-z15.d }, za.d[x9, #0]\n"
"add x21, %x[args_ptr], %[offset_min]\n"
"add x20, %x[args_ptr], %[offset_max]\n"
- ".inst 0xc0062c24 // mova { z4.d-z7.d }, za.d[x9, #1]\n"
+ ".inst 0xc0062c20 // mova { z0.d-z3.d }, za.d[x9, #1]\n"
"ld1rh { z17.h }, p1/Z, [x21]\n"
- ".inst 0xc0062c40 // mova { z0.d-z3.d }, za.d[x9, #2]\n"
+ ".inst 0xc0062c44 // mova { z4.d-z7.d }, za.d[x9, #2]\n"
"ld1rh { z16.h }, p1/Z, [x20]\n"
".inst 0xc120e18c // fcvt z12.h, { z12.s-z13.s }\n"
".inst 0xc120e1cd // fcvt z13.h, { z14.s-z15.s }\n"
- ".inst 0xc120e092 // fcvt z18.h, { z4.s-z5.s }\n"
- ".inst 0xc120e0d3 // fcvt z19.h, { z6.s-z7.s }\n"
- ".inst 0xc170c22c // fclamp { z12.h-z13.h }, z17.h, z16.h\n"
- ".inst 0xc170c232 // fclamp { z18.h-z19.h }, z17.h, z16.h\n"
".inst 0xc120e00e // fcvt z14.h, { z0.s-z1.s }\n"
".inst 0xc120e04f // fcvt z15.h, { z2.s-z3.s }\n"
+ ".inst 0xc170c22c // fclamp { z12.h-z13.h }, z17.h, z16.h\n"
+ ".inst 0xc120e092 // fcvt z18.h, { z4.s-z5.s }\n"
+ ".inst 0xc120e0d3 // fcvt z19.h, { z6.s-z7.s }\n"
".inst 0xc170c22e // fclamp { z14.h-z15.h }, z17.h, z16.h\n"
+ ".inst 0xc170c232 // fclamp { z18.h-z19.h }, z17.h, z16.h\n"
".inst 0xa060272c // st1h { z12.h-z13.h }, pn9.b, [x25]\n"
- ".inst 0xa0612732 // st1h { z18.h-z19.h }, pn9.b, [x25, #0x2, MUL VL]\n"
- ".inst 0xa062232e // st1h { z14.h-z15.h }, p8, [x25, #0x4, MUL VL]\n"
+ ".inst 0xa061272e // st1h { z14.h-z15.h }, pn9.b, [x25, #0x2, MUL VL]\n"
+ ".inst 0xa0622332 // st1h { z18.h-z19.h }, p8, [x25, #0x4, MUL VL]\n"
"addvl x25, x25, #6\n"
"b 27f\n"
"26:" // Width 3: No activation
- ".inst 0xc0062c04 // mova { z4.d-z7.d }, za.d[x9, #0]\n"
- ".inst 0xc0062c20 // mova { z0.d-z3.d }, za.d[x9, #1]\n"
- ".inst 0xc0062c48 // mova { z8.d-z11.d }, za.d[x9, #2]\n"
- ".inst 0xc120e091 // fcvt z17.h, { z4.s-z5.s }\n"
- ".inst 0xc120e0d9 // fcvt z25.h, { z6.s-z7.s }\n"
+ ".inst 0xc0062c18 // mova { z24.d-z27.d }, za.d[x9, #0]\n"
+ ".inst 0xc0062c28 // mova { z8.d-z11.d }, za.d[x9, #1]\n"
+ ".inst 0xc0062c4c // mova { z12.d-z15.d }, za.d[x9, #2]\n"
+ ".inst 0xc120e311 // fcvt z17.h, { z24.s-z25.s }\n"
+ ".inst 0xc120e359 // fcvt z25.h, { z26.s-z27.s }\n"
".inst 0xa1602731 // st1h { z17.h, z25.h }, pn9.b, [x25]\n"
- ".inst 0xc120e012 // fcvt z18.h, { z0.s-z1.s }\n"
- ".inst 0xc120e053 // fcvt z19.h, { z2.s-z3.s }\n"
+ ".inst 0xc120e112 // fcvt z18.h, { z8.s-z9.s }\n"
+ ".inst 0xc120e153 // fcvt z19.h, { z10.s-z11.s }\n"
".inst 0xa0612732 // st1h { z18.h-z19.h }, pn9.b, [x25, #0x2, MUL VL]\n"
- ".inst 0xc120e111 // fcvt z17.h, { z8.s-z9.s }\n"
- ".inst 0xc120e159 // fcvt z25.h, { z10.s-z11.s }\n"
+ ".inst 0xc120e191 // fcvt z17.h, { z12.s-z13.s }\n"
+ ".inst 0xc120e1d9 // fcvt z25.h, { z14.s-z15.s }\n"
".inst 0xa1622331 // st1h { z17.h, z25.h }, p8, [x25, #0x4, MUL VL]\n"
"addvl x25, x25, #6\n"
"27:" // Width 3: Output done
@@ -422,16 +467,45 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"msub x20, x28, x20, %x[N]\n"
"mov x22, %x[K]\n"
".inst 0xf8b54af8 // rprfm pldmany, x21, [x23]\n"
- ".inst 0x257467f0 // whilelt p8.h, XZR, x20, VLx4\n"
+ ".inst 0x257447f0 // whilelt p8.h, XZR, x20, VLx2\n"
"cbz x24, 29f\n"
- ".inst 0xa040c704 // ld1w { z4.s-z7.s }, pn9.b/Z, [x24]\n"
- ".inst 0xa041c710 // ld1w { z16.s-z19.s }, pn9.b/Z, [x24, #0x4, MUL VL]\n"
- ".inst 0xa042c708 // ld1w { z8.s-z11.s }, pn9.b/Z, [x24, #0x8, MUL VL]\n"
- ".inst 0xa043c71c // ld1w { z28.s-z31.s }, pn9.b/Z, [x24, #0xc, MUL VL]\n"
- ".inst 0xc0042c80 // mova za.d[x9, #0], { z4.d-z7.d }\n"
- "addvl x24, x24, #16\n"
- ".inst 0xc0042e01 // mova za.d[x9, #1], { z16.d-z19.d }\n"
- ".inst 0xc0042d02 // mova za.d[x9, #2], { z8.d-z11.d }\n"
+ "addvl x20, x24, #4\n"
+ "ld1h { z28.s }, p1/Z, [x24]\n"
+ "ld1h { z29.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "ld1h { z30.s }, p1/Z, [x24, #2, MUL VL]\n"
+ "ld1h { z31.s }, p1/Z, [x24, #3, MUL VL]\n"
+ "fcvt z28.s, p1/m, z28.h\n"
+ "ld1h { z8.s }, p1/Z, [x24, #4, MUL VL]\n"
+ "fcvt z29.s, p1/m, z29.h\n"
+ "ld1h { z9.s }, p1/Z, [x24, #5, MUL VL]\n"
+ "fcvt z30.s, p1/m, z30.h\n"
+ "ld1h { z10.s }, p1/Z, [x24, #6, MUL VL]\n"
+ "fcvt z31.s, p1/m, z31.h\n"
+ "ld1h { z11.s }, p1/Z, [x24, #7, MUL VL]\n"
+ "fcvt z8.s, p1/m, z8.h\n"
+ "addvl x24, x24, #8\n"
+ "ld1h { z0.s }, p1/Z, [x20]\n"
+ "fcvt z9.s, p1/m, z9.h\n"
+ "ld1h { z1.s }, p1/Z, [x20, #1, MUL VL]\n"
+ "fcvt z10.s, p1/m, z10.h\n"
+ "ld1h { z2.s }, p1/Z, [x20, #2, MUL VL]\n"
+ "fcvt z11.s, p1/m, z11.h\n"
+ ".inst 0xc0042f80 // mova za.d[x9, #0], { z28.d-z31.d }\n"
+ "ld1h { z3.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "fcvt z0.s, p1/m, z0.h\n"
+ "ld1h { z28.s }, p1/Z, [x20, #4, MUL VL]\n"
+ "fcvt z1.s, p1/m, z1.h\n"
+ "ld1h { z29.s }, p1/Z, [x20, #5, MUL VL]\n"
+ "fcvt z2.s, p1/m, z2.h\n"
+ "ld1h { z30.s }, p1/Z, [x20, #6, MUL VL]\n"
+ "fcvt z3.s, p1/m, z3.h\n"
+ ".inst 0xc0042d01 // mova za.d[x9, #1], { z8.d-z11.d }\n"
+ "ld1h { z31.s }, p1/Z, [x20, #7, MUL VL]\n"
+ "fcvt z28.s, p1/m, z28.h\n"
+ "fcvt z29.s, p1/m, z29.h\n"
+ "fcvt z30.s, p1/m, z30.h\n"
+ "fcvt z31.s, p1/m, z31.h\n"
+ ".inst 0xc0042c02 // mova za.d[x9, #2], { z0.d-z3.d }\n"
".inst 0xc0042f83 // mova za.d[x9, #3], { z28.d-z31.d }\n"
"b 30f\n"
"29:" // Width 4: no bias
@@ -441,93 +515,93 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
"ble 32f\n"
"31:" // Width 4: Multiply loop: Main loop head
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
"sub x22, x22, #0x8\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z3.h }, p0/Z, [x23]\n"
"cmp x22, #0x8\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b288 // fdot za.s[x9, 0], { z20.h-z23.h }, z7.h[0]\n"
+ ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa043a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc153b108 // fdot za.s[x9, 0], { z8.h-z11.h }, z3.h[0]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b189 // fdot za.s[x9, 1], { z12.h-z15.h }, z7.h[0]\n"
+ ".inst 0xc153b389 // fdot za.s[x9, 1], { z28.h-z31.h }, z3.h[0]\n"
".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b38a // fdot za.s[x9, 2], { z28.h-z31.h }, z7.h[0]\n"
- ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157b10b // fdot za.s[x9, 3], { z8.h-z11.h }, z7.h[0]\n"
- ".inst 0xa042a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b588 // fdot za.s[x9, 0], { z12.h-z15.h }, z7.h[1]\n"
- "addvl x27, x27, #16\n"
- ".inst 0xc157b409 // fdot za.s[x9, 1], { z0.h-z3.h }, z7.h[1]\n"
- ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b50a // fdot za.s[x9, 2], { z8.h-z11.h }, z7.h[1]\n"
+ ".inst 0xc153b30a // fdot za.s[x9, 2], { z24.h-z27.h }, z3.h[0]\n"
".inst 0xa041a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157b70b // fdot za.s[x9, 3], { z24.h-z27.h }, z7.h[1]\n"
- ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xc153b08b // fdot za.s[x9, 3], { z4.h-z7.h }, z3.h[0]\n"
+ ".inst 0xa042a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
".inst 0xa043a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b808 // fdot za.s[x9, 0], { z0.h-z3.h }, z7.h[2]\n"
+ ".inst 0xc153b588 // fdot za.s[x9, 0], { z12.h-z15.h }, z3.h[1]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b909 // fdot za.s[x9, 1], { z8.h-z11.h }, z7.h[2]\n"
- ".inst 0xa040a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27]\n"
- ".inst 0xc157b98a // fdot za.s[x9, 2], { z12.h-z15.h }, z7.h[2]\n"
- ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xc157bb0b // fdot za.s[x9, 3], { z24.h-z27.h }, z7.h[2]\n"
+ ".inst 0xc153b509 // fdot za.s[x9, 1], { z8.h-z11.h }, z3.h[1]\n"
+ ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xc153b60a // fdot za.s[x9, 2], { z16.h-z19.h }, z3.h[1]\n"
+ ".inst 0xa041a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc153b70b // fdot za.s[x9, 3], { z24.h-z27.h }, z3.h[1]\n"
+ ".inst 0xa042a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa043a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc153b988 // fdot za.s[x9, 0], { z12.h-z15.h }, z3.h[2]\n"
+ "addvl x27, x27, #16\n"
+ ".inst 0xc153b889 // fdot za.s[x9, 1], { z4.h-z7.h }, z3.h[2]\n"
+ ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xc153ba8a // fdot za.s[x9, 2], { z20.h-z23.h }, z3.h[2]\n"
+ ".inst 0xa041a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xc153b90b // fdot za.s[x9, 3], { z8.h-z11.h }, z3.h[2]\n"
".inst 0xa042a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157bf88 // fdot za.s[x9, 0], { z28.h-z31.h }, z7.h[3]\n"
+ ".inst 0xa043a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc153bd88 // fdot za.s[x9, 0], { z12.h-z15.h }, z3.h[3]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bc09 // fdot za.s[x9, 1], { z0.h-z3.h }, z7.h[3]\n"
- ".inst 0xc157bd0a // fdot za.s[x9, 2], { z8.h-z11.h }, z7.h[3]\n"
- ".inst 0xc157bd8b // fdot za.s[x9, 3], { z12.h-z15.h }, z7.h[3]\n"
+ ".inst 0xc153bc89 // fdot za.s[x9, 1], { z4.h-z7.h }, z3.h[3]\n"
+ ".inst 0xc153bd0a // fdot za.s[x9, 2], { z8.h-z11.h }, z3.h[3]\n"
+ ".inst 0xc153be0b // fdot za.s[x9, 3], { z16.h-z19.h }, z3.h[3]\n"
"bgt 31b\n"
"32:" // Width 4: Multiply loop: Single iteration only
"whilelt p0.h, XZR, x22\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- "ld1rqh { z7.h }, p0/Z, [x23]\n"
+ "ld1rqh { z11.h }, p0/Z, [x23]\n"
"add x23, x23, #0x10\n"
- ".inst 0xa041a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b108 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[0]\n"
+ ".inst 0xa041a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa043a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc15bb208 // fdot za.s[x9, 0], { z16.h-z19.h }, z11.h[0]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b009 // fdot za.s[x9, 1], { z0.h-z3.h }, z7.h[0]\n"
- ".inst 0xc157b20a // fdot za.s[x9, 2], { z16.h-z19.h }, z7.h[0]\n"
- ".inst 0xc157b18b // fdot za.s[x9, 3], { z12.h-z15.h }, z7.h[0]\n"
+ ".inst 0xc15bb089 // fdot za.s[x9, 1], { z4.h-z7.h }, z11.h[0]\n"
+ ".inst 0xc15bb18a // fdot za.s[x9, 2], { z12.h-z15.h }, z11.h[0]\n"
+ ".inst 0xc15bb38b // fdot za.s[x9, 3], { z28.h-z31.h }, z11.h[0]\n"
"ble 33f\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b508 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[1]\n"
+ ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa043a779 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc15bb488 // fdot za.s[x9, 0], { z4.h-z7.h }, z11.h[1]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b789 // fdot za.s[x9, 1], { z28.h-z31.h }, z7.h[1]\n"
- ".inst 0xc157b40a // fdot za.s[x9, 2], { z0.h-z3.h }, z7.h[1]\n"
- ".inst 0xc157b68b // fdot za.s[x9, 3], { z20.h-z23.h }, z7.h[1]\n"
+ ".inst 0xc15bb609 // fdot za.s[x9, 1], { z16.h-z19.h }, z11.h[1]\n"
+ ".inst 0xc15bb58a // fdot za.s[x9, 2], { z12.h-z15.h }, z11.h[1]\n"
+ ".inst 0xc15bb70b // fdot za.s[x9, 3], { z24.h-z27.h }, z11.h[1]\n"
"ble 33f\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa040a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27]\n"
"subs x22, x22, #0x2\n"
- ".inst 0xa041a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
- ".inst 0xa043a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157b908 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[2]\n"
+ ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa043a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
+ ".inst 0xc15bb988 // fdot za.s[x9, 0], { z12.h-z15.h }, z11.h[2]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157b989 // fdot za.s[x9, 1], { z12.h-z15.h }, z7.h[2]\n"
- ".inst 0xc157bb8a // fdot za.s[x9, 2], { z28.h-z31.h }, z7.h[2]\n"
- ".inst 0xc157ba0b // fdot za.s[x9, 3], { z16.h-z19.h }, z7.h[2]\n"
+ ".inst 0xc15bbb89 // fdot za.s[x9, 1], { z28.h-z31.h }, z11.h[2]\n"
+ ".inst 0xc15bb80a // fdot za.s[x9, 2], { z0.h-z3.h }, z11.h[2]\n"
+ ".inst 0xc15bb88b // fdot za.s[x9, 3], { z4.h-z7.h }, z11.h[2]\n"
"ble 33f\n"
- ".inst 0xa040a769 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x27]\n"
- ".inst 0xa041a77d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
- ".inst 0xa042a76d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
+ ".inst 0xa040a761 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x27]\n"
+ ".inst 0xa041a771 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa042a765 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x27, #0x8, MUL VL]\n"
".inst 0xa043a775 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x27, #0xc, MUL VL]\n"
- ".inst 0xc157bd08 // fdot za.s[x9, 0], { z8.h-z11.h }, z7.h[3]\n"
+ ".inst 0xc15bbc08 // fdot za.s[x9, 0], { z0.h-z3.h }, z11.h[3]\n"
"addvl x27, x27, #16\n"
- ".inst 0xc157bf89 // fdot za.s[x9, 1], { z28.h-z31.h }, z7.h[3]\n"
- ".inst 0xc157bd8a // fdot za.s[x9, 2], { z12.h-z15.h }, z7.h[3]\n"
- ".inst 0xc157be8b // fdot za.s[x9, 3], { z20.h-z23.h }, z7.h[3]\n"
+ ".inst 0xc15bbe09 // fdot za.s[x9, 1], { z16.h-z19.h }, z11.h[3]\n"
+ ".inst 0xc15bbc8a // fdot za.s[x9, 2], { z4.h-z7.h }, z11.h[3]\n"
+ ".inst 0xc15bbe8b // fdot za.s[x9, 3], { z20.h-z23.h }, z11.h[3]\n"
"33:" // Width 4: Multiply loop: multiply skip
"tbz %x[flags], #1, 34f\n"
".inst 0xc0062c1c // mova { z28.d-z31.d }, za.d[x9, #0]\n"
@@ -543,12 +617,12 @@ void sme2_gemv_fp16fp32fp16_dot_16VL (
".inst 0xc120e18c // fcvt z12.h, { z12.s-z13.s }\n"
".inst 0xc120e1cd // fcvt z13.h, { z14.s-z15.s }\n"
".inst 0xc172c26a // fclamp { z10.h-z11.h }, z19.h, z18.h\n"
- ".inst 0xc172c26c // fclamp { z12.h-z13.h }, z19.h, z18.h\n"
".inst 0xc120e00e // fcvt z14.h, { z0.s-z1.s }\n"
".inst 0xc120e04f // fcvt z15.h, { z2.s-z3.s }\n"
- ".inst 0xc172c26e // fclamp { z14.h-z15.h }, z19.h, z18.h\n"
+ ".inst 0xc172c26c // fclamp { z12.h-z13.h }, z19.h, z18.h\n"
".inst 0xc120e090 // fcvt z16.h, { z4.s-z5.s }\n"
".inst 0xc120e0d1 // fcvt z17.h, { z6.s-z7.s }\n"
+ ".inst 0xc172c26e // fclamp { z14.h-z15.h }, z19.h, z18.h\n"
".inst 0xc172c270 // fclamp { z16.h-z17.h }, z19.h, z18.h\n"
".inst 0xa060272a // st1h { z10.h-z11.h }, pn9.b, [x25]\n"
".inst 0xa061272c // st1h { z12.h-z13.h }, pn9.b, [x25, #0x2, MUL VL]\n"
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp
new file mode 100644
index 000000000..779219285
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 1;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 4;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL;
+
+ StdTransformsSME<operand_type, result_type, 1, 4, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
new file mode 100644
index 000000000..4b26a6578
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
@@ -0,0 +1,417 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x13, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p0.b\n"
+ ".inst 0x25207811 // ptrue pn9.b\n"
+ "ldr x11, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x10, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x13, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xa041c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n"
+ ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w9, [%x[args], %[offsetof_M]]\n"
+ "mov x28, #0x0\n"
+ "mov x27, #0x0\n"
+ "ldr w26, [%x[args], %[offsetof_N]]\n"
+ "ldr x25, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x24, x25\n"
+ ".inst 0x25ba6770 // whilelt pn8.s, x27, x26, VLx4\n"
+ "tbnz x13, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ ".inst 0xa01bc288 // ld1w { z8.s-z11.s }, p8/Z, [x20, x27, LSL #2]\n"
+ ".inst 0xc0900100 // addha za0.s, p0/M, p0/M, z8.s\n"
+ ".inst 0xc0900121 // addha za1.s, p0/M, p0/M, z9.s\n"
+ ".inst 0xc0900142 // addha za2.s, p0/M, p0/M, z10.s\n"
+ ".inst 0xc0900163 // addha za3.s, p0/M, p0/M, z11.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x27\n"
+ "mov x21, x28\n"
+ "incw x20, ALL, MUL #4\n"
+ "incw x21\n"
+ "cmp x20, x26\n"
+ "mov x20, x13\n"
+ "csel x21, x28, x21, LT\n"
+ "bfm x13, XZR, #0x0, #0x0 // bfc x13, #0x0, #0x1\n"
+ "cmp x21, x9\n"
+ "csel x13, x20, x13, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x27, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ "ld1b { z31.b }, p0/Z, [x24]\n"
+ ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n"
+ "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n"
+ ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n"
+ ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
+ "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n"
+ "addvl x24, x24, #4\n"
+ ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
+ "addvl x23, x23, #16\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n"
+ ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n"
+ ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n"
+ "ld1b { z31.b }, p0/Z, [x24]\n"
+ ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n"
+ ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n"
+ ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n"
+ ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n"
+ "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n"
+ ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n"
+ ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n"
+ ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n"
+ ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n"
+ "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n"
+ ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
+ ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n"
+ ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n"
+ ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n"
+ ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n"
+ "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n"
+ "addvl x24, x24, #4\n"
+ ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
+ "addvl x23, x23, #16\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n"
+ ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n"
+ ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n"
+ ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n"
+ ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n"
+ ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n"
+ ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n"
+ ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n"
+ ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n"
+ ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n"
+ ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n"
+ ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n"
+ ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n"
+ ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n"
+ ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n"
+ ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ "ld1b { z18.b }, p0/Z, [x24]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x24, x24, #1\n"
+ ".inst 0xa04086fc // ld1b { z28.b-z31.b }, pn9.b/Z, [x23]\n"
+ "addvl x23, x23, #4\n"
+ ".inst 0xa09c0240 // smopa za0.s, p0/M, p0/M, z18.b, z28.b\n"
+ ".inst 0xa09d0241 // smopa za1.s, p0/M, p0/M, z18.b, z29.b\n"
+ ".inst 0xa09e0242 // smopa za2.s, p0/M, p0/M, z18.b, z30.b\n"
+ ".inst 0xa09f0243 // smopa za3.s, p0/M, p0/M, z18.b, z31.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x13, #1, 14f\n"
+ "tbz x13, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xa041c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa060c548 // st1w { z8.s-z11.s }, pn9.b, [x10]\n"
+ ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n"
+ ".inst 0xa061c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c544 // st1w { z4.s-z7.s }, pn9.b, [x10, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0xc, MUL VL]\n"
+ "addvl x10, x10, #16\n"
+ "blt 11b\n"
+ "b 21f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xa060c544 // st1w { z4.s-z7.s }, pn9.b, [x10]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c548 // st1w { z8.s-z11.s }, pn9.b, [x10, #0x8, MUL VL]\n"
+ ".inst 0xa063c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0xc, MUL VL]\n"
+ "addvl x10, x10, #16\n"
+ "blt 13b\n"
+ "b 21f\n"
+ "14:" // Store to output array
+ "ldr x23, [%x[args], %[offsetof_C]]\n"
+ "sub x21, x9, x28\n"
+ "ld1rw { z18.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z20.s, #0x0\n"
+ "ldr x22, [%x[args], %[offsetof_ldcb]]\n"
+ "fmov z21.s, #0x0\n"
+ "fmov z22.s, #0x0\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "fmov z23.s, #0x0\n"
+ "add x23, x23, x27, LSL #2\n" // C += n
+ "madd x23, x28, x22, x23\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x27, LSL #2\n"
+ ".inst 0xa040c294 // ld1w { z20.s-z23.s }, p8/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x20\n"
+ "ld1rw { z17.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x21, x20\n"
+ "ld1rw { z16.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x20, x21, x20, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z0.s, p0/M, z18.s, z20.s\n"
+ "fmad z1.s, p0/M, z18.s, z20.s\n"
+ "fmad z2.s, p0/M, z18.s, z20.s\n"
+ "fmad z3.s, p0/M, z18.s, z20.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z4.s, p0/M, z18.s, z21.s\n"
+ "fmad z5.s, p0/M, z18.s, z21.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z6.s, p0/M, z18.s, z21.s\n"
+ "fmad z7.s, p0/M, z18.s, z21.s\n"
+ "fmad z8.s, p0/M, z18.s, z22.s\n"
+ "fmad z9.s, p0/M, z18.s, z22.s\n"
+ "fmad z10.s, p0/M, z18.s, z22.s\n"
+ "fmad z11.s, p0/M, z18.s, z22.s\n"
+ "fmad z12.s, p0/M, z18.s, z23.s\n"
+ "fmad z13.s, p0/M, z18.s, z23.s\n"
+ "fmad z14.s, p0/M, z18.s, z23.s\n"
+ "fmad z15.s, p0/M, z18.s, z23.s\n"
+ ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n"
+ ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e3 // st1w { z3.s, z7.s, z11.s, z15.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z0.s, p0/M, z18.s, z20.s\n"
+ "fmad z1.s, p0/M, z18.s, z20.s\n"
+ "fmad z2.s, p0/M, z18.s, z20.s\n"
+ "fmad z3.s, p0/M, z18.s, z20.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z4.s, p0/M, z18.s, z21.s\n"
+ "fmad z5.s, p0/M, z18.s, z21.s\n"
+ "fmad z6.s, p0/M, z18.s, z21.s\n"
+ "fmad z7.s, p0/M, z18.s, z21.s\n"
+ "fmad z8.s, p0/M, z18.s, z22.s\n"
+ "fmad z9.s, p0/M, z18.s, z22.s\n"
+ "fmad z10.s, p0/M, z18.s, z22.s\n"
+ "fmad z11.s, p0/M, z18.s, z22.s\n"
+ "fmad z12.s, p0/M, z18.s, z23.s\n"
+ "fmad z13.s, p0/M, z18.s, z23.s\n"
+ "fmad z14.s, p0/M, z18.s, z23.s\n"
+ "fmad z15.s, p0/M, z18.s, z23.s\n"
+ ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n"
+ ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "beq 18f\n"
+ ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "19:" // Store to output array: End
+ "tbz x13, #0, 21f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "20:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xa041c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xa042c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c568 // ld1w { z8.s-z11.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 20b\n"
+ "21:" // End block
+ "incw x27, ALL, MUL #4\n"
+ "cmp x27, x26\n"
+ "blt 3b\n"
+ "incw x28\n"
+ "mov x27, #0x0\n"
+ "cmp x28, x9\n"
+ "mov x25, x24\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp
new file mode 100644
index 000000000..df2c9c0ca
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 2;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 2;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL;
+
+ StdTransformsSME<operand_type, result_type, 2, 2, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
new file mode 100644
index 000000000..1631fae8e
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
@@ -0,0 +1,448 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x16, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p0.b\n"
+ ".inst 0x25207811 // ptrue pn9.b\n"
+ "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x16, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xa041c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c5e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5f8 // ld1w { z24.s-z27.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840681 // mova za1h.s[x12], { z20.s-z23.s }\n"
+ ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w13, [%x[args], %[offsetof_M]]\n"
+ "mov x11, #0x0\n"
+ "mov x10, #0x0\n"
+ "ldr w9, [%x[args], %[offsetof_N]]\n"
+ "ldr x28, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x27, x28\n"
+ ".inst 0x25a94550 // whilelt pn8.s, x10, x9, VLx2\n"
+ "tbnz x16, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ ".inst 0xa10a4286 // ld1w { z6.s, z14.s }, p8/Z, [x20, x10, LSL #2]\n"
+ ".inst 0xc09000c0 // addha za0.s, p0/M, p0/M, z6.s\n"
+ ".inst 0xc09001c1 // addha za1.s, p0/M, p0/M, z14.s\n"
+ ".inst 0xc09000c2 // addha za2.s, p0/M, p0/M, z6.s\n"
+ ".inst 0xc09001c3 // addha za3.s, p0/M, p0/M, z14.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x10\n"
+ "mov x21, x11\n"
+ "incw x20, ALL, MUL #2\n"
+ "incw x21, ALL, MUL #2\n"
+ "cmp x20, x9\n"
+ "mov x20, x16\n"
+ "csel x21, x11, x21, LT\n"
+ "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n"
+ "cmp x21, x13\n"
+ "csel x16, x20, x16, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n"
+ ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
+ ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
+ ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
+ "addvl x27, x27, #8\n"
+ ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
+ "addvl x23, x23, #8\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n"
+ ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n"
+ ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n"
+ ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n"
+ ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n"
+ ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n"
+ ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n"
+ ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n"
+ ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
+ ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n"
+ ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
+ ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n"
+ ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n"
+ ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n"
+ ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n"
+ ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n"
+ ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n"
+ ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n"
+ ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
+ "addvl x27, x27, #8\n"
+ ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
+ "addvl x23, x23, #8\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n"
+ ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n"
+ ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n"
+ ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n"
+ ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n"
+ ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n"
+ ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n"
+ ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n"
+ ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n"
+ ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n"
+ ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n"
+ ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n"
+ ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n"
+ ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n"
+ ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n"
+ ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ ".inst 0xa040077e // ld1b { z30.b-z31.b }, pn9.b/Z, [x27]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x27, x27, #2\n"
+ ".inst 0xa14006e7 // ld1b { z7.b, z15.b }, pn9.b/Z, [x23]\n"
+ "addvl x23, x23, #2\n"
+ ".inst 0xa08703c0 // smopa za0.s, p0/M, p0/M, z30.b, z7.b\n"
+ ".inst 0xa08f03c1 // smopa za1.s, p0/M, p0/M, z30.b, z15.b\n"
+ ".inst 0xa08703e2 // smopa za2.s, p0/M, p0/M, z31.b, z7.b\n"
+ ".inst 0xa08f03e3 // smopa za3.s, p0/M, p0/M, z31.b, z15.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x16, #1, 14f\n"
+ "tbz x16, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n"
+ ".inst 0xa041c5f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n"
+ ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n"
+ ".inst 0xa042c5fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840601 // mova za1h.s[x12], { z16.s-z19.s }\n"
+ ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n"
+ ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa061c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c5d8 // st1w { z24.s-z27.s }, pn9.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 11b\n"
+ "b 24f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n"
+ ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n"
+ ".inst 0xa060c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c5d0 // st1w { z16.s-z19.s }, pn9.b, [x14, #0x8, MUL VL]\n"
+ ".inst 0xa063c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 13b\n"
+ "b 24f\n"
+ "14:" // Store to output array
+ "ldr x26, [%x[args], %[offsetof_C]]\n"
+ "sub x25, x13, x11\n"
+ "ld1rw { z3.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z2.s, #0x0\n"
+ "ldr x24, [%x[args], %[offsetof_ldcb]]\n"
+ "fmov z10.s, #0x0\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "add x26, x26, x10, LSL #2\n" // C += n
+ "madd x26, x11, x24, x26\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x10, LSL #2\n"
+ ".inst 0xa1404282 // ld1w { z2.s, z10.s }, p8/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x23\n"
+ "ld1rw { z1.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x25, x23\n"
+ "ld1rw { z0.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z4.s, p0/M, z3.s, z2.s\n"
+ "fmad z5.s, p0/M, z3.s, z2.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z6.s, p0/M, z3.s, z2.s\n"
+ "fmad z7.s, p0/M, z3.s, z2.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z12.s, p0/M, z3.s, z10.s\n"
+ "fmad z13.s, p0/M, z3.s, z10.s\n"
+ "fmad z14.s, p0/M, z3.s, z10.s\n"
+ "fmad z15.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n"
+ ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n"
+ ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n"
+ "fmad z16.s, p0/M, z3.s, z2.s\n"
+ "fmad z17.s, p0/M, z3.s, z2.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z18.s, p0/M, z3.s, z2.s\n"
+ "fmad z19.s, p0/M, z3.s, z2.s\n"
+ "fmad z24.s, p0/M, z3.s, z10.s\n"
+ "fmad z25.s, p0/M, z3.s, z10.s\n"
+ "fmad z26.s, p0/M, z3.s, z10.s\n"
+ "fmad z27.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c830 // fclamp { z16.s-z19.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c838 // fclamp { z24.s-z27.s }, z1.s, z0.s\n"
+ ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 22f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x20, x25, x23, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 20f\n"
+ "19:" // Store to output array: Accumulator row 1 loop
+ ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n"
+ ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n"
+ ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n"
+ ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n"
+ "fmad z20.s, p0/M, z3.s, z2.s\n"
+ "fmad z21.s, p0/M, z3.s, z2.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z22.s, p0/M, z3.s, z2.s\n"
+ "fmad z23.s, p0/M, z3.s, z2.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z28.s, p0/M, z3.s, z10.s\n"
+ "fmad z29.s, p0/M, z3.s, z10.s\n"
+ "fmad z30.s, p0/M, z3.s, z10.s\n"
+ "fmad z31.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c834 // fclamp { z20.s-z23.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c83c // fclamp { z28.s-z31.s }, z1.s, z0.s\n"
+ ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 19b\n"
+ "20:" // Store to output array: Accumulator row 1 oddments
+ "cbz x20, 21f\n"
+ ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z4.s, p0/M, z3.s, z2.s\n"
+ "fmad z5.s, p0/M, z3.s, z2.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z6.s, p0/M, z3.s, z2.s\n"
+ "fmad z7.s, p0/M, z3.s, z2.s\n"
+ "fmad z12.s, p0/M, z3.s, z10.s\n"
+ "fmad z13.s, p0/M, z3.s, z10.s\n"
+ "fmad z14.s, p0/M, z3.s, z10.s\n"
+ "fmad z15.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n"
+ ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n"
+ "21:" // Store to output array: Accumulator row 1 oddments: End
+ "22:" // Store to output array: End
+ "tbz x16, #0, 24f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "23:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xa041c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c5e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n"
+ ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 23b\n"
+ "24:" // End block
+ "incw x10, ALL, MUL #2\n"
+ "cmp x10, x9\n"
+ "blt 3b\n"
+ "incw x11, ALL, MUL #2\n"
+ "mov x10, #0x0\n"
+ "cmp x11, x13\n"
+ "mov x28, x27\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp
new file mode 100644
index 000000000..70952f4f0
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 4;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 1;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL;
+
+ StdTransformsSME<operand_type, result_type, 4, 1, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
new file mode 100644
index 000000000..bafb16bca
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
@@ -0,0 +1,513 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x16, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p1.b\n"
+ ".inst 0x25207810 // ptrue pn8.b\n"
+ "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x16, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xa041c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1f0 // ld1w { z16.s-z19.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xc0840502 // mova za2h.s[x12], { z8.s-z11.s }\n"
+ ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w13, [%x[args], %[offsetof_M]]\n"
+ "mov x11, #0x0\n"
+ "mov x10, #0x0\n"
+ "ldr w9, [%x[args], %[offsetof_N]]\n"
+ "ldr x28, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x27, x28\n"
+ "whilelt p0.s, x10, x9\n"
+ "tbnz x16, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ "ld1w { z23.s }, p0/Z, [x20, x10, LSL #2]\n"
+ ".inst 0xc09026e0 // addha za0.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e1 // addha za1.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e2 // addha za2.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e3 // addha za3.s, p1/M, p1/M, z23.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x10\n"
+ "mov x21, x11\n"
+ "incw x20\n"
+ "incw x21, ALL, MUL #4\n"
+ "cmp x20, x9\n"
+ "mov x20, x16\n"
+ "csel x21, x11, x21, LT\n"
+ "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n"
+ "cmp x21, x13\n"
+ "csel x16, x20, x16, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n"
+ "ld1b { z4.b }, p1/Z, [x23]\n"
+ ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n"
+ "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n"
+ ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n"
+ "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n"
+ ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n"
+ "addvl x27, x27, #16\n"
+ "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n"
+ "addvl x23, x23, #4\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n"
+ ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n"
+ ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n"
+ ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n"
+ ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n"
+ "ld1b { z4.b }, p1/Z, [x23]\n"
+ ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n"
+ ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n"
+ ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n"
+ ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n"
+ "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n"
+ ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n"
+ ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n"
+ ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n"
+ ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n"
+ "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n"
+ ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n"
+ ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n"
+ ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n"
+ ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n"
+ ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n"
+ "addvl x27, x27, #16\n"
+ "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n"
+ "addvl x23, x23, #4\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n"
+ ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n"
+ ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n"
+ ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n"
+ ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n"
+ ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n"
+ ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n"
+ ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n"
+ ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n"
+ ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n"
+ ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n"
+ ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n"
+ ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n"
+ ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n"
+ ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n"
+ ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ ".inst 0xa1408373 // ld1b { z19.b, z23.b, z27.b, z31.b }, pn8.b/Z, [x27]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x27, x27, #4\n"
+ "ld1b { z16.b }, p1/Z, [x23]\n"
+ "addvl x23, x23, #1\n"
+ ".inst 0xa0902660 // smopa za0.s, p1/M, p1/M, z19.b, z16.b\n"
+ ".inst 0xa09026e1 // smopa za1.s, p1/M, p1/M, z23.b, z16.b\n"
+ ".inst 0xa0902762 // smopa za2.s, p1/M, p1/M, z27.b, z16.b\n"
+ ".inst 0xa09027e3 // smopa za3.s, p1/M, p1/M, z31.b, z16.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x16, #1, 14f\n"
+ "tbz x16, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xa041c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xc0860458 // mova { z24.s-z27.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa042c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xa060c1c0 // st1w { z0.s-z3.s }, pn8.b, [x14]\n"
+ ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa061c1c4 // st1w { z4.s-z7.s }, pn8.b, [x14, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c1d8 // st1w { z24.s-z27.s }, pn8.b, [x14, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 11b\n"
+ "b 30f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, pn8.b, [x14]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c1cc // st1w { z12.s-z15.s }, pn8.b, [x14, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c1d4 // st1w { z20.s-z23.s }, pn8.b, [x14, #0x8, MUL VL]\n"
+ ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 13b\n"
+ "b 30f\n"
+ "14:" // Store to output array
+ "ldr x26, [%x[args], %[offsetof_C]]\n"
+ "sub x25, x13, x11\n"
+ "ld1rw { z23.s }, p1/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z22.s, #0x0\n"
+ "ldr x24, [%x[args], %[offsetof_ldcb]]\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "add x26, x26, x10, LSL #2\n" // C += n
+ "madd x26, x11, x24, x26\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x10, LSL #2\n"
+ "ld1w { z22.s }, p0/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x23\n"
+ "ld1rw { z21.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x25, x23\n"
+ "ld1rw { z20.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z0.s, p1/M, z23.s, z22.s\n"
+ "fmad z1.s, p1/M, z23.s, z22.s\n"
+ "fmad z2.s, p1/M, z23.s, z22.s\n"
+ "fmad z3.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n"
+ "st1w { z0.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z1.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z2.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z3.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 20f\n"
+ "19:" // Store to output array: Accumulator row 1 loop
+ ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z19.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 19b\n"
+ "20:" // Store to output array: Accumulator row 1 oddments
+ "cbz x20, 21f\n"
+ ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n"
+ "fmad z28.s, p1/M, z23.s, z22.s\n"
+ "fmad z29.s, p1/M, z23.s, z22.s\n"
+ "fmad z30.s, p1/M, z23.s, z22.s\n"
+ "fmad z31.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cabc // fclamp { z28.s-z31.s }, z21.s, z20.s\n"
+ "st1w { z28.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z29.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "st1w { z30.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "21:" // Store to output array: Accumulator row 1 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 23f\n"
+ "22:" // Store to output array: Accumulator row 2 loop
+ ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z12.s, p1/M, z23.s, z22.s\n"
+ "fmad z13.s, p1/M, z23.s, z22.s\n"
+ "fmad z14.s, p1/M, z23.s, z22.s\n"
+ "fmad z15.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n"
+ "st1w { z12.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z13.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z14.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z15.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 22b\n"
+ "23:" // Store to output array: Accumulator row 2 oddments
+ "cbz x20, 24f\n"
+ ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 24f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 24f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "24:" // Store to output array: Accumulator row 2 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x20, x25, x23, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 26f\n"
+ "25:" // Store to output array: Accumulator row 3 loop
+ ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z24.s, p1/M, z23.s, z22.s\n"
+ "fmad z25.s, p1/M, z23.s, z22.s\n"
+ "fmad z26.s, p1/M, z23.s, z22.s\n"
+ "fmad z27.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n"
+ "st1w { z24.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z25.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z26.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z27.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 25b\n"
+ "26:" // Store to output array: Accumulator row 3 oddments
+ "cbz x20, 27f\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 27f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 27f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "27:" // Store to output array: Accumulator row 3 oddments: End
+ "28:" // Store to output array: End
+ "tbz x16, #0, 30f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "29:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xa041c1e0 // ld1w { z0.s-z3.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1e4 // ld1w { z4.s-z7.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840582 // mova za2h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840483 // mova za3h.s[x12], { z4.s-z7.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 29b\n"
+ "30:" // End block
+ "incw x10\n"
+ "cmp x10, x9\n"
+ "blt 3b\n"
+ "incw x11, ALL, MUL #4\n"
+ "mov x10, #0x0\n"
+ "cmp x11, x13\n"
+ "mov x28, x27\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
index 887d78e1d..23f686a90 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
{
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 28.74 };
default:
- return { 32.35 };
+ return { 15.27 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
index d0ef531c3..1fe5f48da 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
- default:
- return { 39.66, 5.18, 4.37 };
+ case CPUModel::V1:
+ return { 53.48, 4.23, 6.53 };
+ default:
+ return { 29.07, 2.76, 5.39 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp b/src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp
index a4124c4a5..d3665534a 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,6 +30,7 @@
#include "arm_gemm.hpp"
#include "asmlib.hpp"
+#include "bfloat.hpp"
#include "utils.hpp"
#include "mergeresults.hpp"
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
index 2b712cee6..e100d9fe4 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,6 +30,7 @@
#include "arm_gemm.hpp"
#include "asmlib.hpp"
+#include "bfloat.hpp"
#include "utils.hpp"
namespace arm_gemm {
@@ -114,4 +115,8 @@ template void MergeResults<12u, 8u, false, float, __fp16>(__fp16*, float const*,
template void MergeResults<8u, 6u, false, float, __fp16>(__fp16*, float const*, int, int, int, int, int, __fp16 const*, Activation, bool);
#endif
+#if defined(__arm__) && defined(ARM_COMPUTE_ENABLE_BF16)
+template void MergeResults<8u, 6u, false, float, bfloat16>(bfloat16*, float const*, int, int, int, int, int, bfloat16 const*, Activation, bool);
+#endif
+
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp32_bf16_8x12.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp32_bf16_8x12.hpp
new file mode 100644
index 000000000..a57a855e3
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp32_bf16_8x12.hpp
@@ -0,0 +1,2809 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#ifdef __aarch64__
+
+template<>
+void MergeResults<12, 8, false>(
+ bfloat16 *out_ptr,
+ const float * in_ptr,
+ const int ldout,
+ const int y0, const int ymax,
+ const int x0, const int xmax,
+ const bfloat16 *bias,
+ Activation act,
+ bool accumulate)
+{
+ float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
+ float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
+
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0;
+ break;
+ }
+
+ size_t rows = ymax-y0;
+ size_t cols = xmax-x0;
+
+ out_ptr += (y0 * ldout) + x0;
+ bias = (bias == nullptr) ? nullptr : bias + x0;
+
+ __asm__ __volatile__(
+ "cbz %x[cols], 108f\n"
+ "cbz %x[rows], 108f\n"
+ "mov x11, #0x20\n"
+ "dup v13.4s, %w[maxval]\n"
+ "dup v12.4s, %w[minval]\n"
+ "mul x11, %x[ldout], x11\n"
+ "cbnz %x[accumulate], 66f\n"
+ "1:" // Initial: Row loop
+ "cmp %x[rows], #0x7\n"
+ "bgt 58f\n"
+ "beq 50f\n"
+ "cmp %x[rows], #0x5\n"
+ "bgt 42f\n"
+ "beq 34f\n"
+ "cmp %x[rows], #0x3\n"
+ "bgt 26f\n"
+ "beq 18f\n"
+ "cmp %x[rows], #0x1\n"
+ "bgt 10f\n"
+ "2:" // Initial: Height 1
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "mov x28, %x[bias]\n"
+ "cmp x10, #0xc\n"
+ "blt 6f\n"
+ "3:" // Initial: Height 1: Block loop
+ "cbnz %x[bias], 4f\n"
+ "movi v21.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "b 5f\n"
+ "4:" // Initial: Height 1: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "5:" // Initial: Height 1: Width 3: init done
+ "ldr q18, [%x[in_ptr], #0x0]\n"
+ "ldr q17, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q16, [%x[in_ptr], #0x20]\n"
+ "cmp x10, #0xc\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v18.4s, v18.4s, v21.4s\n"
+ "fadd v17.4s, v17.4s, v20.4s\n"
+ "fadd v16.4s, v16.4s, v19.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "str d18, [x9, #0x0]\n"
+ "str d17, [x9, #0x8]\n"
+ "str d16, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ "bge 3b\n"
+ "6:" // Initial: Height 1: no full blocks
+ "cbz x10, 9f\n"
+ "mov x20, %x[in_ptr]\n"
+ "7:" // Initial: Height 1: Single loop
+ "movi v17.16b, #0x0\n"
+ "cbz %x[bias], 8f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v17.4s, v16.4h, #0x10\n"
+ "8:" // Initial: Height 1: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v16.4s, v16.4s, v17.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "bne 7b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "9:" // Initial: Height 1: no oddments
+ "b 108f\n"
+ "10:" // Initial: Height 2
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "mov x28, %x[bias]\n"
+ "cmp x10, #0xc\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "blt 14f\n"
+ "11:" // Initial: Height 2: Block loop
+ "cbnz %x[bias], 12f\n"
+ "movi v24.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "b 13f\n"
+ "12:" // Initial: Height 2: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v24.4s, v18.4h, #0x10\n"
+ "shll v23.4s, v17.4h, #0x10\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "13:" // Initial: Height 2: Width 3: init done
+ "ldr q16, [%x[in_ptr], #0x0]\n"
+ "ldr q20, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q19, [%x[in_ptr], #0x20]\n"
+ "ldr q18, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q17, [%x[in_ptr], #0x40]\n"
+ "ldr q21, [%x[in_ptr], #0x50]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v16.4s, v16.4s, v24.4s\n"
+ "fadd v20.4s, v20.4s, v23.4s\n"
+ "fadd v19.4s, v19.4s, v22.4s\n"
+ "fadd v18.4s, v18.4s, v24.4s\n"
+ "fadd v17.4s, v17.4s, v23.4s\n"
+ "fadd v21.4s, v21.4s, v22.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str d16, [x9, #0x0]\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ "str d20, [x9, #0x8]\n"
+ "str d19, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ "str d18, [x27, #0x0]\n"
+ "str d17, [x27, #0x8]\n"
+ "str d16, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "bge 11b\n"
+ "14:" // Initial: Height 2: no full blocks
+ "cbz x10, 17f\n"
+ "mov x20, %x[in_ptr]\n"
+ "15:" // Initial: Height 2: Single loop
+ "movi v18.16b, #0x0\n"
+ "cbz %x[bias], 16f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v18.4s, v16.4h, #0x10\n"
+ "16:" // Initial: Height 2: Scalar: no bias
+ "ldr s17, [%x[in_ptr], #0x0]\n"
+ "ldr s16, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v17.4s, v17.4s, v18.4s\n"
+ "fadd v16.4s, v16.4s, v18.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "str h17, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h16, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "bne 15b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "17:" // Initial: Height 2: no oddments
+ "b 108f\n"
+ "18:" // Initial: Height 3
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "blt 22f\n"
+ "19:" // Initial: Height 3: Block loop
+ "cbnz %x[bias], 20f\n"
+ "movi v27.16b, #0x0\n"
+ "movi v26.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "b 21f\n"
+ "20:" // Initial: Height 3: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v27.4s, v18.4h, #0x10\n"
+ "shll v26.4s, v17.4h, #0x10\n"
+ "shll v25.4s, v16.4h, #0x10\n"
+ "21:" // Initial: Height 3: Width 3: init done
+ "ldr q18, [%x[in_ptr], #0x0]\n"
+ "ldr q17, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q16, [%x[in_ptr], #0x20]\n"
+ "ldr q21, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q20, [%x[in_ptr], #0x40]\n"
+ "ldr q19, [%x[in_ptr], #0x50]\n"
+ "ldr q24, [%x[in_ptr], #0x60]\n"
+ "ldr q23, [%x[in_ptr], #0x70]\n"
+ "fadd v18.4s, v18.4s, v27.4s\n"
+ "fadd v17.4s, v17.4s, v26.4s\n"
+ "ldr q22, [%x[in_ptr], #0x80]\n"
+ "fadd v16.4s, v16.4s, v25.4s\n"
+ "fadd v21.4s, v21.4s, v27.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v20.4s, v20.4s, v26.4s\n"
+ "fadd v19.4s, v19.4s, v25.4s\n"
+ "fadd v24.4s, v24.4s, v27.4s\n"
+ "fadd v23.4s, v23.4s, v26.4s\n"
+ "fadd v22.4s, v22.4s, v25.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "str d18, [x9, #0x0]\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16b12 // bfcvtn v18.4h, v24.4s\n"
+ "str d17, [x9, #0x8]\n"
+ "str d16, [x9, #0x10]\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d21, [x27, #0x0]\n"
+ "str d20, [x27, #0x8]\n"
+ "str d19, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "str d18, [x26, #0x0]\n"
+ "str d17, [x26, #0x8]\n"
+ "str d16, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "bge 19b\n"
+ "22:" // Initial: Height 3: no full blocks
+ "cbz x10, 25f\n"
+ "mov x20, %x[in_ptr]\n"
+ "23:" // Initial: Height 3: Single loop
+ "movi v19.16b, #0x0\n"
+ "cbz %x[bias], 24f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "24:" // Initial: Height 3: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "ldr s17, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s18, [%x[in_ptr], #0x60]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v16.4s, v16.4s, v19.4s\n"
+ "fadd v17.4s, v17.4s, v19.4s\n"
+ "fadd v18.4s, v18.4s, v19.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a50 // bfcvtn v16.4h, v18.4s\n"
+ "str h17, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h16, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "bne 23b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "25:" // Initial: Height 3: no oddments
+ "b 108f\n"
+ "26:" // Initial: Height 4
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "blt 30f\n"
+ "27:" // Initial: Height 4: Block loop
+ "cbnz %x[bias], 28f\n"
+ "movi v30.16b, #0x0\n"
+ "movi v29.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "b 29f\n"
+ "28:" // Initial: Height 4: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v30.4s, v18.4h, #0x10\n"
+ "shll v29.4s, v17.4h, #0x10\n"
+ "shll v28.4s, v16.4h, #0x10\n"
+ "29:" // Initial: Height 4: Width 3: init done
+ "ldr q19, [%x[in_ptr], #0x0]\n"
+ "ldr q18, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q17, [%x[in_ptr], #0x20]\n"
+ "ldr q16, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q23, [%x[in_ptr], #0x40]\n"
+ "ldr q22, [%x[in_ptr], #0x50]\n"
+ "ldr q21, [%x[in_ptr], #0x60]\n"
+ "ldr q20, [%x[in_ptr], #0x70]\n"
+ "fadd v19.4s, v19.4s, v30.4s\n"
+ "fadd v18.4s, v18.4s, v29.4s\n"
+ "ldr q27, [%x[in_ptr], #0x80]\n"
+ "ldr q26, [%x[in_ptr], #0x90]\n"
+ "fadd v17.4s, v17.4s, v28.4s\n"
+ "fadd v16.4s, v16.4s, v30.4s\n"
+ "ldr q25, [%x[in_ptr], #0xa0]\n"
+ "ldr q24, [%x[in_ptr], #0xb0]\n"
+ "fadd v23.4s, v23.4s, v29.4s\n"
+ "fadd v22.4s, v22.4s, v28.4s\n"
+ "fadd v21.4s, v21.4s, v30.4s\n"
+ "fadd v20.4s, v20.4s, v29.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v27.4s, v27.4s, v28.4s\n"
+ "fadd v26.4s, v26.4s, v30.4s\n"
+ "fadd v25.4s, v25.4s, v29.4s\n"
+ "fadd v24.4s, v24.4s, v28.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n"
+ ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n"
+ "str d19, [x9, #0x0]\n"
+ "str d18, [x9, #0x8]\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "str d17, [x9, #0x10]\n"
+ ".inst 0x0ea16b73 // bfcvtn v19.4h, v27.4s\n"
+ ".inst 0x0ea16b52 // bfcvtn v18.4h, v26.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d16, [x27, #0x0]\n"
+ ".inst 0x0ea16b31 // bfcvtn v17.4h, v25.4s\n"
+ ".inst 0x0ea16b10 // bfcvtn v16.4h, v24.4s\n"
+ "str d23, [x27, #0x8]\n"
+ "str d22, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "str d21, [x26, #0x0]\n"
+ "str d20, [x26, #0x8]\n"
+ "str d19, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d18, [x25, #0x0]\n"
+ "str d17, [x25, #0x8]\n"
+ "str d16, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "bge 27b\n"
+ "30:" // Initial: Height 4: no full blocks
+ "cbz x10, 33f\n"
+ "mov x20, %x[in_ptr]\n"
+ "31:" // Initial: Height 4: Single loop
+ "movi v20.16b, #0x0\n"
+ "cbz %x[bias], 32f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v20.4s, v16.4h, #0x10\n"
+ "32:" // Initial: Height 4: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "ldr s18, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s17, [%x[in_ptr], #0x60]\n"
+ "ldr s19, [%x[in_ptr], #0x90]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v16.4s, v16.4s, v20.4s\n"
+ "fadd v18.4s, v18.4s, v20.4s\n"
+ "fadd v17.4s, v17.4s, v20.4s\n"
+ "fadd v19.4s, v19.4s, v20.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a70 // bfcvtn v16.4h, v19.4s\n"
+ "str h18, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h17, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h16, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "bne 31b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "33:" // Initial: Height 4: no oddments
+ "b 108f\n"
+ "34:" // Initial: Height 5
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "blt 38f\n"
+ "35:" // Initial: Height 5: Block loop
+ "cbnz %x[bias], 36f\n"
+ "movi v1.16b, #0x0\n"
+ "movi v0.16b, #0x0\n"
+ "movi v31.16b, #0x0\n"
+ "b 37f\n"
+ "36:" // Initial: Height 5: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v1.4s, v18.4h, #0x10\n"
+ "shll v0.4s, v17.4h, #0x10\n"
+ "shll v31.4s, v16.4h, #0x10\n"
+ "37:" // Initial: Height 5: Width 3: init done
+ "ldr q16, [%x[in_ptr], #0x0]\n"
+ "ldr q20, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q19, [%x[in_ptr], #0x20]\n"
+ "ldr q18, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q17, [%x[in_ptr], #0x40]\n"
+ "ldr q30, [%x[in_ptr], #0x50]\n"
+ "ldr q24, [%x[in_ptr], #0x60]\n"
+ "ldr q23, [%x[in_ptr], #0x70]\n"
+ "fadd v16.4s, v16.4s, v1.4s\n"
+ "fadd v20.4s, v20.4s, v0.4s\n"
+ "ldr q22, [%x[in_ptr], #0x80]\n"
+ "ldr q21, [%x[in_ptr], #0x90]\n"
+ "fadd v19.4s, v19.4s, v31.4s\n"
+ "fadd v18.4s, v18.4s, v1.4s\n"
+ "ldr q29, [%x[in_ptr], #0xa0]\n"
+ "ldr q28, [%x[in_ptr], #0xb0]\n"
+ "fadd v17.4s, v17.4s, v0.4s\n"
+ "fadd v30.4s, v30.4s, v31.4s\n"
+ "ldr q27, [%x[in_ptr], #0xc0]\n"
+ "ldr q26, [%x[in_ptr], #0xd0]\n"
+ "fadd v24.4s, v24.4s, v1.4s\n"
+ "fadd v23.4s, v23.4s, v0.4s\n"
+ "ldr q25, [%x[in_ptr], #0xe0]\n"
+ "fadd v22.4s, v22.4s, v31.4s\n"
+ "fadd v21.4s, v21.4s, v1.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v29.4s, v29.4s, v0.4s\n"
+ "fadd v28.4s, v28.4s, v31.4s\n"
+ "fadd v27.4s, v27.4s, v1.4s\n"
+ "fadd v26.4s, v26.4s, v0.4s\n"
+ "fadd v25.4s, v25.4s, v31.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str d16, [x9, #0x0]\n"
+ ".inst 0x0ea16bd0 // bfcvtn v16.4h, v30.4s\n"
+ ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n"
+ "str d20, [x9, #0x8]\n"
+ "str d19, [x9, #0x10]\n"
+ ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n"
+ ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d18, [x27, #0x0]\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16bb4 // bfcvtn v20.4h, v29.4s\n"
+ "str d17, [x27, #0x8]\n"
+ ".inst 0x0ea16b93 // bfcvtn v19.4h, v28.4s\n"
+ ".inst 0x0ea16b72 // bfcvtn v18.4h, v27.4s\n"
+ "str d16, [x27, #0x10]\n"
+ ".inst 0x0ea16b51 // bfcvtn v17.4h, v26.4s\n"
+ ".inst 0x0ea16b30 // bfcvtn v16.4h, v25.4s\n"
+ "add x27, x27, #0x18\n"
+ "str d24, [x26, #0x0]\n"
+ "str d23, [x26, #0x8]\n"
+ "str d22, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d21, [x25, #0x0]\n"
+ "str d20, [x25, #0x8]\n"
+ "str d19, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d18, [x24, #0x0]\n"
+ "str d17, [x24, #0x8]\n"
+ "str d16, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "bge 35b\n"
+ "38:" // Initial: Height 5: no full blocks
+ "cbz x10, 41f\n"
+ "mov x20, %x[in_ptr]\n"
+ "39:" // Initial: Height 5: Single loop
+ "movi v21.16b, #0x0\n"
+ "cbz %x[bias], 40f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v21.4s, v16.4h, #0x10\n"
+ "40:" // Initial: Height 5: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "ldr s19, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s18, [%x[in_ptr], #0x60]\n"
+ "ldr s17, [%x[in_ptr], #0x90]\n"
+ "ldr s20, [%x[in_ptr], #0xc0]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v16.4s, v16.4s, v21.4s\n"
+ "fadd v19.4s, v19.4s, v21.4s\n"
+ "fadd v18.4s, v18.4s, v21.4s\n"
+ "fadd v17.4s, v17.4s, v21.4s\n"
+ "fadd v20.4s, v20.4s, v21.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a90 // bfcvtn v16.4h, v20.4s\n"
+ "str h19, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h18, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h17, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h16, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "bne 39b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "41:" // Initial: Height 5: no oddments
+ "b 108f\n"
+ "42:" // Initial: Height 6
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "blt 46f\n"
+ "43:" // Initial: Height 6: Block loop
+ "cbnz %x[bias], 44f\n"
+ "movi v4.16b, #0x0\n"
+ "movi v3.16b, #0x0\n"
+ "movi v2.16b, #0x0\n"
+ "b 45f\n"
+ "44:" // Initial: Height 6: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v4.4s, v18.4h, #0x10\n"
+ "shll v3.4s, v17.4h, #0x10\n"
+ "shll v2.4s, v16.4h, #0x10\n"
+ "45:" // Initial: Height 6: Width 3: init done
+ "ldr q21, [%x[in_ptr], #0x0]\n"
+ "ldr q16, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q20, [%x[in_ptr], #0x20]\n"
+ "ldr q19, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q18, [%x[in_ptr], #0x40]\n"
+ "ldr q17, [%x[in_ptr], #0x50]\n"
+ "ldr q1, [%x[in_ptr], #0x60]\n"
+ "ldr q26, [%x[in_ptr], #0x70]\n"
+ "fadd v21.4s, v21.4s, v4.4s\n"
+ "fadd v16.4s, v16.4s, v3.4s\n"
+ "ldr q25, [%x[in_ptr], #0x80]\n"
+ "ldr q24, [%x[in_ptr], #0x90]\n"
+ "fadd v20.4s, v20.4s, v2.4s\n"
+ "fadd v19.4s, v19.4s, v4.4s\n"
+ "ldr q23, [%x[in_ptr], #0xa0]\n"
+ "ldr q22, [%x[in_ptr], #0xb0]\n"
+ "fadd v18.4s, v18.4s, v3.4s\n"
+ "fadd v17.4s, v17.4s, v2.4s\n"
+ "ldr q0, [%x[in_ptr], #0xc0]\n"
+ "ldr q31, [%x[in_ptr], #0xd0]\n"
+ "fadd v1.4s, v1.4s, v4.4s\n"
+ "fadd v26.4s, v26.4s, v3.4s\n"
+ "ldr q30, [%x[in_ptr], #0xe0]\n"
+ "ldr q29, [%x[in_ptr], #0xf0]\n"
+ "fadd v25.4s, v25.4s, v2.4s\n"
+ "fadd v24.4s, v24.4s, v4.4s\n"
+ "ldr q28, [%x[in_ptr], #0x100]\n"
+ "ldr q27, [%x[in_ptr], #0x110]\n"
+ "fadd v23.4s, v23.4s, v3.4s\n"
+ "fadd v22.4s, v22.4s, v2.4s\n"
+ "fadd v0.4s, v0.4s, v4.4s\n"
+ "fadd v31.4s, v31.4s, v3.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v30.4s, v30.4s, v2.4s\n"
+ "fadd v29.4s, v29.4s, v4.4s\n"
+ "fadd v28.4s, v28.4s, v3.4s\n"
+ "fadd v27.4s, v27.4s, v2.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str d21, [x9, #0x0]\n"
+ "str d16, [x9, #0x8]\n"
+ ".inst 0x0ea16830 // bfcvtn v16.4h, v1.4s\n"
+ ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n"
+ "str d20, [x9, #0x10]\n"
+ ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n"
+ ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d19, [x27, #0x0]\n"
+ ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n"
+ ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n"
+ "str d18, [x27, #0x8]\n"
+ ".inst 0x0ea16815 // bfcvtn v21.4h, v0.4s\n"
+ ".inst 0x0ea16bf4 // bfcvtn v20.4h, v31.4s\n"
+ "str d17, [x27, #0x10]\n"
+ ".inst 0x0ea16bd3 // bfcvtn v19.4h, v30.4s\n"
+ ".inst 0x0ea16bb2 // bfcvtn v18.4h, v29.4s\n"
+ "add x27, x27, #0x18\n"
+ "str d16, [x26, #0x0]\n"
+ ".inst 0x0ea16b91 // bfcvtn v17.4h, v28.4s\n"
+ ".inst 0x0ea16b70 // bfcvtn v16.4h, v27.4s\n"
+ "str d26, [x26, #0x8]\n"
+ "str d25, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d24, [x25, #0x0]\n"
+ "str d23, [x25, #0x8]\n"
+ "str d22, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d21, [x24, #0x0]\n"
+ "str d20, [x24, #0x8]\n"
+ "str d19, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d18, [x23, #0x0]\n"
+ "str d17, [x23, #0x8]\n"
+ "str d16, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "bge 43b\n"
+ "46:" // Initial: Height 6: no full blocks
+ "cbz x10, 49f\n"
+ "mov x20, %x[in_ptr]\n"
+ "47:" // Initial: Height 6: Single loop
+ "movi v22.16b, #0x0\n"
+ "cbz %x[bias], 48f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "48:" // Initial: Height 6: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "ldr s20, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s19, [%x[in_ptr], #0x60]\n"
+ "ldr s18, [%x[in_ptr], #0x90]\n"
+ "ldr s17, [%x[in_ptr], #0xc0]\n"
+ "ldr s21, [%x[in_ptr], #0xf0]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v16.4s, v16.4s, v22.4s\n"
+ "fadd v20.4s, v20.4s, v22.4s\n"
+ "fadd v19.4s, v19.4s, v22.4s\n"
+ "fadd v18.4s, v18.4s, v22.4s\n"
+ "fadd v17.4s, v17.4s, v22.4s\n"
+ "fadd v21.4s, v21.4s, v22.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ "str h20, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h19, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h18, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h17, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h16, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "bne 47b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "49:" // Initial: Height 6: no oddments
+ "b 108f\n"
+ "50:" // Initial: Height 7
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "blt 54f\n"
+ "51:" // Initial: Height 7: Block loop
+ "cbnz %x[bias], 52f\n"
+ "movi v7.16b, #0x0\n"
+ "movi v6.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "b 53f\n"
+ "52:" // Initial: Height 7: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v7.4s, v18.4h, #0x10\n"
+ "shll v6.4s, v17.4h, #0x10\n"
+ "shll v5.4s, v16.4h, #0x10\n"
+ "53:" // Initial: Height 7: Width 3: init done
+ "ldr q18, [%x[in_ptr], #0x0]\n"
+ "ldr q17, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q16, [%x[in_ptr], #0x20]\n"
+ "ldr q21, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q20, [%x[in_ptr], #0x40]\n"
+ "ldr q19, [%x[in_ptr], #0x50]\n"
+ "ldr q4, [%x[in_ptr], #0x60]\n"
+ "ldr q3, [%x[in_ptr], #0x70]\n"
+ "fadd v18.4s, v18.4s, v7.4s\n"
+ "fadd v17.4s, v17.4s, v6.4s\n"
+ "ldr q2, [%x[in_ptr], #0x80]\n"
+ "ldr q27, [%x[in_ptr], #0x90]\n"
+ "fadd v16.4s, v16.4s, v5.4s\n"
+ "fadd v21.4s, v21.4s, v7.4s\n"
+ "ldr q26, [%x[in_ptr], #0xa0]\n"
+ "ldr q25, [%x[in_ptr], #0xb0]\n"
+ "fadd v20.4s, v20.4s, v6.4s\n"
+ "fadd v19.4s, v19.4s, v5.4s\n"
+ "ldr q24, [%x[in_ptr], #0xc0]\n"
+ "ldr q23, [%x[in_ptr], #0xd0]\n"
+ "fadd v4.4s, v4.4s, v7.4s\n"
+ "fadd v3.4s, v3.4s, v6.4s\n"
+ "ldr q22, [%x[in_ptr], #0xe0]\n"
+ "ldr q1, [%x[in_ptr], #0xf0]\n"
+ "fadd v2.4s, v2.4s, v5.4s\n"
+ "fadd v27.4s, v27.4s, v7.4s\n"
+ "ldr q0, [%x[in_ptr], #0x100]\n"
+ "ldr q31, [%x[in_ptr], #0x110]\n"
+ "fadd v26.4s, v26.4s, v6.4s\n"
+ "fadd v25.4s, v25.4s, v5.4s\n"
+ "ldr q30, [%x[in_ptr], #0x120]\n"
+ "ldr q29, [%x[in_ptr], #0x130]\n"
+ "fadd v24.4s, v24.4s, v7.4s\n"
+ "fadd v23.4s, v23.4s, v6.4s\n"
+ "ldr q28, [%x[in_ptr], #0x140]\n"
+ "fadd v22.4s, v22.4s, v5.4s\n"
+ "fadd v1.4s, v1.4s, v7.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v0.4s, v0.4s, v6.4s\n"
+ "fadd v31.4s, v31.4s, v5.4s\n"
+ "fadd v30.4s, v30.4s, v7.4s\n"
+ "fadd v29.4s, v29.4s, v6.4s\n"
+ "fadd v28.4s, v28.4s, v5.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v4.4s, v4.4s, v13.4s\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v4.4s, v4.4s, v12.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "str d18, [x9, #0x0]\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16892 // bfcvtn v18.4h, v4.4s\n"
+ "str d17, [x9, #0x8]\n"
+ "str d16, [x9, #0x10]\n"
+ ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n"
+ ".inst 0x0ea16850 // bfcvtn v16.4h, v2.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d21, [x27, #0x0]\n"
+ ".inst 0x0ea16b7b // bfcvtn v27.4h, v27.4s\n"
+ ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n"
+ "str d20, [x27, #0x8]\n"
+ ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n"
+ ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n"
+ "str d19, [x27, #0x10]\n"
+ ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n"
+ ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n"
+ "add x27, x27, #0x18\n"
+ "str d18, [x26, #0x0]\n"
+ ".inst 0x0ea16835 // bfcvtn v21.4h, v1.4s\n"
+ ".inst 0x0ea16814 // bfcvtn v20.4h, v0.4s\n"
+ "str d17, [x26, #0x8]\n"
+ ".inst 0x0ea16bf3 // bfcvtn v19.4h, v31.4s\n"
+ ".inst 0x0ea16bd2 // bfcvtn v18.4h, v30.4s\n"
+ "str d16, [x26, #0x10]\n"
+ ".inst 0x0ea16bb1 // bfcvtn v17.4h, v29.4s\n"
+ ".inst 0x0ea16b90 // bfcvtn v16.4h, v28.4s\n"
+ "add x26, x26, #0x18\n"
+ "str d27, [x25, #0x0]\n"
+ "str d26, [x25, #0x8]\n"
+ "str d25, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d24, [x24, #0x0]\n"
+ "str d23, [x24, #0x8]\n"
+ "str d22, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d21, [x23, #0x0]\n"
+ "str d20, [x23, #0x8]\n"
+ "str d19, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "str d18, [x22, #0x0]\n"
+ "str d17, [x22, #0x8]\n"
+ "str d16, [x22, #0x10]\n"
+ "add x22, x22, #0x18\n"
+ "bge 51b\n"
+ "54:" // Initial: Height 7: no full blocks
+ "cbz x10, 57f\n"
+ "mov x20, %x[in_ptr]\n"
+ "55:" // Initial: Height 7: Single loop
+ "movi v23.16b, #0x0\n"
+ "cbz %x[bias], 56f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v23.4s, v16.4h, #0x10\n"
+ "56:" // Initial: Height 7: Scalar: no bias
+ "ldr s16, [%x[in_ptr], #0x0]\n"
+ "ldr s21, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s20, [%x[in_ptr], #0x60]\n"
+ "ldr s19, [%x[in_ptr], #0x90]\n"
+ "ldr s18, [%x[in_ptr], #0xc0]\n"
+ "ldr s17, [%x[in_ptr], #0xf0]\n"
+ "ldr s22, [%x[in_ptr], #0x120]\n"
+ "fadd v16.4s, v16.4s, v23.4s\n"
+ "fadd v21.4s, v21.4s, v23.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v20.4s, v20.4s, v23.4s\n"
+ "fadd v19.4s, v19.4s, v23.4s\n"
+ "fadd v18.4s, v18.4s, v23.4s\n"
+ "fadd v17.4s, v17.4s, v23.4s\n"
+ "fadd v22.4s, v22.4s, v23.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "str h21, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h20, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h19, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h18, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h17, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "str h16, [x22, #0x0]\n"
+ "add x22, x22, #0x2\n"
+ "bne 55b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "57:" // Initial: Height 7: no oddments
+ "b 108f\n"
+ "58:" // Initial: Height 8
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "mov x28, %x[bias]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "add x21, x22, %x[ldout], LSL #1\n"
+ "blt 62f\n"
+ "59:" // Initial: Height 8: Block loop
+ "cbnz %x[bias], 60f\n"
+ "movi v10.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v8.16b, #0x0\n"
+ "b 61f\n"
+ "60:" // Initial: Height 8: Width 3: bias
+ "ldr d18, [x28, #0x0]\n"
+ "ldr d17, [x28, #0x8]\n"
+ "ldr d16, [x28, #0x10]\n"
+ "shll v10.4s, v18.4h, #0x10\n"
+ "shll v9.4s, v17.4h, #0x10\n"
+ "shll v8.4s, v16.4h, #0x10\n"
+ "61:" // Initial: Height 8: Width 3: init done
+ "ldr q18, [%x[in_ptr], #0x0]\n"
+ "ldr q17, [%x[in_ptr], #0x10]\n"
+ "sub x10, x10, #0xc\n"
+ "add x28, x28, #0x18\n"
+ "ldr q16, [%x[in_ptr], #0x20]\n"
+ "ldr q22, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q21, [%x[in_ptr], #0x40]\n"
+ "ldr q20, [%x[in_ptr], #0x50]\n"
+ "ldr q19, [%x[in_ptr], #0x60]\n"
+ "ldr q7, [%x[in_ptr], #0x70]\n"
+ "fadd v18.4s, v18.4s, v10.4s\n"
+ "fadd v17.4s, v17.4s, v9.4s\n"
+ "ldr q6, [%x[in_ptr], #0x80]\n"
+ "ldr q5, [%x[in_ptr], #0x90]\n"
+ "fadd v16.4s, v16.4s, v8.4s\n"
+ "fadd v22.4s, v22.4s, v10.4s\n"
+ "ldr q29, [%x[in_ptr], #0xa0]\n"
+ "ldr q28, [%x[in_ptr], #0xb0]\n"
+ "fadd v21.4s, v21.4s, v9.4s\n"
+ "fadd v20.4s, v20.4s, v8.4s\n"
+ "ldr q27, [%x[in_ptr], #0xc0]\n"
+ "ldr q26, [%x[in_ptr], #0xd0]\n"
+ "fadd v19.4s, v19.4s, v10.4s\n"
+ "fadd v7.4s, v7.4s, v9.4s\n"
+ "ldr q25, [%x[in_ptr], #0xe0]\n"
+ "ldr q24, [%x[in_ptr], #0xf0]\n"
+ "fadd v6.4s, v6.4s, v8.4s\n"
+ "fadd v5.4s, v5.4s, v10.4s\n"
+ "ldr q23, [%x[in_ptr], #0x100]\n"
+ "ldr q4, [%x[in_ptr], #0x110]\n"
+ "fadd v29.4s, v29.4s, v9.4s\n"
+ "fadd v28.4s, v28.4s, v8.4s\n"
+ "ldr q3, [%x[in_ptr], #0x120]\n"
+ "ldr q2, [%x[in_ptr], #0x130]\n"
+ "fadd v27.4s, v27.4s, v10.4s\n"
+ "fadd v26.4s, v26.4s, v9.4s\n"
+ "ldr q1, [%x[in_ptr], #0x140]\n"
+ "ldr q0, [%x[in_ptr], #0x150]\n"
+ "fadd v25.4s, v25.4s, v8.4s\n"
+ "fadd v24.4s, v24.4s, v10.4s\n"
+ "ldr q31, [%x[in_ptr], #0x160]\n"
+ "ldr q30, [%x[in_ptr], #0x170]\n"
+ "fadd v23.4s, v23.4s, v9.4s\n"
+ "fadd v4.4s, v4.4s, v8.4s\n"
+ "fadd v3.4s, v3.4s, v10.4s\n"
+ "fadd v2.4s, v2.4s, v9.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v1.4s, v1.4s, v8.4s\n"
+ "fadd v0.4s, v0.4s, v10.4s\n"
+ "fadd v31.4s, v31.4s, v9.4s\n"
+ "fadd v30.4s, v30.4s, v8.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v7.4s, v7.4s, v13.4s\n"
+ "fmin v6.4s, v6.4s, v13.4s\n"
+ "fmin v5.4s, v5.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v4.4s, v4.4s, v13.4s\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v7.4s, v7.4s, v12.4s\n"
+ "fmax v6.4s, v6.4s, v12.4s\n"
+ "fmax v5.4s, v5.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v4.4s, v4.4s, v12.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "str d18, [x9, #0x0]\n"
+ "str d17, [x9, #0x8]\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea168f2 // bfcvtn v18.4h, v7.4s\n"
+ "str d16, [x9, #0x10]\n"
+ ".inst 0x0ea168d1 // bfcvtn v17.4h, v6.4s\n"
+ ".inst 0x0ea168b0 // bfcvtn v16.4h, v5.4s\n"
+ "add x9, x9, #0x18\n"
+ "str d22, [x27, #0x0]\n"
+ ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n"
+ ".inst 0x0ea16b9c // bfcvtn v28.4h, v28.4s\n"
+ "str d21, [x27, #0x8]\n"
+ ".inst 0x0ea16b7b // bfcvtn v27.4h, v27.4s\n"
+ ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n"
+ "str d20, [x27, #0x10]\n"
+ ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n"
+ ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n"
+ "add x27, x27, #0x18\n"
+ "str d19, [x26, #0x0]\n"
+ ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n"
+ ".inst 0x0ea16896 // bfcvtn v22.4h, v4.4s\n"
+ "str d18, [x26, #0x8]\n"
+ ".inst 0x0ea16875 // bfcvtn v21.4h, v3.4s\n"
+ ".inst 0x0ea16854 // bfcvtn v20.4h, v2.4s\n"
+ "str d17, [x26, #0x10]\n"
+ ".inst 0x0ea16833 // bfcvtn v19.4h, v1.4s\n"
+ ".inst 0x0ea16812 // bfcvtn v18.4h, v0.4s\n"
+ "add x26, x26, #0x18\n"
+ "str d16, [x25, #0x0]\n"
+ ".inst 0x0ea16bf1 // bfcvtn v17.4h, v31.4s\n"
+ ".inst 0x0ea16bd0 // bfcvtn v16.4h, v30.4s\n"
+ "str d29, [x25, #0x8]\n"
+ "str d28, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d27, [x24, #0x0]\n"
+ "str d26, [x24, #0x8]\n"
+ "str d25, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d24, [x23, #0x0]\n"
+ "str d23, [x23, #0x8]\n"
+ "str d22, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "str d21, [x22, #0x0]\n"
+ "str d20, [x22, #0x8]\n"
+ "str d19, [x22, #0x10]\n"
+ "add x22, x22, #0x18\n"
+ "str d18, [x21, #0x0]\n"
+ "str d17, [x21, #0x8]\n"
+ "str d16, [x21, #0x10]\n"
+ "add x21, x21, #0x18\n"
+ "bge 59b\n"
+ "62:" // Initial: Height 8: no full blocks
+ "cbz x10, 65f\n"
+ "mov x20, %x[in_ptr]\n"
+ "63:" // Initial: Height 8: Single loop
+ "movi v24.16b, #0x0\n"
+ "cbz %x[bias], 64f\n"
+ "ldr h16, [x28, #0x0]\n"
+ "shll v24.4s, v16.4h, #0x10\n"
+ "64:" // Initial: Height 8: Scalar: no bias
+ "ldr s17, [%x[in_ptr], #0x0]\n"
+ "ldr s16, [%x[in_ptr], #0x30]\n"
+ "subs x10, x10, #0x1\n"
+ "add x28, x28, #0x2\n"
+ "ldr s21, [%x[in_ptr], #0x60]\n"
+ "ldr s20, [%x[in_ptr], #0x90]\n"
+ "ldr s19, [%x[in_ptr], #0xc0]\n"
+ "ldr s18, [%x[in_ptr], #0xf0]\n"
+ "ldr s23, [%x[in_ptr], #0x120]\n"
+ "ldr s22, [%x[in_ptr], #0x150]\n"
+ "fadd v17.4s, v17.4s, v24.4s\n"
+ "fadd v16.4s, v16.4s, v24.4s\n"
+ "fadd v21.4s, v21.4s, v24.4s\n"
+ "fadd v20.4s, v20.4s, v24.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v19.4s, v19.4s, v24.4s\n"
+ "fadd v18.4s, v18.4s, v24.4s\n"
+ "fadd v23.4s, v23.4s, v24.4s\n"
+ "fadd v22.4s, v22.4s, v24.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v16.4s, v16.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v16.4s, v16.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ "str h17, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h16, [x27, #0x0]\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "add x27, x27, #0x2\n"
+ "str h21, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h20, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h19, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h18, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "str h17, [x22, #0x0]\n"
+ "add x22, x22, #0x2\n"
+ "str h16, [x21, #0x0]\n"
+ "add x21, x21, #0x2\n"
+ "bne 63b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "65:" // Initial: Height 8: no oddments
+ "subs %x[rows], %x[rows], #0x8\n"
+ "add %x[out_ptr], %x[out_ptr], x11\n"
+ "bgt 1b\n"
+ "b 108f\n"
+ "66:" // Accumulate
+ "67:" // Accumulate: Row loop
+ "cmp %x[rows], #0x7\n"
+ "bgt 103f\n"
+ "beq 98f\n"
+ "cmp %x[rows], #0x5\n"
+ "bgt 93f\n"
+ "beq 88f\n"
+ "cmp %x[rows], #0x3\n"
+ "bgt 83f\n"
+ "beq 78f\n"
+ "cmp %x[rows], #0x1\n"
+ "bgt 73f\n"
+ "68:" // Accumulate: Height 1
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "cmp x10, #0xc\n"
+ "blt 70f\n"
+ "69:" // Accumulate: Height 1: Block loop
+ "ldr d16, [x9, #0x0]\n"
+ "ldr q19, [%x[in_ptr], #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr q18, [%x[in_ptr], #0x10]\n"
+ "ldr q17, [%x[in_ptr], #0x20]\n"
+ "cmp x10, #0xc\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v19.4s, v19.4s, v16.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ ".inst 0x0ea16a70 // bfcvtn v16.4h, v19.4s\n"
+ "str d16, [x9, #0x0]\n"
+ "ldr d16, [x9, #0x8]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v18.4s, v18.4s, v16.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ ".inst 0x0ea16a50 // bfcvtn v16.4h, v18.4s\n"
+ "str d16, [x9, #0x8]\n"
+ "ldr d16, [x9, #0x10]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v17.4s, v17.4s, v16.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ ".inst 0x0ea16a30 // bfcvtn v16.4h, v17.4s\n"
+ "str d16, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ "bge 69b\n"
+ "70:" // Accumulate: Height 1: no full blocks
+ "cbz x10, 72f\n"
+ "mov x20, %x[in_ptr]\n"
+ "71:" // Accumulate: Height 1: Single loop
+ "ldr h16, [x9, #0x0]\n"
+ "ldr s17, [%x[in_ptr], #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v17.4s, v17.4s, v16.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ ".inst 0x0ea16a30 // bfcvtn v16.4h, v17.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "bne 71b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "72:" // Accumulate: Height 1: no oddments
+ "b 108f\n"
+ "73:" // Accumulate: Height 2
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "cmp x10, #0xc\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "blt 75f\n"
+ "74:" // Accumulate: Height 2: Block loop
+ "ldr d17, [x9, #0x0]\n"
+ "ldr d16, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr q23, [%x[in_ptr], #0x0]\n"
+ "ldr q22, [%x[in_ptr], #0x30]\n"
+ "cmp x10, #0xc\n"
+ "ldr q21, [%x[in_ptr], #0x10]\n"
+ "ldr q20, [%x[in_ptr], #0x40]\n"
+ "ldr q19, [%x[in_ptr], #0x20]\n"
+ "ldr q18, [%x[in_ptr], #0x50]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v23.4s, v23.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16af0 // bfcvtn v16.4h, v23.4s\n"
+ ".inst 0x0ea16ad1 // bfcvtn v17.4h, v22.4s\n"
+ "str d16, [x9, #0x0]\n"
+ "ldr d16, [x9, #0x8]\n"
+ "str d17, [x27, #0x0]\n"
+ "shll v17.4s, v16.4h, #0x10\n"
+ "ldr d16, [x27, #0x8]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v21.4s, v21.4s, v17.4s\n"
+ "fadd v20.4s, v20.4s, v16.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ ".inst 0x0ea16a91 // bfcvtn v17.4h, v20.4s\n"
+ "str d16, [x9, #0x8]\n"
+ "ldr d16, [x9, #0x10]\n"
+ "str d17, [x27, #0x8]\n"
+ "shll v17.4s, v16.4h, #0x10\n"
+ "ldr d16, [x27, #0x10]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v19.4s, v19.4s, v17.4s\n"
+ "fadd v18.4s, v18.4s, v16.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ ".inst 0x0ea16a71 // bfcvtn v17.4h, v19.4s\n"
+ ".inst 0x0ea16a50 // bfcvtn v16.4h, v18.4s\n"
+ "str d17, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ "str d16, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "bge 74b\n"
+ "75:" // Accumulate: Height 2: no full blocks
+ "cbz x10, 77f\n"
+ "mov x20, %x[in_ptr]\n"
+ "76:" // Accumulate: Height 2: Single loop
+ "ldr h17, [x9, #0x0]\n"
+ "ldr h16, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr s19, [%x[in_ptr], #0x0]\n"
+ "ldr s18, [%x[in_ptr], #0x30]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v19.4s, v19.4s, v17.4s\n"
+ "fadd v18.4s, v18.4s, v16.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ ".inst 0x0ea16a70 // bfcvtn v16.4h, v19.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a50 // bfcvtn v16.4h, v18.4s\n"
+ "str h16, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "bne 76b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "77:" // Accumulate: Height 2: no oddments
+ "b 108f\n"
+ "78:" // Accumulate: Height 3
+ "mov x10, %x[cols]\n"
+ "mov x9, %x[out_ptr]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "blt 80f\n"
+ "79:" // Accumulate: Height 3: Block loop
+ "ldr d18, [x9, #0x0]\n"
+ "ldr d17, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d16, [x26, #0x0]\n"
+ "ldr q27, [%x[in_ptr], #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr q26, [%x[in_ptr], #0x30]\n"
+ "ldr q25, [%x[in_ptr], #0x60]\n"
+ "ldr q24, [%x[in_ptr], #0x10]\n"
+ "ldr q23, [%x[in_ptr], #0x40]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "ldr q22, [%x[in_ptr], #0x70]\n"
+ "ldr q21, [%x[in_ptr], #0x20]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr q20, [%x[in_ptr], #0x50]\n"
+ "ldr q19, [%x[in_ptr], #0x80]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fadd v27.4s, v27.4s, v18.4s\n"
+ "fadd v26.4s, v26.4s, v17.4s\n"
+ "fadd v25.4s, v25.4s, v16.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ ".inst 0x0ea16b72 // bfcvtn v18.4h, v27.4s\n"
+ ".inst 0x0ea16b50 // bfcvtn v16.4h, v26.4s\n"
+ ".inst 0x0ea16b31 // bfcvtn v17.4h, v25.4s\n"
+ "str d18, [x9, #0x0]\n"
+ "str d16, [x27, #0x0]\n"
+ "ldr d16, [x9, #0x8]\n"
+ "str d17, [x26, #0x0]\n"
+ "ldr d17, [x27, #0x8]\n"
+ "shll v18.4s, v16.4h, #0x10\n"
+ "ldr d16, [x26, #0x8]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v24.4s, v24.4s, v18.4s\n"
+ "fadd v23.4s, v23.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16b10 // bfcvtn v16.4h, v24.4s\n"
+ ".inst 0x0ea16af2 // bfcvtn v18.4h, v23.4s\n"
+ "str d16, [x9, #0x8]\n"
+ ".inst 0x0ea16ad1 // bfcvtn v17.4h, v22.4s\n"
+ "ldr d16, [x9, #0x10]\n"
+ "str d18, [x27, #0x8]\n"
+ "str d17, [x26, #0x8]\n"
+ "shll v18.4s, v16.4h, #0x10\n"
+ "ldr d17, [x27, #0x10]\n"
+ "ldr d16, [x26, #0x10]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v21.4s, v21.4s, v18.4s\n"
+ "fadd v20.4s, v20.4s, v17.4s\n"
+ "fadd v19.4s, v19.4s, v16.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ ".inst 0x0ea16a91 // bfcvtn v17.4h, v20.4s\n"
+ "str d16, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ ".inst 0x0ea16a70 // bfcvtn v16.4h, v19.4s\n"
+ "str d17, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "str d16, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "bge 79b\n"
+ "80:" // Accumulate: Height 3: no full blocks
+ "cbz x10, 82f\n"
+ "mov x20, %x[in_ptr]\n"
+ "81:" // Accumulate: Height 3: Single loop
+ "ldr h18, [x9, #0x0]\n"
+ "ldr h17, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h16, [x26, #0x0]\n"
+ "ldr s21, [%x[in_ptr], #0x0]\n"
+ "ldr s20, [%x[in_ptr], #0x30]\n"
+ "ldr s19, [%x[in_ptr], #0x60]\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v21.4s, v21.4s, v18.4s\n"
+ "fadd v20.4s, v20.4s, v17.4s\n"
+ "fadd v19.4s, v19.4s, v16.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ ".inst 0x0ea16a91 // bfcvtn v17.4h, v20.4s\n"
+ "str h16, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ ".inst 0x0ea16a70 // bfcvtn v16.4h, v19.4s\n"
+ "str h17, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h16, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "bne 81b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "82:" // Accumulate: Height 3: no oddments
+ "b 108f\n"
+ "83:" // Accumulate: Height 4
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "blt 85f\n"
+ "84:" // Accumulate: Height 4: Block loop
+ "ldr d19, [x9, #0x0]\n"
+ "ldr d18, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d17, [x26, #0x0]\n"
+ "ldr d16, [x25, #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr q31, [%x[in_ptr], #0x0]\n"
+ "ldr q30, [%x[in_ptr], #0x30]\n"
+ "ldr q29, [%x[in_ptr], #0x60]\n"
+ "ldr q28, [%x[in_ptr], #0x90]\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "ldr q27, [%x[in_ptr], #0x10]\n"
+ "ldr q26, [%x[in_ptr], #0x40]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr q25, [%x[in_ptr], #0x70]\n"
+ "ldr q24, [%x[in_ptr], #0xa0]\n"
+ "ldr q23, [%x[in_ptr], #0x20]\n"
+ "ldr q22, [%x[in_ptr], #0x50]\n"
+ "fadd v31.4s, v31.4s, v19.4s\n"
+ "fadd v30.4s, v30.4s, v18.4s\n"
+ "ldr q21, [%x[in_ptr], #0x80]\n"
+ "ldr q20, [%x[in_ptr], #0xb0]\n"
+ "fadd v29.4s, v29.4s, v17.4s\n"
+ "fadd v28.4s, v28.4s, v16.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ ".inst 0x0ea16bf3 // bfcvtn v19.4h, v31.4s\n"
+ ".inst 0x0ea16bd0 // bfcvtn v16.4h, v30.4s\n"
+ ".inst 0x0ea16bb2 // bfcvtn v18.4h, v29.4s\n"
+ ".inst 0x0ea16b91 // bfcvtn v17.4h, v28.4s\n"
+ "str d19, [x9, #0x0]\n"
+ "str d16, [x27, #0x0]\n"
+ "ldr d16, [x9, #0x8]\n"
+ "str d18, [x26, #0x0]\n"
+ "str d17, [x25, #0x0]\n"
+ "ldr d18, [x27, #0x8]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "ldr d17, [x26, #0x8]\n"
+ "ldr d16, [x25, #0x8]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v27.4s, v27.4s, v19.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v26.4s, v26.4s, v18.4s\n"
+ "fadd v25.4s, v25.4s, v17.4s\n"
+ "fadd v24.4s, v24.4s, v16.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ ".inst 0x0ea16b71 // bfcvtn v17.4h, v27.4s\n"
+ ".inst 0x0ea16b53 // bfcvtn v19.4h, v26.4s\n"
+ ".inst 0x0ea16b30 // bfcvtn v16.4h, v25.4s\n"
+ "str d17, [x9, #0x8]\n"
+ ".inst 0x0ea16b12 // bfcvtn v18.4h, v24.4s\n"
+ "ldr d17, [x9, #0x10]\n"
+ "str d19, [x27, #0x8]\n"
+ "str d16, [x26, #0x8]\n"
+ "ldr d16, [x27, #0x10]\n"
+ "str d18, [x25, #0x8]\n"
+ "shll v19.4s, v17.4h, #0x10\n"
+ "ldr d17, [x26, #0x10]\n"
+ "shll v18.4s, v16.4h, #0x10\n"
+ "ldr d16, [x25, #0x10]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v23.4s, v23.4s, v19.4s\n"
+ "fadd v22.4s, v22.4s, v18.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v21.4s, v21.4s, v17.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fadd v20.4s, v20.4s, v16.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "str d17, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ ".inst 0x0ea16ab1 // bfcvtn v17.4h, v21.4s\n"
+ "str d16, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ ".inst 0x0ea16a90 // bfcvtn v16.4h, v20.4s\n"
+ "str d17, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d16, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "bge 84b\n"
+ "85:" // Accumulate: Height 4: no full blocks
+ "cbz x10, 87f\n"
+ "mov x20, %x[in_ptr]\n"
+ "86:" // Accumulate: Height 4: Single loop
+ "ldr h19, [x9, #0x0]\n"
+ "ldr h18, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h17, [x26, #0x0]\n"
+ "ldr h16, [x25, #0x0]\n"
+ "ldr s23, [%x[in_ptr], #0x0]\n"
+ "ldr s22, [%x[in_ptr], #0x30]\n"
+ "ldr s21, [%x[in_ptr], #0x60]\n"
+ "ldr s20, [%x[in_ptr], #0x90]\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v23.4s, v23.4s, v19.4s\n"
+ "fadd v22.4s, v22.4s, v18.4s\n"
+ "fadd v21.4s, v21.4s, v17.4s\n"
+ "fadd v20.4s, v20.4s, v16.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ ".inst 0x0ea16af3 // bfcvtn v19.4h, v23.4s\n"
+ ".inst 0x0ea16ad2 // bfcvtn v18.4h, v22.4s\n"
+ ".inst 0x0ea16ab1 // bfcvtn v17.4h, v21.4s\n"
+ ".inst 0x0ea16a90 // bfcvtn v16.4h, v20.4s\n"
+ "str h19, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h18, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h17, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h16, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "bne 86b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "87:" // Accumulate: Height 4: no oddments
+ "b 108f\n"
+ "88:" // Accumulate: Height 5
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "blt 90f\n"
+ "89:" // Accumulate: Height 5: Block loop
+ "ldr d20, [x9, #0x0]\n"
+ "ldr d19, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d18, [x26, #0x0]\n"
+ "ldr d17, [x25, #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr d16, [x24, #0x0]\n"
+ "ldr q3, [%x[in_ptr], #0x0]\n"
+ "ldr q2, [%x[in_ptr], #0x30]\n"
+ "ldr q1, [%x[in_ptr], #0x60]\n"
+ "shll v20.4s, v20.4h, #0x10\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "ldr q0, [%x[in_ptr], #0x90]\n"
+ "ldr q31, [%x[in_ptr], #0xc0]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "ldr q30, [%x[in_ptr], #0x10]\n"
+ "ldr q29, [%x[in_ptr], #0x40]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr q28, [%x[in_ptr], #0x70]\n"
+ "ldr q27, [%x[in_ptr], #0xa0]\n"
+ "fadd v3.4s, v3.4s, v20.4s\n"
+ "fadd v2.4s, v2.4s, v19.4s\n"
+ "ldr q26, [%x[in_ptr], #0xd0]\n"
+ "ldr q25, [%x[in_ptr], #0x20]\n"
+ "fadd v1.4s, v1.4s, v18.4s\n"
+ "fadd v0.4s, v0.4s, v17.4s\n"
+ "ldr q24, [%x[in_ptr], #0x50]\n"
+ "ldr q23, [%x[in_ptr], #0x80]\n"
+ "fadd v31.4s, v31.4s, v16.4s\n"
+ "ldr q22, [%x[in_ptr], #0xb0]\n"
+ "ldr q21, [%x[in_ptr], #0xe0]\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ ".inst 0x0ea16874 // bfcvtn v20.4h, v3.4s\n"
+ ".inst 0x0ea16853 // bfcvtn v19.4h, v2.4s\n"
+ ".inst 0x0ea16831 // bfcvtn v17.4h, v1.4s\n"
+ ".inst 0x0ea16810 // bfcvtn v16.4h, v0.4s\n"
+ ".inst 0x0ea16bf2 // bfcvtn v18.4h, v31.4s\n"
+ "str d20, [x9, #0x0]\n"
+ "str d19, [x27, #0x0]\n"
+ "str d17, [x26, #0x0]\n"
+ "ldr d17, [x9, #0x8]\n"
+ "str d16, [x25, #0x0]\n"
+ "ldr d16, [x27, #0x8]\n"
+ "str d18, [x24, #0x0]\n"
+ "ldr d18, [x26, #0x8]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x25, #0x8]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "ldr d16, [x24, #0x8]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v30.4s, v30.4s, v20.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v29.4s, v29.4s, v19.4s\n"
+ "fadd v28.4s, v28.4s, v18.4s\n"
+ "fadd v27.4s, v27.4s, v17.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fadd v26.4s, v26.4s, v16.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ ".inst 0x0ea16bd2 // bfcvtn v18.4h, v30.4s\n"
+ ".inst 0x0ea16bb3 // bfcvtn v19.4h, v29.4s\n"
+ ".inst 0x0ea16b91 // bfcvtn v17.4h, v28.4s\n"
+ ".inst 0x0ea16b70 // bfcvtn v16.4h, v27.4s\n"
+ "str d18, [x9, #0x8]\n"
+ ".inst 0x0ea16b52 // bfcvtn v18.4h, v26.4s\n"
+ "str d19, [x27, #0x8]\n"
+ "str d17, [x26, #0x8]\n"
+ "ldr d17, [x9, #0x10]\n"
+ "str d16, [x25, #0x8]\n"
+ "ldr d16, [x27, #0x10]\n"
+ "str d18, [x24, #0x8]\n"
+ "ldr d18, [x26, #0x10]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x25, #0x10]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "ldr d16, [x24, #0x10]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v25.4s, v25.4s, v20.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v24.4s, v24.4s, v19.4s\n"
+ "fadd v23.4s, v23.4s, v18.4s\n"
+ "fadd v22.4s, v22.4s, v17.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fadd v21.4s, v21.4s, v16.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ ".inst 0x0ea16b30 // bfcvtn v16.4h, v25.4s\n"
+ ".inst 0x0ea16b13 // bfcvtn v19.4h, v24.4s\n"
+ ".inst 0x0ea16af2 // bfcvtn v18.4h, v23.4s\n"
+ ".inst 0x0ea16ad1 // bfcvtn v17.4h, v22.4s\n"
+ "str d16, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ "str d19, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "str d18, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d17, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d16, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "bge 89b\n"
+ "90:" // Accumulate: Height 5: no full blocks
+ "cbz x10, 92f\n"
+ "mov x20, %x[in_ptr]\n"
+ "91:" // Accumulate: Height 5: Single loop
+ "ldr h20, [x9, #0x0]\n"
+ "ldr h19, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h18, [x26, #0x0]\n"
+ "ldr h17, [x25, #0x0]\n"
+ "ldr h16, [x24, #0x0]\n"
+ "ldr s25, [%x[in_ptr], #0x0]\n"
+ "ldr s24, [%x[in_ptr], #0x30]\n"
+ "ldr s23, [%x[in_ptr], #0x60]\n"
+ "shll v20.4s, v20.4h, #0x10\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "ldr s22, [%x[in_ptr], #0x90]\n"
+ "ldr s21, [%x[in_ptr], #0xc0]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v25.4s, v25.4s, v20.4s\n"
+ "fadd v24.4s, v24.4s, v19.4s\n"
+ "fadd v23.4s, v23.4s, v18.4s\n"
+ "fadd v22.4s, v22.4s, v17.4s\n"
+ "fadd v21.4s, v21.4s, v16.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmin v21.4s, v21.4s, v13.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ "fmax v21.4s, v21.4s, v12.4s\n"
+ ".inst 0x0ea16b34 // bfcvtn v20.4h, v25.4s\n"
+ ".inst 0x0ea16b13 // bfcvtn v19.4h, v24.4s\n"
+ ".inst 0x0ea16af2 // bfcvtn v18.4h, v23.4s\n"
+ ".inst 0x0ea16ad1 // bfcvtn v17.4h, v22.4s\n"
+ ".inst 0x0ea16ab0 // bfcvtn v16.4h, v21.4s\n"
+ "str h20, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h19, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h18, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h17, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h16, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "bne 91b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "92:" // Accumulate: Height 5: no oddments
+ "b 108f\n"
+ "93:" // Accumulate: Height 6
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "blt 95f\n"
+ "94:" // Accumulate: Height 6: Block loop
+ "ldr d21, [x9, #0x0]\n"
+ "ldr d20, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d19, [x26, #0x0]\n"
+ "ldr d18, [x25, #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr d17, [x24, #0x0]\n"
+ "ldr d16, [x23, #0x0]\n"
+ "ldr q6, [%x[in_ptr], #0x0]\n"
+ "ldr q5, [%x[in_ptr], #0x30]\n"
+ "shll v22.4s, v21.4h, #0x10\n"
+ "shll v21.4s, v20.4h, #0x10\n"
+ "ldr q4, [%x[in_ptr], #0x60]\n"
+ "ldr q3, [%x[in_ptr], #0x90]\n"
+ "shll v20.4s, v19.4h, #0x10\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "ldr q2, [%x[in_ptr], #0xc0]\n"
+ "ldr q19, [%x[in_ptr], #0xf0]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr q1, [%x[in_ptr], #0x10]\n"
+ "ldr q0, [%x[in_ptr], #0x40]\n"
+ "fadd v6.4s, v6.4s, v22.4s\n"
+ "fadd v5.4s, v5.4s, v21.4s\n"
+ "ldr q31, [%x[in_ptr], #0x70]\n"
+ "ldr q30, [%x[in_ptr], #0xa0]\n"
+ "fadd v4.4s, v4.4s, v20.4s\n"
+ "fadd v3.4s, v3.4s, v18.4s\n"
+ "ldr q29, [%x[in_ptr], #0xd0]\n"
+ "ldr q28, [%x[in_ptr], #0x100]\n"
+ "fadd v2.4s, v2.4s, v17.4s\n"
+ "fadd v19.4s, v19.4s, v16.4s\n"
+ "ldr q27, [%x[in_ptr], #0x20]\n"
+ "ldr q26, [%x[in_ptr], #0x50]\n"
+ "fmin v6.4s, v6.4s, v13.4s\n"
+ "fmin v5.4s, v5.4s, v13.4s\n"
+ "ldr q25, [%x[in_ptr], #0x80]\n"
+ "ldr q24, [%x[in_ptr], #0xb0]\n"
+ "fmin v4.4s, v4.4s, v13.4s\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "ldr q23, [%x[in_ptr], #0xe0]\n"
+ "ldr q22, [%x[in_ptr], #0x110]\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmax v6.4s, v6.4s, v12.4s\n"
+ "fmax v5.4s, v5.4s, v12.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fmax v4.4s, v4.4s, v12.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ ".inst 0x0ea168d5 // bfcvtn v21.4h, v6.4s\n"
+ ".inst 0x0ea168b4 // bfcvtn v20.4h, v5.4s\n"
+ ".inst 0x0ea16892 // bfcvtn v18.4h, v4.4s\n"
+ ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n"
+ ".inst 0x0ea16850 // bfcvtn v16.4h, v2.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ "str d21, [x9, #0x0]\n"
+ "str d20, [x27, #0x0]\n"
+ "str d18, [x26, #0x0]\n"
+ "ldr d18, [x9, #0x8]\n"
+ "str d17, [x25, #0x0]\n"
+ "ldr d17, [x27, #0x8]\n"
+ "str d16, [x24, #0x0]\n"
+ "ldr d16, [x26, #0x8]\n"
+ "str d19, [x23, #0x0]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "ldr d18, [x25, #0x8]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x24, #0x8]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "ldr d16, [x23, #0x8]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "fadd v1.4s, v1.4s, v21.4s\n"
+ "fadd v0.4s, v0.4s, v20.4s\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v31.4s, v31.4s, v19.4s\n"
+ "fadd v30.4s, v30.4s, v18.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fadd v29.4s, v29.4s, v17.4s\n"
+ "fadd v28.4s, v28.4s, v16.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ ".inst 0x0ea16832 // bfcvtn v18.4h, v1.4s\n"
+ ".inst 0x0ea16810 // bfcvtn v16.4h, v0.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ ".inst 0x0ea16bf4 // bfcvtn v20.4h, v31.4s\n"
+ ".inst 0x0ea16bd1 // bfcvtn v17.4h, v30.4s\n"
+ "str d18, [x9, #0x8]\n"
+ "str d16, [x27, #0x8]\n"
+ ".inst 0x0ea16bb3 // bfcvtn v19.4h, v29.4s\n"
+ ".inst 0x0ea16b92 // bfcvtn v18.4h, v28.4s\n"
+ "ldr d16, [x9, #0x10]\n"
+ "str d20, [x26, #0x8]\n"
+ "str d17, [x25, #0x8]\n"
+ "ldr d17, [x27, #0x10]\n"
+ "str d19, [x24, #0x8]\n"
+ "shll v21.4s, v16.4h, #0x10\n"
+ "ldr d16, [x26, #0x10]\n"
+ "str d18, [x23, #0x8]\n"
+ "ldr d18, [x25, #0x10]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x24, #0x10]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "fadd v27.4s, v27.4s, v21.4s\n"
+ "ldr d16, [x23, #0x10]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v26.4s, v26.4s, v20.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v25.4s, v25.4s, v19.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fadd v24.4s, v24.4s, v18.4s\n"
+ "fadd v23.4s, v23.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ ".inst 0x0ea16b71 // bfcvtn v17.4h, v27.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16b50 // bfcvtn v16.4h, v26.4s\n"
+ "str d17, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ ".inst 0x0ea16b33 // bfcvtn v19.4h, v25.4s\n"
+ ".inst 0x0ea16b12 // bfcvtn v18.4h, v24.4s\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ "str d16, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "str d19, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d18, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d17, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d16, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "bge 94b\n"
+ "95:" // Accumulate: Height 6: no full blocks
+ "cbz x10, 97f\n"
+ "mov x20, %x[in_ptr]\n"
+ "96:" // Accumulate: Height 6: Single loop
+ "ldr h21, [x9, #0x0]\n"
+ "ldr h20, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h19, [x26, #0x0]\n"
+ "ldr h18, [x25, #0x0]\n"
+ "ldr h17, [x24, #0x0]\n"
+ "ldr h16, [x23, #0x0]\n"
+ "ldr s27, [%x[in_ptr], #0x0]\n"
+ "ldr s26, [%x[in_ptr], #0x30]\n"
+ "shll v21.4s, v21.4h, #0x10\n"
+ "shll v20.4s, v20.4h, #0x10\n"
+ "ldr s25, [%x[in_ptr], #0x60]\n"
+ "ldr s24, [%x[in_ptr], #0x90]\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "ldr s23, [%x[in_ptr], #0xc0]\n"
+ "ldr s22, [%x[in_ptr], #0xf0]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v27.4s, v27.4s, v21.4s\n"
+ "fadd v26.4s, v26.4s, v20.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v25.4s, v25.4s, v19.4s\n"
+ "fadd v24.4s, v24.4s, v18.4s\n"
+ "fadd v23.4s, v23.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16b75 // bfcvtn v21.4h, v27.4s\n"
+ ".inst 0x0ea16b54 // bfcvtn v20.4h, v26.4s\n"
+ ".inst 0x0ea16b33 // bfcvtn v19.4h, v25.4s\n"
+ ".inst 0x0ea16b12 // bfcvtn v18.4h, v24.4s\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "str h21, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h20, [x27, #0x0]\n"
+ "add x27, x27, #0x2\n"
+ "str h19, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h18, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h17, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h16, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "bne 96b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "97:" // Accumulate: Height 6: no oddments
+ "b 108f\n"
+ "98:" // Accumulate: Height 7
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "blt 100f\n"
+ "99:" // Accumulate: Height 7: Block loop
+ "ldr d22, [x9, #0x0]\n"
+ "ldr d21, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d20, [x26, #0x0]\n"
+ "ldr d19, [x25, #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr d18, [x24, #0x0]\n"
+ "ldr d17, [x23, #0x0]\n"
+ "ldr d16, [x22, #0x0]\n"
+ "ldr q9, [%x[in_ptr], #0x0]\n"
+ "shll v24.4s, v22.4h, #0x10\n"
+ "shll v23.4s, v21.4h, #0x10\n"
+ "ldr q8, [%x[in_ptr], #0x30]\n"
+ "ldr q7, [%x[in_ptr], #0x60]\n"
+ "shll v21.4s, v20.4h, #0x10\n"
+ "shll v19.4s, v19.4h, #0x10\n"
+ "ldr q6, [%x[in_ptr], #0x90]\n"
+ "ldr q5, [%x[in_ptr], #0xc0]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "ldr q20, [%x[in_ptr], #0xf0]\n"
+ "ldr q22, [%x[in_ptr], #0x120]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v9.4s, v9.4s, v24.4s\n"
+ "ldr q4, [%x[in_ptr], #0x10]\n"
+ "ldr q3, [%x[in_ptr], #0x40]\n"
+ "fadd v8.4s, v8.4s, v23.4s\n"
+ "fadd v7.4s, v7.4s, v21.4s\n"
+ "ldr q2, [%x[in_ptr], #0x70]\n"
+ "ldr q1, [%x[in_ptr], #0xa0]\n"
+ "fadd v6.4s, v6.4s, v19.4s\n"
+ "fadd v5.4s, v5.4s, v18.4s\n"
+ "ldr q0, [%x[in_ptr], #0xd0]\n"
+ "ldr q31, [%x[in_ptr], #0x100]\n"
+ "fadd v20.4s, v20.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "ldr q30, [%x[in_ptr], #0x130]\n"
+ "ldr q29, [%x[in_ptr], #0x20]\n"
+ "fmin v9.4s, v9.4s, v13.4s\n"
+ "fmin v8.4s, v8.4s, v13.4s\n"
+ "ldr q28, [%x[in_ptr], #0x50]\n"
+ "ldr q27, [%x[in_ptr], #0x80]\n"
+ "fmin v7.4s, v7.4s, v13.4s\n"
+ "fmin v6.4s, v6.4s, v13.4s\n"
+ "ldr q26, [%x[in_ptr], #0xb0]\n"
+ "ldr q25, [%x[in_ptr], #0xe0]\n"
+ "fmin v5.4s, v5.4s, v13.4s\n"
+ "fmin v20.4s, v20.4s, v13.4s\n"
+ "ldr q24, [%x[in_ptr], #0x110]\n"
+ "ldr q23, [%x[in_ptr], #0x140]\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v9.4s, v9.4s, v12.4s\n"
+ "fmax v8.4s, v8.4s, v12.4s\n"
+ "fmax v7.4s, v7.4s, v12.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fmax v6.4s, v6.4s, v12.4s\n"
+ "fmax v5.4s, v5.4s, v12.4s\n"
+ "fmax v20.4s, v20.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16935 // bfcvtn v21.4h, v9.4s\n"
+ ".inst 0x0ea16913 // bfcvtn v19.4h, v8.4s\n"
+ ".inst 0x0ea168f0 // bfcvtn v16.4h, v7.4s\n"
+ ".inst 0x0ea168d2 // bfcvtn v18.4h, v6.4s\n"
+ ".inst 0x0ea168b1 // bfcvtn v17.4h, v5.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "str d21, [x9, #0x0]\n"
+ "str d19, [x27, #0x0]\n"
+ ".inst 0x0ea16ad3 // bfcvtn v19.4h, v22.4s\n"
+ "str d16, [x26, #0x0]\n"
+ "ldr d16, [x9, #0x8]\n"
+ "str d18, [x25, #0x0]\n"
+ "ldr d18, [x27, #0x8]\n"
+ "str d17, [x24, #0x0]\n"
+ "ldr d17, [x26, #0x8]\n"
+ "str d20, [x23, #0x0]\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "ldr d16, [x25, #0x8]\n"
+ "str d19, [x22, #0x0]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "ldr d18, [x24, #0x8]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x23, #0x8]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "fadd v4.4s, v4.4s, v22.4s\n"
+ "ldr d16, [x22, #0x8]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "fadd v3.4s, v3.4s, v21.4s\n"
+ "fadd v2.4s, v2.4s, v20.4s\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v1.4s, v1.4s, v19.4s\n"
+ "fadd v0.4s, v0.4s, v18.4s\n"
+ "fmin v4.4s, v4.4s, v13.4s\n"
+ "fadd v31.4s, v31.4s, v17.4s\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "fadd v30.4s, v30.4s, v16.4s\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmax v4.4s, v4.4s, v12.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ ".inst 0x0ea16893 // bfcvtn v19.4h, v4.4s\n"
+ ".inst 0x0ea16875 // bfcvtn v21.4h, v3.4s\n"
+ ".inst 0x0ea16850 // bfcvtn v16.4h, v2.4s\n"
+ ".inst 0x0ea16832 // bfcvtn v18.4h, v1.4s\n"
+ ".inst 0x0ea16811 // bfcvtn v17.4h, v0.4s\n"
+ "str d19, [x9, #0x8]\n"
+ ".inst 0x0ea16bf4 // bfcvtn v20.4h, v31.4s\n"
+ ".inst 0x0ea16bd3 // bfcvtn v19.4h, v30.4s\n"
+ "str d21, [x27, #0x8]\n"
+ "str d16, [x26, #0x8]\n"
+ "ldr d16, [x9, #0x10]\n"
+ "str d18, [x25, #0x8]\n"
+ "ldr d18, [x27, #0x10]\n"
+ "str d17, [x24, #0x8]\n"
+ "ldr d17, [x26, #0x10]\n"
+ "str d20, [x23, #0x8]\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "ldr d16, [x25, #0x10]\n"
+ "str d19, [x22, #0x8]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "ldr d18, [x24, #0x10]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "ldr d17, [x23, #0x10]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "fadd v29.4s, v29.4s, v22.4s\n"
+ "ldr d16, [x22, #0x10]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "fadd v28.4s, v28.4s, v21.4s\n"
+ "fadd v27.4s, v27.4s, v20.4s\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v26.4s, v26.4s, v19.4s\n"
+ "fadd v25.4s, v25.4s, v18.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fadd v24.4s, v24.4s, v17.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fadd v23.4s, v23.4s, v16.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ ".inst 0x0ea16bb0 // bfcvtn v16.4h, v29.4s\n"
+ ".inst 0x0ea16b95 // bfcvtn v21.4h, v28.4s\n"
+ ".inst 0x0ea16b74 // bfcvtn v20.4h, v27.4s\n"
+ ".inst 0x0ea16b53 // bfcvtn v19.4h, v26.4s\n"
+ ".inst 0x0ea16b32 // bfcvtn v18.4h, v25.4s\n"
+ "str d16, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ ".inst 0x0ea16b11 // bfcvtn v17.4h, v24.4s\n"
+ ".inst 0x0ea16af0 // bfcvtn v16.4h, v23.4s\n"
+ "str d21, [x27, #0x10]\n"
+ "add x27, x27, #0x18\n"
+ "str d20, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d19, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d18, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d17, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "str d16, [x22, #0x10]\n"
+ "add x22, x22, #0x18\n"
+ "bge 99b\n"
+ "100:" // Accumulate: Height 7: no full blocks
+ "cbz x10, 102f\n"
+ "mov x20, %x[in_ptr]\n"
+ "101:" // Accumulate: Height 7: Single loop
+ "ldr h22, [x9, #0x0]\n"
+ "ldr h21, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h20, [x26, #0x0]\n"
+ "ldr h19, [x25, #0x0]\n"
+ "ldr h18, [x24, #0x0]\n"
+ "ldr h17, [x23, #0x0]\n"
+ "ldr h16, [x22, #0x0]\n"
+ "ldr s29, [%x[in_ptr], #0x0]\n"
+ "shll v28.4s, v22.4h, #0x10\n"
+ "shll v27.4s, v21.4h, #0x10\n"
+ "ldr s26, [%x[in_ptr], #0x30]\n"
+ "ldr s25, [%x[in_ptr], #0x60]\n"
+ "shll v21.4s, v20.4h, #0x10\n"
+ "shll v20.4s, v19.4h, #0x10\n"
+ "ldr s24, [%x[in_ptr], #0x90]\n"
+ "ldr s23, [%x[in_ptr], #0xc0]\n"
+ "shll v19.4s, v18.4h, #0x10\n"
+ "shll v18.4s, v17.4h, #0x10\n"
+ "ldr s17, [%x[in_ptr], #0xf0]\n"
+ "ldr s22, [%x[in_ptr], #0x120]\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v29.4s, v29.4s, v28.4s\n"
+ "fadd v26.4s, v26.4s, v27.4s\n"
+ "fadd v25.4s, v25.4s, v21.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v24.4s, v24.4s, v20.4s\n"
+ "fadd v23.4s, v23.4s, v19.4s\n"
+ "fadd v17.4s, v17.4s, v18.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v17.4s, v17.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v17.4s, v17.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16bb5 // bfcvtn v21.4h, v29.4s\n"
+ ".inst 0x0ea16b50 // bfcvtn v16.4h, v26.4s\n"
+ ".inst 0x0ea16b34 // bfcvtn v20.4h, v25.4s\n"
+ ".inst 0x0ea16b13 // bfcvtn v19.4h, v24.4s\n"
+ ".inst 0x0ea16af2 // bfcvtn v18.4h, v23.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ "str h21, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h16, [x27, #0x0]\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "add x27, x27, #0x2\n"
+ "str h20, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h19, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h18, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h17, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "str h16, [x22, #0x0]\n"
+ "add x22, x22, #0x2\n"
+ "bne 101b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "102:" // Accumulate: Height 7: no oddments
+ "b 108f\n"
+ "103:" // Accumulate: Height 8
+ "mov x9, %x[out_ptr]\n"
+ "mov x10, %x[cols]\n"
+ "add x27, x9, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "cmp x10, #0xc\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "add x21, x22, %x[ldout], LSL #1\n"
+ "blt 105f\n"
+ "104:" // Accumulate: Height 8: Block loop
+ "ldr d23, [x9, #0x0]\n"
+ "ldr d22, [x27, #0x0]\n"
+ "sub x10, x10, #0xc\n"
+ "ldr d21, [x26, #0x0]\n"
+ "ldr d20, [x25, #0x0]\n"
+ "cmp x10, #0xc\n"
+ "ldr d19, [x24, #0x0]\n"
+ "ldr d18, [x23, #0x0]\n"
+ "ldr d17, [x22, #0x0]\n"
+ "ldr d16, [x21, #0x0]\n"
+ "shll v26.4s, v23.4h, #0x10\n"
+ "shll v25.4s, v22.4h, #0x10\n"
+ "ldr q11, [%x[in_ptr], #0x0]\n"
+ "ldr q10, [%x[in_ptr], #0x30]\n"
+ "shll v24.4s, v21.4h, #0x10\n"
+ "shll v23.4s, v20.4h, #0x10\n"
+ "ldr q9, [%x[in_ptr], #0x60]\n"
+ "ldr q8, [%x[in_ptr], #0x90]\n"
+ "shll v21.4s, v19.4h, #0x10\n"
+ "shll v20.4s, v18.4h, #0x10\n"
+ "ldr q18, [%x[in_ptr], #0xc0]\n"
+ "ldr q19, [%x[in_ptr], #0xf0]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr q7, [%x[in_ptr], #0x120]\n"
+ "ldr q22, [%x[in_ptr], #0x150]\n"
+ "fadd v11.4s, v11.4s, v26.4s\n"
+ "fadd v10.4s, v10.4s, v25.4s\n"
+ "ldr q6, [%x[in_ptr], #0x10]\n"
+ "ldr q5, [%x[in_ptr], #0x40]\n"
+ "fadd v9.4s, v9.4s, v24.4s\n"
+ "fadd v8.4s, v8.4s, v23.4s\n"
+ "ldr q4, [%x[in_ptr], #0x70]\n"
+ "ldr q3, [%x[in_ptr], #0xa0]\n"
+ "fadd v18.4s, v18.4s, v21.4s\n"
+ "fadd v19.4s, v19.4s, v20.4s\n"
+ "ldr q2, [%x[in_ptr], #0xd0]\n"
+ "ldr q1, [%x[in_ptr], #0x100]\n"
+ "fadd v7.4s, v7.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "ldr q0, [%x[in_ptr], #0x130]\n"
+ "ldr q31, [%x[in_ptr], #0x160]\n"
+ "fmin v11.4s, v11.4s, v13.4s\n"
+ "fmin v10.4s, v10.4s, v13.4s\n"
+ "ldr q30, [%x[in_ptr], #0x20]\n"
+ "ldr q29, [%x[in_ptr], #0x50]\n"
+ "fmin v9.4s, v9.4s, v13.4s\n"
+ "fmin v8.4s, v8.4s, v13.4s\n"
+ "ldr q28, [%x[in_ptr], #0x80]\n"
+ "ldr q27, [%x[in_ptr], #0xb0]\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "ldr q26, [%x[in_ptr], #0xe0]\n"
+ "ldr q25, [%x[in_ptr], #0x110]\n"
+ "fmin v7.4s, v7.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "ldr q24, [%x[in_ptr], #0x140]\n"
+ "ldr q23, [%x[in_ptr], #0x170]\n"
+ "fmax v11.4s, v11.4s, v12.4s\n"
+ "fmax v10.4s, v10.4s, v12.4s\n"
+ "fmax v9.4s, v9.4s, v12.4s\n"
+ "fmax v8.4s, v8.4s, v12.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x180\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v7.4s, v7.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16975 // bfcvtn v21.4h, v11.4s\n"
+ ".inst 0x0ea16954 // bfcvtn v20.4h, v10.4s\n"
+ ".inst 0x0ea16931 // bfcvtn v17.4h, v9.4s\n"
+ ".inst 0x0ea16910 // bfcvtn v16.4h, v8.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ "str d21, [x9, #0x0]\n"
+ "str d20, [x27, #0x0]\n"
+ ".inst 0x0ea168f5 // bfcvtn v21.4h, v7.4s\n"
+ ".inst 0x0ea16ad4 // bfcvtn v20.4h, v22.4s\n"
+ "str d17, [x26, #0x0]\n"
+ "ldr d17, [x9, #0x8]\n"
+ "str d16, [x25, #0x0]\n"
+ "ldr d16, [x27, #0x8]\n"
+ "str d18, [x24, #0x0]\n"
+ "ldr d18, [x26, #0x8]\n"
+ "str d19, [x23, #0x0]\n"
+ "shll v19.4s, v17.4h, #0x10\n"
+ "ldr d17, [x25, #0x8]\n"
+ "str d21, [x22, #0x0]\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "ldr d16, [x24, #0x8]\n"
+ "str d20, [x21, #0x0]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "ldr d18, [x23, #0x8]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "fadd v6.4s, v6.4s, v19.4s\n"
+ "ldr d17, [x22, #0x8]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "fadd v5.4s, v5.4s, v22.4s\n"
+ "ldr d16, [x21, #0x8]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "fadd v4.4s, v4.4s, v21.4s\n"
+ "fadd v3.4s, v3.4s, v20.4s\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v2.4s, v2.4s, v19.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v1.4s, v1.4s, v18.4s\n"
+ "fmin v6.4s, v6.4s, v13.4s\n"
+ "fmin v5.4s, v5.4s, v13.4s\n"
+ "fadd v0.4s, v0.4s, v17.4s\n"
+ "fmin v4.4s, v4.4s, v13.4s\n"
+ "fadd v31.4s, v31.4s, v16.4s\n"
+ "fmin v3.4s, v3.4s, v13.4s\n"
+ "fmin v2.4s, v2.4s, v13.4s\n"
+ "fmin v1.4s, v1.4s, v13.4s\n"
+ "fmin v0.4s, v0.4s, v13.4s\n"
+ "fmax v6.4s, v6.4s, v12.4s\n"
+ "fmin v31.4s, v31.4s, v13.4s\n"
+ "fmax v5.4s, v5.4s, v12.4s\n"
+ "fmax v4.4s, v4.4s, v12.4s\n"
+ "fmax v3.4s, v3.4s, v12.4s\n"
+ "fmax v2.4s, v2.4s, v12.4s\n"
+ "fmax v1.4s, v1.4s, v12.4s\n"
+ "fmax v0.4s, v0.4s, v12.4s\n"
+ "fmax v31.4s, v31.4s, v12.4s\n"
+ ".inst 0x0ea168d5 // bfcvtn v21.4h, v6.4s\n"
+ ".inst 0x0ea168b4 // bfcvtn v20.4h, v5.4s\n"
+ ".inst 0x0ea16891 // bfcvtn v17.4h, v4.4s\n"
+ ".inst 0x0ea16870 // bfcvtn v16.4h, v3.4s\n"
+ ".inst 0x0ea16852 // bfcvtn v18.4h, v2.4s\n"
+ ".inst 0x0ea16833 // bfcvtn v19.4h, v1.4s\n"
+ "str d21, [x9, #0x8]\n"
+ "str d20, [x27, #0x8]\n"
+ ".inst 0x0ea16815 // bfcvtn v21.4h, v0.4s\n"
+ ".inst 0x0ea16bf4 // bfcvtn v20.4h, v31.4s\n"
+ "str d17, [x26, #0x8]\n"
+ "ldr d17, [x9, #0x10]\n"
+ "str d16, [x25, #0x8]\n"
+ "ldr d16, [x27, #0x10]\n"
+ "str d18, [x24, #0x8]\n"
+ "ldr d18, [x26, #0x10]\n"
+ "str d19, [x23, #0x8]\n"
+ "shll v19.4s, v17.4h, #0x10\n"
+ "ldr d17, [x25, #0x10]\n"
+ "str d21, [x22, #0x8]\n"
+ "shll v22.4s, v16.4h, #0x10\n"
+ "ldr d16, [x24, #0x10]\n"
+ "str d20, [x21, #0x8]\n"
+ "shll v21.4s, v18.4h, #0x10\n"
+ "ldr d18, [x23, #0x10]\n"
+ "shll v20.4s, v17.4h, #0x10\n"
+ "fadd v30.4s, v30.4s, v19.4s\n"
+ "ldr d17, [x22, #0x10]\n"
+ "shll v19.4s, v16.4h, #0x10\n"
+ "fadd v29.4s, v29.4s, v22.4s\n"
+ "ldr d16, [x21, #0x10]\n"
+ "shll v18.4s, v18.4h, #0x10\n"
+ "fadd v28.4s, v28.4s, v21.4s\n"
+ "fadd v27.4s, v27.4s, v20.4s\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "fadd v26.4s, v26.4s, v19.4s\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "fadd v25.4s, v25.4s, v18.4s\n"
+ "fmin v30.4s, v30.4s, v13.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fadd v24.4s, v24.4s, v17.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fadd v23.4s, v23.4s, v16.4s\n"
+ "fmin v27.4s, v27.4s, v13.4s\n"
+ "fmin v26.4s, v26.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmax v30.4s, v30.4s, v12.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v27.4s, v27.4s, v12.4s\n"
+ "fmax v26.4s, v26.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ ".inst 0x0ea16bd1 // bfcvtn v17.4h, v30.4s\n"
+ ".inst 0x0ea16bb0 // bfcvtn v16.4h, v29.4s\n"
+ ".inst 0x0ea16b95 // bfcvtn v21.4h, v28.4s\n"
+ ".inst 0x0ea16b74 // bfcvtn v20.4h, v27.4s\n"
+ ".inst 0x0ea16b53 // bfcvtn v19.4h, v26.4s\n"
+ ".inst 0x0ea16b32 // bfcvtn v18.4h, v25.4s\n"
+ "str d17, [x9, #0x10]\n"
+ "add x9, x9, #0x18\n"
+ "str d16, [x27, #0x10]\n"
+ ".inst 0x0ea16b11 // bfcvtn v17.4h, v24.4s\n"
+ ".inst 0x0ea16af0 // bfcvtn v16.4h, v23.4s\n"
+ "add x27, x27, #0x18\n"
+ "str d21, [x26, #0x10]\n"
+ "add x26, x26, #0x18\n"
+ "str d20, [x25, #0x10]\n"
+ "add x25, x25, #0x18\n"
+ "str d19, [x24, #0x10]\n"
+ "add x24, x24, #0x18\n"
+ "str d18, [x23, #0x10]\n"
+ "add x23, x23, #0x18\n"
+ "str d17, [x22, #0x10]\n"
+ "add x22, x22, #0x18\n"
+ "str d16, [x21, #0x10]\n"
+ "add x21, x21, #0x18\n"
+ "bge 104b\n"
+ "105:" // Accumulate: Height 8: no full blocks
+ "cbz x10, 107f\n"
+ "mov x20, %x[in_ptr]\n"
+ "106:" // Accumulate: Height 8: Single loop
+ "ldr h23, [x9, #0x0]\n"
+ "ldr h22, [x27, #0x0]\n"
+ "subs x10, x10, #0x1\n"
+ "ldr h21, [x26, #0x0]\n"
+ "ldr h20, [x25, #0x0]\n"
+ "ldr h19, [x24, #0x0]\n"
+ "ldr h18, [x23, #0x0]\n"
+ "ldr h17, [x22, #0x0]\n"
+ "ldr h16, [x21, #0x0]\n"
+ "shll v31.4s, v23.4h, #0x10\n"
+ "shll v30.4s, v22.4h, #0x10\n"
+ "ldr s29, [%x[in_ptr], #0x0]\n"
+ "ldr s28, [%x[in_ptr], #0x30]\n"
+ "shll v27.4s, v21.4h, #0x10\n"
+ "shll v26.4s, v20.4h, #0x10\n"
+ "ldr s25, [%x[in_ptr], #0x60]\n"
+ "ldr s24, [%x[in_ptr], #0x90]\n"
+ "shll v21.4s, v19.4h, #0x10\n"
+ "shll v20.4s, v18.4h, #0x10\n"
+ "ldr s19, [%x[in_ptr], #0xc0]\n"
+ "ldr s18, [%x[in_ptr], #0xf0]\n"
+ "shll v17.4s, v17.4h, #0x10\n"
+ "shll v16.4s, v16.4h, #0x10\n"
+ "ldr s23, [%x[in_ptr], #0x120]\n"
+ "ldr s22, [%x[in_ptr], #0x150]\n"
+ "fadd v29.4s, v29.4s, v31.4s\n"
+ "fadd v28.4s, v28.4s, v30.4s\n"
+ "fadd v25.4s, v25.4s, v27.4s\n"
+ "fadd v24.4s, v24.4s, v26.4s\n"
+ "add %x[in_ptr], %x[in_ptr], #0x4\n"
+ "fadd v19.4s, v19.4s, v21.4s\n"
+ "fadd v18.4s, v18.4s, v20.4s\n"
+ "fadd v23.4s, v23.4s, v17.4s\n"
+ "fadd v22.4s, v22.4s, v16.4s\n"
+ "fmin v29.4s, v29.4s, v13.4s\n"
+ "fmin v28.4s, v28.4s, v13.4s\n"
+ "fmin v25.4s, v25.4s, v13.4s\n"
+ "fmin v24.4s, v24.4s, v13.4s\n"
+ "fmin v19.4s, v19.4s, v13.4s\n"
+ "fmin v18.4s, v18.4s, v13.4s\n"
+ "fmin v23.4s, v23.4s, v13.4s\n"
+ "fmin v22.4s, v22.4s, v13.4s\n"
+ "fmax v29.4s, v29.4s, v12.4s\n"
+ "fmax v28.4s, v28.4s, v12.4s\n"
+ "fmax v25.4s, v25.4s, v12.4s\n"
+ "fmax v24.4s, v24.4s, v12.4s\n"
+ "fmax v19.4s, v19.4s, v12.4s\n"
+ "fmax v18.4s, v18.4s, v12.4s\n"
+ "fmax v23.4s, v23.4s, v12.4s\n"
+ "fmax v22.4s, v22.4s, v12.4s\n"
+ ".inst 0x0ea16bb1 // bfcvtn v17.4h, v29.4s\n"
+ ".inst 0x0ea16b90 // bfcvtn v16.4h, v28.4s\n"
+ ".inst 0x0ea16b35 // bfcvtn v21.4h, v25.4s\n"
+ ".inst 0x0ea16b14 // bfcvtn v20.4h, v24.4s\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ "str h17, [x9, #0x0]\n"
+ "add x9, x9, #0x2\n"
+ "str h16, [x27, #0x0]\n"
+ ".inst 0x0ea16af1 // bfcvtn v17.4h, v23.4s\n"
+ ".inst 0x0ea16ad0 // bfcvtn v16.4h, v22.4s\n"
+ "add x27, x27, #0x2\n"
+ "str h21, [x26, #0x0]\n"
+ "add x26, x26, #0x2\n"
+ "str h20, [x25, #0x0]\n"
+ "add x25, x25, #0x2\n"
+ "str h19, [x24, #0x0]\n"
+ "add x24, x24, #0x2\n"
+ "str h18, [x23, #0x0]\n"
+ "add x23, x23, #0x2\n"
+ "str h17, [x22, #0x0]\n"
+ "add x22, x22, #0x2\n"
+ "str h16, [x21, #0x0]\n"
+ "add x21, x21, #0x2\n"
+ "bne 106b\n"
+ "add %x[in_ptr], x20, #0x180\n"
+ "107:" // Accumulate: Height 8: no oddments
+ "subs %x[rows], %x[rows], #0x8\n"
+ "add %x[out_ptr], %x[out_ptr], x11\n"
+ "bgt 67b\n"
+ "108:" // Exit
+ : [in_ptr] "+&r" (in_ptr), [out_ptr] "+&r" (out_ptr), [rows] "+&r" (rows)
+ : [accumulate] "r" (accumulate), [bias] "r" (bias), [cols] "r" (cols), [ldout] "r" (ldout), [maxval] "r" (maxval), [minval] "r" (minval)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+}
+
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/merges/list-sve.hpp b/src/core/NEON/kernels/arm_gemm/merges/list-sve.hpp
index aded4b3b8..d11740e5c 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/list-sve.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/list-sve.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,5 +24,6 @@
#include "sve_merge_fp16_3VLx8.hpp"
#include "sve_merge_fp32_3VLx8.hpp"
+#include "sve_merge_fp32_bf16_8x3VL.hpp"
#include "sve_merge_s32_3VLx8.hpp"
#include "sve_merge_u32_3VLx8.hpp" \ No newline at end of file
diff --git a/src/core/NEON/kernels/arm_gemm/merges/list.hpp b/src/core/NEON/kernels/arm_gemm/merges/list.hpp
index 3443c6f0a..fd6be5b69 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/list.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/list.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "a32_merge_float_8x6.hpp"
#include "a64_merge_fp32_12x8.hpp"
+#include "a64_merge_fp32_bf16_8x12.hpp"
#include "a64_merge_s32_12x8.hpp"
#include "a64_merge_s32_4x4.hpp"
#include "a64_merge_u32_12x8.hpp"
diff --git a/src/core/NEON/kernels/arm_gemm/merges/sve_merge_fp32_bf16_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/merges/sve_merge_fp32_bf16_8x3VL.hpp
new file mode 100644
index 000000000..5d4a8bf34
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/merges/sve_merge_fp32_bf16_8x3VL.hpp
@@ -0,0 +1,2137 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#ifdef ARM_COMPUTE_ENABLE_SVE
+
+template<>
+void MergeResults<3, 8, true>(
+ bfloat16 *out_ptr,
+ const float * in_ptr,
+ const int ldout,
+ const int y0, const int ymax,
+ const int x0, const int xmax,
+ const bfloat16 *bias,
+ Activation act,
+ bool accumulate)
+{
+ float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
+ float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
+
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0;
+ break;
+ }
+
+ size_t rows = ymax-y0;
+ size_t cols = xmax-x0;
+
+ out_ptr += (y0 * ldout) + x0;
+ bias = (bias == nullptr) ? nullptr : bias + x0;
+
+ __asm__ __volatile__(
+ "ptrue p3.b\n"
+ "cbz %x[cols], 52f\n"
+ "cbz %x[rows], 52f\n"
+ "mov x12, #0x20\n"
+ "dup z12.s, %w[maxval]\n"
+ "dup z11.s, %w[minval]\n"
+ "mul x12, %x[ldout], x12\n"
+ "cbnz %x[accumulate], 34f\n"
+ "1:" // Initial: Row loop
+ "cmp %x[rows], #0x7\n"
+ "bgt 30f\n"
+ "beq 26f\n"
+ "cmp %x[rows], #0x5\n"
+ "bgt 22f\n"
+ "beq 18f\n"
+ "cmp %x[rows], #0x3\n"
+ "bgt 14f\n"
+ "beq 10f\n"
+ "cmp %x[rows], #0x1\n"
+ "bgt 6f\n"
+ "2:" // Initial: Height 1
+ "mov x11, %x[cols]\n"
+ "mov x10, %x[out_ptr]\n"
+ "mov x9, %x[bias]\n"
+ "3:" // Initial: Height 1: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 4f\n"
+ "mov z21.b, #0x0\n"
+ "mov z20.b, #0x0\n"
+ "mov z19.b, #0x0\n"
+ "b 5f\n"
+ "4:" // Initial: Height 1: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z21.s, z18.s, #0x10\n"
+ "lsl z20.s, z17.s, #0x10\n"
+ "lsl z19.s, z16.s, #0x10\n"
+ "5:" // Initial: Height 1: Width 3: init done
+ "ld1w { z17.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z16.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z18.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "fadd z17.s, z17.s, z21.s\n"
+ "fadd z16.s, z16.s, z20.s\n"
+ "cmp x11, XZR\n"
+ "fadd z18.s, z18.s, z19.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ "st1h { z17.s }, p2, [x10]\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aae50 // bfcvt z16.h, p3/M, z18.s\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ "bgt 3b\n"
+ "b 52f\n"
+ "6:" // Initial: Height 2
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "7:" // Initial: Height 2: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 8f\n"
+ "mov z24.b, #0x0\n"
+ "mov z23.b, #0x0\n"
+ "mov z22.b, #0x0\n"
+ "b 9f\n"
+ "8:" // Initial: Height 2: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z24.s, z18.s, #0x10\n"
+ "lsl z23.s, z17.s, #0x10\n"
+ "lsl z22.s, z16.s, #0x10\n"
+ "9:" // Initial: Height 2: Width 3: init done
+ "ld1w { z17.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z16.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z19.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z21.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z20.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "fadd z17.s, z17.s, z24.s\n"
+ "fadd z16.s, z16.s, z23.s\n"
+ "cmp x11, XZR\n"
+ "fadd z19.s, z19.s, z22.s\n"
+ "fadd z18.s, z18.s, z24.s\n"
+ "fadd z21.s, z21.s, z23.s\n"
+ "fadd z20.s, z20.s, z22.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "st1h { z17.s }, p2, [x10]\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aaeb1 // bfcvt z17.h, p3/M, z21.s\n"
+ ".inst 0x658aae90 // bfcvt z16.h, p3/M, z20.s\n"
+ "st1h { z19.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x28]\n"
+ "st1h { z17.s }, p1, [x28, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "bgt 7b\n"
+ "b 52f\n"
+ "10:" // Initial: Height 3
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "11:" // Initial: Height 3: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 12f\n"
+ "mov z27.b, #0x0\n"
+ "mov z26.b, #0x0\n"
+ "mov z25.b, #0x0\n"
+ "b 13f\n"
+ "12:" // Initial: Height 3: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z27.s, z18.s, #0x10\n"
+ "lsl z26.s, z17.s, #0x10\n"
+ "lsl z25.s, z16.s, #0x10\n"
+ "13:" // Initial: Height 3: Width 3: init done
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z17.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z16.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z20.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z19.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z23.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z18.s, z18.s, z27.s\n"
+ "fadd z17.s, z17.s, z26.s\n"
+ "ld1w { z22.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "fadd z16.s, z16.s, z25.s\n"
+ "fadd z21.s, z21.s, z27.s\n"
+ "cmp x11, XZR\n"
+ "fadd z20.s, z20.s, z26.s\n"
+ "fadd z19.s, z19.s, z25.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "fadd z24.s, z24.s, z27.s\n"
+ "fadd z23.s, z23.s, z26.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fadd z22.s, z22.s, z25.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "st1h { z18.s }, p2, [x10]\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aaf12 // bfcvt z18.h, p3/M, z24.s\n"
+ "st1h { z17.s }, p1, [x10, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aaef1 // bfcvt z17.h, p3/M, z23.s\n"
+ ".inst 0x658aaed0 // bfcvt z16.h, p3/M, z22.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x28]\n"
+ "st1h { z20.s }, p1, [x28, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x27]\n"
+ "st1h { z17.s }, p1, [x27, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "bgt 11b\n"
+ "b 52f\n"
+ "14:" // Initial: Height 4
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "15:" // Initial: Height 4: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 16f\n"
+ "mov z30.b, #0x0\n"
+ "mov z29.b, #0x0\n"
+ "mov z28.b, #0x0\n"
+ "b 17f\n"
+ "16:" // Initial: Height 4: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z30.s, z18.s, #0x10\n"
+ "lsl z29.s, z17.s, #0x10\n"
+ "lsl z28.s, z16.s, #0x10\n"
+ "17:" // Initial: Height 4: Width 3: init done
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z17.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z16.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z23.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z22.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z20.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z18.s, z18.s, z30.s\n"
+ "fadd z17.s, z17.s, z29.s\n"
+ "ld1w { z19.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z27.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "fadd z16.s, z16.s, z28.s\n"
+ "fadd z24.s, z24.s, z30.s\n"
+ "ld1w { z26.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fadd z23.s, z23.s, z29.s\n"
+ "fadd z22.s, z22.s, z28.s\n"
+ "fadd z21.s, z21.s, z30.s\n"
+ "fadd z20.s, z20.s, z29.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fadd z19.s, z19.s, z28.s\n"
+ "fadd z27.s, z27.s, z30.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fadd z26.s, z26.s, z29.s\n"
+ "fadd z25.s, z25.s, z28.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ ".inst 0x658aaf18 // bfcvt z24.h, p3/M, z24.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ ".inst 0x658aaef7 // bfcvt z23.h, p3/M, z23.s\n"
+ ".inst 0x658aaed6 // bfcvt z22.h, p3/M, z22.s\n"
+ "cmp x11, XZR\n"
+ "st1h { z18.s }, p2, [x10]\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "st1h { z17.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aaf72 // bfcvt z18.h, p3/M, z27.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aaf51 // bfcvt z17.h, p3/M, z26.s\n"
+ ".inst 0x658aaf30 // bfcvt z16.h, p3/M, z25.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z24.s }, p2, [x28]\n"
+ "st1h { z23.s }, p1, [x28, #1, MUL VL]\n"
+ "st1h { z22.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x27]\n"
+ "st1h { z20.s }, p1, [x27, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x26]\n"
+ "st1h { z17.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "bgt 15b\n"
+ "b 52f\n"
+ "18:" // Initial: Height 5
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "19:" // Initial: Height 5: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 20f\n"
+ "mov z1.b, #0x0\n"
+ "mov z0.b, #0x0\n"
+ "mov z31.b, #0x0\n"
+ "b 21f\n"
+ "20:" // Initial: Height 5: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z1.s, z18.s, #0x10\n"
+ "lsl z0.s, z17.s, #0x10\n"
+ "lsl z31.s, z16.s, #0x10\n"
+ "21:" // Initial: Height 5: Width 3: init done
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z20.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z19.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z17.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z16.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z23.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z21.s, z21.s, z1.s\n"
+ "fadd z20.s, z20.s, z0.s\n"
+ "ld1w { z22.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z30.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "fadd z19.s, z19.s, z31.s\n"
+ "fadd z18.s, z18.s, z1.s\n"
+ "ld1w { z29.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "ld1w { z28.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fadd z17.s, z17.s, z0.s\n"
+ "fadd z16.s, z16.s, z31.s\n"
+ "ld1w { z27.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "ld1w { z26.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "fadd z24.s, z24.s, z1.s\n"
+ "fadd z23.s, z23.s, z0.s\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "fadd z22.s, z22.s, z31.s\n"
+ "fadd z30.s, z30.s, z1.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fadd z29.s, z29.s, z0.s\n"
+ "fadd z28.s, z28.s, z31.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fadd z27.s, z27.s, z1.s\n"
+ "fadd z26.s, z26.s, z0.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fadd z25.s, z25.s, z31.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "st1h { z21.s }, p2, [x10]\n"
+ ".inst 0x658aaf18 // bfcvt z24.h, p3/M, z24.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "cmp x11, XZR\n"
+ "st1h { z20.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aaef7 // bfcvt z23.h, p3/M, z23.s\n"
+ "st1h { z19.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aaed6 // bfcvt z22.h, p3/M, z22.s\n"
+ ".inst 0x658aafd5 // bfcvt z21.h, p3/M, z30.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x28]\n"
+ ".inst 0x658aafb4 // bfcvt z20.h, p3/M, z29.s\n"
+ ".inst 0x658aaf93 // bfcvt z19.h, p3/M, z28.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z17.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaf72 // bfcvt z18.h, p3/M, z27.s\n"
+ ".inst 0x658aaf51 // bfcvt z17.h, p3/M, z26.s\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ ".inst 0x658aaf30 // bfcvt z16.h, p3/M, z25.s\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z24.s }, p2, [x27]\n"
+ "st1h { z23.s }, p1, [x27, #1, MUL VL]\n"
+ "st1h { z22.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x26]\n"
+ "st1h { z20.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x25]\n"
+ "st1h { z17.s }, p1, [x25, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "bgt 19b\n"
+ "b 52f\n"
+ "22:" // Initial: Height 6
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "23:" // Initial: Height 6: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 24f\n"
+ "mov z4.b, #0x0\n"
+ "mov z3.b, #0x0\n"
+ "mov z2.b, #0x0\n"
+ "b 25f\n"
+ "24:" // Initial: Height 6: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z4.s, z18.s, #0x10\n"
+ "lsl z3.s, z17.s, #0x10\n"
+ "lsl z2.s, z16.s, #0x10\n"
+ "25:" // Initial: Height 6: Width 3: init done
+ "ld1w { z17.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z16.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z21.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z20.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z19.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z18.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z1.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z0.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z17.s, z17.s, z4.s\n"
+ "fadd z16.s, z16.s, z3.s\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "fadd z21.s, z21.s, z2.s\n"
+ "fadd z20.s, z20.s, z4.s\n"
+ "ld1w { z23.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "ld1w { z22.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fadd z19.s, z19.s, z3.s\n"
+ "fadd z18.s, z18.s, z2.s\n"
+ "ld1w { z31.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "ld1w { z30.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "fadd z1.s, z1.s, z4.s\n"
+ "fadd z0.s, z0.s, z3.s\n"
+ "ld1w { z29.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "ld1w { z28.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "fadd z25.s, z25.s, z2.s\n"
+ "fadd z24.s, z24.s, z4.s\n"
+ "ld1w { z27.s }, p1/Z, [x20]\n"
+ "ld1w { z26.s }, p0/Z, [x20, #1, MUL VL]\n"
+ "fadd z23.s, z23.s, z3.s\n"
+ "fadd z22.s, z22.s, z2.s\n"
+ "fadd z31.s, z31.s, z4.s\n"
+ "fadd z30.s, z30.s, z3.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fadd z29.s, z29.s, z2.s\n"
+ "fadd z28.s, z28.s, z4.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fadd z27.s, z27.s, z3.s\n"
+ "fadd z26.s, z26.s, z2.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "st1h { z17.s }, p2, [x10]\n"
+ ".inst 0x658aac31 // bfcvt z17.h, p3/M, z1.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aac10 // bfcvt z16.h, p3/M, z0.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "st1h { z21.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aaf39 // bfcvt z25.h, p3/M, z25.s\n"
+ "cmp x11, XZR\n"
+ "st1h { z20.s }, p2, [x28]\n"
+ ".inst 0x658aaf18 // bfcvt z24.h, p3/M, z24.s\n"
+ ".inst 0x658aaef7 // bfcvt z23.h, p3/M, z23.s\n"
+ "st1h { z19.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaed6 // bfcvt z22.h, p3/M, z22.s\n"
+ ".inst 0x658aaff5 // bfcvt z21.h, p3/M, z31.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z18.s }, p0, [x28, #2, MUL VL]\n"
+ ".inst 0x658aafd4 // bfcvt z20.h, p3/M, z30.s\n"
+ ".inst 0x658aafb3 // bfcvt z19.h, p3/M, z29.s\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z17.s }, p2, [x27]\n"
+ ".inst 0x658aaf92 // bfcvt z18.h, p3/M, z28.s\n"
+ ".inst 0x658aaf71 // bfcvt z17.h, p3/M, z27.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z16.s }, p1, [x27, #1, MUL VL]\n"
+ ".inst 0x658aaf50 // bfcvt z16.h, p3/M, z26.s\n"
+ "st1h { z25.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z24.s }, p2, [x26]\n"
+ "st1h { z23.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z22.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x25]\n"
+ "st1h { z20.s }, p1, [x25, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x24]\n"
+ "st1h { z17.s }, p1, [x24, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "bgt 23b\n"
+ "b 52f\n"
+ "26:" // Initial: Height 7
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "27:" // Initial: Height 7: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 28f\n"
+ "mov z7.b, #0x0\n"
+ "mov z6.b, #0x0\n"
+ "mov z5.b, #0x0\n"
+ "b 29f\n"
+ "28:" // Initial: Height 7: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z7.s, z18.s, #0x10\n"
+ "lsl z6.s, z17.s, #0x10\n"
+ "lsl z5.s, z16.s, #0x10\n"
+ "29:" // Initial: Height 7: Width 3: init done
+ "ld1w { z19.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z18.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z17.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z16.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z21.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z20.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z4.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z3.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z19.s, z19.s, z7.s\n"
+ "fadd z18.s, z18.s, z6.s\n"
+ "ld1w { z2.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z1.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "fadd z17.s, z17.s, z5.s\n"
+ "fadd z16.s, z16.s, z7.s\n"
+ "ld1w { z26.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fadd z21.s, z21.s, z6.s\n"
+ "fadd z20.s, z20.s, z5.s\n"
+ "ld1w { z24.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "ld1w { z23.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "fadd z4.s, z4.s, z7.s\n"
+ "fadd z3.s, z3.s, z6.s\n"
+ "ld1w { z22.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "ld1w { z0.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "fadd z2.s, z2.s, z5.s\n"
+ "fadd z1.s, z1.s, z7.s\n"
+ "ld1w { z31.s }, p1/Z, [x20]\n"
+ "ld1w { z30.s }, p0/Z, [x20, #1, MUL VL]\n"
+ "fadd z26.s, z26.s, z6.s\n"
+ "fadd z25.s, z25.s, z5.s\n"
+ "ld1w { z29.s }, p2/Z, [x20, #2, MUL VL]\n"
+ "ld1w { z28.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "fadd z24.s, z24.s, z7.s\n"
+ "fadd z23.s, z23.s, z6.s\n"
+ "ld1w { z27.s }, p0/Z, [x20, #4, MUL VL]\n"
+ "fadd z22.s, z22.s, z5.s\n"
+ "fadd z0.s, z0.s, z7.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fadd z31.s, z31.s, z6.s\n"
+ "fadd z30.s, z30.s, z5.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fadd z29.s, z29.s, z7.s\n"
+ "fadd z28.s, z28.s, z6.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fadd z27.s, z27.s, z5.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z4.s, p3/M, z4.s, z12.s\n"
+ "fmin z3.s, p3/M, z3.s, z12.s\n"
+ "fmin z2.s, p3/M, z2.s, z12.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z4.s, p3/M, z4.s, z11.s\n"
+ "fmax z3.s, p3/M, z3.s, z11.s\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "fmax z2.s, p3/M, z2.s, z11.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "st1h { z19.s }, p2, [x10]\n"
+ ".inst 0x658aac93 // bfcvt z19.h, p3/M, z4.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "st1h { z18.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aac72 // bfcvt z18.h, p3/M, z3.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "st1h { z17.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aac51 // bfcvt z17.h, p3/M, z2.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "st1h { z16.s }, p2, [x28]\n"
+ ".inst 0x658aac30 // bfcvt z16.h, p3/M, z1.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "cmp x11, XZR\n"
+ "st1h { z21.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaf5a // bfcvt z26.h, p3/M, z26.s\n"
+ "st1h { z20.s }, p0, [x28, #2, MUL VL]\n"
+ ".inst 0x658aaf39 // bfcvt z25.h, p3/M, z25.s\n"
+ ".inst 0x658aaf18 // bfcvt z24.h, p3/M, z24.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z19.s }, p2, [x27]\n"
+ ".inst 0x658aaef7 // bfcvt z23.h, p3/M, z23.s\n"
+ ".inst 0x658aaed6 // bfcvt z22.h, p3/M, z22.s\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z18.s }, p1, [x27, #1, MUL VL]\n"
+ ".inst 0x658aac15 // bfcvt z21.h, p3/M, z0.s\n"
+ ".inst 0x658aaff4 // bfcvt z20.h, p3/M, z31.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z17.s }, p0, [x27, #2, MUL VL]\n"
+ ".inst 0x658aafd3 // bfcvt z19.h, p3/M, z30.s\n"
+ ".inst 0x658aafb2 // bfcvt z18.h, p3/M, z29.s\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z16.s }, p2, [x26]\n"
+ ".inst 0x658aaf91 // bfcvt z17.h, p3/M, z28.s\n"
+ ".inst 0x658aaf70 // bfcvt z16.h, p3/M, z27.s\n"
+ "st1h { z26.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z25.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z24.s }, p2, [x25]\n"
+ "st1h { z23.s }, p1, [x25, #1, MUL VL]\n"
+ "st1h { z22.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x24]\n"
+ "st1h { z20.s }, p1, [x24, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x23]\n"
+ "st1h { z17.s }, p1, [x23, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x23, #2, MUL VL]\n"
+ "inch x23, ALL, MUL #3\n"
+ "bgt 27b\n"
+ "b 52f\n"
+ "30:" // Initial: Height 8
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "mov x9, %x[bias]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "31:" // Initial: Height 8: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "whilelt p0.s, x21, x11\n"
+ "incw x21\n"
+ "cbnz %x[bias], 32f\n"
+ "mov z10.b, #0x0\n"
+ "mov z9.b, #0x0\n"
+ "mov z8.b, #0x0\n"
+ "b 33f\n"
+ "32:" // Initial: Height 8: Width 3: bias
+ "ld1h { z18.s }, p2/Z, [x9]\n"
+ "ld1h { z17.s }, p1/Z, [x9, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x9, #2, MUL VL]\n"
+ "lsl z10.s, z18.s, #0x10\n"
+ "lsl z9.s, z17.s, #0x10\n"
+ "lsl z8.s, z16.s, #0x10\n"
+ "33:" // Initial: Height 8: Width 3: init done
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z20.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "inch x9, ALL, MUL #3\n"
+ "ld1w { z19.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z17.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z16.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z7.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z6.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z21.s, z21.s, z10.s\n"
+ "fadd z20.s, z20.s, z9.s\n"
+ "ld1w { z5.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z4.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "fadd z19.s, z19.s, z8.s\n"
+ "fadd z18.s, z18.s, z10.s\n"
+ "ld1w { z3.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "ld1w { z2.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fadd z17.s, z17.s, z9.s\n"
+ "fadd z16.s, z16.s, z8.s\n"
+ "ld1w { z27.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "ld1w { z26.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "fadd z7.s, z7.s, z10.s\n"
+ "fadd z6.s, z6.s, z9.s\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "fadd z5.s, z5.s, z8.s\n"
+ "fadd z4.s, z4.s, z10.s\n"
+ "ld1w { z23.s }, p1/Z, [x20]\n"
+ "ld1w { z22.s }, p0/Z, [x20, #1, MUL VL]\n"
+ "fadd z3.s, z3.s, z9.s\n"
+ "fadd z2.s, z2.s, z8.s\n"
+ "ld1w { z1.s }, p2/Z, [x20, #2, MUL VL]\n"
+ "ld1w { z0.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "fadd z27.s, z27.s, z10.s\n"
+ "fadd z26.s, z26.s, z9.s\n"
+ "ld1w { z31.s }, p0/Z, [x20, #4, MUL VL]\n"
+ "ld1w { z30.s }, p2/Z, [x20, #5, MUL VL]\n"
+ "fadd z25.s, z25.s, z8.s\n"
+ "fadd z24.s, z24.s, z10.s\n"
+ "ld1w { z29.s }, p1/Z, [x20, #6, MUL VL]\n"
+ "ld1w { z28.s }, p0/Z, [x20, #7, MUL VL]\n"
+ "fadd z23.s, z23.s, z9.s\n"
+ "fadd z22.s, z22.s, z8.s\n"
+ "fadd z1.s, z1.s, z10.s\n"
+ "fadd z0.s, z0.s, z9.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fadd z31.s, z31.s, z8.s\n"
+ "fadd z30.s, z30.s, z10.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fadd z29.s, z29.s, z9.s\n"
+ "fadd z28.s, z28.s, z8.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmin z16.s, p3/M, z16.s, z12.s\n"
+ "fmin z7.s, p3/M, z7.s, z12.s\n"
+ "fmin z6.s, p3/M, z6.s, z12.s\n"
+ "fmin z5.s, p3/M, z5.s, z12.s\n"
+ "fmin z4.s, p3/M, z4.s, z12.s\n"
+ "fmin z3.s, p3/M, z3.s, z12.s\n"
+ "fmin z2.s, p3/M, z2.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "fmax z16.s, p3/M, z16.s, z11.s\n"
+ "fmax z7.s, p3/M, z7.s, z11.s\n"
+ "fmax z6.s, p3/M, z6.s, z11.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aae94 // bfcvt z20.h, p3/M, z20.s\n"
+ "fmax z5.s, p3/M, z5.s, z11.s\n"
+ "fmax z4.s, p3/M, z4.s, z11.s\n"
+ ".inst 0x658aae73 // bfcvt z19.h, p3/M, z19.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "fmax z3.s, p3/M, z3.s, z11.s\n"
+ "fmax z2.s, p3/M, z2.s, z11.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ ".inst 0x658aae10 // bfcvt z16.h, p3/M, z16.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "st1h { z21.s }, p2, [x10]\n"
+ ".inst 0x658aacf5 // bfcvt z21.h, p3/M, z7.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "st1h { z20.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aacd4 // bfcvt z20.h, p3/M, z6.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "st1h { z19.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aacb3 // bfcvt z19.h, p3/M, z5.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "st1h { z18.s }, p2, [x28]\n"
+ ".inst 0x658aac92 // bfcvt z18.h, p3/M, z4.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "st1h { z17.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aac71 // bfcvt z17.h, p3/M, z3.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ ".inst 0x658aac50 // bfcvt z16.h, p3/M, z2.s\n"
+ "cmp x11, XZR\n"
+ "st1h { z21.s }, p2, [x27]\n"
+ ".inst 0x658aaf7b // bfcvt z27.h, p3/M, z27.s\n"
+ ".inst 0x658aaf5a // bfcvt z26.h, p3/M, z26.s\n"
+ "st1h { z20.s }, p1, [x27, #1, MUL VL]\n"
+ ".inst 0x658aaf39 // bfcvt z25.h, p3/M, z25.s\n"
+ ".inst 0x658aaf18 // bfcvt z24.h, p3/M, z24.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z19.s }, p0, [x27, #2, MUL VL]\n"
+ ".inst 0x658aaef7 // bfcvt z23.h, p3/M, z23.s\n"
+ ".inst 0x658aaed6 // bfcvt z22.h, p3/M, z22.s\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x26]\n"
+ ".inst 0x658aac35 // bfcvt z21.h, p3/M, z1.s\n"
+ ".inst 0x658aac14 // bfcvt z20.h, p3/M, z0.s\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z17.s }, p1, [x26, #1, MUL VL]\n"
+ ".inst 0x658aaff3 // bfcvt z19.h, p3/M, z31.s\n"
+ ".inst 0x658aafd2 // bfcvt z18.h, p3/M, z30.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z16.s }, p0, [x26, #2, MUL VL]\n"
+ ".inst 0x658aafb1 // bfcvt z17.h, p3/M, z29.s\n"
+ ".inst 0x658aaf90 // bfcvt z16.h, p3/M, z28.s\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z27.s }, p2, [x25]\n"
+ "st1h { z26.s }, p1, [x25, #1, MUL VL]\n"
+ "st1h { z25.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z24.s }, p2, [x24]\n"
+ "st1h { z23.s }, p1, [x24, #1, MUL VL]\n"
+ "st1h { z22.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "st1h { z21.s }, p2, [x23]\n"
+ "st1h { z20.s }, p1, [x23, #1, MUL VL]\n"
+ "st1h { z19.s }, p0, [x23, #2, MUL VL]\n"
+ "inch x23, ALL, MUL #3\n"
+ "st1h { z18.s }, p2, [x22]\n"
+ "st1h { z17.s }, p1, [x22, #1, MUL VL]\n"
+ "st1h { z16.s }, p0, [x22, #2, MUL VL]\n"
+ "inch x22, ALL, MUL #3\n"
+ "bgt 31b\n"
+ "subs %x[rows], %x[rows], #0x8\n"
+ "add %x[out_ptr], %x[out_ptr], x12\n"
+ "bgt 1b\n"
+ "b 52f\n"
+ "34:" // Accumulate
+ "35:" // Accumulate: Row loop
+ "cmp %x[rows], #0x7\n"
+ "bgt 50f\n"
+ "beq 48f\n"
+ "cmp %x[rows], #0x5\n"
+ "bgt 46f\n"
+ "beq 44f\n"
+ "cmp %x[rows], #0x3\n"
+ "bgt 42f\n"
+ "beq 40f\n"
+ "cmp %x[rows], #0x1\n"
+ "bgt 38f\n"
+ "36:" // Accumulate: Height 1
+ "mov x11, %x[cols]\n"
+ "mov x10, %x[out_ptr]\n"
+ "37:" // Accumulate: Height 1: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z16.s }, p2/Z, [x10]\n"
+ "ld1w { z19.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "fadd z19.s, z19.s, z16.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "ld1w { z18.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ "incw x21\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "ld1w { z17.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "cmp x11, XZR\n"
+ ".inst 0x658aae70 // bfcvt z16.h, p3/M, z19.s\n"
+ "st1h { z16.s }, p2, [x10]\n"
+ "ld1h { z16.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z18.s, z18.s, z16.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ ".inst 0x658aae50 // bfcvt z16.h, p3/M, z18.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z17.s, z17.s, z16.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ ".inst 0x658aae30 // bfcvt z16.h, p3/M, z17.s\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ "bgt 37b\n"
+ "b 52f\n"
+ "38:" // Accumulate: Height 2
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "39:" // Accumulate: Height 2: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z17.s }, p2/Z, [x10]\n"
+ "ld1h { z16.s }, p2/Z, [x28]\n"
+ "ld1w { z23.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "ld1w { z22.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "fadd z23.s, z23.s, z17.s\n"
+ "fadd z22.s, z22.s, z16.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "ld1w { z21.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z20.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ "incw x21\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "ld1w { z19.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z18.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "cmp x11, XZR\n"
+ ".inst 0x658aaef1 // bfcvt z17.h, p3/M, z23.s\n"
+ ".inst 0x658aaed0 // bfcvt z16.h, p3/M, z22.s\n"
+ "st1h { z17.s }, p2, [x10]\n"
+ "st1h { z16.s }, p2, [x28]\n"
+ "ld1h { z17.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "ld1h { z16.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z21.s, z21.s, z17.s\n"
+ "fadd z20.s, z20.s, z16.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ ".inst 0x658aaeb0 // bfcvt z16.h, p3/M, z21.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aae90 // bfcvt z16.h, p3/M, z20.s\n"
+ "ld1h { z17.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z16.s }, p1, [x28, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z19.s, z19.s, z17.s\n"
+ "fadd z18.s, z18.s, z16.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ ".inst 0x658aae70 // bfcvt z16.h, p3/M, z19.s\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ ".inst 0x658aae50 // bfcvt z16.h, p3/M, z18.s\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "bgt 39b\n"
+ "b 52f\n"
+ "40:" // Accumulate: Height 3
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "41:" // Accumulate: Height 3: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z18.s }, p2/Z, [x10]\n"
+ "ld1h { z17.s }, p2/Z, [x28]\n"
+ "ld1h { z16.s }, p2/Z, [x27]\n"
+ "ld1w { z26.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z19.s, z18.s, #0x10\n"
+ "ld1w { z25.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "ld1w { z18.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z26.s, z26.s, z19.s\n"
+ "fadd z25.s, z25.s, z17.s\n"
+ "ld1w { z24.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z23.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z22.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "fadd z18.s, z18.s, z16.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ "incw x21\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "ld1w { z21.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z20.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "ld1w { z19.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "cmp x11, XZR\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ ".inst 0x658aaf51 // bfcvt z17.h, p3/M, z26.s\n"
+ ".inst 0x658aaf30 // bfcvt z16.h, p3/M, z25.s\n"
+ "st1h { z17.s }, p2, [x10]\n"
+ "st1h { z16.s }, p2, [x28]\n"
+ ".inst 0x658aae51 // bfcvt z17.h, p3/M, z18.s\n"
+ "ld1h { z16.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "st1h { z17.s }, p2, [x27]\n"
+ "ld1h { z17.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "lsl z18.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z24.s, z24.s, z18.s\n"
+ "fadd z23.s, z23.s, z17.s\n"
+ "fadd z22.s, z22.s, z16.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aaf10 // bfcvt z16.h, p3/M, z24.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aaef2 // bfcvt z18.h, p3/M, z23.s\n"
+ ".inst 0x658aaed1 // bfcvt z17.h, p3/M, z22.s\n"
+ "ld1h { z16.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z18.s }, p1, [x28, #1, MUL VL]\n"
+ "st1h { z17.s }, p1, [x27, #1, MUL VL]\n"
+ "ld1h { z17.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "lsl z18.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z21.s, z21.s, z18.s\n"
+ "fadd z20.s, z20.s, z17.s\n"
+ "fadd z19.s, z19.s, z16.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ ".inst 0x658aaeb0 // bfcvt z16.h, p3/M, z21.s\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ ".inst 0x658aae91 // bfcvt z17.h, p3/M, z20.s\n"
+ ".inst 0x658aae70 // bfcvt z16.h, p3/M, z19.s\n"
+ "st1h { z17.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "bgt 41b\n"
+ "b 52f\n"
+ "42:" // Accumulate: Height 4
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "43:" // Accumulate: Height 4: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z19.s }, p2/Z, [x10]\n"
+ "ld1h { z18.s }, p2/Z, [x28]\n"
+ "ld1h { z17.s }, p2/Z, [x27]\n"
+ "ld1h { z16.s }, p2/Z, [x26]\n"
+ "ld1w { z30.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z20.s, z19.s, #0x10\n"
+ "ld1w { z29.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "ld1w { z28.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z19.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z30.s, z30.s, z20.s\n"
+ "fadd z29.s, z29.s, z18.s\n"
+ "ld1w { z27.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z26.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "ld1w { z25.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "ld1w { z24.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ "fadd z28.s, z28.s, z17.s\n"
+ "fadd z19.s, z19.s, z16.s\n"
+ "incw x21\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "ld1w { z23.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z22.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "ld1w { z21.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z20.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "cmp x11, XZR\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z19.s, p3/M, z19.s, z12.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "fmax z19.s, p3/M, z19.s, z11.s\n"
+ ".inst 0x658aafd2 // bfcvt z18.h, p3/M, z30.s\n"
+ ".inst 0x658aafb1 // bfcvt z17.h, p3/M, z29.s\n"
+ ".inst 0x658aaf90 // bfcvt z16.h, p3/M, z28.s\n"
+ "st1h { z18.s }, p2, [x10]\n"
+ "st1h { z17.s }, p2, [x28]\n"
+ ".inst 0x658aae71 // bfcvt z17.h, p3/M, z19.s\n"
+ "st1h { z16.s }, p2, [x27]\n"
+ "ld1h { z16.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "st1h { z17.s }, p2, [x26]\n"
+ "ld1h { z18.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "ld1h { z17.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "lsl z19.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p1/Z, [x26, #1, MUL VL]\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "fadd z27.s, z27.s, z19.s\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z26.s, z26.s, z18.s\n"
+ "fadd z25.s, z25.s, z17.s\n"
+ "fadd z24.s, z24.s, z16.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ ".inst 0x658aaf71 // bfcvt z17.h, p3/M, z27.s\n"
+ ".inst 0x658aaf50 // bfcvt z16.h, p3/M, z26.s\n"
+ "st1h { z17.s }, p1, [x10, #1, MUL VL]\n"
+ "st1h { z16.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaf32 // bfcvt z18.h, p3/M, z25.s\n"
+ ".inst 0x658aaf11 // bfcvt z17.h, p3/M, z24.s\n"
+ "ld1h { z16.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z18.s }, p1, [x27, #1, MUL VL]\n"
+ "st1h { z17.s }, p1, [x26, #1, MUL VL]\n"
+ "ld1h { z18.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "lsl z19.s, z16.s, #0x10\n"
+ "ld1h { z17.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x26, #2, MUL VL]\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "fadd z23.s, z23.s, z19.s\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z22.s, z22.s, z18.s\n"
+ "fadd z21.s, z21.s, z17.s\n"
+ "fadd z20.s, z20.s, z16.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ ".inst 0x658aaef1 // bfcvt z17.h, p3/M, z23.s\n"
+ ".inst 0x658aaed0 // bfcvt z16.h, p3/M, z22.s\n"
+ "st1h { z17.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ ".inst 0x658aaeb1 // bfcvt z17.h, p3/M, z21.s\n"
+ ".inst 0x658aae90 // bfcvt z16.h, p3/M, z20.s\n"
+ "st1h { z17.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "bgt 43b\n"
+ "b 52f\n"
+ "44:" // Accumulate: Height 5
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "45:" // Accumulate: Height 5: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z20.s }, p2/Z, [x10]\n"
+ "ld1h { z19.s }, p2/Z, [x28]\n"
+ "ld1h { z18.s }, p2/Z, [x27]\n"
+ "ld1h { z17.s }, p2/Z, [x26]\n"
+ "ld1h { z16.s }, p2/Z, [x25]\n"
+ "ld1w { z1.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z22.s, z20.s, #0x10\n"
+ "ld1w { z0.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z21.s, z19.s, #0x10\n"
+ "ld1w { z31.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "whilelt p1.s, x21, x11\n"
+ "lsl z19.s, z18.s, #0x10\n"
+ "ld1w { z20.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "lsl z18.s, z17.s, #0x10\n"
+ "ld1w { z17.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z1.s, z1.s, z22.s\n"
+ "incw x21\n"
+ "fadd z0.s, z0.s, z21.s\n"
+ "ld1w { z30.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z29.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "fadd z31.s, z31.s, z19.s\n"
+ "fadd z20.s, z20.s, z18.s\n"
+ "ld1w { z28.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "ld1w { z27.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "fadd z17.s, z17.s, z16.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "ld1w { z26.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmin z20.s, p3/M, z20.s, z12.s\n"
+ "fmin z17.s, p3/M, z17.s, z12.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "ld1w { z25.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z24.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "ld1w { z23.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z22.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fmax z20.s, p3/M, z20.s, z11.s\n"
+ "fmax z17.s, p3/M, z17.s, z11.s\n"
+ "ld1w { z21.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ ".inst 0x658aac30 // bfcvt z16.h, p3/M, z1.s\n"
+ "cmp x11, XZR\n"
+ "incw x21\n"
+ ".inst 0x658aac13 // bfcvt z19.h, p3/M, z0.s\n"
+ ".inst 0x658aaff2 // bfcvt z18.h, p3/M, z31.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z16.s }, p2, [x10]\n"
+ ".inst 0x658aae90 // bfcvt z16.h, p3/M, z20.s\n"
+ ".inst 0x658aae31 // bfcvt z17.h, p3/M, z17.s\n"
+ "st1h { z19.s }, p2, [x28]\n"
+ "st1h { z18.s }, p2, [x27]\n"
+ "st1h { z16.s }, p2, [x26]\n"
+ "ld1h { z16.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "st1h { z17.s }, p2, [x25]\n"
+ "ld1h { z19.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "ld1h { z18.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "ld1h { z17.s }, p1/Z, [x26, #1, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p1/Z, [x25, #1, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z30.s, z30.s, z20.s\n"
+ "fadd z29.s, z29.s, z19.s\n"
+ "fadd z28.s, z28.s, z18.s\n"
+ "fadd z27.s, z27.s, z17.s\n"
+ "fadd z26.s, z26.s, z16.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ ".inst 0x658aafd2 // bfcvt z18.h, p3/M, z30.s\n"
+ ".inst 0x658aafb1 // bfcvt z17.h, p3/M, z29.s\n"
+ ".inst 0x658aaf90 // bfcvt z16.h, p3/M, z28.s\n"
+ "st1h { z18.s }, p1, [x10, #1, MUL VL]\n"
+ "st1h { z17.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaf72 // bfcvt z18.h, p3/M, z27.s\n"
+ ".inst 0x658aaf51 // bfcvt z17.h, p3/M, z26.s\n"
+ "st1h { z16.s }, p1, [x27, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z18.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z17.s }, p1, [x25, #1, MUL VL]\n"
+ "ld1h { z19.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "ld1h { z18.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z17.s }, p0/Z, [x26, #2, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x25, #2, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z25.s, z25.s, z20.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z24.s, z24.s, z19.s\n"
+ "fadd z23.s, z23.s, z18.s\n"
+ "fadd z22.s, z22.s, z17.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fadd z21.s, z21.s, z16.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aaf31 // bfcvt z17.h, p3/M, z25.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ ".inst 0x658aaf10 // bfcvt z16.h, p3/M, z24.s\n"
+ "st1h { z17.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ ".inst 0x658aaef2 // bfcvt z18.h, p3/M, z23.s\n"
+ ".inst 0x658aaed1 // bfcvt z17.h, p3/M, z22.s\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ ".inst 0x658aaeb0 // bfcvt z16.h, p3/M, z21.s\n"
+ "st1h { z18.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z17.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "bgt 45b\n"
+ "b 52f\n"
+ "46:" // Accumulate: Height 6
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "47:" // Accumulate: Height 6: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z21.s }, p2/Z, [x10]\n"
+ "ld1h { z20.s }, p2/Z, [x28]\n"
+ "ld1h { z19.s }, p2/Z, [x27]\n"
+ "ld1h { z18.s }, p2/Z, [x26]\n"
+ "ld1h { z17.s }, p2/Z, [x25]\n"
+ "ld1h { z16.s }, p2/Z, [x24]\n"
+ "ld1w { z6.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z22.s, z21.s, #0x10\n"
+ "ld1w { z5.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z21.s, z20.s, #0x10\n"
+ "ld1w { z4.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "lsl z20.s, z19.s, #0x10\n"
+ "ld1w { z3.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "lsl z19.s, z18.s, #0x10\n"
+ "ld1w { z2.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "ld1w { z18.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z6.s, z6.s, z22.s\n"
+ "fadd z5.s, z5.s, z21.s\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "fadd z4.s, z4.s, z20.s\n"
+ "fadd z3.s, z3.s, z19.s\n"
+ "fadd z2.s, z2.s, z17.s\n"
+ "fadd z18.s, z18.s, z16.s\n"
+ "fmin z6.s, p3/M, z6.s, z12.s\n"
+ "fmin z5.s, p3/M, z5.s, z12.s\n"
+ "ld1w { z1.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z0.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ "fmin z4.s, p3/M, z4.s, z12.s\n"
+ "fmin z3.s, p3/M, z3.s, z12.s\n"
+ "ld1w { z31.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "ld1w { z30.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "fmin z2.s, p3/M, z2.s, z12.s\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "ld1w { z29.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "ld1w { z28.s }, p1/Z, [x20]\n"
+ "fmax z6.s, p3/M, z6.s, z11.s\n"
+ "fmax z5.s, p3/M, z5.s, z11.s\n"
+ "ld1w { z27.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z26.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "fmax z4.s, p3/M, z4.s, z11.s\n"
+ "fmax z3.s, p3/M, z3.s, z11.s\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "ld1w { z24.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "fmax z2.s, p3/M, z2.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "ld1w { z23.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "ld1w { z22.s }, p0/Z, [x20, #1, MUL VL]\n"
+ ".inst 0x658aacd5 // bfcvt z21.h, p3/M, z6.s\n"
+ ".inst 0x658aacb4 // bfcvt z20.h, p3/M, z5.s\n"
+ "cmp x11, XZR\n"
+ "incw x21\n"
+ ".inst 0x658aac93 // bfcvt z19.h, p3/M, z4.s\n"
+ ".inst 0x658aac71 // bfcvt z17.h, p3/M, z3.s\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ ".inst 0x658aac50 // bfcvt z16.h, p3/M, z2.s\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ "st1h { z21.s }, p2, [x10]\n"
+ "st1h { z20.s }, p2, [x28]\n"
+ "st1h { z19.s }, p2, [x27]\n"
+ "st1h { z17.s }, p2, [x26]\n"
+ "ld1h { z17.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "st1h { z16.s }, p2, [x25]\n"
+ "ld1h { z16.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "st1h { z18.s }, p2, [x24]\n"
+ "ld1h { z19.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "ld1h { z18.s }, p1/Z, [x26, #1, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1h { z17.s }, p1/Z, [x25, #1, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z1.s, z1.s, z21.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z0.s, z0.s, z20.s\n"
+ "fadd z31.s, z31.s, z19.s\n"
+ "fadd z30.s, z30.s, z18.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fadd z29.s, z29.s, z17.s\n"
+ "fadd z28.s, z28.s, z16.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ ".inst 0x658aac34 // bfcvt z20.h, p3/M, z1.s\n"
+ ".inst 0x658aac12 // bfcvt z18.h, p3/M, z0.s\n"
+ ".inst 0x658aaff3 // bfcvt z19.h, p3/M, z31.s\n"
+ ".inst 0x658aafd1 // bfcvt z17.h, p3/M, z30.s\n"
+ ".inst 0x658aafb0 // bfcvt z16.h, p3/M, z29.s\n"
+ "st1h { z20.s }, p1, [x10, #1, MUL VL]\n"
+ "st1h { z18.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aaf92 // bfcvt z18.h, p3/M, z28.s\n"
+ "st1h { z19.s }, p1, [x27, #1, MUL VL]\n"
+ "st1h { z17.s }, p1, [x26, #1, MUL VL]\n"
+ "ld1h { z17.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z16.s }, p1, [x25, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "st1h { z18.s }, p1, [x24, #1, MUL VL]\n"
+ "ld1h { z19.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "ld1h { z18.s }, p0/Z, [x26, #2, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1h { z17.s }, p0/Z, [x25, #2, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p0/Z, [x24, #2, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z27.s, z27.s, z21.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z26.s, z26.s, z20.s\n"
+ "fadd z25.s, z25.s, z19.s\n"
+ "fadd z24.s, z24.s, z18.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fadd z23.s, z23.s, z17.s\n"
+ "fadd z22.s, z22.s, z16.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aaf74 // bfcvt z20.h, p3/M, z27.s\n"
+ ".inst 0x658aaf50 // bfcvt z16.h, p3/M, z26.s\n"
+ ".inst 0x658aaf33 // bfcvt z19.h, p3/M, z25.s\n"
+ ".inst 0x658aaf12 // bfcvt z18.h, p3/M, z24.s\n"
+ ".inst 0x658aaef1 // bfcvt z17.h, p3/M, z23.s\n"
+ "st1h { z20.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x28, #2, MUL VL]\n"
+ ".inst 0x658aaed0 // bfcvt z16.h, p3/M, z22.s\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z19.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z18.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z17.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "bgt 47b\n"
+ "b 52f\n"
+ "48:" // Accumulate: Height 7
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "49:" // Accumulate: Height 7: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z22.s }, p2/Z, [x10]\n"
+ "ld1h { z21.s }, p2/Z, [x28]\n"
+ "ld1h { z20.s }, p2/Z, [x27]\n"
+ "ld1h { z19.s }, p2/Z, [x26]\n"
+ "ld1h { z18.s }, p2/Z, [x25]\n"
+ "ld1h { z17.s }, p2/Z, [x24]\n"
+ "ld1h { z16.s }, p2/Z, [x23]\n"
+ "ld1w { z8.s }, p2/Z, [%x[in_ptr]]\n"
+ "lsl z25.s, z22.s, #0x10\n"
+ "lsl z24.s, z21.s, #0x10\n"
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "ld1w { z7.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "lsl z20.s, z20.s, #0x10\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "ld1w { z23.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "ld1w { z6.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "ld1w { z5.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "ld1w { z22.s }, p2/Z, [x20, #2, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z8.s, z8.s, z25.s\n"
+ "fadd z21.s, z21.s, z24.s\n"
+ "fadd z7.s, z7.s, z20.s\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "fadd z23.s, z23.s, z19.s\n"
+ "fadd z6.s, z6.s, z18.s\n"
+ "fadd z5.s, z5.s, z17.s\n"
+ "fadd z22.s, z22.s, z16.s\n"
+ "fmin z8.s, p3/M, z8.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z7.s, p3/M, z7.s, z12.s\n"
+ "ld1w { z4.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z3.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "whilelt p0.s, x21, x11\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z6.s, p3/M, z6.s, z12.s\n"
+ "ld1w { z2.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "ld1w { z1.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "fmin z5.s, p3/M, z5.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "ld1w { z0.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "ld1w { z31.s }, p1/Z, [x20]\n"
+ "fmax z8.s, p3/M, z8.s, z11.s\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "ld1w { z30.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "ld1w { z29.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "fmax z7.s, p3/M, z7.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "ld1w { z28.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "ld1w { z27.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "fmax z6.s, p3/M, z6.s, z11.s\n"
+ "fmax z5.s, p3/M, z5.s, z11.s\n"
+ "ld1w { z26.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "ld1w { z25.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aad13 // bfcvt z19.h, p3/M, z8.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ "ld1w { z24.s }, p0/Z, [x20, #1, MUL VL]\n"
+ ".inst 0x658aacf4 // bfcvt z20.h, p3/M, z7.s\n"
+ ".inst 0x658aaef2 // bfcvt z18.h, p3/M, z23.s\n"
+ "ld1w { z23.s }, p0/Z, [x20, #4, MUL VL]\n"
+ "decw x11, ALL, MUL #3\n"
+ ".inst 0x658aacd1 // bfcvt z17.h, p3/M, z6.s\n"
+ ".inst 0x658aacb0 // bfcvt z16.h, p3/M, z5.s\n"
+ "incw x21\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z19.s }, p2, [x10]\n"
+ ".inst 0x658aaed3 // bfcvt z19.h, p3/M, z22.s\n"
+ "st1h { z21.s }, p2, [x28]\n"
+ "cmp x11, XZR\n"
+ "st1h { z20.s }, p2, [x27]\n"
+ "st1h { z18.s }, p2, [x26]\n"
+ "ld1h { z18.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "st1h { z17.s }, p2, [x25]\n"
+ "ld1h { z17.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "st1h { z16.s }, p2, [x24]\n"
+ "ld1h { z16.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "st1h { z19.s }, p2, [x23]\n"
+ "ld1h { z19.s }, p1/Z, [x26, #1, MUL VL]\n"
+ "lsl z22.s, z18.s, #0x10\n"
+ "ld1h { z18.s }, p1/Z, [x25, #1, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1h { z17.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p1/Z, [x23, #1, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z4.s, z4.s, z22.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z3.s, z3.s, z21.s\n"
+ "fadd z2.s, z2.s, z20.s\n"
+ "fadd z1.s, z1.s, z19.s\n"
+ "fadd z0.s, z0.s, z18.s\n"
+ "fadd z31.s, z31.s, z17.s\n"
+ "fmin z4.s, p3/M, z4.s, z12.s\n"
+ "fadd z30.s, z30.s, z16.s\n"
+ "fmin z3.s, p3/M, z3.s, z12.s\n"
+ "fmin z2.s, p3/M, z2.s, z12.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmax z4.s, p3/M, z4.s, z11.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fmax z3.s, p3/M, z3.s, z11.s\n"
+ "fmax z2.s, p3/M, z2.s, z11.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ ".inst 0x658aac90 // bfcvt z16.h, p3/M, z4.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ ".inst 0x658aac74 // bfcvt z20.h, p3/M, z3.s\n"
+ ".inst 0x658aac53 // bfcvt z19.h, p3/M, z2.s\n"
+ ".inst 0x658aac32 // bfcvt z18.h, p3/M, z1.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aac11 // bfcvt z17.h, p3/M, z0.s\n"
+ ".inst 0x658aaff0 // bfcvt z16.h, p3/M, z31.s\n"
+ "st1h { z20.s }, p1, [x28, #1, MUL VL]\n"
+ "st1h { z19.s }, p1, [x27, #1, MUL VL]\n"
+ ".inst 0x658aafd3 // bfcvt z19.h, p3/M, z30.s\n"
+ "st1h { z18.s }, p1, [x26, #1, MUL VL]\n"
+ "ld1h { z18.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z17.s }, p1, [x25, #1, MUL VL]\n"
+ "ld1h { z17.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "st1h { z16.s }, p1, [x24, #1, MUL VL]\n"
+ "ld1h { z16.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "st1h { z19.s }, p1, [x23, #1, MUL VL]\n"
+ "ld1h { z19.s }, p0/Z, [x26, #2, MUL VL]\n"
+ "lsl z22.s, z18.s, #0x10\n"
+ "ld1h { z18.s }, p0/Z, [x25, #2, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1h { z17.s }, p0/Z, [x24, #2, MUL VL]\n"
+ "lsl z20.s, z16.s, #0x10\n"
+ "ld1h { z16.s }, p0/Z, [x23, #2, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z29.s, z29.s, z22.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z28.s, z28.s, z21.s\n"
+ "fadd z27.s, z27.s, z20.s\n"
+ "fadd z26.s, z26.s, z19.s\n"
+ "fadd z25.s, z25.s, z18.s\n"
+ "fadd z24.s, z24.s, z17.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fadd z23.s, z23.s, z16.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ ".inst 0x658aafb1 // bfcvt z17.h, p3/M, z29.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ ".inst 0x658aaf94 // bfcvt z20.h, p3/M, z28.s\n"
+ ".inst 0x658aaf70 // bfcvt z16.h, p3/M, z27.s\n"
+ ".inst 0x658aaf53 // bfcvt z19.h, p3/M, z26.s\n"
+ "st1h { z17.s }, p0, [x10, #2, MUL VL]\n"
+ "inch x10, ALL, MUL #3\n"
+ ".inst 0x658aaf32 // bfcvt z18.h, p3/M, z25.s\n"
+ ".inst 0x658aaf11 // bfcvt z17.h, p3/M, z24.s\n"
+ "st1h { z20.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x27, #2, MUL VL]\n"
+ ".inst 0x658aaef0 // bfcvt z16.h, p3/M, z23.s\n"
+ "inch x27, ALL, MUL #3\n"
+ "st1h { z19.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z18.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z17.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x23, #2, MUL VL]\n"
+ "inch x23, ALL, MUL #3\n"
+ "bgt 49b\n"
+ "b 52f\n"
+ "50:" // Accumulate: Height 8
+ "mov x10, %x[out_ptr]\n"
+ "mov x11, %x[cols]\n"
+ "add x28, x10, %x[ldout], LSL #1\n"
+ "add x27, x28, %x[ldout], LSL #1\n"
+ "add x26, x27, %x[ldout], LSL #1\n"
+ "add x25, x26, %x[ldout], LSL #1\n"
+ "add x24, x25, %x[ldout], LSL #1\n"
+ "add x23, x24, %x[ldout], LSL #1\n"
+ "add x22, x23, %x[ldout], LSL #1\n"
+ "51:" // Accumulate: Height 8: Block loop
+ "mov x21, #0x0\n"
+ "addvl x20, %x[in_ptr], #16\n"
+ "whilelt p2.s, x21, x11\n"
+ "incw x21\n"
+ "ld1h { z23.s }, p2/Z, [x10]\n"
+ "ld1h { z22.s }, p2/Z, [x28]\n"
+ "ld1h { z21.s }, p2/Z, [x27]\n"
+ "ld1h { z20.s }, p2/Z, [x26]\n"
+ "ld1h { z19.s }, p2/Z, [x25]\n"
+ "ld1h { z18.s }, p2/Z, [x24]\n"
+ "ld1h { z17.s }, p2/Z, [x23]\n"
+ "ld1h { z16.s }, p2/Z, [x22]\n"
+ "lsl z31.s, z23.s, #0x10\n"
+ "lsl z30.s, z22.s, #0x10\n"
+ "ld1w { z29.s }, p2/Z, [%x[in_ptr]]\n"
+ "ld1w { z28.s }, p2/Z, [%x[in_ptr], #3, MUL VL]\n"
+ "lsl z27.s, z21.s, #0x10\n"
+ "lsl z26.s, z20.s, #0x10\n"
+ "ld1w { z21.s }, p2/Z, [%x[in_ptr], #6, MUL VL]\n"
+ "ld1w { z25.s }, p2/Z, [x20, #-7, MUL VL]\n"
+ "lsl z20.s, z19.s, #0x10\n"
+ "lsl z19.s, z18.s, #0x10\n"
+ "ld1w { z18.s }, p2/Z, [x20, #-4, MUL VL]\n"
+ "ld1w { z24.s }, p2/Z, [x20, #-1, MUL VL]\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "ld1w { z23.s }, p2/Z, [x20, #2, MUL VL]\n"
+ "ld1w { z22.s }, p2/Z, [x20, #5, MUL VL]\n"
+ "fadd z29.s, z29.s, z31.s\n"
+ "fadd z28.s, z28.s, z30.s\n"
+ "fadd z21.s, z21.s, z27.s\n"
+ "fadd z25.s, z25.s, z26.s\n"
+ "whilelt p1.s, x21, x11\n"
+ "incw x21\n"
+ "fadd z18.s, z18.s, z20.s\n"
+ "fadd z24.s, z24.s, z19.s\n"
+ "fadd z23.s, z23.s, z17.s\n"
+ "fadd z22.s, z22.s, z16.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fmin z21.s, p3/M, z21.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "ld1w { z6.s }, p1/Z, [%x[in_ptr], #1, MUL VL]\n"
+ "ld1w { z5.s }, p1/Z, [%x[in_ptr], #4, MUL VL]\n"
+ "fmin z18.s, p3/M, z18.s, z12.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "ld1w { z4.s }, p1/Z, [%x[in_ptr], #7, MUL VL]\n"
+ "ld1w { z3.s }, p1/Z, [x20, #-6, MUL VL]\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmin z22.s, p3/M, z22.s, z12.s\n"
+ "ld1w { z2.s }, p1/Z, [x20, #-3, MUL VL]\n"
+ "ld1w { z1.s }, p1/Z, [x20]\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "ld1w { z0.s }, p1/Z, [x20, #3, MUL VL]\n"
+ "ld1w { z31.s }, p1/Z, [x20, #6, MUL VL]\n"
+ "fmax z21.s, p3/M, z21.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ "fmax z18.s, p3/M, z18.s, z11.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "fmax z22.s, p3/M, z22.s, z11.s\n"
+ ".inst 0x658aafb4 // bfcvt z20.h, p3/M, z29.s\n"
+ ".inst 0x658aaf93 // bfcvt z19.h, p3/M, z28.s\n"
+ ".inst 0x658aaeb5 // bfcvt z21.h, p3/M, z21.s\n"
+ ".inst 0x658aaf30 // bfcvt z16.h, p3/M, z25.s\n"
+ "whilelt p0.s, x21, x11\n"
+ "decw x11, ALL, MUL #3\n"
+ ".inst 0x658aae52 // bfcvt z18.h, p3/M, z18.s\n"
+ ".inst 0x658aaf11 // bfcvt z17.h, p3/M, z24.s\n"
+ "incw x21\n"
+ "st1h { z20.s }, p2, [x10]\n"
+ "st1h { z19.s }, p2, [x28]\n"
+ ".inst 0x658aaef4 // bfcvt z20.h, p3/M, z23.s\n"
+ ".inst 0x658aaed3 // bfcvt z19.h, p3/M, z22.s\n"
+ "st1h { z21.s }, p2, [x27]\n"
+ "ld1w { z30.s }, p0/Z, [%x[in_ptr], #2, MUL VL]\n"
+ "ld1w { z29.s }, p0/Z, [%x[in_ptr], #5, MUL VL]\n"
+ "cmp x11, XZR\n"
+ "st1h { z16.s }, p2, [x26]\n"
+ "ld1h { z16.s }, p1/Z, [x10, #1, MUL VL]\n"
+ "ld1w { z28.s }, p0/Z, [x20, #-8, MUL VL]\n"
+ "addvl %x[in_ptr], %x[in_ptr], #24\n"
+ "st1h { z18.s }, p2, [x25]\n"
+ "ld1h { z18.s }, p1/Z, [x28, #1, MUL VL]\n"
+ "ld1w { z27.s }, p0/Z, [x20, #-5, MUL VL]\n"
+ "st1h { z17.s }, p2, [x24]\n"
+ "ld1h { z17.s }, p1/Z, [x27, #1, MUL VL]\n"
+ "ld1w { z26.s }, p0/Z, [x20, #-2, MUL VL]\n"
+ "st1h { z20.s }, p2, [x23]\n"
+ "ld1h { z20.s }, p1/Z, [x26, #1, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "ld1w { z25.s }, p0/Z, [x20, #1, MUL VL]\n"
+ "st1h { z19.s }, p2, [x22]\n"
+ "ld1h { z19.s }, p1/Z, [x25, #1, MUL VL]\n"
+ "lsl z22.s, z18.s, #0x10\n"
+ "ld1w { z24.s }, p0/Z, [x20, #4, MUL VL]\n"
+ "ld1h { z18.s }, p1/Z, [x24, #1, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1w { z23.s }, p0/Z, [x20, #7, MUL VL]\n"
+ "ld1h { z17.s }, p1/Z, [x23, #1, MUL VL]\n"
+ "lsl z20.s, z20.s, #0x10\n"
+ "fadd z6.s, z6.s, z16.s\n"
+ "ld1h { z16.s }, p1/Z, [x22, #1, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "fadd z5.s, z5.s, z22.s\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z4.s, z4.s, z21.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fmin z6.s, p3/M, z6.s, z12.s\n"
+ "fadd z3.s, z3.s, z20.s\n"
+ "fadd z2.s, z2.s, z19.s\n"
+ "fmin z5.s, p3/M, z5.s, z12.s\n"
+ "fadd z1.s, z1.s, z18.s\n"
+ "fmin z4.s, p3/M, z4.s, z12.s\n"
+ "fadd z0.s, z0.s, z17.s\n"
+ "fadd z31.s, z31.s, z16.s\n"
+ "fmax z6.s, p3/M, z6.s, z11.s\n"
+ "fmin z3.s, p3/M, z3.s, z12.s\n"
+ "fmin z2.s, p3/M, z2.s, z12.s\n"
+ "fmax z5.s, p3/M, z5.s, z11.s\n"
+ "fmin z1.s, p3/M, z1.s, z12.s\n"
+ "fmin z0.s, p3/M, z0.s, z12.s\n"
+ "fmin z31.s, p3/M, z31.s, z12.s\n"
+ "fmax z4.s, p3/M, z4.s, z11.s\n"
+ ".inst 0x658aacd0 // bfcvt z16.h, p3/M, z6.s\n"
+ "fmax z3.s, p3/M, z3.s, z11.s\n"
+ "fmax z2.s, p3/M, z2.s, z11.s\n"
+ ".inst 0x658aacb1 // bfcvt z17.h, p3/M, z5.s\n"
+ "fmax z1.s, p3/M, z1.s, z11.s\n"
+ "fmax z0.s, p3/M, z0.s, z11.s\n"
+ "fmax z31.s, p3/M, z31.s, z11.s\n"
+ "st1h { z16.s }, p1, [x10, #1, MUL VL]\n"
+ ".inst 0x658aac90 // bfcvt z16.h, p3/M, z4.s\n"
+ "st1h { z17.s }, p1, [x28, #1, MUL VL]\n"
+ ".inst 0x658aac75 // bfcvt z21.h, p3/M, z3.s\n"
+ ".inst 0x658aac52 // bfcvt z18.h, p3/M, z2.s\n"
+ ".inst 0x658aac31 // bfcvt z17.h, p3/M, z1.s\n"
+ ".inst 0x658aac14 // bfcvt z20.h, p3/M, z0.s\n"
+ "st1h { z16.s }, p1, [x27, #1, MUL VL]\n"
+ ".inst 0x658aaff3 // bfcvt z19.h, p3/M, z31.s\n"
+ "ld1h { z16.s }, p0/Z, [x10, #2, MUL VL]\n"
+ "st1h { z21.s }, p1, [x26, #1, MUL VL]\n"
+ "st1h { z18.s }, p1, [x25, #1, MUL VL]\n"
+ "ld1h { z18.s }, p0/Z, [x28, #2, MUL VL]\n"
+ "st1h { z17.s }, p1, [x24, #1, MUL VL]\n"
+ "ld1h { z17.s }, p0/Z, [x27, #2, MUL VL]\n"
+ "st1h { z20.s }, p1, [x23, #1, MUL VL]\n"
+ "ld1h { z20.s }, p0/Z, [x26, #2, MUL VL]\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "st1h { z19.s }, p1, [x22, #1, MUL VL]\n"
+ "ld1h { z19.s }, p0/Z, [x25, #2, MUL VL]\n"
+ "lsl z22.s, z18.s, #0x10\n"
+ "ld1h { z18.s }, p0/Z, [x24, #2, MUL VL]\n"
+ "lsl z21.s, z17.s, #0x10\n"
+ "ld1h { z17.s }, p0/Z, [x23, #2, MUL VL]\n"
+ "lsl z20.s, z20.s, #0x10\n"
+ "fadd z30.s, z30.s, z16.s\n"
+ "ld1h { z16.s }, p0/Z, [x22, #2, MUL VL]\n"
+ "lsl z19.s, z19.s, #0x10\n"
+ "lsl z18.s, z18.s, #0x10\n"
+ "fadd z29.s, z29.s, z22.s\n"
+ "lsl z17.s, z17.s, #0x10\n"
+ "fadd z28.s, z28.s, z21.s\n"
+ "lsl z16.s, z16.s, #0x10\n"
+ "fadd z27.s, z27.s, z20.s\n"
+ "fmin z30.s, p3/M, z30.s, z12.s\n"
+ "fadd z26.s, z26.s, z19.s\n"
+ "fadd z25.s, z25.s, z18.s\n"
+ "fmin z29.s, p3/M, z29.s, z12.s\n"
+ "fadd z24.s, z24.s, z17.s\n"
+ "fmin z28.s, p3/M, z28.s, z12.s\n"
+ "fadd z23.s, z23.s, z16.s\n"
+ "fmin z27.s, p3/M, z27.s, z12.s\n"
+ "fmax z30.s, p3/M, z30.s, z11.s\n"
+ "fmin z26.s, p3/M, z26.s, z12.s\n"
+ "fmin z25.s, p3/M, z25.s, z12.s\n"
+ "fmax z29.s, p3/M, z29.s, z11.s\n"
+ "fmin z24.s, p3/M, z24.s, z12.s\n"
+ "fmin z23.s, p3/M, z23.s, z12.s\n"
+ "fmax z28.s, p3/M, z28.s, z11.s\n"
+ "fmax z27.s, p3/M, z27.s, z11.s\n"
+ ".inst 0x658aafd0 // bfcvt z16.h, p3/M, z30.s\n"
+ "fmax z26.s, p3/M, z26.s, z11.s\n"
+ "fmax z25.s, p3/M, z25.s, z11.s\n"
+ ".inst 0x658aafb1 // bfcvt z17.h, p3/M, z29.s\n"
+ "fmax z24.s, p3/M, z24.s, z11.s\n"
+ "fmax z23.s, p3/M, z23.s, z11.s\n"
+ "st1h { z16.s }, p0, [x10, #2, MUL VL]\n"
+ ".inst 0x658aaf90 // bfcvt z16.h, p3/M, z28.s\n"
+ ".inst 0x658aaf74 // bfcvt z20.h, p3/M, z27.s\n"
+ "inch x10, ALL, MUL #3\n"
+ "st1h { z17.s }, p0, [x28, #2, MUL VL]\n"
+ "inch x28, ALL, MUL #3\n"
+ ".inst 0x658aaf53 // bfcvt z19.h, p3/M, z26.s\n"
+ ".inst 0x658aaf32 // bfcvt z18.h, p3/M, z25.s\n"
+ "st1h { z16.s }, p0, [x27, #2, MUL VL]\n"
+ "inch x27, ALL, MUL #3\n"
+ ".inst 0x658aaf11 // bfcvt z17.h, p3/M, z24.s\n"
+ ".inst 0x658aaef0 // bfcvt z16.h, p3/M, z23.s\n"
+ "st1h { z20.s }, p0, [x26, #2, MUL VL]\n"
+ "inch x26, ALL, MUL #3\n"
+ "st1h { z19.s }, p0, [x25, #2, MUL VL]\n"
+ "inch x25, ALL, MUL #3\n"
+ "st1h { z18.s }, p0, [x24, #2, MUL VL]\n"
+ "inch x24, ALL, MUL #3\n"
+ "st1h { z17.s }, p0, [x23, #2, MUL VL]\n"
+ "inch x23, ALL, MUL #3\n"
+ "st1h { z16.s }, p0, [x22, #2, MUL VL]\n"
+ "inch x22, ALL, MUL #3\n"
+ "bgt 51b\n"
+ "subs %x[rows], %x[rows], #0x8\n"
+ "add %x[out_ptr], %x[out_ptr], x12\n"
+ "bgt 35b\n"
+ "52:" // Exit
+ : [in_ptr] "+&r" (in_ptr), [out_ptr] "+&r" (out_ptr), [rows] "+&r" (rows)
+ : [accumulate] "r" (accumulate), [bias] "r" (bias), [cols] "r" (cols), [ldout] "r" (ldout), [maxval] "r" (maxval), [minval] "r" (minval)
+ : "cc", "memory", "p0", "p1", "p2", "p3", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+#endif // ARM_COMPUTE_ENABLE_SVE
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index ce727032e..d35825c42 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -184,9 +184,11 @@ public:
col_sums_pretransposed(B, ldb, B_multi_stride);
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ assert(!transposed);
+
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
- _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
+ _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed);
requantize_bias(buffer, B, ldb, B_multi_stride);
}
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index 111d01ed3..6da9f4be0 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1142,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
+void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
+ const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride,
+ const float* bias_ptr, bool accumulate, const Activation &act)
+{
+ const float32x4_t vscale = vdupq_n_f32(qp.scale);
+ float maxval = std::numeric_limits<float>::infinity();
+ float minval = -std::numeric_limits<float>::infinity();
+
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0;
+ break;
+ }
+
+ const float32x4_t vmin = vdupq_n_f32(minval);
+ const float32x4_t vmax = vdupq_n_f32(maxval);
+
+ for(unsigned int row=0; row<height; row++) {
+ auto row_in_ptr = in_ptr + (row * in_stride);
+ auto row_out_ptr = out_ptr + (row * out_stride);
+ unsigned int col=0;
+ if (width >= 4) {
+ for(; col <= (width - 4); col+= 4) {
+ const int32x4_t vin = vld1q_s32(row_in_ptr + col);
+ float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale);
+ if(bias_ptr) {
+ const float32x4_t bin = vld1q_f32(bias_ptr + col);
+ vdeq = vaddq_f32(vdeq, bin);
+ }
+ if(accumulate) {
+ vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col));
+ }
+ vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax);
+ vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq);
+ }
+ }
+ // left-over elements
+ for(; col < width; ++col) {
+ const int32_t val = *(row_in_ptr + col);
+ float res = static_cast<float>(val * qp.scale);
+ if(bias_ptr) {
+ res += static_cast<float>(*(bias_ptr + col));
+ }
+ if(accumulate) {
+ res += *(row_out_ptr + col);
+ }
+ res = std::min(std::max(res, minval), maxval);
+ *(row_out_ptr + col) = res;
+ }
+ }
+}
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp
index 31dd65b39..bc64fd967 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019, 2023 Arm Limited.
+ * Copyright (c) 2019, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,4 +45,8 @@ template<typename T>
void row_sums_indirect(size_t num_strings, const unsigned int *string_lengths, IndirectInputArg<T> A_arg,
size_t M, int32_t *output_ptr, const Requantize32 *qp);
+void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
+ const int32_t* input, unsigned int in_stride, float *output, unsigned int out_stride,
+ const float *row_bias, bool not_first_pass, const Activation &act);
+
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
index 4669be999..a9cbf4ec8 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -63,9 +63,14 @@ public:
ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
}
+ bool PrepareB_supports_transpose() const {
+ return false;
+ }
+
template<typename TIn>
void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
- const int xmax, const int k0, const int kmax) const {
+ const int xmax, const int k0, const int kmax, bool transposed) const {
+ assert(!transposed);
Transform<width, block, true>(out, in, stride, x0, xmax, k0, kmax);
}
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp
new file mode 100644
index 000000000..1db716455
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include "convolver.hpp"
+#include "mergeresults.hpp"
+#include "transform.hpp"
+#include "interleave_indirect.hpp"
+
+namespace arm_gemm {
+
+/*
+ * Define "standard" transforms for the blocked GEMMs with fixed vector
+ * length. This version supports accepting the RHS/B matrix in transposed
+ * format.
+ *
+ * This assumes that A is interleaved 'height' ways, B is interleaved
+ * 'width' ways and transposed, and that the merge needs to work in 'height'
+ * x 'width' blocks.
+ *
+ * The optional 'block' parameter is for kernels using dot-product type
+ * instructions like UDOT and SDOT.
+ */
+template<typename TOperand, typename TResult, unsigned int height, unsigned int width, unsigned int block=1, bool integrate_sums=false>
+class StdTransformsFixedTRB
+{
+public:
+ template<typename TIn>
+ void PrepareA(TOperand *out, const TIn *in, const int stride, const int y0,
+ const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) const {
+ Interleave<height, block, VLType::None>(out, in, stride, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+ }
+
+ template<typename TIn>
+ void PrepareA_indirect(TOperand *out, const TIn * const * const *ptr, size_t stringlen, size_t rounded_stringlen, const int y0,
+ const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) {
+ IndirectInterleave<height, block, VLType::None>(out, ptr, stringlen, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+ }
+
+ template<typename TIn>
+ void PrepareA_convolution(TOperand *out, const TIn *ptr, size_t stride, const convolver<TIn> &conv, size_t rounded_stringlen,
+ const int y0, const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) {
+ ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+ }
+
+ bool PrepareB_supports_transpose() const {
+ return true;
+ }
+
+ template<typename TIn>
+ void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
+ const int xmax, const int k0, const int kmax, bool transposed) const {
+ if (transposed) {
+ Transform<width, block, false>(out, in, stride, x0, xmax, k0, kmax);
+ } else {
+ Transform<width, block, true>(out, in, stride, x0, xmax, k0, kmax);
+ }
+ }
+
+ template<typename TOut>
+ void Merge(TOut *out, const TResult *in, int stride, int y0, int ymax, int x0, int xmax, const TOut *bias, const Activation act, bool append) const {
+ MergeResults<width, height>(out, in, stride, y0, ymax, x0, xmax, bias, act, append);
+ }
+};
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
index afe24e7ce..40f61626a 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,9 +60,14 @@ public:
ConvolutionInterleave<height_vectors, block, VLType::SME>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
}
+ bool PrepareB_supports_transpose() const {
+ return false;
+ }
+
template<typename TIn>
void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
- const int xmax, const int k0, const int kmax) {
+ const int xmax, const int k0, const int kmax, bool transposed) {
+ assert (!transposed);
Transform<width_vectors, block, true, VLType::SME>(out, in, stride, x0, xmax, k0, kmax);
}
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
index 3256d919e..c516bfc45 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2018,2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,9 +61,14 @@ public:
ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
}
+ bool PrepareB_supports_transpose() const {
+ return false;
+ }
+
template<typename TIn>
void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
- const int xmax, const int k0, const int kmax) {
+ const int xmax, const int k0, const int kmax, bool transposed) {
+ assert (!transposed);
Transform<width_vectors, block, true, VLType::SVE>(out, in, stride, x0, xmax, k0, kmax);
}
diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp
index 5aa62f0fe..45e4f0e1d 100644
--- a/src/core/NEON/kernels/arm_gemm/transform.cpp
+++ b/src/core/NEON/kernels/arm_gemm/transform.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -134,7 +134,14 @@ template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int,
#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
#ifdef ARM_COMPUTE_ENABLE_BF16
template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int);
-#endif
+#endif // ARM_COMPUTE_ENABLE_BF16
#endif // AArch32
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int);
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#ifdef ARM_COMPUTE_ENABLE_BF16
+template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int);
+#endif // ARM_COMPUTE_ENABLE_BF16
+
} // namespace arm_gemm
diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h
index 50b3fc128..a74316b48 100644
--- a/src/core/common/Registrars.h
+++ b/src/core/common/Registrars.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2023 Arm Limited.
+ * Copyright (c) 2020-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,6 +38,12 @@
#define REGISTER_FP16_SVE2(func_name) nullptr
#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
+#if defined(ARM_COMPUTE_ENABLE_SME2)
+#define REGISTER_FP16_SME2(func_name) &(func_name)
+#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */
+#define REGISTER_FP16_SME2(func_name) nullptr
+#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */
+
#if defined(ARM_COMPUTE_ENABLE_NEON)
#define REGISTER_FP16_NEON(func_name) &(func_name)
#else /* !defined(ARM_COMPUTE_ENABLE_NEON) */
@@ -48,6 +54,7 @@
#define REGISTER_FP16_NEON(func_name) nullptr
#define REGISTER_FP16_SVE(func_name) nullptr
#define REGISTER_FP16_SVE2(func_name) nullptr
+#define REGISTER_FP16_SME2(func_name) nullptr
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
#if defined(ENABLE_FP32_KERNELS)
@@ -64,6 +71,12 @@
#define REGISTER_FP32_SVE2(func_name) nullptr
#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
+#if defined(ARM_COMPUTE_ENABLE_SME2)
+#define REGISTER_FP32_SME2(func_name) &(func_name)
+#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */
+#define REGISTER_FP32_SME2(func_name) nullptr
+#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */
+
#if defined(ARM_COMPUTE_ENABLE_NEON)
#define REGISTER_FP32_NEON(func_name) &(func_name)
#else /* !defined(ARM_COMPUTE_ENABLE_NEON) */
@@ -74,6 +87,7 @@
#define REGISTER_FP32_NEON(func_name) nullptr
#define REGISTER_FP32_SVE(func_name) nullptr
#define REGISTER_FP32_SVE2(func_name) nullptr
+#define REGISTER_FP32_SME2(func_name) nullptr
#endif /* defined(ENABLE_FP32_KERNELS) */
#if defined(ENABLE_QASYMM8_SIGNED_KERNELS)
diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp
index 19d0badd7..212cfdaba 100644
--- a/src/core/utils/helpers/tensor_transform.cpp
+++ b/src/core/utils/helpers/tensor_transform.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -117,7 +117,10 @@ int calculate_end_on_index(TensorShape input_shape,
}
// Final clamp
- stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1);
+ if (stride > 0)
+ stop = utility::clamp(stop, 0, dim_size);
+ else
+ stop = utility::clamp(stop, -1, dim_size - 1);
return stop;
}
diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp
index f66d3e706..f8b74a985 100644
--- a/src/core/utils/quantization/AsymmHelpers.cpp
+++ b/src/core/utils/quantization/AsymmHelpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -122,13 +122,13 @@ arm_compute::Status calculate_quantized_multipliers(const QuantizationInfo &iq_
ARM_COMPUTE_RETURN_ERROR_ON(iq_info.scale().empty());
ARM_COMPUTE_RETURN_ERROR_ON(wq_info.scale().empty());
ARM_COMPUTE_RETURN_ERROR_ON(oq_info.scale().empty());
-
- const unsigned int size = wq_info.scale().size();
-
- auto &quant_multipliers = stage_info.gemmlowp_multipliers;
- auto &quant_shifts = stage_info.gemmlowp_shifts;
- quant_multipliers.resize(size);
- quant_shifts.resize(size);
+ constexpr unsigned int padding_elems = 32; // assembly kernels assume the shifts and multipliers buffers are padded
+ const unsigned int size = wq_info.scale().size();
+ const size_t padded_size = (size == 1) ? 1 : size + padding_elems;
+ auto &quant_multipliers = stage_info.gemmlowp_multipliers;
+ auto &quant_shifts = stage_info.gemmlowp_shifts;
+ quant_multipliers.resize(padded_size);
+ quant_shifts.resize(padded_size);
const auto &w_scales = wq_info.scale();
const float i_scale = iq_info.scale().at(0);
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
index e29078302..2a76a5958 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,17 +51,19 @@ Status validate_arguments(const ITensorInfo *mm_result,
int32_t a_offset,
int32_t b_offset)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32, DataType::F32);
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ // We run if the offset is nonzero or a sum col has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (a_offset != 0 || vector_sum_col != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
}
- // If b_offset == 0, vector_sum_row can be a nullptr
- if (b_offset != 0)
+ // We run if the offset is nonzero or a sum row has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (b_offset != 0 || vector_sum_row != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
@@ -86,7 +88,7 @@ Status validate_arguments(const ITensorInfo *mm_result,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
"mm_result tensor must have the same number of batches of output tensor");
- if (a_offset != 0)
+ if (vector_sum_col != nullptr)
{
TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
vector_sum_col_shape.collapse_from(1);
@@ -102,6 +104,275 @@ Status validate_arguments(const ITensorInfo *mm_result,
return Status{};
}
+void run_offset_contribution_float(const Window &window,
+ ITensor *mm_result,
+ const ITensor *vector_sum_col,
+ const ITensor *vector_sum_row,
+ int32_t a_offset,
+ int32_t b_offset,
+ int32_t k_offset,
+ float scale,
+ bool slide_vector_sum_col,
+ bool is_gemm3d)
+{
+ Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
+ collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
+ const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16;
+
+ // if vector_sum_col is nullptr then stride_y is 0, else get stride_y
+ const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
+ Iterator mm_result_it(mm_result, collapsed_window);
+
+ if ((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col = batch_id * (sum_col_stride_y);
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t b_offset_term_s32 =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ b_offset_term_s32 *= b_offset;
+
+ const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ // Add a_offset_term_s32 and b_offset_term_s32
+ int32x4x4_t offset_term_s32 = {
+ {vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}};
+
+ offset_term_s32.val[0] =
+ vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
+ offset_term_s32.val[1] =
+ vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
+ offset_term_s32.val[2] =
+ vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
+ offset_term_s32.val[3] =
+ vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Convert and scale the S32 offsets to match the already scaled GEMM results
+ float32x4x4_t offset_terms_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[3]), scale),
+ }};
+
+ // Add the offset terms to the GEMM result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], offset_terms_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], offset_terms_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], offset_terms_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], offset_terms_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ a_offset_term_s32 *= a_offset;
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += (k_offset + a_offset_term_s32 + b_offset_term_s32) * scale;
+ }
+ },
+ vector_sum_col_it, vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true
+ {
+ ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row);
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t row_sum =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ float scaled_b_offset_term_f32 = row_sum * b_offset * scale;
+
+ const float32x4_t b_offset_term_f32_vec = vdupq_n_f32(scaled_b_offset_term_f32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], b_offset_term_f32_vec);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], b_offset_term_f32_vec);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], b_offset_term_f32_vec);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], b_offset_term_f32_vec);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += scaled_b_offset_term_f32;
+ }
+ },
+ vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col =
+ batch_id *
+ (sum_col_stride_y); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ float32x4x4_t a_offset_term_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[3]), scale),
+ }};
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], a_offset_term_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], a_offset_term_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], a_offset_term_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], a_offset_term_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += a_offset_term_s32 * a_offset * scale;
+ }
+ },
+ vector_sum_col_it, mm_result_it);
+ }
+ else // false, false
+ {
+ // No offset contribution from matrix A and matrix B
+ return;
+ }
+}
+
void run_offset_contribution(const Window &window,
ITensor *mm_result,
const ITensor *vector_sum_col,
@@ -361,7 +632,8 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset)
+ int32_t b_offset,
+ float scale)
{
// Perform validate step
ARM_COMPUTE_UNUSED(vector_sum_row);
@@ -370,10 +642,11 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ _scale = scale;
+
+ if (vector_sum_col != nullptr)
{
// Check if vector_sum_col_shape should be slidden or not
// Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
@@ -386,6 +659,21 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ICpuKernel::configure(win);
}
+void CpuGemmLowpOffsetContributionKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_scale(float scale)
+{
+ _scale = scale;
+}
+
Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result,
const ITensorInfo *vector_sum_col,
const ITensorInfo *vector_sum_row,
@@ -410,8 +698,18 @@ void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Win
const bool reinterpret_as_3d = vector_sum_row != nullptr && mm_result->info()->num_dimensions() > 1 &&
mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
- run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset,
- _slide_vector_sum_col, reinterpret_as_3d);
+ // check to see what is the output type of result
+ auto k_offset = _a_offset * _b_offset * _k;
+ if (mm_result->info()->data_type() == DataType::F32)
+ {
+ run_offset_contribution_float(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _scale, _slide_vector_sum_col, reinterpret_as_3d);
+ }
+ else
+ {
+ run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _slide_vector_sum_col, reinterpret_as_3d);
+ }
}
const char *CpuGemmLowpOffsetContributionKernel::name() const
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
index 08b2d4752..ecbfb0c28 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,12 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+#include <cstdint>
+
namespace arm_compute
{
namespace cpu
@@ -62,13 +64,16 @@ public:
* @param[in] k Number of matrix A columns or Matrix B rows
* @param[in] a_offset Offset to be added to each element of the matrix A.
* @param[in] b_offset Offset to be added to each element of the matrix B.
+ * @param[in] scale (Optional) multiplies the contribution to make it the same scale as the dst in the case where mm_result is float
+ * (and so has already been scaled). Default is 1.0
*/
void configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_col,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset);
+ int32_t b_offset,
+ float scale = 1.0f);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuGemmLowpOffsetContributionKernel::configure()
@@ -81,6 +86,29 @@ public:
int32_t a_offset,
int32_t b_offset);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_row must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
+ /** Set the dequantize scale
+ *
+ * @param[in] scale Multiplies the contribution to make it the same scale as the dst in the case where
+ * mm_result is float (and so has already been scaled).
+ */
+ void set_scale(float scale);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -88,10 +116,11 @@ public:
private:
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
+ float _scale{1.0};
bool _slide_vector_sum_col{true};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
index d00884239..3c113f282 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021, 2023 Arm Limited.
+ * Copyright (c) 2019-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -919,7 +919,7 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::configure(const ITensorInfo
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
_output_stage = output_stage;
// If a_offset == 0, vector_sum_col can be a nullptr
@@ -958,6 +958,16 @@ Status CpuGemmLowpOffsetContributionOutputStageKernel::validate(const ITensorInf
return Status{};
}
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &tensors,
const Window &window,
const ThreadInfo &info)
@@ -993,10 +1003,11 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
// Check if symmetric per-channel execution
const bool is_symm = _output_stage.is_quantized_per_channel;
+ auto k_offset = _a_offset * _b_offset * _k;
if (is_symm)
{
run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst,
- _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched,
+ _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched,
_output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
@@ -1004,13 +1015,13 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
if (is_signed)
{
run_offset_contribution_output_stage<int8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
{
run_offset_contribution_output_stage<uint8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
}
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
index af477d475..ff706ff3d 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
#include "arm_compute/core/KernelDescriptors.h"
@@ -110,6 +110,22 @@ public:
int32_t b_offset,
GEMMLowpOutputStageInfo output_stage);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -118,11 +134,11 @@ private:
/** Function to use for the particular tensors passed to configure() */
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
bool _is_vector_sum_col_batched{true};
GEMMLowpOutputStageInfo _output_stage{GEMMLowpOutputStageInfo()};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 45ebeec39..d71789cc3 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -104,6 +104,7 @@ struct SoftmaxKernelDataTypeISASelectorData
DataType dt;
cpuinfo::CpuIsaInfo isa;
bool is_log;
+ int axis;
};
// Selector pointer types
diff --git a/src/cpu/kernels/CpuQuantizeKernel.cpp b/src/cpu/kernels/CpuQuantizeKernel.cpp
index 5dde68083..d2ac6cf8a 100644
--- a/src/cpu/kernels/CpuQuantizeKernel.cpp
+++ b/src/cpu/kernels/CpuQuantizeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -104,6 +104,18 @@ vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const Uni
return vquantize_signed(qv, qi);
}
+template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type>
+inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
+{
+ return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper));
+}
+
+template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type>
+inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
+{
+ return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper));
+}
+
} // namespace
void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
@@ -120,6 +132,19 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
{"op_QASYMM8_SIGNED_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, int8_t>},
{"op_QASYMM8_SIGNED_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<int8_t>},
+ // Functions for offset only requantization
+ {"op_OFFSET_ONLY_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, uint8_t>},
+ {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, int8_t>},
+ {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<int8_t, uint8_t>},
+ {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED",
+ &CpuQuantizeKernel::run_requantize_offset_only<int8_t, int8_t>},
+
+ // Functions for offset uint8 to int8 and vice versa quantization (no scale changes)
+ {"op_OFFSET_ONLY_CONVERT_QASYMM8_SIGNED_QASYMM8",
+ &CpuQuantizeKernel::run_requantize_offset_only_convert<int8_t, uint8_t>},
+ {"op_OFFSET_ONLY_CONVERT_QASYMM8_QASYMM8_SIGNED",
+ &CpuQuantizeKernel::run_requantize_offset_only_convert<uint8_t, int8_t>},
+
{"op_F32_QSYMM8", &CpuQuantizeKernel::run_quantize_qsymm8<float, int8_t>},
{"op_F32_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float, uint8_t>},
@@ -134,6 +159,26 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
};
std::string function_to_call("op_");
+
+ // For offset only functions - must be 8-bit and have identical scale values.
+ if (src->quantization_info().scale() == dst->quantization_info().scale() &&
+ (is_data_type_quantized_asymmetric_char(src->data_type()) &&
+ is_data_type_quantized_asymmetric_char(dst->data_type())))
+ {
+ function_to_call += "OFFSET_ONLY_";
+ // For optimized datatype conversion 8-bit re-quantization offset only functions.
+ // These must have an offset of exactly 128 to match requirements - has specific circumstances to match use case.
+ auto uqinfo =
+ compute_requantization_scale_offset(src->quantization_info().uniform(), dst->quantization_info().uniform());
+ const auto src_dt = src->data_type();
+ if (src->data_type() != dst->data_type() && ((src_dt == DataType::QASYMM8_SIGNED && uqinfo.offset == 128) ||
+ (src_dt == DataType::QASYMM8 && uqinfo.offset == -128)))
+ {
+ function_to_call += "CONVERT_";
+ }
+ }
+
+ // Specify datatype for function
function_to_call += string_from_data_type(src->data_type()) + "_";
function_to_call += string_from_data_type(dst->data_type());
@@ -145,9 +190,11 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
}
_func = it->second;
- // Configure kernel window
- Window win_config = calculate_max_window(*src, Steps());
- ICpuKernel::configure(win_config);
+ // Calculate window. Squash if possible.
+ Window win;
+ std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*src);
+
+ ICpuKernel::configure(win);
}
Status CpuQuantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
@@ -164,10 +211,8 @@ void CpuQuantizeKernel::run_quantize_qsymm8(const ITensor *src, ITensor *dst, co
const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- if (is_data_type_quantized_asymmetric(src->info()->data_type()))
- {
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
- }
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
// Collapse window and reset first dimension to handle tail calculations manually
Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
@@ -195,6 +240,114 @@ void CpuQuantizeKernel::run_quantize_qsymm8(const ITensor *src, ITensor *dst, co
}
template <typename TIn, typename TOut>
+void CpuQuantizeKernel::run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Calculate output offset difference.
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Duplicate offset in signed vector format
+ const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ const wrapper::traits::neon_vector_t<TIn, window_step> qv =
+ wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
+
+ // Signed addition.
+ auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset);
+
+ // Output is dependent on datatype.
+ wrapper::vstore(&output_ptr[x],
+ reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res));
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
+ output_ptr[x] = static_cast<TOut>(result);
+ }
+ },
+ input, output);
+}
+
+template <typename TIn, typename TOut>
+void CpuQuantizeKernel::run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Duplicate offset in signed vector format
+ const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
+
+ const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128;
+ const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127;
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
+ int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv)));
+ int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv)));
+
+ // Signed addition.
+ lower = wrapper::vqadd(lower, offset);
+ upper = wrapper::vqadd(upper, offset);
+
+ // Output is dependent on datatype.
+ auto res = recombine_8_16<TOut>(lower, upper);
+ wrapper::vstore(&output_ptr[x], res);
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ // Add offset and clamp result to within the range of the output datatype.
+ int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
+ result = utility::clamp<int32_t>(result, low_bound, upper_bound);
+
+ // Cast result to output datatype.
+ output_ptr[x] = static_cast<TOut>(result);
+ }
+ },
+ input, output);
+}
+
+template <typename TIn, typename TOut>
void CpuQuantizeKernel::run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
{
const auto window_start_x = static_cast<int>(window.x().start());
@@ -302,6 +455,7 @@ const char *CpuQuantizeKernel::name() const
{
return "CpuQuantizeKernel";
}
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuQuantizeKernel.h b/src/cpu/kernels/CpuQuantizeKernel.h
index d6714136d..c2f7ac6d9 100644
--- a/src/cpu/kernels/CpuQuantizeKernel.h
+++ b/src/cpu/kernels/CpuQuantizeKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_QUANTIZE_KERNEL_H
-#define ARM_COMPUTE_CPU_QUANTIZE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUQUANTIZEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUQUANTIZEKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
@@ -58,6 +58,15 @@ public:
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *dst);
+ /** Get the preferred dimension in which the scheduler splits the work into multiple jobs.
+ *
+ * @return The split dimension hint.
+ */
+ size_t get_split_dimension_hint() const
+ {
+ return _split_dimension;
+ }
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -86,9 +95,17 @@ private:
template <typename TIn, typename TOut>
void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window);
+ template <typename TIn, typename TOut>
+ void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window);
+
+ template <typename TIn, typename TOut>
+ void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window);
+
QuantizeFunctionExecutorPtr _func{nullptr};
+ size_t _split_dimension{Window::DimY};
};
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_QUANTIZE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUQUANTIZEKERNEL_H
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397ac..5cf81f815 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -50,9 +50,17 @@ namespace
{
/* Softmax */
static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = {
+ {"sme2_fp32_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP32_SME2(sme2_fp32_softmax)},
{"neon_fp32_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); },
REGISTER_FP32_NEON(neon_fp32_softmax<false>)},
+ {"sme2_fp16_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP16_SME2(sme2_fp16_softmax)},
{"neon_fp16_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
{ return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
@@ -81,7 +89,7 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker
};
Status validate_arguments_softmax(
- const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+ const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
{
ARM_COMPUTE_UNUSED(beta);
// Check input
@@ -89,6 +97,8 @@ Status validate_arguments_softmax(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
// Check output if configured
@@ -124,10 +134,13 @@ const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> &CpuSoftmaxKernel::g
return available_kernels;
}
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+ const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
{
+ _axis = axis;
+
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
// Configure kernel window
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -145,7 +158,7 @@ void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float
}
const auto *uk = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log});
+ SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis});
ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel");
@@ -154,25 +167,40 @@ void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float
_run_method = uk->ukernel;
_name = kernel_name.append("/").append(uk->name);
- Window win = calculate_max_window(*dst, Steps());
+ Window win;
+
+ int vec_size = 16 / dst->element_size();
- /// TODO: Check dimensions > 0 for holes only. For this, we need
- /// a utility function checking if there are holes after some dimension.
- if (!has_holes(*dst, dst->num_dimensions() - 1))
+ if (_axis == 0)
+ {
+ win = calculate_max_window(*dst, Steps());
+
+ /// TODO:Check dimensions > 0 for holes only. For this, we need
+ /// a utility function checking if there are holes after some dimension.
+ if (!has_holes(*dst, dst->num_dimensions() - 1))
+ {
+ win = win.collapse(win, Window::DimY);
+ }
+ }
+ else if (_axis > 0 && _axis <= 3)
+ {
+ win = calculate_max_window(*dst, Steps(vec_size));
+ }
+ else
{
- win = win.collapse(win, Window::DimY);
+ ARM_COMPUTE_ERROR("Invalid axis");
}
- win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+ win.set(_axis, Window::Dimension(0, 1, 1));
ICpuKernel<CpuSoftmaxKernel>::configure(win);
}
Status CpuSoftmaxKernel::validate(
- const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+ const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
return Status{};
}
@@ -188,19 +216,25 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const
if (is_data_type_quantized_asymmetric(src->info()->data_type()))
{
- auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
- const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+ auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+ unsigned int num_elems_processed_per_iteration;
+ if (_axis == 0)
+ {
+ num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+ }
+ else
+ {
+ //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+ num_elems_processed_per_iteration = 16;
+ }
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
- ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- _run_method(src, tmp_for_thread, dst, _beta, window);
+ _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
}
else
{
- _run_method(src, nullptr, dst, _beta, window);
+ _run_method(src, nullptr, dst, _beta, _axis, window);
}
}
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 3db1f3d0e..043ad975d 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,7 +38,7 @@ class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel>
{
private:
using SoftmaxKernelPtr =
- std::add_pointer<void(const ITensor *, void *const, ITensor *, float, const Window &)>::type;
+ std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type;
public:
CpuSoftmaxKernel() = default;
@@ -49,11 +49,12 @@ public:
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[out] dst Destination tensor info. Data types supported: same as @p input.
* @param[in] beta A scaling factor for the exponent.
- * @param[in] is_log True if the operation is log-softmax
+ * @param[in] is_log True if the operation is log-softmax.
+ * @param[in] axis The axis along which to perform the softmax operation.
*
* @param tmp Auxiliary tensor info. Must be type F32 and same shape as the input.
*/
- void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp);
+ void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuSoftmaxKernel::configure()
@@ -61,7 +62,7 @@ public:
* @return a status
*/
static Status
- validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp);
+ validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
@@ -80,6 +81,7 @@ private:
float _beta{1.0f};
SoftmaxKernelPtr _run_method{nullptr};
std::string _name{};
+ int _axis{};
};
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 9a913c5c5..941fed0ba 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+
#pragma once
#include "arm_gemm_local.hpp"
@@ -151,6 +155,7 @@ public:
int _maxthreads;
bool _fixed_format;
bool _fast_mode;
+ bool _accumulate;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci,
@@ -165,6 +170,7 @@ public:
const int maxthreads,
bool fixed_format = false,
bool fast_mode = false,
+ bool accumulate = false,
const GemmConfig *cfg = nullptr)
: _ci(ci),
_Msize(M),
@@ -178,6 +184,7 @@ public:
_maxthreads(maxthreads),
_fixed_format(fixed_format),
_fast_mode(fast_mode),
+ _accumulate(accumulate),
_cfg(cfg)
{
}
@@ -253,6 +260,19 @@ public:
}
};
+struct DequantizeFloat
+{
+public:
+ float scale = 0;
+
+ DequantizeFloat() = default;
+
+ // Constructor
+ DequantizeFloat(const float scale) : scale(scale)
+ {
+ }
+};
+
struct Nothing
{
};
@@ -278,3 +298,5 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 6fe9f13f0..45d1e4327 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021,2023 Arm Limited.
+ * Copyright (c) 2017-2021,2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP
+
#pragma once
#include "convolution_parameters.hpp"
@@ -116,6 +120,11 @@ public:
{
return false;
}
+ /* Does pretranspose accept the transposed flag? */
+ virtual bool B_pretranspose_supports_transpose() const
+ {
+ return false;
+ }
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const
{
@@ -128,10 +137,10 @@ public:
}
/* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
/* The "real" version of this depends on the templated operand type (see below). */
- virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
+ virtual void pretranspose_B_array_generic(void *, const void *, const int, const int, bool) = 0;
/* Threaded version with window start/end parameters */
virtual void
- pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0;
+ pretranspose_B_array_part_generic(void *, const void *, const int, const int, bool, const size_t, const size_t) = 0;
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
virtual void set_pretransposed_B_data(void *)
@@ -157,6 +166,12 @@ public:
{
}
+ /*** Dequanize scale interface (optional) ***/
+ /* Set the dequantize scale for GEMMs when converting from int to float (float out = scale * float(int out) ) */
+ virtual void set_dequantize_scale(const float)
+ {
+ }
+
/*** Introspection interface ***/
/* Get the configuration of this GEMM */
virtual GemmConfig get_config() = 0;
@@ -251,28 +266,34 @@ public:
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int){};
+ virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
- void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override
+ void pretranspose_B_array_generic(
+ void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
+ pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
* The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
- virtual void
- pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t)
+ virtual void pretranspose_B_array_part(
+ void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
- pretranspose_B_array(out, in, row_stride, multi_stride);
+ pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
- void pretranspose_B_array_part_generic(
- void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override
+ void pretranspose_B_array_part_generic(void *out,
+ const void *in,
+ const int row_stride,
+ const int multi_stride,
+ bool transposed,
+ size_t start,
+ size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end);
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
@@ -287,3 +308,5 @@ public:
};
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP
diff --git a/src/cpu/kernels/elementwise_binary/generic/neon/impl.h b/src/cpu/kernels/elementwise_binary/generic/neon/impl.h
index 98f7e8b94..78e3baf74 100644
--- a/src/cpu/kernels/elementwise_binary/generic/neon/impl.h
+++ b/src/cpu/kernels/elementwise_binary/generic/neon/impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
-#define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
+#ifndef ACL_SRC_CPU_KERNELS_ELEMENTWISE_BINARY_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_ELEMENTWISE_BINARY_GENERIC_NEON_IMPL_H
#include "src/core/NEON/NEAsymm.h"
@@ -198,14 +198,6 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar
case ArithmeticOperation::DIV:
{
res = a / b;
- if (std::is_integral<ScalarType>::value)
- {
- res = (b == 0) ? 0 : res;
- if (static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
- {
- --res;
- }
- }
break;
}
case ArithmeticOperation::POWER:
@@ -224,7 +216,15 @@ inline int32x4_t
elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a,
const int32x4_t &b)
{
- return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
+ int32x4_t result;
+
+ // Neon(TM) does not have vector integer division
+ result[0] = a[0] / b[0];
+ result[1] = a[1] / b[1];
+ result[2] = a[2] / b[2];
+ result[3] = a[3] / b[3];
+
+ return result;
}
template <>
@@ -1313,4 +1313,4 @@ void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2
} // namespace cpu
} // namespace arm_compute
-#endif /* SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H */
+#endif // ACL_SRC_CPU_KERNELS_ELEMENTWISE_BINARY_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
index 9ba245148..2c1cb1578 100644
--- a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
+++ b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -100,7 +101,6 @@ Status
CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const PoolingLayerInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
-
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_MSG("32-bit is not supported by assembly kernels");
#endif /* __aarch64__ */
@@ -120,6 +120,8 @@ CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorIn
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+ const TensorInfo out_info(compute_pool_shape(*src, info), 1, dst->data_type());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &out_info);
const auto src_qinfo = src->quantization_info().uniform();
const auto dst_qinfo = dst->quantization_info().uniform();
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
index 6470f391e..344b9df0c 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,26 +66,20 @@ void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, fl
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
}
- float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec));
- sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
- sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
-
- float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec);
- sum_sq_carry_res = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res);
-
- float16_t sum = vget_lane_f16(sum_carry_res, 0);
- float sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0);
+ float32x4_t sum_carry_res =
+ vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
+ float sum = vaddvq_f32(sum_carry_res);
+ float sum_sq = vaddvq_f32(sum_sq_vec);
// Compute left-over elements
for (; x < window_end_x; ++x)
{
- float16_t data = *(in_ptr + x);
- sum += data;
- float fdata = static_cast<float>(data);
+ const float fdata = static_cast<float>(*(in_ptr + x));
+ sum += fdata;
sum_sq += fdata * fdata;
}
- float16_t mean = sum / input->info()->dimension(0);
+ float16_t mean = static_cast<float16_t>(sum / input->info()->dimension(0));
float var = (sum_sq / input->info()->dimension(0)) - (mean * mean);
float16_t stddev_inv = static_cast<float16_t>(1.f / sqrt(var + epsilon));
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index db8f88171..da62d2d61 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,15 +33,23 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp16_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp16_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp16_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index c281d1bf3..070162063 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,15 +31,23 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp32_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp32_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_float<float, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_fp32_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp32_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp32_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp32_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp
index 487f6ae05..31baf8a9d 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,8 +30,11 @@ namespace arm_compute
namespace cpu
{
template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
{
+ ARM_COMPUTE_UNUSED(axis);
+
static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
"quantized type should be either qasymm8_t or qasymm8_signed_t.");
@@ -248,16 +251,346 @@ void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, fl
in_it, out_it);
}
-template void neon_softmax_quantized<qasymm8_signed_t, true>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+ static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
+ "quantized type should be either qasymm8_t or qasymm8_signed_t.");
+
+ const float scale_beta = -beta * in->info()->quantization_info().uniform().scale;
+ const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
+
+ Iterator in_it(in, window);
+ Iterator out_it(out, window);
+
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ constexpr int vec_size = 16;
+ const ITensorInfo *in_info = in->info();
+ const ITensorInfo *out_info = out->info();
+ const int x_width = in_info->valid_region().shape.x();
+ const int in_axis_stride = in_info->strides_in_bytes()[axis];
+ const int out_axis_stride = out_info->strides_in_bytes()[axis];
+ const int tmp_axis_stride = in_axis_stride;
+ const int axis_width = in_info->dimension(axis);
+ const int end_actual = std::min(window[0].end(), x_width);
+
+ execute_window_loop(
+ window,
+ [&](const Coordinates &winCoords)
+ {
+ const bool vector_exceeds_bounds = ((winCoords[0] + vec_size) > end_actual);
+
+ int num_remaining = (end_actual - winCoords[0]);
+ int num_remaining_full = num_remaining / 4;
+ int num_remaining_partial = num_remaining % 4;
+
+ /* Get pointers */
+ const uint8_t *in_ptr = in_it.ptr();
+ uint8_t *out_ptr = out_it.ptr();
+ uint8_t *tmp_ptr = reinterpret_cast<uint8_t *>(tmp);
+
+ auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+ /* Compute Max */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const auto current_value =
+ wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ vec_max = wrapper::vmax(vec_max, current_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = ((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ int j = 0;
+ for (; j < num_remaining; ++j)
+ {
+ const T current_value = *(base_ptr_in + j);
+ vec_max[j] = std::max(vec_max[j], current_value);
+ }
+ }
+ }
+ } // Compute Max
+
+ float32x4x4_t vec_sum_transformed = {
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ };
+
+ /* Compute exponentials and sum */
+ {
+ /* Init sum to zero */
+ float32x4x4_t vec_sum = vec_sum_transformed;
+
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ float32x4x4_t vec_elements_flt;
+
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ vec_elements = wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ vec_elements = wrapper::vqsub(vec_max, vec_elements);
+ vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+ if (IS_LOG)
+ {
+ vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+ vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+ vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+ vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+ }
+ else
+ {
+ vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+ vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+ vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+ vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+ }
+ vst4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr), vec_elements_flt);
+ }
+
+ auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256.f), ExactTagType{});
+ if (!IS_LOG)
+ {
+ vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+ }
+ else
+ {
+ vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = (i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr);
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+ //vec_els is functionally redundant but is needed as a workaround for a toolchain bug.
+ std::vector<T> vec_els(16);
+
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ vec_els[k * 4 + j] = *(base_ptr_in + (4 * k + j));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ vec_els[num_remaining_full * 4 + j] = *(base_ptr_in + (4 * num_remaining_full + j));
+ }
+ for (int q = 0; q < 16; q++)
+ {
+ vec_elements[q] = vec_els[q];
+ }
+ vec_elements = wrapper::vqsub(vec_max, vec_elements);
+ float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+ if (IS_LOG)
+ {
+ vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+ vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+ vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+ vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+ }
+ else
+ {
+ vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+ vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+ vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+ vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+ }
+
+ float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_tmp + (4 * k + j)) = vec_elements_flt.val[k][j];
+ }
+ }
+
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_tmp + (4 * num_remaining_full + j)) =
+ vec_elements_flt.val[num_remaining_full][j];
+ }
+ }
+
+ auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256), ExactTagType{});
+ if (!IS_LOG)
+ {
+ vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+ }
+ else
+ {
+ vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+ }
+ }
+ } // Compute exponentials and sum
+
+ /* Normalize exponentials */
+ {
+ constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
+ float32x4x4_t vec_in = vld4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr));
+
+ int_vec_type normalized_value{};
+
+ if (IS_LOG)
+ {
+ const float32x4x4_t sub = {
+ vsubq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+ vsubq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+ vsubq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+ vsubq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+ };
+ normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
+ }
+ else
+ {
+ float32x4x4_t mul = {
+ vmulq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+ vmulq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+ vmulq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+ vmulq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+ };
+
+ if (is_qasymm8_signed)
+ {
+ const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
+ mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
+ mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
+ mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
+ mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
+ }
+
+ normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
+ }
+ wrapper::vstore((i * out_axis_stride) + reinterpret_cast<T *>(out_ptr), normalized_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = (i * out_axis_stride) + reinterpret_cast<T *>(out_ptr);
+ float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+ if (IS_LOG)
+ {
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+ (*(base_ptr_tmp + (4 * k + j)) - vec_sum_transformed.val[k][j]));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_out + (4 * num_remaining_full + j)) =
+ utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) -
+ vec_sum_transformed.val[num_remaining_full][j]);
+ }
+ }
+ else
+ {
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+ *(base_ptr_tmp + (4 * k + j)) * vec_sum_transformed.val[k][j] -
+ (is_qasymm8_signed ? 128.f : 0));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_out + (4 * num_remaining_full + j)) =
+ utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) *
+ vec_sum_transformed.val[num_remaining_full][j] -
+ (is_qasymm8_signed ? 128.f : 0));
+ }
+ }
+ }
+ }
+ } // Normalize exponentials
+ },
+ in_it, out_it);
+}
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_signed_t, false>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_t, true>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_t, false>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 60380cd23..e417271d0 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,8 +62,9 @@ inline float wrapper_vaddv(const float32x4_t &a, int sum_stages)
// The template implementation for float data types is stored in the header file because
// we need all fp16 instantiated code to live in fp16.cpp files.
template <typename T, bool IS_LOG>
-void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
{
+ ARM_COMPUTE_UNUSED(axis);
ARM_COMPUTE_UNUSED(tmp);
const int input_width = in->info()->valid_region().shape.x();
@@ -228,9 +229,199 @@ void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float
},
in_it, out_it);
}
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_float(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(tmp);
+
+ Iterator in_it(in, window);
+ Iterator out_it(out, window);
+
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+ constexpr int vec_size = 16 / sizeof(T);
+ const ITensorInfo *in_info = in->info();
+ const ITensorInfo *out_info = out->info();
+ const int x_width = in_info->valid_region().shape.x();
+ const unsigned int in_axis_stride = in_info->strides_in_bytes()[axis];
+ const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
+ const int axis_width = in_info->dimension(axis);
+
+ execute_window_loop(
+ window,
+ [&](const Coordinates &winCoords)
+ {
+ const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
+
+ /* Get pointers */
+ const uint8_t *in_ptr = in_it.ptr();
+ uint8_t *out_ptr = out_it.ptr();
+
+ // Init max value
+ auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+ /* Compute Max */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const auto current_value =
+ wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+ vec_max = wrapper::vmax(vec_max, current_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ const auto current_value = *(base_ptr_in + j);
+ vec_max[j] = std::max(vec_max[j], current_value);
+ }
+ }
+ }
+ } // compute max
+
+ auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+ /* Init sum to zero */
+ auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ /* Compute exponentials and sum */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
+ /* Loop over row and compute exponentials and sum */
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+ vec_elements = wrapper::vsub(vec_elements, vec_max);
+ if (IS_LOG)
+ {
+ vec_elements = wrapper::vmul(vec_elements, beta_vec);
+ vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+ }
+ else
+ {
+ vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+ vec_sum = wrapper::vadd(vec_sum, vec_elements);
+ }
+
+ wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
+ }
+
+ if (!IS_LOG)
+ {
+ vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
+ }
+ else
+ {
+ vec_sum_transformed = wrapper::vlog(vec_sum);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ vec_elements[j] = *(base_ptr_in + j);
+ vec_elements[j] -= vec_max[j];
+ if (IS_LOG)
+ {
+ vec_elements[j] *= beta;
+ vec_sum[j] += std::exp(vec_elements[j]);
+ }
+ else
+ {
+ vec_elements[j] = std::exp(vec_elements[j] * beta);
+ vec_sum[j] += vec_elements[j];
+ }
+ *(base_ptr_out + j) = vec_elements[j];
+ }
+ }
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ if (!IS_LOG)
+ {
+ vec_sum_transformed[j] = 1 / vec_sum[j];
+ }
+ else
+ {
+ vec_sum_transformed[j] = std::log(vec_sum[j]);
+ }
+ }
+ }
+ } // Compute exponentials and sum
+
+ /* Normalize exponentials */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ /* Loop over row and compute softmax */
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ auto vec_in = wrapper::vloadq(base_ptr_out);
+ if (IS_LOG)
+ {
+ wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
+ }
+ else
+ {
+ wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
+ }
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ if (IS_LOG)
+ {
+ *(base_ptr_out + j) -= vec_sum_transformed[j];
+ }
+ else
+ {
+ *(base_ptr_out + j) *= vec_sum_transformed[j];
+ }
+ }
+ }
+ }
+ } // Normalize exponentials
+ },
+ in_it, out_it);
+}
+template <typename T, bool IS_LOG>
+void neon_softmax_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+void neon_softmax_non_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index 9589ebcd7..d39240bb3 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,15 +30,23 @@ namespace arm_compute
namespace cpu
{
template <bool IS_LOG>
-void neon_qasymm8_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_qasymm8_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_qasymm8_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_qasymm8_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_qasymm8_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 0bf6b2859..26fd5dbfa 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,15 +31,22 @@ namespace cpu
{
template <bool IS_LOG>
void neon_qasymm8_signed_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
template void neon_qasymm8_signed_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
template void neon_qasymm8_signed_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
new file mode 100644
index 000000000..bcd34d1ca
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
@@ -0,0 +1,774 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f16_softmax_kernel( //
+ const float16_t *src,
+ float16_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, x_fp32_lower_halves, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value, x_fp32_upper_halves
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate for find-max & normalize loops
+ // * p2-p4: left-over predicates for regularize loop
+ // * p4-p7: underflow in vector loop
+ // * p5-p6: underflow in leftover loop
+ // *
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+ fcvt h26, s26
+ dup z26.h, z26.h[0]
+
+ mov w10, #0xfc00 // -inf: 0xfc00 for fp16
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cnth x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ dup z16.h, w10
+ dup z17.h, w10
+ dup z18.h, w10
+ dup z19.h, w10
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.h, w10 // z11: max_value = -inf
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+ .inst 0xc16cb910 // fmax {z16.h-z19.h}, {z16.h-z19.h}, {z12.h-z15.h} // z16-z19: max_value = max(max_value, x)
+
+ inch x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1h z12.h, p1/z, [x27, x9, LSL #1] // z12: x
+ fmax z16.h, p1/m, z16.h, z12.h // z16: max_value = max(max_value, x)
+
+ inch x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc172b110 // fmax {z16.h-z17.h}, {z16.h-z17.h}, {z18.s-z19.h}
+ fmax z16.h, p0/m, z16.h, z17.h
+ fmaxv h16, p0, z16.h
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.h, z16.h[0]
+
+ // ==================================================
+ // Step 2: Regularize, i.e. Calculate exp(x - max(x)
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value (in fp32)
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.h, z12.h, z11.h
+ fsub z13.h, z13.h, z11.h
+ fsub z14.h, z14.h, z11.h
+ fsub z15.h, z15.h, z11.h
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.h, z12.h, z26.h
+ fmul z13.h, z13.h, z26.h
+ fmul z14.h, z14.h, z26.h
+ fmul z15.h, z15.h, z26.h
+
+ // ----------------------------------------------------------------
+ // Convert fp16 values to fp32. This results in four more registers.
+ // z12 --> z12, z28
+ fcvtlt z28.s, p0/m, z12.h
+ fcvt z12.s, p0/m, z12.h
+
+ // z13 --> z13, z29
+ fcvtlt z29.s, p0/m, z13.h
+ fcvt z13.s, p0/m, z13.h
+
+ // z14 --> z14, z30
+ fcvtlt z30.s, p0/m, z14.h
+ fcvt z14.s, p0/m, z14.h
+
+ // z15 --> z15, z31
+ fcvtlt z31.s, p0/m, z15.h
+ fcvt z15.s, p0/m, z15.h
+
+ // ----------------------------------------------------------------
+ // Process z12-z15
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z12-z13)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z14-z15)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z12.s, p4, z10.s, z16.s
+ sel z13.s, p5, z10.s, z17.s
+ sel z14.s, p6, z10.s, z18.s
+ sel z15.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17d80 // fadd za.s[w11, #0, VGx4], {z12.s-z15.s} za0-za3: sum_value = sum_value + poly
+
+ // ----------------------------------------------------------------
+ // Process z28-z31
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z28.s, z9.s
+ fcmlt p5.s, p0/z, z29.s, z9.s
+ fcmlt p6.s, p0/z, z30.s, z9.s
+ fcmlt p7.s, p0/z, z31.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z28.s, z6.s
+ fmla z17.s, p0/m, z29.s, z6.s
+ fmla z18.s, p0/m, z30.s, z6.s
+ fmla z19.s, p0/m, z31.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z24-z27: r_hi = x + n * neg_ln2_hi
+ fmla z28.s, p0/m, z20.s, z7.s
+ fmla z29.s, p0/m, z21.s, z7.s
+ fmla z30.s, p0/m, z22.s, z7.s
+ fmla z31.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z27-z30: r = r_hi + n * neg_ln2_lo
+ fmla z28.s, p0/m, z20.s, z8.s
+ fmla z29.s, p0/m, z21.s, z8.s
+ fmla z30.s, p0/m, z22.s, z8.s
+ fmla z31.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z28-z29)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z28.s, z0.s
+ fmul z21.s, z29.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z28.s, z2.s
+ fmla z23.s, p0/m, z29.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z25: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z28.s, z4.s
+ fmla z25.s, p0/m, z29.s, z4.s
+
+ // ---------------------------------------------------------------- z28-z29: r2 = r * r
+ fmul z28.s, z28.s, z28.s
+ fmul z29.s, z29.s, z29.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z28.s, z24.s
+ fmla z23.s, p0/m, z29.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z28.s, z22.s
+ fmla z21.s, p0/m, z29.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z30-z31)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z30.s, z0.s
+ fmul z21.s, z31.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z30.s, z2.s
+ fmla z23.s, p0/m, z31.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z30.s, z4.s
+ fmla z25.s, p0/m, z31.s, z4.s
+
+ // ---------------------------------------------------------------- z30-z31: r2 = r * r
+ fmul z30.s, z30.s, z30.s
+ fmul z31.s, z31.s, z31.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z30.s, z24.s
+ fmla z23.s, p0/m, z31.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z30.s, z22.s
+ fmla z21.s, p0/m, z31.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z28.s, p4, z10.s, z16.s
+ sel z29.s, p5, z10.s, z17.s
+ sel z30.s, p6, z10.s, z18.s
+ sel z31.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17f80 // fadd za.s[w11, #0, VGx4], {z28.s-z31.s} za0-za3: sum_value = sum_value + poly
+
+ fcvt z12.h, p0/m, z12.s
+ fcvtnt z12.h, p0/m, z28.s
+
+ fcvt z13.h, p0/m, z13.s
+ fcvtnt z13.h, p0/m, z29.s
+
+ fcvt z14.h, p0/m, z14.s
+ fcvtnt z14.h, p0/m, z30.s
+
+ fcvt z15.h, p0/m, z15.s
+ fcvtnt z15.h, p0/m, z31.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p2.h, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1h z12.h, p2/z, [x27, x9, LSL #1] // x12: input_data
+
+ fsub z12.h, z12.h, z11.h // z12: x = input_data - max_value
+ fmul z12.h, z12.h, z26.h // z12: x = (input_data - max_value) * beta
+
+ // ---------------------------------------------------------------- z12.h --> z12.s, z13.s
+ fcvtlt z13.s, p2/m, z12.h
+ fcvt z12.s, p2/m, z12.h
+
+ // ---------------------------------------------------------------- p3, p4: predicates for z12, z14
+ pfalse p1.b
+ trn1 p3.h, p2.h, p1.h // for z12
+ trn2 p4.h, p2.h, p1.h // for z13
+
+ mov z16.d, z5.d // z16: shift
+ mov z17.d, z5.d // z17: shift
+ fcmlt p5.s, p3/z, z12.s, z9.s // p5: underflow = x < min_input
+ fcmlt p6.s, p4/z, z13.s, z9.s // p6: underflow = x < min_input
+ fmla z16.s, p3/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fmla z17.s, p4/m, z13.s, z6.s // z17: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fsub z21.s, z17.s, z5.s // z21: n = z - shift
+ fmla z12.s, p3/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z13.s, p4/m, z21.s, z7.s // z13: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p3/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ fmla z13.s, p4/m, z21.s, z8.s // z13: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p3/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ urshl z17.s, p4/m, z17.s, z10.s // z17: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ fmul z21.s, z13.s, z0.s // z21: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ mov z23.d, z1.d // z23: p23 = c2
+ fmla z22.s, p3/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ fmla z23.s, p4/m, z13.s, z2.s // z23: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ mov z25.d, z3.d // z25: c4
+ fmla z24.s, p3/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmla z25.s, p4/m, z13.s, z4.s // z25: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmul z13.s, z13.s, z13.s // z13: r2 = r * r
+ fmla z22.s, p3/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z23.s, p4/m, z13.s, z25.s // z23: p2345 = p23 + r2 * p45
+ fmla z20.s, p3/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z21.s, p4/m, z13.s, z23.s // z21: p12345 = p1 + r2 * p2345
+ fmla z16.s, p3/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ fmla z17.s, p4/m, z21.s, z17.s // z17: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p5, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+ sel z17.s, p6, z10.s, z17.s // z17: poly = underflow ? 0 : poly
+ fadd z28.s, p3/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+ fadd z28.s, p4/m, z28.s, z17.s // z28: sum_value = sum_value + poly
+
+ fcvt z16.h, p3/m, z16.s
+ fcvtnt z16.h, p4/m, z17.s
+ st1h z16.h, p2, [x28, x9, LSL #1]
+
+ inch x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ faddv s28, p0, z28.s
+ fmov s29, #1.0 // 1.0f
+ fdiv s28, s29, s28
+ fcvt h28, s28
+
+ dup z28.h, z28.h[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009a78c // ld1h {z12.h-z15.h}, pn9/z, [x28, x9, LSL #1]
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.h, z12.h, z28.h
+ fmul z13.h, z13.h, z28.h
+ fmul z14.h, z14.h, z28.h
+ fmul z15.h, z15.h, z28.h
+
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1h z12.h, p1/z, [x28, x9, LSL #1] // z12: x
+ fmul z12.h, z12.h, z28.h // z12: result = x * inv_sum_value
+
+ st1h z12.h, p1, [x28, x9, LSL #1]
+
+ inch x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", "x14", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp16_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float16_t *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float16_t *>(out->buffer() + k_dst_offset);
+
+ sme2_f16_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
new file mode 100644
index 000000000..159039a32
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
@@ -0,0 +1,578 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f32_softmax_kernel( //
+ const float *src,
+ float *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(float)
+ // * dst_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate
+ // * p4-p7: underflow
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+
+ mov w10, #0x0000 // -inf: 0xff800000
+ movk w10, #0xff80 // -inf: 0xff800000
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntw x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.s, w10 // z11: max_value = -inf
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ mov z16.d, z11.d
+ mov z17.d, z11.d
+ mov z18.d, z11.d
+ mov z19.d, z11.d
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x
+ .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x)
+
+ incw x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x
+ fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x)
+
+ incw x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s}
+ fmax z16.s, p0/m, z16.s, z17.s
+ fmaxv s16, p0, z16.s
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.s, z16.s[0]
+
+ // ==================================================
+ // Step 2: Regularize
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.s, z12.s, z11.s
+ fsub z13.s, z13.s, z11.s
+ fsub z14.s, z14.s, z11.s
+ fsub z15.s, z15.s, z11.s
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.s, z12.s, z26.s
+ fmul z13.s, z13.s, z26.s
+ fmul z14.s, z14.s, z26.s
+ fmul z15.s, z15.s, z26.s
+
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors.
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z16.s, p4, z10.s, z16.s
+ sel z17.s, p5, z10.s, z17.s
+ sel z18.s, p6, z10.s, z18.s
+ sel z19.s, p7, z10.s, z19.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
+
+ .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly
+
+ incw x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data
+
+ fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value
+ fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta
+
+ mov z16.d, z5.d // z16: shift
+ fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
+ fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+
+ st1w z16.s, p1, [x28, x9, LSL #2]
+
+ fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+
+ incw x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ fmov s29, #1.0 // 1.0f
+ faddv s28, p0, z28.s
+ fdiv s28, s29, s28
+ dup z28.s, z28.s[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.s, z12.s, z28.s
+ fmul z13.s, z13.s, z28.s
+ fmul z14.s, z14.s, z28.s
+ fmul z15.s, z15.s, z28.s
+
+ .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2]
+
+ incw x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x
+ fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value
+
+ st1w z12.s, p1, [x28, x9, LSL #2]
+
+ incw x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset);
+
+ sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index c143f6659..1bb8ed50f 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,13 +30,23 @@ namespace cpu
{
#define DECLARE_SOFTMAX_KERNEL(func_name) \
template <bool IS_LOG> \
- void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+ void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax);
DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax);
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+void sme2_fp32_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+
+void sme2_fp16_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+
+#endif // ARM_COMPUTE_ENABLE_SME2
+
#undef DECLARE_SOFTMAX_KERNEL
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuConv2d.cpp b/src/cpu/operators/CpuConv2d.cpp
index 19311733d..26ca2ee78 100644
--- a/src/cpu/operators/CpuConv2d.cpp
+++ b/src/cpu/operators/CpuConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -209,12 +209,24 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i
}
else
{
+ const bool gemmDirectConv2d_validates =
+ bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info));
+
// SRGAN
// Output might not be initialized when it is an internal tensor of the layer using the convolution
- if (input->total_size() > 1e7 && (weights->dimension(idx_h) > 7) &&
- (CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info)))
+ if (input->total_size() > 1e7 && weights->dimension(idx_h) > 7)
{
- return ConvolutionMethod::DIRECT;
+ // This configuration is memory demanding for GEMM method. GEMM_CONV2D which uses indirect convolution
+ // kernels underneath is the best option.
+ if (gemmDirectConv2d_validates)
+ {
+ return ConvolutionMethod::GEMM_CONV2D;
+ }
+ else if (bool(CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info)))
+ {
+ // NCHW data layout is not supported by GEMM_CONV2D
+ return ConvolutionMethod::DIRECT;
+ }
}
if (input->dimension(idx_c) < 16)
{
@@ -270,7 +282,7 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i
{
return ConvolutionMethod::WINOGRAD;
}
- if (bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info)))
+ if (gemmDirectConv2d_validates)
{
return ConvolutionMethod::GEMM_CONV2D;
}
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index e035de013..905e86c18 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
asm_info.weight_format = info.weight_format();
+ asm_info.accumulate = info.accumulate();
asm_info.transpose_b =
info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
@@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0.
+ // Do the appropriate checks before proceeding.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (beta != 0 && c != nullptr),
+ "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c");
+ }
+
const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
// Check if we should use the pretransposed_b or original b
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 7460f2020..55d950ff4 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -809,9 +809,16 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
if (!_skip_im2col)
{
// Run input reshaping
- unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}};
- NEScheduler::get().schedule_op(_im2col_kernel.get(), y_dim, _im2col_kernel->window(), pack);
+ unsigned int hint_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ unsigned int x_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ unsigned int hint_dim_iterations = _im2col_kernel->window().num_iterations(hint_dim);
+ unsigned int x_dim_iterations = _im2col_kernel->window().num_iterations(x_dim);
+ if (hint_dim_iterations < NEScheduler::get().num_threads() && x_dim_iterations > hint_dim_iterations)
+ {
+ hint_dim = x_dim;
+ }
+ ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}};
+ NEScheduler::get().schedule_op(_im2col_kernel.get(), hint_dim, _im2col_kernel->window(), pack);
gemm_input_to_use = im2col_output.get();
}
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index b25505a85..f3396fbb5 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
+ asm_info.accumulate = info.accumulate();
return asm_info;
}
@@ -127,6 +128,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
const ITensorInfo *a_to_use = a;
@@ -228,8 +234,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
// Build reduction info
const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false);
- // Initialize matrix B reduction kernel only if _a_offset is not equal to 0
- if (_a_offset != 0)
+ if (a_offset_kernel_needed)
{
_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32);
@@ -238,8 +243,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mtx_b_reduction_kernel->configure(b, &_vector_sum_col, reduction_info);
}
- // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
- if (_b_offset != 0)
+ if (b_offset_kernel_needed)
{
_vector_sum_row = TensorInfo(compute_reductionB_shape(*a_to_use), 1, DataType::S32);
@@ -260,8 +264,8 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_offset_contribution_output_stage_kernel =
std::make_unique<kernels::CpuGemmLowpOffsetContributionOutputStageKernel>();
_offset_contribution_output_stage_kernel->configure(
- &_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col,
- _b_offset == 0 ? nullptr : &_vector_sum_row, c, _flip_signedness ? &_signed_output : dst,
+ &_mm_result_s32, a_offset_kernel_needed ? &_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &_vector_sum_row : nullptr, c, _flip_signedness ? &_signed_output : dst,
a->dimension(0), _a_offset, _b_offset, info.gemmlowp_output_stage());
if (_flip_signedness)
@@ -272,6 +276,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
else
{
+ // This scale is needed for the s8_f32 kernel where the multiplication output is dequantized to F32.
+ const float dequantize_scale =
+ (dst->data_type() == DataType::F32)
+ ? a->quantization_info().uniform().scale * b->quantization_info().uniform().scale
+ : 1.0f;
// Configure matrix multiply kernel
if (!_assembly_path)
{
@@ -280,9 +289,9 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
// Configure offset contribution kernel
_offset_contribution_kernel = std::make_unique<kernels::CpuGemmLowpOffsetContributionKernel>();
- _offset_contribution_kernel->configure(dst, _a_offset == 0 ? nullptr : &_vector_sum_col,
- _b_offset == 0 ? nullptr : &_vector_sum_row, a_to_use->dimension(0),
- _a_offset, _b_offset);
+ _offset_contribution_kernel->configure(dst, a_offset_kernel_needed ? &_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &_vector_sum_row : nullptr,
+ a_to_use->dimension(0), _a_offset, _b_offset, dequantize_scale);
}
}
// Configure activation
@@ -305,11 +314,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
// Request memory for LHS and RHS reshape matrix
- _aux_mem[VectorSumCol] =
- MemoryInfo(offset_int_vec(VectorSumCol),
- !_fused_assembly_path && _a_offset != 0 && _reshape_b_only_on_first_run ? MemoryLifetime::Persistent
- : MemoryLifetime::Temporary,
- _vector_sum_col.total_size());
+ _aux_mem[VectorSumCol] = MemoryInfo(offset_int_vec(VectorSumCol),
+ !_fused_assembly_path && a_offset_kernel_needed && _reshape_b_only_on_first_run
+ ? MemoryLifetime::Persistent
+ : MemoryLifetime::Temporary,
+ _vector_sum_col.total_size());
_aux_mem[VectorSumRow] =
MemoryInfo(offset_int_vec(VectorSumRow), MemoryLifetime::Temporary, _vector_sum_row.total_size());
_aux_mem[TmpA] = MemoryInfo(offset_int_vec(TmpA), MemoryLifetime::Temporary, _tmp_a.total_size());
@@ -333,8 +342,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32, DataType::QASYMM8,
- DataType::QASYMM8_SIGNED);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr &&
+ DataType::QASYMM8_SIGNED, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr && output->data_type() != DataType::F32 &&
gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::NONE,
"Bias addition not supported in NEGEMMLowpMatrixMultiplyCore for output S32");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
@@ -343,6 +352,16 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ // When using accumulation(in place summation), for now, the only supported DataType for output is S32.
+ if (gemm_info.accumulate())
+ {
+#ifdef __arm__
+ ARM_COMPUTE_RETURN_ERROR_MSG("Accumulation is not supported for armv7");
+#endif /* __arm__ */
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE,
+ "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED");
+ }
+
GEMMInfo info = gemm_info;
const ITensorInfo *matrix_a_info = a;
const ITensorInfo *matrix_b_info = b;
@@ -356,6 +375,10 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
+
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -478,7 +501,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false);
// Validate matrix B reduction kernel only if _a_offset is not equal to 0
- if (a_offset != 0)
+ if (a_offset_kernel_needed)
{
info_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32);
@@ -488,7 +511,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
}
// Validate Matrix A reduction kernel only if _b_offset is not equal to 0
- if (b_offset != 0)
+ if (b_offset_kernel_needed)
{
info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32);
@@ -514,9 +537,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
// Validate offset contribution kernel
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionOutputStageKernel::validate(
- &mm_result_s32_info, a_offset == 0 ? nullptr : &info_vector_sum_col,
- b_offset == 0 ? nullptr : &info_vector_sum_row, c, flip_signedness ? &signed_output : output, a_offset,
- b_offset, info.gemmlowp_output_stage()));
+ &mm_result_s32_info, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &info_vector_sum_row : nullptr, c, flip_signedness ? &signed_output : output,
+ a_offset, b_offset, info.gemmlowp_output_stage()));
}
else
{
@@ -534,8 +557,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
}
// Validate offset contribution kernel
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionKernel::validate(
- output, a_offset == 0 ? nullptr : &info_vector_sum_col, b_offset == 0 ? nullptr : &info_vector_sum_row,
- a_offset, b_offset));
+ output, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &info_vector_sum_row : nullptr, a_offset, b_offset));
}
}
@@ -569,6 +592,14 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
CpuAuxTensorHandler signed_a(offset_int_vec(SignedA), _signed_a, tensors, false);
CpuAuxTensorHandler signed_output(offset_int_vec(SignedOutput), _signed_output, tensors, false);
+ const QuantizationInfo a_qinfo = a->info()->quantization_info();
+ const QuantizationInfo b_qinfo = b->info()->quantization_info();
+
+ if (a_qinfo.is_dynamic())
+ _a_offset = a_qinfo.uniform().offset;
+ if (b_qinfo.is_dynamic())
+ _b_offset = b_qinfo.uniform().offset;
+
// Convert QASYMM8->QASYMM8_SIGNED
if (_flip_signedness)
{
@@ -651,6 +682,11 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
if (_fuse_output_stage)
{
+ if (a_qinfo.is_dynamic())
+ _offset_contribution_output_stage_kernel->set_a_offset(_a_offset);
+ if (b_qinfo.is_dynamic())
+ _offset_contribution_output_stage_kernel->set_b_offset(_b_offset);
+
ITensorPack pack;
pack.add_tensor(TensorType::ACL_SRC_0, mm_result_s32.get());
pack.add_tensor(TensorType::ACL_SRC_1, _a_offset == 0 ? nullptr : vector_sum_col.get());
@@ -664,6 +700,16 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
}
else
{
+ if (a_qinfo.is_dynamic())
+ _offset_contribution_kernel->set_a_offset(_a_offset);
+ if (b_qinfo.is_dynamic())
+ _offset_contribution_kernel->set_b_offset(_b_offset);
+ if (a_qinfo.is_dynamic() || b_qinfo.is_dynamic())
+ {
+ const float dequantize_scale = a_qinfo.uniform().scale * b_qinfo.uniform().scale;
+ _offset_contribution_kernel->set_scale(dequantize_scale);
+ }
+
ITensorPack pack;
pack.add_tensor(TensorType::ACL_SRC_0, _a_offset == 0 ? nullptr : vector_sum_col.get());
pack.add_tensor(TensorType::ACL_SRC_1, _b_offset == 0 ? nullptr : vector_sum_row.get());
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 78065a895..38121c9bb 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -92,6 +92,7 @@ public:
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
+ * |QASYMM8_SIGNED |QASYMM8_SIGNED |F32 |F32 |
*
* @note GEMM_LOWP: low precision GEMM kernel
* This kernel performs the following computations:
@@ -100,12 +101,12 @@ public:
* -# Convert b values from QASYMM8 to int32 add b_offset to each of them.
* -# Compute the matrix product of the resulting a * b in int32.
*
- * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise
+ * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED/F32 otherwise
*
* @param[in] a First input tensor info (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
* @param[in] b Second input tensor info (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL.
- * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32
- * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED
+ * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32/F32
+ * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED/F32
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
* if the reshape of matrix B should be executed only for the first run
*/
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index 89087129c..f68ae9883 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,8 +102,8 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8,
- DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::BFLOAT16,
+ DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
@@ -120,6 +120,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
auto gemm_info = AsmGemmInfo();
gemm_info.activation_info = act_info;
gemm_info.fast_mode = settings.fast_math();
+ gemm_info.fixed_format = settings.fixed_format();
// Validate and then permute a/b
if (adj_lhs)
@@ -157,6 +158,14 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
gemm_info.activation_info, gemm_info.output_stage));
}
+ if (gemm_info.fixed_format)
+ {
+ gemm_info.weight_format = WeightFormat::ANY;
+ arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, lhs_to_use,
+ rhs_to_use, nullptr, dst, gemm_info));
+ }
+
cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
return Status{};
@@ -221,6 +230,7 @@ void CpuMatMul::configure(ITensorInfo *lhs,
// Fill AsmGemmInfo class object before configuration
_gemm_info.activation_info = act_info;
_gemm_info.fast_mode = settings.fast_math();
+ _gemm_info.fixed_format = settings.fixed_format();
_gemm_info.negated_offsets = false;
lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
@@ -233,6 +243,18 @@ void CpuMatMul::configure(ITensorInfo *lhs,
_gemm_info.output_stage);
}
+ if (_gemm_info.fixed_format)
+ {
+ _gemm_info.weight_format = WeightFormat::ANY;
+ arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+ ARM_COMPUTE_ERROR_THROW_ON(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, &lhs_to_use,
+ &rhs_to_use, nullptr, dst, _gemm_info));
+ // Set gemm weights info to the one returned by has_opt_impl
+ _gemm_info.weight_format = expected_weight_format;
+ // has_opt_impl may return a non fast math kernel, even if we requested one
+ _gemm_info.fast_mode = arm_compute::is_fixed_format_fast_math(expected_weight_format);
+ }
+
// Configure Asm Kernel
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
_asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use,
diff --git a/src/cpu/operators/CpuQuantize.cpp b/src/cpu/operators/CpuQuantize.cpp
index 4315499c3..4a3f1827c 100644
--- a/src/cpu/operators/CpuQuantize.cpp
+++ b/src/cpu/operators/CpuQuantize.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,7 +55,8 @@ void CpuQuantize::configure(const ITensorInfo *src, ITensorInfo *dst)
void CpuQuantize::run(ITensorPack &tensors)
{
ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
- NEScheduler::get().schedule_op(_kernel.get(), Window::DimY, _kernel->window(), tensors);
+ auto split_dimension = static_cast<kernels::CpuQuantizeKernel *>(_kernel.get())->get_split_dimension_hint();
+ NEScheduler::get().schedule_op(_kernel.get(), split_dimension, _kernel->window(), tensors);
}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuSoftmax.cpp b/src/cpu/operators/CpuSoftmax.cpp
index ae14381ad..fecee7d76 100644
--- a/src/cpu/operators/CpuSoftmax.cpp
+++ b/src/cpu/operators/CpuSoftmax.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,15 +41,7 @@ namespace arm_compute
{
namespace cpu
{
-CpuSoftmaxGeneric::CpuSoftmaxGeneric()
- : _permute_input(),
- _permute_output(),
- _softmax_kernel(),
- _tmp(),
- _input_permuted(),
- _output_permuted(),
- _needs_permute(false),
- _aux_mem(InternalTensorIdx::COUNT)
+CpuSoftmaxGeneric::CpuSoftmaxGeneric() : _softmax_kernel(), _tmp(), _aux_mem(InternalTensorIdx::COUNT)
{
}
@@ -63,17 +55,9 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
const unsigned int actual_axis =
static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
- _needs_permute = actual_axis > 0;
+ _axis = actual_axis;
- if (_needs_permute)
- {
- _permute_input.configure(src, &_input_permuted,
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
- }
-
- // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
- // or it is the original input case (2D case)
- const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
+ const ITensorInfo *tmp_input = src;
TensorInfo tensor_info_tmp;
if (is_data_type_quantized_asymmetric(src->data_type()))
@@ -88,20 +72,10 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
// Configure kernels
auto sm = std::make_unique<kernels::CpuSoftmaxKernel>();
- if (_needs_permute)
- {
- // The normalization kernel stores the result in a permuted output tensor
- sm->configure(tmp_input, &_output_permuted, beta, is_log, &_tmp);
- // Re-permute the permuted output into the requested (4D) output
- _permute_output.configure(&_output_permuted, dst,
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
- }
- else
- {
- // Softmax 2D case
- sm->configure(tmp_input, dst, beta, is_log, &_tmp);
- }
+ // Softmax 2D case
+ sm->configure(tmp_input, dst, beta, is_log, actual_axis, &_tmp);
+
_softmax_kernel = std::move(sm);
if (_tmp.total_size() > 0)
@@ -109,11 +83,6 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
_aux_mem[InternalTensorIdx::TMP] =
MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
}
-
- _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC),
- MemoryLifetime::Temporary, _input_permuted.total_size());
- _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST),
- MemoryLifetime::Temporary, _output_permuted.total_size());
}
Status
@@ -133,25 +102,11 @@ CpuSoftmaxGeneric::validate(const ITensorInfo *src, const ITensorInfo *dst, floa
{
tensor_info_tmp = src->clone()->set_data_type(DataType::F32).set_is_resizable(true);
}
-
const unsigned int actual_axis =
static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
- const bool needs_permute = actual_axis > 0;
-
- if (needs_permute)
- {
- const PermutationVector permutation_vector =
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
- const TensorShape permuted_shape =
- misc::shape_calculator::compute_permutation_output_shape(*src, permutation_vector);
- TensorInfo input_permuted(src->clone()->set_tensor_shape(permuted_shape));
- ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(src, &input_permuted, permutation_vector));
- TensorInfo output_permuted(dst->clone()->set_tensor_shape(permuted_shape));
- ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&output_permuted, dst, permutation_vector));
- }
-
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuSoftmaxKernel::validate(src, dst, beta, is_log, &tensor_info_tmp));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ kernels::CpuSoftmaxKernel::validate(src, dst, beta, actual_axis, is_log, &tensor_info_tmp));
return Status{};
}
@@ -165,34 +120,17 @@ void CpuSoftmaxGeneric::run(ITensorPack &tensors)
CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, true);
- CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, true);
- CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors,
- true);
-
ITensorPack softmax_pack;
- if (_needs_permute)
- {
- ITensorPack permute_in_pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, input_permuted.get()}};
- _permute_input.run(permute_in_pack);
+ softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
- softmax_pack = {{TensorType::ACL_SRC_0, input_permuted.get()},
- {TensorType::ACL_DST_0, output_permuted.get()},
- {TensorType::ACL_DST_1, tmp.get()}};
- }
- else
+ if (_axis == 0)
{
- softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
+ NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
}
-
- NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
-
- if (_needs_permute)
+ else
{
- ITensorPack permute_out_pack;
- permute_out_pack.add_tensor(TensorType::ACL_SRC, output_permuted.get());
- permute_out_pack.add_tensor(TensorType::ACL_DST, dst);
- _permute_output.run(permute_out_pack);
+ NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimX, _softmax_kernel->window(), softmax_pack);
}
}
diff --git a/src/cpu/operators/CpuSoftmax.h b/src/cpu/operators/CpuSoftmax.h
index 47020e9b7..6ba3476ef 100644
--- a/src/cpu/operators/CpuSoftmax.h
+++ b/src/cpu/operators/CpuSoftmax.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -89,16 +89,13 @@ private:
COUNT
};
- CpuPermute _permute_input;
- CpuPermute _permute_output;
std::unique_ptr<ICPPKernel> _softmax_kernel;
TensorInfo _tmp;
- TensorInfo _input_permuted;
- TensorInfo _output_permuted;
- bool _needs_permute;
experimental::MemoryRequirements _aux_mem{};
+
+ unsigned int _axis = 0;
};
} // namespace cpu
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 611bc7646..7d8588565 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,7 +60,8 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutpu
const TypeInput *src,
int src_ld,
int src_multi_stride,
- unsigned int num_threads)
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -77,7 +78,8 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutpu
if (start < end)
{
- gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+ gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, transpose, start,
+ end);
}
};
}
@@ -279,6 +281,8 @@ private:
bool _B_pretranspose_required{false};
bool _is_b_constant{true};
bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -443,8 +447,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const AsmGemmInfo &gemm_info,
const OutputStage &os)
{
- ARM_COMPUTE_UNUSED(c);
-
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
@@ -479,16 +481,23 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_optimised_kernel = std::move(acl_gemm_wrapper);
_gemm_info = gemm_info;
+
// Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose.
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
- if (run_pre_pretranspose_b)
+ _B_pre_pretranspose_required = _gemm_info.transpose_b && !isVarWeightsKernel();
+ _B_pretranspose_required = _gemm_kernel_asm->B_pretranspose_required();
+
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ const bool kernel_can_fuse_transpose = _B_pretranspose_required && kernel_supports_transpose;
+ _run_pre_pretranspose_b = _B_pre_pretranspose_required && !kernel_can_fuse_transpose;
+
+ if (_run_pre_pretranspose_b)
{
_pre_pretranspose_b = std::make_unique<CpuTranspose>();
_pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info);
MemoryLifetime lifetime;
if (_is_b_constant)
{
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose
// So PrePretransposedB can be freed inside prepare()
@@ -513,7 +522,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
}
// Check for pre-transposed support
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
@@ -524,7 +533,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
_aux_mem[Pretranspose] =
MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
- _B_pretranspose_required = true;
}
// Handle indirect GEMM convolution
@@ -532,6 +540,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
{
configure_indirect(a, b, d, gemm_info);
}
+
+ if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value)
+ {
+ // Output dequantization is just the two src scales multiplied together
+ _gemm_kernel_asm->set_dequantize_scale(a->quantization_info().uniform().scale *
+ b->quantization_info().uniform().scale);
+ }
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -550,15 +565,16 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
const ITensor *b_to_use = b;
+
// Pre-pretranspose B if required
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
CpuAuxTensorHandler pre_pretransposed_b(
offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
/*pack_inject: no need to inject into tensors*/
false,
/*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/
- !run_pre_pretranspose_b);
- if (run_pre_pretranspose_b)
+ !_run_pre_pretranspose_b);
+
+ if (_run_pre_pretranspose_b)
{
ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
@@ -567,7 +583,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
// Pretranspose B if required
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
@@ -578,13 +594,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
+
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
- in1_ptr, ldb, multi_stride_b,
- NEScheduler::get().num_threads());
+
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
b->mark_as_unused();
- // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, its memory will be auto-managed by the handler
+ // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b,
+ // its memory will be auto-managed by the handler
}
if (_gemm_info.method == AsmConvMethod::Indirect)
@@ -617,6 +637,15 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
auto d = tensors.get_tensor(TensorType::ACL_DST);
ARM_COMPUTE_ERROR_ON_NULLPTR(a, d);
+ // Only update at runtime if the src quantization is dynamic
+ if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value &&
+ (a->info()->quantization_info().is_dynamic() || b->info()->quantization_info().is_dynamic()))
+ {
+ // Output dequantization is just the two src scales multiplied together
+ _gemm_kernel_asm->set_dequantize_scale(a->info()->quantization_info().uniform().scale *
+ b->info()->quantization_info().uniform().scale);
+ }
+
int lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
int ldb = 0;
const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
@@ -640,12 +669,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const ITensor *b_to_use = b;
// Pre-pretranspose B if required
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
CpuAuxTensorHandler pre_pretransposed_b(
offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
false /*pack_inject: no need to inject into tensors*/,
- !run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
- if (b_to_use && !_is_b_constant && run_pre_pretranspose_b)
+ !_run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
+ if (b_to_use && !_is_b_constant && _run_pre_pretranspose_b)
{
ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
@@ -691,9 +719,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
else
{
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
- b_ptr, ldb, multi_stride_b,
- NEScheduler::get().num_threads());
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
}
}
@@ -762,7 +791,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -771,6 +800,39 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
}
template <typename TypeInput, typename TypeOutput>
+void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::Activation activation,
+ const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_UNUSED(activation);
+
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ const unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
+
+ // Create arm_gemm fallback
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
+
+ // Configure requantization info
+ const GEMMLowpOutputStageInfo os_info = info.output_stage;
+
+ arm_gemm::DequantizeFloat gemm_dequant_info{};
+ gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale);
+
+ fallback->configure(a, b, c, d, args, info, gemm_dequant_info);
+ arm_gemm = std::move(fallback);
+}
+
+template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -787,7 +849,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -842,7 +904,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
@@ -886,9 +948,18 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ }
break;
}
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
@@ -944,8 +1015,9 @@ Status CpuGemmAssemblyDispatch::validate(
"Only F32 output supported for F32 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16,
"Only F16 output supported for F16 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32,
- "Only F32 output supported for BFLOAT16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 &&
+ (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16),
+ "Only F32/BFLOAT16 output supported for BFLOAT16 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32,
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
@@ -1008,6 +1080,10 @@ void CpuGemmAssemblyDispatch::configure(
{
create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
+ else if (d->data_type() == DataType::F32)
+ {
+ create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ }
else
{
create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
@@ -1016,7 +1092,14 @@ void CpuGemmAssemblyDispatch::configure(
#endif /* __aarch64__ */
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ }
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 671a222fe..44c5c189a 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,6 +57,7 @@ struct AsmGemmInfo
bool fixed_format{false};
arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
bool reshape_b_only_on_first_run{true};
+ bool accumulate{false};
/** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
* @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
* by the reshape_b_only_on_first_run flag
diff --git a/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.cpp b/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.cpp
index 9ca20fa15..eab5cddd0 100644
--- a/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.cpp
+++ b/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,12 +26,11 @@
#include "arm_compute/core/CL/ICLTensor.h"
#include "src/core/CL/CLUtils.h"
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/runtime/gpu/cl/ckw_driver/GpuCkwKernelArgumentsHelpers.h"
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h"
#include "src/gpu/cl/ClKernelLibrary.h"
#include "support/Cast.h"
+
namespace arm_compute
{
namespace experimental
@@ -61,128 +60,6 @@ void ClKernelRuntime::configure(const ClCompileContext &compile_ctx, const GpuKe
_arguments = code.arguments();
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-
-inline void ClKernelRuntime::add_tensor_argument(unsigned int &idx,
- const GpuKernelArgumentInfo &arg,
- const ICLTensor *tensor,
- const Window &arg_slice,
- std::vector<cl::Image2D> &cl_images)
-{
- ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
-
- switch (arg.type)
- {
- case GpuKernelArgumentInfo::Type::Scalar:
- {
- ARM_COMPUTE_ERROR("Unsupported yet");
- break;
- }
-
- case GpuKernelArgumentInfo::Type::Vector:
- {
- add_1D_tensor_argument(idx, tensor, arg_slice);
- break;
- }
-
- case GpuKernelArgumentInfo::Type::Image:
- {
- add_2D_tensor_argument(idx, tensor, arg_slice);
- break;
- }
- case GpuKernelArgumentInfo::Type::Image_Reinterpret_As_3D:
- {
- add_2D_tensor_argument(idx, tensor, arg_slice);
- const unsigned int total_cross_plane_pad = tensor->info()->padding().top + tensor->info()->padding().bottom;
- _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(total_cross_plane_pad));
- break;
- }
- case GpuKernelArgumentInfo::Type::Image_Export_To_ClImage2D:
- {
- const TensorShape shape2d(tensor->info()->dimension(0) / 4, tensor->info()->dimension(1) *
- tensor->info()->dimension(2) *
- tensor->info()->dimension(3));
- const size_t image_row_pitch = tensor->info()->strides_in_bytes()[1];
- cl::Image2D tensor_image2d =
- create_image2d_from_buffer(CLKernelLibrary::get().context(), tensor->cl_buffer(), shape2d,
- tensor->info()->data_type(), image_row_pitch, CLImage2DType::ReadOnly);
- cl_images.push_back(tensor_image2d);
- _kernel.setArg(idx++, tensor_image2d);
- break;
- }
-
- case GpuKernelArgumentInfo::Type::Image_3D:
- {
- add_2D_tensor_argument(idx, tensor, arg_slice);
- _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(tensor->info()->strides_in_bytes()[2]));
- break;
- }
- case GpuKernelArgumentInfo::Type::Image_3D_Export_To_ClImage2D:
- {
- const TensorShape shape2d(tensor->info()->dimension(0) / 4, tensor->info()->dimension(1) *
- tensor->info()->dimension(2) *
- tensor->info()->dimension(3));
- const size_t image_row_pitch = tensor->info()->strides_in_bytes()[1];
- cl::Image2D tensor_image2d =
- create_image2d_from_buffer(CLKernelLibrary::get().context(), tensor->cl_buffer(), shape2d,
- tensor->info()->data_type(), image_row_pitch, CLImage2DType::ReadOnly);
- cl_images.push_back(tensor_image2d);
- _kernel.setArg(idx++, tensor_image2d);
- _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(tensor->info()->strides_in_bytes()[2]));
- break;
- }
-
- case GpuKernelArgumentInfo::Type::Tensor_3D:
- {
- add_3D_tensor_argument(idx, tensor, arg_slice);
- break;
- }
-
- case GpuKernelArgumentInfo::Type::Tensor_4D:
- {
- add_4D_tensor_argument(idx, tensor, arg_slice);
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer:
- {
- add_4d_tensor_nhwc_argument(idx, tensor);
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Image:
- {
- const size_t image_w = tensor->info()->dimension(0) / 4;
- const size_t image_h = tensor->info()->tensor_shape().total_size_upper(1);
- const size_t image_stride_y = tensor->info()->strides_in_bytes()[1];
-
- cl::Image2D tensor_image2d = create_image2d_from_buffer(
- CLKernelLibrary::get().context(), tensor->cl_buffer(), TensorShape(image_w, image_h),
- tensor->info()->data_type(), image_stride_y, CLImage2DType::ReadOnly);
- cl_images.push_back(tensor_image2d);
-
- _kernel.setArg(idx++, tensor_image2d);
- add_4d_tensor_nhwc_argument(idx, tensor);
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_Special_0:
- {
- const ITensorInfo *info = tensor->info();
- const Strides &strides = info->strides_in_bytes();
-
- _kernel.setArg(idx++, tensor->cl_buffer());
- const size_t dim1xdim2 = info->tensor_shape()[1] * info->tensor_shape()[2];
- _kernel.setArg<cl_int>(idx++, static_cast<int32_t>(dim1xdim2));
- const size_t stride1 = strides[1];
- _kernel.setArg<cl_int>(idx++, static_cast<int32_t>(stride1));
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Unsupported");
- }
- }
-}
-
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
inline void ClKernelRuntime::add_kernel_argument(unsigned int &idx,
const GpuKernelArgumentBinding &arg,
const ICLTensor *tensor,
@@ -234,7 +111,6 @@ inline void ClKernelRuntime::add_kernel_argument(unsigned int
}
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
void ClKernelRuntime::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
@@ -253,17 +129,7 @@ void ClKernelRuntime::run_op(ITensorPack &tensors, const Window &window, cl::Com
// Set kernel arguments
// CLImages created from tensor arguments. Need to be retained until enqueue
std::vector<cl::Image2D> cl_images;
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- for (auto id_arg : _arguments)
- {
- const auto arg = id_arg.second;
- auto tensor = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(id_arg.first));
- ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
- ARM_COMPUTE_ERROR_ON_NULLPTR(tensor->info());
- add_tensor_argument(idx, *arg.kernel_argument_info(), tensor, slice, cl_images);
- }
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
for (const auto &arg : _arguments)
{
auto tensor = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(arg.id()));
@@ -271,7 +137,6 @@ void ClKernelRuntime::run_op(ITensorPack &tensors, const Window &window, cl::Com
ARM_COMPUTE_ERROR_ON_NULLPTR(tensor->info());
add_kernel_argument(idx, arg, tensor, cl_images);
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
// Dispatch kernel
enqueue(queue, *this, slice, lws_hint(), use_dummy_work_items);
diff --git a/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.h b/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.h
index e78567eb9..148e4db58 100644
--- a/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.h
+++ b/src/dynamic_fusion/runtime/gpu/cl/ClKernelRuntime.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME
-#define SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME
+#ifndef ACL_SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME_H
+#define ACL_SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME_H
#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
#include "src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h"
@@ -59,21 +59,6 @@ public:
virtual void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- /** Set a kernel tensor argument
- *
- * @param[in,out] idx Index at which to start adding the tensor's arguments. Will be incremented by the number of kernel arguments set.
- * @param[in] arg Kernel argument descriptor accompanying @p tensor
- * @param[in] tensor Tensor to set as an argument of the object's kernel
- * @param[in] arg_slice Window the kernel will be run on
- * @param[out] cl_images Extra cl images created from the tensor (will need to be retained until the kernel is enqueued)
- */
- inline void add_tensor_argument(unsigned int &idx,
- const GpuKernelArgumentInfo &arg,
- const ICLTensor *tensor,
- const Window &arg_slice,
- std::vector<cl::Image2D> &cl_images);
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
/** Set a kernel argument as part of a tensor
*
* @param[in,out] idx Index at which to start adding the tensor's arguments. Will be incremented by the number of kernel arguments set.
@@ -85,7 +70,6 @@ private:
const GpuKernelArgumentBinding &arg,
const ICLTensor *tensor,
std::vector<cl::Image2D> &cl_images);
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
private:
GpuKernelArgumentList _arguments{};
@@ -94,4 +78,4 @@ private:
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME */
+#endif // ACL_SRC_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLKERNELRUNTIME_H
diff --git a/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h b/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h
index 03817173f..c923bf9c1 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT_H
#include "arm_compute/core/TensorInfo.h"
@@ -32,96 +32,6 @@ namespace experimental
{
namespace dynamic_fusion
{
-/** Contain information required to set up a kernel argument at run time
- * @deprecated To be removed along with ClTemplateWriter
- */
-struct GpuKernelArgumentInfo
-{
- /** Enumerate all the tensor arguments variants used by all kernel implementations. */
- enum class Type : int
- {
- Scalar,
-
- Vector,
-
- Image,
- Image_Reinterpret_As_3D,
- Image_Export_To_ClImage2D,
-
- Image_3D, // 3D Tensor represented as a 2D Image + stride_z
- Image_3D_Export_To_ClImage2D,
-
- Tensor_3D,
- Tensor_4D,
- Tensor_4D_t_Buffer,
- Tensor_4D_t_Image,
-
- Tensor_Special_0,
- };
- /** Default constructor */
- GpuKernelArgumentInfo() = default;
- /** Constructor */
- GpuKernelArgumentInfo(Type type) : type{type}
- {
- }
- Type type{Type::Tensor_4D_t_Buffer};
-};
-bool operator==(const GpuKernelArgumentInfo &info0, const GpuKernelArgumentInfo &info1);
-/** Kernel argument information linked with its corresponding @ref ITensorInfo
- * @deprecated To be removed along with ClTemplateWriter
- */
-class GpuKernelArgument
-{
-public:
- /** Constructor
- *
- * @param[in] tensor_info Associated @ref ITensorInfo
- * @param[in] kernel_arg_info Associated @ref GpuKernelArgumentInfo
- */
- GpuKernelArgument(const ITensorInfo &tensor_info, const GpuKernelArgumentInfo &kernel_arg_info)
- : _tensor_info{tensor_info}, _kernel_arg_info{kernel_arg_info}
- {
- }
- /** Get workload tensor id */
- ITensorInfo::Id id() const
- {
- return _tensor_info.id();
- }
- /** Get associated @ref ITensorInfo */
- ITensorInfo *tensor_info()
- {
- return &_tensor_info;
- }
- /** Get associated @ref ITensorInfo */
- const ITensorInfo *tensor_info() const
- {
- return &_tensor_info;
- }
- /** Get associated @ref GpuKernelArgumentInfo */
- GpuKernelArgumentInfo *kernel_argument_info()
- {
- return &_kernel_arg_info;
- }
- /** Get associated @ref GpuKernelArgumentInfo */
- const GpuKernelArgumentInfo *kernel_argument_info() const
- {
- return &_kernel_arg_info;
- }
- /** Check if the associated workload tensor has valid id
- *
- * @return true if has valid id
- * @return false otherwise
- */
- bool has_valid_id() const
- {
- return _tensor_info.has_valid_id();
- }
-
-private:
- TensorInfo _tensor_info{};
- GpuKernelArgumentInfo _kernel_arg_info{};
-};
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
/** Describe how the tensor runtime memory can be accessed
*
* Please see documentation under @ref GpuKernelArgumentBinding
@@ -243,9 +153,8 @@ private:
};
Value _value;
};
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELARGUMENT_H
diff --git a/src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h b/src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h
index 24812cd8a..11d916eec 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,19 +21,15 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE
-#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE_H
#include "arm_compute/core/CL/CLCompileContext.h"
#include "arm_compute/core/Window.h"
#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include <map>
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
#include <deque>
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
#include <string>
namespace arm_compute
@@ -43,11 +39,7 @@ namespace experimental
namespace dynamic_fusion
{
/** The argument list of a @ref GpuKernelSourceCode */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-using GpuKernelArgumentList = std::map<ITensorInfo::Id, GpuKernelArgument>;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
using GpuKernelArgumentList = std::deque<GpuKernelArgumentBinding>;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Container of kernel code to be compiled and run in a @ref GpuUnitWorkload
*/
@@ -132,4 +124,4 @@ private:
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUKERNELSOURCECODE_H
diff --git a/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.cpp b/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.cpp
index 502ceab80..725a46e91 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.cpp
+++ b/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,14 +26,10 @@
#include "arm_compute/core/experimental/Types.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
+#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h"
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h"
#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
#include "src/dynamic_fusion/sketch/gpu/GpuComponentServices.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.h"
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h"
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -41,8 +37,8 @@ namespace experimental
{
namespace dynamic_fusion
{
-GpuLogicalKernel::GpuLogicalKernel(GpuComponentServices *services, const GpuKernelComponentGroup &components)
- : _comp_group{components}, _store_components{}
+GpuLogicalKernel::GpuLogicalKernel(GpuComponentServices *services, GpuKernelComponentGroup components) // NOLINT
+ : _comp_group{std::move(components)}, _store_components{}
{
ARM_COMPUTE_UNUSED(services);
}
@@ -50,19 +46,11 @@ GpuLogicalKernel::GpuLogicalKernel(GpuComponentServices *services, const GpuKern
GpuKernelSourceCode GpuLogicalKernel::write_kernel_code()
{
GpuKernelSourceCode code;
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- ClTemplateWriter writer{_comp_group};
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
- GpuCkwDriver writer{_comp_group};
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
+ GpuCkwDriver writer{_comp_group};
code.name(writer.get_name());
code.code(writer.get_code());
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- code.arguments(writer.get_tensors());
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
code.arguments(writer.get_kernel_arguments());
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
code.build_options(writer.get_build_options());
code.config_id(writer.get_config_id());
code.window(writer.get_window());
diff --git a/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.h b/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.h
index 1fd40f0ac..e2bc83b28 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuLogicalKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL_H
#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
#include "src/dynamic_fusion/sketch/gpu/GpuKernelSourceCode.h"
@@ -52,7 +52,7 @@ public:
* @param[in] services @ref GpuComponentServices to be used
* @param[in] components Component group from which this logical kernel is initialized
*/
- explicit GpuLogicalKernel(GpuComponentServices *services, const GpuKernelComponentGroup &components);
+ explicit GpuLogicalKernel(GpuComponentServices *services, GpuKernelComponentGroup components); // NOLINT
/** Allow instances of this class to be copy constructed */
GpuLogicalKernel(const GpuLogicalKernel &) = default;
/** Allow instances of this class to be copied */
@@ -71,4 +71,4 @@ private:
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPULOGICALKERNEL_H
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h
index 43bcc47fa..5d75bcaaa 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE_H
#include "arm_compute/core/experimental/Types.h"
#include "arm_compute/dynamic_fusion/sketch/MemoryDescriptor.h"
@@ -36,7 +36,6 @@ namespace experimental
{
namespace dynamic_fusion
{
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
namespace
{
/** Extract kernel arguments of one tensor from a flat list of kernel arguments.
@@ -70,7 +69,6 @@ GpuKernelArgumentList extract_kernel_args_for_one_tensor(GpuKernelArgumentList &
return tensor_kargs;
}
} // namespace
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Uniquely identifies a @ref GpuUnitWorkload within a @ref GpuWorkloadSourceCode */
using UnitWorkloadId = int32_t;
@@ -83,25 +81,11 @@ class GpuWorkloadArgument
public:
/** Default constructor */
GpuWorkloadArgument() = default;
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
/** Constructor
*
- * @param[in] tensor_info @ref ITensorInfo of the workload argument
- * @param[in] mem_desc @ref MemoryDescriptor of the workload argument
- * @param[in] kernel_arg_info @ref GpuKernelArgumentInfo of the workload argument
- */
- GpuWorkloadArgument(const ITensorInfo &tensor_info,
- const MemoryDescriptor &mem_desc,
- const GpuKernelArgumentInfo &kernel_arg_info)
- : _tensor_info{tensor_info}, _mem_desc{mem_desc}, _kernel_arg_info{kernel_arg_info}
- {
- }
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
- /** Constructor
- *
- * @param[in] tensor_info @ref ITensorInfo of the workload argument
- * @param[in] mem_desc @ref MemoryDescriptor of the workload argument
- * @param[in] kernel_arg_list @ref GpuKernelArgumentList of the workload argument
+ * @param[in] tensor_info @ref ITensorInfo of the workload argument
+ * @param[in] mem_desc @ref MemoryDescriptor of the workload argument
+ * @param[in] kernel_args @ref GpuKernelArgumentList of the workload argument
*/
GpuWorkloadArgument(const ITensorInfo &tensor_info,
const MemoryDescriptor &mem_desc,
@@ -109,7 +93,6 @@ public:
: _tensor_info{tensor_info}, _mem_desc{mem_desc}, _kernel_args{kernel_args}
{
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Get tensor id within workload */
ITensorInfo::Id id() const
{
@@ -135,18 +118,6 @@ public:
{
return &_mem_desc;
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- /** Get @ref GpuKernelArgumentInfo of the argument */
- GpuKernelArgumentInfo *kernel_argument_info()
- {
- return &_kernel_arg_info;
- }
- /** Get @ref GpuKernelArgumentInfo of the argument */
- const GpuKernelArgumentInfo *kernel_argument_info() const
- {
- return &_kernel_arg_info;
- }
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
/** Get @ref GpuKernelArgumentList of the workload tensor */
GpuKernelArgumentList *kernel_argument_list()
{
@@ -157,7 +128,6 @@ public:
{
return &_kernel_args;
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Check if the workload argument has valid id
*
* @return true If has valid id
@@ -169,13 +139,9 @@ public:
}
private:
- TensorInfo _tensor_info{};
- MemoryDescriptor _mem_desc{};
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- GpuKernelArgumentInfo _kernel_arg_info{};
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
+ TensorInfo _tensor_info{};
+ MemoryDescriptor _mem_desc{};
GpuKernelArgumentList _kernel_args{};
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
};
/** Describes when a unit workload is run.
@@ -259,22 +225,7 @@ public:
const auto uwk_id = static_cast<UnitWorkloadId>(_unit_workloads.size());
const auto unit_work = GpuUnitWorkload(uwk_id, kernel_code, stage);
_unit_workloads.push_back(unit_work);
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- ARM_COMPUTE_UNUSED(context);
- // Assemble kernel argument with memory descriptor to form workload argument
- for (const auto &id_arg : kernel_code.arguments())
- {
- const auto arg_id = id_arg.first;
- const auto arg = id_arg.second;
- _workload_arguments[arg_id] =
- GpuWorkloadArgument{*arg.tensor_info(), mem_map.at(arg_id), *arg.kernel_argument_info()};
- if (_tensor_uwork_map.find(arg_id) == _tensor_uwork_map.end())
- {
- _tensor_uwork_map[arg_id] = std::set<UnitWorkloadId>();
- }
- _tensor_uwork_map[arg_id].insert(uwk_id);
- }
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
+
GpuKernelArgumentList flat_kernel_args = kernel_code.arguments();
GpuKernelArgumentList tensor_kargs{};
while (true)
@@ -296,7 +247,7 @@ public:
_tensor_uwork_map[tensor_id].insert(uwk_id);
}
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
+
return uwk_id;
}
/** Get a unit workload from its id */
@@ -346,4 +297,4 @@ private:
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSOURCECODE_H
diff --git a/src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h b/src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h
index ad474674f..84972501d 100644
--- a/src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h
+++ b/src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER_H
#include "arm_compute/core/CL/CLCompileContext.h"
#include "arm_compute/core/Window.h"
@@ -62,23 +62,14 @@ public:
virtual std::string get_config_id() = 0;
/** Generate execution window */
virtual Window get_window() const = 0;
- /** Get the kernel argument lists of the kernel
- * @deprecated To be removed along with ClTemplateWriter
- */
- virtual std::map<ITensorInfo::Id, GpuKernelArgument> get_tensors()
- {
- return {};
- }
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
/** Get the flat list of arguments of the kernel*/
virtual GpuKernelArgumentList get_kernel_arguments()
{
return {};
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_IGPUKERNELWRITER_H
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h
index b80ce0d81..f8770920b 100644
--- a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h
+++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwDriver.h
@@ -24,15 +24,12 @@
#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_GPUCKWDRIVER_H
#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_GPUCKWDRIVER_H
-#include "ckw/Kernel.h"
-
#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
#include "src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h"
#include "compute_kernel_writer/include/ckw/Kernel.h"
#include "compute_kernel_writer/include/ckw/KernelArgument.h"
-#include <map>
#include <string>
namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h b/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h
index f1f0e6747..c9ce7eb26 100644
--- a/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h
+++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE
-#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE_H
#include "src/core/common/Macros.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/IGpuCkwComponentDriver.h"
@@ -33,8 +33,6 @@ namespace experimental
{
namespace dynamic_fusion
{
-/** An interface used by @ref ClTemplateWriter to write source code for a kernel component
- */
class GpuCkwStore : public IGpuCkwComponentDriver
{
public:
@@ -61,4 +59,4 @@ private:
} // namespace experimental
} // namespace arm_compute
-#endif /* ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_CKW_DRIVER_COMPONENTS_GPUCKWSTORE_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h b/src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h
index 4b8eea2f5..6678c929e 100644
--- a/src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h
+++ b/src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT
-#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT_H
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h"
@@ -100,10 +100,6 @@ public:
return _properties;
}
/** Get writer for the component */
- virtual const IGpuTemplateComponentWriter *template_writer() const
- {
- return nullptr;
- }
virtual const IGpuCkwComponentDriver *ckw_component_driver() const
{
return nullptr;
@@ -119,4 +115,4 @@ private:
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_IGPUKERNELCOMPONENT_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.cpp
index fdf528a65..e316bdf46 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,11 +24,7 @@
#include "ClComponentActivation.h"
#include "src/core/CL/CLValidate.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h"
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwActivation.h"
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -69,11 +65,7 @@ ClComponentActivation::ClComponentActivation(ComponentId
const ArgumentPack<ITensorInfo> &tensors,
const Attributes &attributes)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateActivation>(id, tensors, attributes)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwActivation>(id, tensors, attributes)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
}
@@ -81,11 +73,7 @@ ClComponentActivation::~ClComponentActivation()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentActivation::template_writer() const
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentActivation::ckw_component_driver() const
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h
index 02c854356..b8185158f 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION_H
#include "arm_compute/function_info/ActivationLayerInfo.h"
@@ -41,11 +41,7 @@ template <typename T>
class ArgumentPack;
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateActivation;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwActivation;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentActivation final : public IGpuKernelComponent
{
@@ -106,11 +102,7 @@ public:
ClComponentActivation &operator=(ClComponentActivation &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
+ const IGpuCkwComponentDriver *ckw_component_driver() const override;
/** Get component type */
GpuComponentType type() const override
@@ -119,13 +111,9 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateActivation> _component_writer;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwActivation> _component_writer;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTACTIVATION_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.cpp
index b1636795a..e1850d78c 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,11 +27,7 @@
#include "src/core/CL/CLValidate.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.h"
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwCast.h"
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -72,22 +68,16 @@ ClComponentCast::ClComponentCast(ComponentId id,
const Attributes &attributes,
const Settings &settings)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateCast>(id, tensors, attributes)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwCast>(id, tensors, attributes)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
ARM_COMPUTE_UNUSED(attributes, settings);
}
+
ClComponentCast::~ClComponentCast()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentCast::template_writer() const
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
+
const IGpuCkwComponentDriver *ClComponentCast::ckw_component_driver() const
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
index ed77b1203..201dacc28 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST_H
#include "arm_compute/dynamic_fusion/sketch/attributes/CastAttributes.h"
@@ -49,11 +49,7 @@ private:
};
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateCast;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwCast;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentCast final : public IGpuKernelComponent
{
@@ -120,11 +116,7 @@ public:
/** Allow instances of this class to be moved */
ClComponentCast &operator=(ClComponentCast &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
/** Get component type */
GpuComponentType type() const override
{
@@ -132,14 +124,10 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateCast> _component_writer;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<GpuCkwCast> _component_writer;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
+ std::unique_ptr<GpuCkwCast> _component_writer;
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTCAST_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp
index ca8037c39..7cd23d611 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,11 +28,7 @@
#include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h"
#include "src/core/CL/CLValidate.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.h"
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwDepthwiseConv2d.h"
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -212,22 +208,14 @@ ClComponentDepthwiseConv2d::ClComponentDepthwiseConv2d(ComponentId
const Attributes &attributes,
const Settings &settings)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateDepthwiseConv2d>(id, tensors, attributes, settings)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwDepthwiseConv2d>(id, tensors, attributes, settings)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
ARM_COMPUTE_UNUSED(attributes, settings);
}
ClComponentDepthwiseConv2d::~ClComponentDepthwiseConv2d()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentDepthwiseConv2d::template_writer() const
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentDepthwiseConv2d::ckw_component_driver() const
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h
index 01168e9de..7526361f1 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,11 +44,7 @@ class ArgumentPack;
class DepthwiseConv2dAttributes;
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateDepthwiseConv2d;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwDepthwiseConv2d;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
/** Component specific settings
*/
@@ -161,13 +157,8 @@ public:
ClComponentDepthwiseConv2d(ClComponentDepthwiseConv2d &&component) = default;
/** Allow instances of this class to be moved */
ClComponentDepthwiseConv2d &operator=(ClComponentDepthwiseConv2d &&component) = default;
- /** Get template writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
-
+ /** Get writer for the component */
+ const IGpuCkwComponentDriver *ckw_component_driver() const override;
/** Get component type */
GpuComponentType type() const override
{
@@ -175,11 +166,7 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateDepthwiseConv2d> _component_writer;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwDepthwiseConv2d> _component_writer;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
index 98f3d6a88..783a17df3 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,12 +28,7 @@
#include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h"
#include "src/core/CL/CLValidate.h"
-
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.h"
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwDirectConv2d.h"
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -153,11 +148,7 @@ ClComponentDirectConv2d::ClComponentDirectConv2d(ComponentId
const Attributes &attributes,
const Settings &settings)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateDirectConv2d>(id, tensors, attributes, settings)}
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwDirectConv2d>(id, tensors, attributes, settings)}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
{
}
@@ -165,11 +156,7 @@ ClComponentDirectConv2d::~ClComponentDirectConv2d()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentDirectConv2d::template_writer() const
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentDirectConv2d::ckw_component_driver() const
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h
index d6d9705d3..c50b0fa0c 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D_H
#include "arm_compute/core/Error.h"
#include "arm_compute/core/KernelDescriptors.h"
@@ -68,11 +68,7 @@ private:
};
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateDirectConv2d;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwDirectConv2d;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentDirectConv2d final : public IGpuKernelComponent
{
@@ -139,11 +135,7 @@ public:
/** Allow instances of this class to be moved */
ClComponentDirectConv2d &operator=(ClComponentDirectConv2d &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
+ const IGpuCkwComponentDriver *ckw_component_driver() const override;
/** Get component type */
GpuComponentType type() const override
{
@@ -151,13 +143,9 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateDirectConv2d> _component_writer;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwDirectConv2d> _component_writer;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
index 5b136427e..209c73dbe 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,11 +26,7 @@
#include "arm_compute/core/Validate.h"
#include "src/core/CL/CLValidate.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h"
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.h"
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -117,19 +113,11 @@ ClComponentElementwiseBinary::ClComponentElementwiseBinary(ComponentId
const ArgumentPack<ITensorInfo> &tensors,
const Attributes &attributes)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateElementwiseBinary>(id, tensors, attributes)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwElementwiseBinary>(id, tensors, attributes)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentElementwiseBinary::template_writer() const
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentElementwiseBinary::ckw_component_driver() const
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h
index 7589b9732..a4395a621 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY
-#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY_H
#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
#include "src/dynamic_fusion/sketch/gpu/operators/internal/GpuElementwiseBinaryCommon.h"
@@ -40,11 +40,7 @@ template <typename T>
class ArgumentPack;
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateElementwiseBinary;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwElementwiseBinary;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentElementwiseBinary final : public IGpuKernelComponent
{
@@ -105,12 +101,7 @@ public:
/** Allow instances of this class to be moved */
ClComponentElementwiseBinary &operator=(ClComponentElementwiseBinary &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
-
+ const IGpuCkwComponentDriver *ckw_component_driver() const override;
/** Get component type */
GpuComponentType type() const override
{
@@ -118,13 +109,9 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateElementwiseBinary> _component_writer;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwElementwiseBinary> _component_writer;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTELEMENTWISEBINARY_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp
deleted file mode 100644
index 27c13bd65..000000000
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.cpp
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h"
-
-#include "arm_compute/core/CL/CLHelpers.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/dynamic_fusion/sketch/attributes/SoftmaxAttributes.h"
-
-#include "src/core/CL/CLValidate.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-Status ClComponentLogits1DMaxShiftExpSum::validate(const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
-{
- ARM_COMPUTE_UNUSED(properties, attributes);
-
- const ITensorInfo *src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
- const ITensorInfo *sum = tensors.get_const_tensor(TensorType::ACL_DST_0);
- const ITensorInfo *dst = tensors.get_const_tensor(TensorType::ACL_DST_1);
-
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(dst);
-
- // 1. Check validity
- // All tensor infos are initialized
- ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0);
- ARM_COMPUTE_RETURN_ERROR_ON(sum->tensor_shape().total_size() == 0);
- ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0);
-
- // Check for mismatches in shapes and data types
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src, dst);
-
- // Device requirements are met
- ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
-
- // 2. Check support level
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
-
- return Status{};
-}
-
-ClComponentLogits1DMaxShiftExpSum::ClComponentLogits1DMaxShiftExpSum(ComponentId id,
- const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuKernelComponent{id, properties, tensors},
- _component_writer{std::make_unique<ClTemplateLogits1DMaxShiftExpSum>(id, tensors, attributes)}
-{
-}
-
-ClComponentLogits1DMaxShiftExpSum::~ClComponentLogits1DMaxShiftExpSum()
-{
-}
-
-const IGpuTemplateComponentWriter *ClComponentLogits1DMaxShiftExpSum::template_writer() const
-{
- return _component_writer.get();
-}
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h
deleted file mode 100644
index 91ab5de3b..000000000
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DMAXSHIFTEXPSUM
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DMAXSHIFTEXPSUM
-
-#include "arm_compute/dynamic_fusion/sketch/attributes/SoftmaxAttributes.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
-
-namespace arm_compute
-{
-/** Forward declaration */
-class ITensorInfo;
-namespace experimental
-{
-namespace dynamic_fusion
-{
-/** Forward declaration */
-template <typename T>
-class ArgumentPack;
-
-/** Forward declaration */
-class ClTemplateLogits1DMaxShiftExpSum;
-
-/** Component to calculate max-shifted exponentials and their sum
- *
- * 1D example:
- * input: [x1, x2, ... , xn], shape: (1 x d)
- *
- * Let max(x1...xn) = m
- *
- * (output) sum: [exp(x1-m) + ... + exp(xn-m)], shape: (1 x 1)
- * (output) dst: [exp(x1-m) ... exp(xn-m)], shape: (1 x d)
- *
- * This component is used by the softmax operator. The subsequent
- * operation normalizes dst with sum, therefore the max-shifting
- * since exp(m) will be cancelled in numerator and denominator.
-*/
-class ClComponentLogits1DMaxShiftExpSum final : public IGpuKernelComponent
-{
-public:
- /** Attributes are a set of backend-agnostic parameters that define what a component does */
- using Attributes = SoftmaxAttributes;
-
- /** Validate the component
- *
- * @param[in] properties Component properties @ref Properties
- * @param[in] tensors Tensor arguments to the component
- * @param[in] attributes Component attributes @ref Attributes
- *
- * @return Status Validation results
- *
- * Tensor argument names:
- * - ACL_SRC_0: Input
- * - ACL_DST_0: Output
- * - ACL_DST_1: Output
- *
- * Tensor argument constness:
- * - ACL_SRC_0: Const
- * - ACL_DST_0: Const
- * - ACL_DST_1: Const
- *
- * Valid data layouts:
- * - All
- *
- ** Valid data type configurations:
- * |ACL_SRC_0 |ACL_DST_0 |ACL_DST_1 |
- * |:----------|:----------|:----------|
- * |F16 | F16 | F16 |
- * |F32 | F32 | F32 |
- */
- static Status
- validate(const Properties &properties, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
-
- /** Constructor
- *
- * Similar to @ref ClComponentLogits1DMaxShiftExpSum::validate()
- */
- ClComponentLogits1DMaxShiftExpSum(ComponentId id,
- const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes);
-
- /** Destructor */
- ~ClComponentLogits1DMaxShiftExpSum() override;
- /** Prevent instances of this class from being copy constructed */
- ClComponentLogits1DMaxShiftExpSum(const ClComponentLogits1DMaxShiftExpSum &component) = delete;
- /** Prevent instances of this class from being copied */
- ClComponentLogits1DMaxShiftExpSum &operator=(const ClComponentLogits1DMaxShiftExpSum &component) = delete;
- /** Allow instances of this class to be move constructed */
- ClComponentLogits1DMaxShiftExpSum(ClComponentLogits1DMaxShiftExpSum &&component) = default;
- /** Allow instances of this class to be moved */
- ClComponentLogits1DMaxShiftExpSum &operator=(ClComponentLogits1DMaxShiftExpSum &&component) = default;
- /** Get template writer for the component */
- const IGpuTemplateComponentWriter *template_writer() const override;
- /** Get component type */
- GpuComponentType type() const override
- {
- return GpuComponentType::Unfusable;
- }
-
-private:
- std::unique_ptr<ClTemplateLogits1DMaxShiftExpSum> _component_writer;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DMAXSHIFTEXPSUM */
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp
deleted file mode 100644
index fb2544385..000000000
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.cpp
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h"
-
-#include "arm_compute/core/CL/CLHelpers.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/dynamic_fusion/sketch/attributes/SoftmaxAttributes.h"
-
-#include "src/core/CL/CLValidate.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-Status ClComponentLogits1DNorm::validate(const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
-{
- ARM_COMPUTE_UNUSED(properties, attributes);
-
- const ITensorInfo *src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
- const ITensorInfo *sum = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- const ITensorInfo *dst = tensors.get_const_tensor(TensorType::ACL_DST_0);
-
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(dst);
-
- // 1. Check validity
- // All tensor infos are initialized
- ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0);
- ARM_COMPUTE_RETURN_ERROR_ON(sum->tensor_shape().total_size() == 0);
- ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0);
-
- // Check for mismatches in shapes and data types
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src, dst);
-
- ARM_COMPUTE_RETURN_ERROR_ON(attributes.is_log_softmax() && !is_data_type_float(src->data_type()));
-
- // Device requirements are met
- ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
-
- // 2. Check support level
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
-
- return Status{};
-}
-
-ClComponentLogits1DNorm::ClComponentLogits1DNorm(ComponentId id,
- const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuKernelComponent{id, properties, tensors},
- _component_writer{std::make_unique<ClTemplateLogits1DNorm>(id, tensors, attributes)}
-{
-}
-
-ClComponentLogits1DNorm::~ClComponentLogits1DNorm()
-{
-}
-
-const IGpuTemplateComponentWriter *ClComponentLogits1DNorm::template_writer() const
-{
- return _component_writer.get();
-}
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h
deleted file mode 100644
index 74c027360..000000000
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DNORM
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DNORM
-
-#include "arm_compute/dynamic_fusion/sketch/attributes/SoftmaxAttributes.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
-
-namespace arm_compute
-{
-/** Forward declaration */
-class ITensorInfo;
-namespace experimental
-{
-namespace dynamic_fusion
-{
-/** Forward declaration */
-template <typename T>
-class ArgumentPack;
-
-/** Forward declaration */
-class ClTemplateLogits1DNorm;
-
-/** Component to calculate the final step of the Softmax Layer
- * where each logit value is multiplied by the inverse of the sum of the logits.
- *
- * 1D example:
- *
- * (input) src: [x1 x2 ... xn], shape: (1 x d)
- * (input) sum: [x1 + x2 + ... + xn], shape: (1 x 1)
- * (output) dst: [x1/sum x2/sum ... xn/sum], shape: (1 x d)
- *
- * This component is used by the softmax operator to get the final result.
-*/
-class ClComponentLogits1DNorm final : public IGpuKernelComponent
-{
-public:
- /** Attributes are a set of backend-agnostic parameters that define what a component does */
- using Attributes = SoftmaxAttributes;
-
- /** Validate the component
- *
- * @param[in] properties Component properties @ref Properties
- * @param[in] tensors Tensor arguments to the component
- * @param[in] attributes Component attributes @ref Attributes
- *
- * @return Status Validation results
- *
- * Tensor argument names:
- * - ACL_SRC_0: Input
- * - ACL_SRC_1: Input
- * - ACL_DST_0: Output
- *
- * Tensor argument constness:
- * - ACL_SRC_0: Const
- * - ACL_SRC_1: Const
- * - ACL_DST_0: Const
- *
- * Valid data layouts:
- * - All
- *
- ** Valid data type configurations:
- * |ACL_SRC_0 |ACL_SRC_1 |ACL_DST_0 |
- * |:----------|:----------|:----------|
- * |F16 | F16 | F16 |
- * |F32 | F32 | F32 |
- */
- static Status
- validate(const Properties &properties, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
-
- /** Constructor
- *
- * Similar to @ref ClComponentLogits1DNorm::validate()
- */
- ClComponentLogits1DNorm(ComponentId id,
- const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes);
-
- /** Destructor */
- ~ClComponentLogits1DNorm() override;
- /** Prevent instances of this class from being copy constructed */
- ClComponentLogits1DNorm(const ClComponentLogits1DNorm &component) = delete;
- /** Prevent instances of this class from being copied */
- ClComponentLogits1DNorm &operator=(const ClComponentLogits1DNorm &component) = delete;
- /** Allow instances of this class to be move constructed */
- ClComponentLogits1DNorm(ClComponentLogits1DNorm &&component) = default;
- /** Allow instances of this class to be moved */
- ClComponentLogits1DNorm &operator=(ClComponentLogits1DNorm &&component) = default;
- /** Get template writer for the component */
- const IGpuTemplateComponentWriter *template_writer() const override;
- /** Get component type */
- GpuComponentType type() const override
- {
- return GpuComponentType::Unfusable;
- }
-
-private:
- std::unique_ptr<ClTemplateLogits1DNorm> _component_writer;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTLOGITS1DNORM */
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
index f238d42d9..53ac8da41 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,7 +21,6 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.h"
@@ -147,5 +146,3 @@ const IGpuCkwComponentDriver *ClComponentMatMul::ckw_component_driver() const
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp
index 5544963b3..6e7243dc0 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.cpp
@@ -30,7 +30,6 @@
#include "src/core/CL/CLValidate.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwPool2d.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.h"
#include "src/dynamic_fusion/utils/Utils.h"
#include <memory>
@@ -93,27 +92,16 @@ ClComponentPool2d::ClComponentPool2d(ComponentId id,
const Attributes &attributes,
const Settings &settings)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplatePool2d>(id, tensors, attributes, settings)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwPool2d>(id, tensors, attributes, settings)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
}
ClComponentPool2d::~ClComponentPool2d()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentPool2d::template_writer() const
-{
- return _component_writer.get();
-}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentPool2d::ckw_component_driver() const
{
return _component_writer.get();
}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h
index 98fed6500..d33e601f1 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,11 +42,7 @@ class ArgumentPack;
class Pool2dAttributes;
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplatePool2d;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwPool2d;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentPool2d final : public IGpuKernelComponent
{
@@ -116,13 +112,9 @@ public:
/** Allow instances of this class to be moved */
ClComponentPool2d &operator=(ClComponentPool2d &&component) = default;
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- /** Get template writer for the component */
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
+
/** Get GPU kernel writer for the component */
const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Get component type */
GpuComponentType type() const override
@@ -131,11 +123,7 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplatePool2d> _component_writer;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwPool2d> _component_writer;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp
index 0ece9de97..dce85c424 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,7 +27,6 @@
#include "arm_compute/core/Validate.h"
#include "src/core/CL/CLValidate.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.h"
namespace arm_compute
{
@@ -54,15 +53,16 @@ Status ClComponentReshape::validate(const ArgumentPack<ITensorInfo> &tensors)
ClComponentReshape::ClComponentReshape(ComponentId id,
const Properties &properties,
const ArgumentPack<ITensorInfo> &tensors)
- : IGpuKernelComponent{id, properties, tensors}, _component_writer{std::make_unique<ClTemplateReshape>(id, tensors)}
+ : IGpuKernelComponent{id, properties, tensors}
{
}
ClComponentReshape::~ClComponentReshape()
{
}
-const IGpuTemplateComponentWriter *ClComponentReshape::template_writer() const
+const IGpuCkwComponentDriver *ClComponentReshape::ckw_component_driver() const
{
- return _component_writer.get();
+ /* NOT IMPLEMENTED */
+ return nullptr;
}
} // namespace dynamic_fusion
} // namespace experimental
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h
index 78163d660..fd0f966da 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE_H
#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
@@ -85,8 +85,8 @@ public:
ClComponentReshape(ClComponentReshape &&component) = default;
/** Allow instances of this class to be moved */
ClComponentReshape &operator=(ClComponentReshape &&component) = default;
- /** Get template writer for the component */
- const IGpuTemplateComponentWriter *template_writer() const override;
+ /** Get writer for the component */
+ const IGpuCkwComponentDriver *ckw_component_driver() const override;
/** Get component type */
GpuComponentType type() const override
{
@@ -94,10 +94,9 @@ public:
}
private:
- std::unique_ptr<ClTemplateReshape> _component_writer;
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESHAPE_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp
index b05eb0469..411eeca80 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,12 +29,7 @@
#include "src/core/CL/CLValidate.h"
#include "src/core/utils/ScaleUtils.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
-
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.h"
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwResize.h"
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
namespace arm_compute
{
@@ -43,11 +38,7 @@ namespace experimental
namespace dynamic_fusion
{
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateResize;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwResize;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
Status ClComponentResize::validate(const IGpuKernelComponent::Properties &properties,
const ArgumentPack<ITensorInfo> &tensors,
@@ -82,11 +73,7 @@ ClComponentResize::ClComponentResize(ComponentId id,
const ArgumentPack<ITensorInfo> &tensors,
const ClComponentResize::Attributes &attributes)
: IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateResize>(id, tensors, attributes)}
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
_component_writer{std::make_unique<GpuCkwResize>(id, tensors, attributes)}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
{
}
@@ -94,11 +81,7 @@ ClComponentResize::~ClComponentResize()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentResize::template_writer() const
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentResize::ckw_component_driver() const
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h
index 29276c325..9a1169c45 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,8 +22,8 @@
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE_H
#include "arm_compute/dynamic_fusion/sketch/attributes/ResizeAttributes.h"
@@ -42,11 +42,7 @@ template <typename T>
class ArgumentPack;
/** Forward declaration */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateResize;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwResize;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentResize final : public IGpuKernelComponent
{
@@ -111,11 +107,7 @@ public:
ClComponentResize &operator=(ClComponentResize &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
/** Get component type */
GpuComponentType type() const override
@@ -124,15 +116,11 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateResize> _component_writer;
-#else // ACL_INTERNAL_TEST_CKW_IN_DF
std::unique_ptr<GpuCkwResize> _component_writer;
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTRESIZE_H
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp
index dcbecaff3..3db6c5cd2 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,11 +24,7 @@
#include "ClComponentStore.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.h"
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwStore.h"
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
#include <memory>
@@ -46,22 +42,13 @@ Status ClComponentStore::validate(const Properties &properties, const ArgumentPa
ClComponentStore::ClComponentStore(ComponentId id,
const Properties &properties,
const ArgumentPack<ITensorInfo> &tensors)
- : IGpuKernelComponent{id, properties, tensors},
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<ClTemplateStore>(id, tensors)}
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer{std::make_unique<GpuCkwStore>(id, tensors)}
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
+ : IGpuKernelComponent{id, properties, tensors}, _component_writer{std::make_unique<GpuCkwStore>(id, tensors)}
{
}
ClComponentStore::~ClComponentStore()
{
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-const IGpuTemplateComponentWriter *ClComponentStore::template_writer() const
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ClComponentStore::ckw_component_driver() const
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
return _component_writer.get();
}
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h
index 948785c48..2c1dd0f6f 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentStore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE_H
#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
@@ -39,11 +39,7 @@ namespace dynamic_fusion
/** Forward declaration */
template <typename T>
class ArgumentPack;
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
-class ClTemplateStore;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
class GpuCkwStore;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
class ClComponentStore final : public IGpuKernelComponent
{
@@ -88,11 +84,7 @@ public:
/** Allow instances of this class to be moved */
ClComponentStore &operator=(ClComponentStore &&component) = default;
/** Get writer for the component */
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- const IGpuTemplateComponentWriter *template_writer() const override;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
const IGpuCkwComponentDriver *ckw_component_driver() const override;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
/** Get component type */
GpuComponentType type() const override
{
@@ -100,13 +92,9 @@ public:
}
private:
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<ClTemplateStore> _component_writer;
-#else //ACL_INTERNAL_TEST_CKW_IN_DF
- std::unique_ptr<GpuCkwStore> _component_writer;
-#endif //ACL_INTERNAL_TEST_CKW_IN_DF
+ std::unique_ptr<GpuCkwStore> _component_writer;
};
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTSTORE_H
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuClamp.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuClamp.cpp
index 697b7d4e1..4d6e7f81b 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuClamp.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuClamp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,7 +30,6 @@
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h"
namespace arm_compute
{
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuMatMul.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuMatMul.cpp
index e24629a03..2997b28ec 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuMatMul.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuMatMul.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,7 +21,6 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMatMul.h"
@@ -244,4 +243,3 @@ ITensorInfo *GpuMatMul::create_op(GpuWorkloadSketch &sketch,
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.cpp
index 431c9110f..d38575220 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.cpp
@@ -28,8 +28,6 @@
#include "src/common/utils/Log.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h"
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h"
#include "src/dynamic_fusion/sketch/gpu/GpuOperatorProperties.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h"
@@ -88,9 +86,8 @@ Status GpuSoftmax::is_supported_op(const GpuWorkloadContext &context,
arguments_norm.add_const_tensor(ACL_SRC_1, &sum);
arguments_norm.add_const_tensor(ACL_DST_0, &dst_info_to_validate);
- ARM_COMPUTE_RETURN_ON_ERROR(
- ClComponentLogits1DMaxShiftExpSum::validate(properties, arguments_exp_sum, attributes));
- ARM_COMPUTE_RETURN_ON_ERROR(ClComponentLogits1DNorm::validate(properties, arguments_norm, attributes));
+ ARM_COMPUTE_UNUSED(properties, attributes);
+ return Status(ErrorCode::RUNTIME_ERROR, "GpuSoftmax is not implemented");
}
else
{
@@ -177,8 +174,8 @@ void GpuSoftmax::create_op(GpuWorkloadSketch &sketch, ITensorInfo *src, ITensorI
arguments_norm.add_const_tensor(ACL_SRC_1, sum);
arguments_norm.add_const_tensor(ACL_DST_0, dst);
- comp_graph.add_new_component<ClComponentLogits1DMaxShiftExpSum>(properties, arguments_exp_sum, attributes);
- comp_graph.add_new_component<ClComponentLogits1DNorm>(properties, arguments_norm, attributes);
+ // Add to component graph -- NOT IMPLEMENTED
+ ARM_COMPUTE_UNUSED(comp_graph, attributes);
}
}
else
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp
index bf0f274c5..b9d01966b 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuTanh.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,7 +31,6 @@
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h"
namespace arm_compute
{
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
deleted file mode 100644
index 775b0a0c8..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "GpuKernelVariableTable.h"
-
-#include "arm_compute/core/CL/CLHelpers.h"
-#include "arm_compute/core/ITensorInfo.h"
-
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-void GpuKernelVariableTable::declare_variable(const GpuKernelComponentGroup &comp_group,
- const ITensorInfo *tensor,
- GpuKernelArgumentInfo argument_info,
- const std::string &alias)
-{
- ARM_COMPUTE_ERROR_ON_MSG(!tensor->has_valid_id(), "Tensor info with valid id expected");
-
- // Do not re-declare if the variable associated with the tensor has already been declared
- auto it = _vars.find(tensor->id());
-
- if (it != _vars.end())
- {
- ARM_COMPUTE_ERROR_ON(!(it->second.kernel_argument_info == argument_info));
- return;
- }
-
- const auto target = comp_group.get_tile_for_tensor(tensor);
-
- if (target != tensor)
- {
- // If the tensor uses a shared tile, don't declare another variable.
- it = _vars.find(target->id());
-
- ARM_COMPUTE_ERROR_ON_MSG(it == _vars.end(), "The variable used for this tensor must have been declared.");
-
- _vars[tensor->id()] = it->second;
- }
- else
- {
- // Declare variable associated with the tensor
- std::stringstream ss;
- ss << alias << "_t" << abs(tensor->id());
- const auto uniq_name = ss.str();
- TensorVariable var{tensor->id(), uniq_name, argument_info};
-
- _vars.emplace(tensor->id(), var);
- }
-}
-
-GpuKernelVariableTable::TensorVariable GpuKernelVariableTable::get_variable(const ITensorInfo *tensor) const
-{
- const auto var = _vars.at(tensor->id());
- return var;
-}
-
-GpuKernelVariableTable::VariableList
-GpuKernelVariableTable::get_variable_list(const std::vector<const ITensorInfo *> &tensors) const
-{
- VariableList vars{};
- for (const auto &tensor : tensors)
- {
- if (!tensor->has_valid_id())
- {
- continue;
- }
- vars.push_back(get_variable(tensor));
- }
- return vars;
-}
-
-TagVal::TagVal(const GpuKernelVariableTable::TensorVariable &var) : value{var.uniq_name}
-{
-}
-
-TagVal::TagVal(const std::string &val) : value{val}
-{
-}
-
-TagVal::TagVal(const char *val) : value{std::string(val)}
-{
-}
-
-TagVal::TagVal(const DataType &data_type) : value{get_cl_type_from_data_type(data_type)}
-{
-}
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
deleted file mode 100644
index c17f131ad..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_GPUKERNELVARIABLETABLE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_GPUKERNELVARIABLETABLE
-
-#include "arm_compute/core/ITensorInfo.h"
-
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
-#include "support/AclRequires.h"
-#include "support/StringSupport.h"
-
-#include <set>
-#include <string>
-#include <type_traits>
-#include <unordered_map>
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class GpuKernelComponentGroup;
-
-/** A table of all the variables used in the kernel.
- * Each kernel has exactly one variable table.
- */
-class GpuKernelVariableTable
-{
-public:
- /** A tensor variable whose main purposes are:
- * - Hold the newly assigned @ref GpuKernelArgumentInfo for the associated tensor info
- * - Hold the generated variable name for the associated tensor info
- */
- struct TensorVariable
- {
- public:
- TensorVariable() = default;
- TensorVariable(const TensorVariable &) = default;
- TensorVariable &operator=(const TensorVariable &) = default;
- ITensorInfo::Id id{ITensorInfo::invalid_tensor_id};
- std::string uniq_name{"empty"}; // Unique name, also the final variable name used in the built code
- GpuKernelArgumentInfo kernel_argument_info{};
- bool has_valid_id() const
- {
- return id != ITensorInfo::invalid_tensor_id;
- }
- };
- using VariableList = std::vector<TensorVariable>;
-
-public:
- /** Declare a @ref TensorVariable for a corresponding tensor info.
- *
- * @param[in] comp_group Component group the tensor belongs to
- * @param[in] tensor Tensor info with which the new variable is associated
- * @param[in] argument_info Kernel argument information
- * @param[in] alias Alias for the variable. Will be used as part of the variable name
- */
- void declare_variable(const GpuKernelComponentGroup &comp_group,
- const ITensorInfo *tensor,
- GpuKernelArgumentInfo argument_info,
- const std::string &alias = "unnamed");
- /** Get the @ref TensorVariable associated with @p tensor
- *
- * @param[in] tensor Tensor info to be queried
- *
- * @return TensorVariable
- */
- TensorVariable get_variable(const ITensorInfo *tensor) const;
- /** Get the @ref TensorVariable list associated with @p tensors
- * @note Empty tensors are skipped
- *
- * @param[in] tensors List of tensor infos to be queried
- *
- * @return VariableList
- */
- VariableList get_variable_list(const std::vector<const ITensorInfo *> &tensors) const;
-
-private:
- std::map<ITensorInfo::Id, TensorVariable> _vars{};
-};
-
-/** A tag value will substitute a tag in a string template during its instantiation */
-struct TagVal
-{
- /** Default constructor */
- TagVal() = default;
- /** Construct a @ref TagVal from a @ref GpuKernelVariableTable::TensorVariable */
- TagVal(const GpuKernelVariableTable::TensorVariable &var);
- /** Construct a @ref TagVal from an integral type */
- template <typename T, ARM_COMPUTE_REQUIRES_TA(std::is_integral<T>::value)>
- TagVal(T val) : value{support::cpp11::to_string(val)}
- {
- }
- /** Construct a @ref TagVal from a string */
- TagVal(const std::string &val);
- /** Construct a @ref TagVal from a c-style string */
- TagVal(const char *val);
- /** Construct a @ref TagVal from a @ref DataType */
- TagVal(const DataType &data_type);
- /** Get the value of the TagVal as a converted string */
- std::string value{};
-};
-
-/** A tag used in a string template is a placeholder string to be substituted by real values during template instantiation */
-using Tag = std::string;
-
-/** Tag lookup table. It is used to instantiate a string template */
-using TagLUT = std::unordered_map<Tag, TagVal>;
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_GPUKERNELVARIABLETABLE */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h b/src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h
deleted file mode 100644
index 9d0b4f592..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_IGPUTEMPLATECOMPONENTWRITER
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_IGPUTEMPLATECOMPONENTWRITER
-
-#include "arm_compute/core/CL/CLCompileContext.h"
-#include "arm_compute/core/ITensorInfo.h"
-#include "arm_compute/core/Window.h"
-
-#include "src/dynamic_fusion/sketch/ArgumentPack.h"
-#include "src/dynamic_fusion/sketch/gpu/components/Types.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-/** Forward declaration */
-class GpuKernelComponentGroup;
-class GpuKernelVariableTable;
-
-/** An interface used by @ref ClTemplateWriter to write source code for a kernel component
- */
-class IGpuTemplateComponentWriter
-{
-public:
- using ComponentGroup = GpuKernelComponentGroup;
-
- /**For now all kernel intermeditate/destination tensors are expected to be of type Tensor_4D_t_Buffer*/
- static constexpr GpuKernelArgumentInfo::Type common_tensor_type = GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
-
-public:
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- */
- IGpuTemplateComponentWriter(ComponentId id, const ArgumentPack<ITensorInfo> &tensors) : _id{id}, _tensors{tensors}
- {
- }
- /** Destructor */
- virtual ~IGpuTemplateComponentWriter()
- {
- }
- /** Generate kernel component name */
- virtual std::string get_name() const = 0;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- virtual std::string get_component_code(const ComponentGroup &comp_group) const = 0;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- virtual void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const = 0;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- virtual TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const = 0;
- /** Generate additional macros used in the component */
- virtual std::string get_additional_macros() const
- {
- return "";
- }
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- virtual CLBuildOptions get_build_options(const ComponentGroup &comp_group) const
- {
- ARM_COMPUTE_UNUSED(comp_group);
- return CLBuildOptions{};
- }
- /** Generate the component config id string used for tuning */
- virtual std::string get_config_id() const
- {
- return "";
- }
- /** Generate the header list used in the component */
- virtual std::set<std::string> get_headers_list() const
- {
- return std::set<std::string>{};
- }
- /** Generate the execution window for the component */
- virtual Window get_window() const
- {
- return Window{};
- }
- /** Get tensor arguments */
- ArgumentPack<ITensorInfo> tensors() const
- {
- return _tensors;
- }
- /** Get component id */
- ComponentId id() const
- {
- return _id;
- }
-
-private:
- ComponentId _id{-1};
- ArgumentPack<ITensorInfo> _tensors{};
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_IGPUTEMPLATECOMPONENTWRITER */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp
deleted file mode 100644
index c165fb5f3..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateActivation.h"
-
-#include "arm_compute/core/utils/ActivationFunctionUtils.h"
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateActivation::ClTemplateActivation(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}, _attributes{attributes}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
-}
-
-std::string ClTemplateActivation::get_name() const
-{
- return "activation";
-}
-
-std::string ClTemplateActivation::get_component_code(const ComponentGroup &comp_group) const
-{
- std::string code;
- const bool is_root = (comp_group.get_root_component()->id() == this->id());
-
- code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-)_";
- if (is_root)
- {
- code += R"_(
-// IN(src) {{src}}
-// OUT(dst, accum) {{dst}}
-
-TILE({{DATA_TYPE}}, M0, N0, {{src}});
-TILE(uint, M0, 1, g_dst_indirect_y);
-{
- {{src}}_offset_first_element_in_bytes += g_ind_2 * {{src}}_stride_z;
-
- T_LOAD({{DATA_TYPE}}, M0, N0, {{TENSOR_TYPE}}, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{src}});
-
- T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{src}}, {{dst}});
-}
-
-LOOP_UNROLLING(int, i, 0, 1, M0,
-{
- g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
- g_dst_indirect_y[i].v += (int)(g_ind_2 % {{arg_dst}}_h) * (int)({{arg_dst}}_w);
- g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
-})
-)_";
- }
- else
- {
- code += R"_(
-// IN/OUT(src, accum) {{src}}
-
-{
- T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{src}}, {{dst}});
-}
-)_";
- }
- code += R"_(
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
- return code;
-}
-
-void ClTemplateActivation::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplateActivation::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- TagLUT lut{};
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
-
- const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["arg_dst"] = dst_argument.uniq_name;
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["DATA_TYPE"] = get_cl_type_from_data_type(_src->data_type());
- lut["TENSOR_TYPE"] = "BUFFER";
-
- const auto f_act = lower_string(string_from_activation_func(_attributes.activation()));
-
- lut["ACT"] = f_act;
- lut["A_VAL"] = float_to_string_with_full_precision(_attributes.a());
- lut["B_VAL"] = float_to_string_with_full_precision(_attributes.b());
-
- return lut;
-}
-
-CLBuildOptions ClTemplateActivation::get_build_options(const ComponentGroup &comp_group) const
-{
- /// NOTE: For now tile sizes (n0, m0) are set by the execution window. This may change in the future
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
-
- CLBuildOptions build_opts;
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateActivation::get_config_id() const
-{
- std::string config_id{};
- config_id += "activation_";
- config_id += lower_string(string_from_data_type(_src->data_type()));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(1));
- return config_id;
-}
-
-std::set<std::string> ClTemplateActivation::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h", "activation_float_helpers.h"};
-}
-
-Window ClTemplateActivation::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
- const unsigned int n0 = adjust_vec_size(16 / _dst->element_size(), _dst->dimension(0));
- Window win = calculate_max_window(*_dst, Steps(n0));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h
deleted file mode 100644
index 88ee37034..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.h
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEACTIVATION
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEACTIVATION
-
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/function_info/ActivationLayerInfo.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentActivation.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateActivation final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentActivation::Attributes;
-
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateActivation(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
-
- /** Destructor */
- ~ClTemplateActivation() override = default;
-
- /** Prevent instances of this class from being copy constructed */
- ClTemplateActivation(const ClTemplateActivation &activation) = delete;
-
- /** Prevent instances of this class from being copied */
- ClTemplateActivation &operator=(const ClTemplateActivation &activation) = delete;
-
- /** Allow instances of this class to be move constructed */
- ClTemplateActivation(ClTemplateActivation &&activation) = default;
-
- /** Allow instances of this class to be moved */
- ClTemplateActivation &operator=(ClTemplateActivation &&activation) = default;
-
- /** Generate kernel component name */
- std::string get_name() const override;
-
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
-
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
-
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
-
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
-
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
- Attributes _attributes;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEACTIVATION */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
deleted file mode 100644
index 0da3a7380..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateCast.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateCast::ClTemplateCast(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}, _attributes{attributes}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
-
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
-}
-
-std::string ClTemplateCast::get_name() const
-{
- const size_t src_size = data_size_from_type(_src->data_type());
- const size_t dst_size = data_size_from_type(_dst->data_type());
-
- return (src_size >= dst_size) ? "cast_down" : "cast_up";
-}
-
-std::string ClTemplateCast::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- const std::string kernel_name = get_name();
- const auto is_root = (comp_group.get_root_component()->id() == this->id());
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} CAST ---------------------
-)_";
-
- if (is_root)
- {
- code += R"_(
-// IN_0(src) {{src}}
-// OUT(dst, accum) {{dst}}
-
-TILE(uint, M0, 1, g_dst_indirect_y);
-{
- {{src}}_offset_first_element_in_bytes += get_global_id(2) * {{src}}_stride_z;
-
- TILE({{DATA_TYPE_IN}}, M0, N0, {{tmp}});
- T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{tmp}});
-)_";
- }
-
- code += R"_(
- LOOP_UNROLLING(int, m0, 0, 1, M0,
- {
-)_";
-
- if (kernel_name == "cast_down" && is_data_type_quantized(_src->data_type()))
- {
- code += R"_(
- {{tmp}}[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80;
-)_";
- }
-
- if (kernel_name == "cast_down" &&
- (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE))
- {
- code += R"_(
- {{dst}}[m0].v = CONVERT_SAT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
-)_";
- }
- else
- {
- code += R"_(
- {{dst}}[m0].v = CONVERT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
-)_";
- }
-
- code += R"_(
- })
-)_";
-
- if (is_root)
- {
- code += R"_(
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
- g_dst_indirect_y[i].v += (int)(g_ind_2 % {{arg_dst}}_h) * (int)({{arg_dst}}_w);
- g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
- })
-}
-)_";
- }
-
- code += R"_(
-//------------------ END KERNEL {{meta_kernel_id}} CAST ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateCast::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplateCast::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- const auto is_root = (comp_group.get_root_component()->id() == this->id());
-
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
- lut["tmp"] = (is_root) ? lut["src"].value + "_in_data" : lut["src"];
-
- const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["arg_dst"] = dst_argument.uniq_name;
-
- // Local build options
- lut["meta_kernel_id"] = id();
-
- lut["DATA_TYPE_IN"] = get_cl_type_from_data_type(_src->data_type());
- lut["DATA_TYPE_OUT"] = get_cl_type_from_data_type(_dst->data_type());
-
- return lut;
-}
-
-CLBuildOptions ClTemplateCast::get_build_options(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
-
- // Set build options
- CLBuildOptions build_opts{};
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(_src->dimension(0) % n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
-
- return build_opts;
-}
-
-std::string ClTemplateCast::get_config_id() const
-{
- std::string config_id{};
-
- config_id += "_";
- config_id += lower_string(string_from_data_type(_src->data_type()));
- config_id += "_";
- config_id += lower_string(string_from_data_type(_dst->data_type()));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(1));
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateCast::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateCast::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- const unsigned int n0 = adjust_vec_size(16 / _dst->element_size(), _dst->dimension(0));
- Window win = calculate_max_window(*_dst, Steps(n0));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.h
deleted file mode 100644
index 3adca4edc..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.h
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATECAST
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATECAST
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateCast final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentCast::Attributes;
-
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateCast(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateCast(const ClTemplateCast &cast) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateCast &operator=(const ClTemplateCast &cast) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateCast(ClTemplateCast &&cast) = default;
- /** Allow instances of this class to be moved */
- ClTemplateCast &operator=(ClTemplateCast &&cast) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
- Attributes _attributes;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATECAST */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp
deleted file mode 100644
index 8380620ab..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp
+++ /dev/null
@@ -1,364 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateDepthwiseConv2d.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateDepthwiseConv2d::ClTemplateDepthwiseConv2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings)
- : IGpuTemplateComponentWriter{id, tensors},
- _src{},
- _weight{},
- _bias{},
- _dst{},
- _attributes{attributes},
- _settings{settings}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _weight = this->tensors().get_const_tensor(TensorType::ACL_SRC_1);
- if (this->tensors().get_const_tensor(TensorType::ACL_SRC_2))
- {
- _bias = this->tensors().get_const_tensor(TensorType::ACL_SRC_2);
- }
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _weight, _dst);
-}
-
-std::string ClTemplateDepthwiseConv2d::get_name() const
-{
- return "depthwise_conv2d";
-}
-
-std::string ClTemplateDepthwiseConv2d::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- constexpr int height_idx = 2; // Data Layout is NHWC
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-// IN_0(src) {{src}}
-// IN_1(wei) {{weight}}
-)_";
-
- if (_bias != nullptr && _bias->has_valid_id())
- {
- code += R"_(
-// IN_1(bia) {{bias}}
-)_";
- }
-
- code += R"_(
-// OUT(dst, accum) {{dst}}
-
-TILE(uint, M0, 1, g_dst_indirect_y);
-
-{
-#define _IWEI_WIDTH {{WEI_WIDTH}}
-#define _IWEI_HEIGHT {{WEI_HEIGHT}}
-#define _IDST_WIDTH {{arg_dst}}_w
-#define _IDST_HEIGHT {{arg_dst}}_h
-#define _IM0_A M0_A
-#define _IN0_A N0_A
-#define _IM0_B _IWEI_WIDTH
-#define _IN0_B N0
-#define _IBOUNDARY_CHECK (!((_IWEI_WIDTH == 1 && _IWEI_HEIGHT == 1 && {{PAD_LEFT}} == 0 && {{PAD_TOP}} == 0 && M0 == 1)))
-)_";
-
- code += R"_(
- const int yo = g_ind_2 % {{arg_dst}}_h;
- const int bout = g_ind_2 / {{arg_dst}}_h;
-)_";
-
- code += R"_(
-
- int xi = g_ind_1 * {{STRIDE_X}};
- int yi = yo * {{STRIDE_Y}};
- xi -= {{PAD_LEFT}};
- yi -= {{PAD_TOP}};
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- {{dst}}[i].v = 0;
- })
-)_";
-
- if (_weight->dimension(height_idx) < 5)
- {
- code += R"_(
- LOOP_UNROLLING(int, yk, 0, 1, _IWEI_HEIGHT,
-)_";
- }
- else
- {
- code += R"_(
- for(int yk = 0; yk < _IWEI_HEIGHT; ++yk)
-)_";
- }
-
- code += R"_(
- {
- TILE({{SRC_DATA_TYPE}}, _IM0_A, _IN0_A, a);
-
- LOOP_UNROLLING(int, i, 0, 1, _IM0_A,
- {
- a[i].v = 0;
- })
-
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, _IM0_A, _IN0_A, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi + yk * {{DILATION_Y}}, xi, (g_ind_0 / {{DEPTH_MULTIPLIER}}), {{src}}_w, {{src}}_h, {{DILATION_X}}, 1, _IBOUNDARY_CHECK, a);
-
- TILE({{WEI_DATA_TYPE}}, _IM0_B, _IN0_B, b);
-
- T_LOAD({{WEI_DATA_TYPE}}, _IM0_B, _IN0_B, {{WEI_TENSOR_TYPE}}, {{weight}}, g_ind_0, yk * _IM0_B, 1, {{weight}}_stride_y, b);
-
- LOOP_UNROLLING(int, m0, 0, 1, M0,
- {
- LOOP_UNROLLING(int, xk, 0, 1, _IWEI_WIDTH,
- {
-)_";
-
- if (!_settings.is_fma_available())
- {
- code += R"_(
- {{dst}}[m0].v += a[xk + m0].v * b[xk].v;
-)_";
- }
- else
- {
- code += R"_(
- {{dst}}[m0].v = fma(a[xk + m0].v, b[xk].v, {{dst}}[m0].v);
-)_";
- }
-
- code += R"_(
- })
- })
- }
-)_";
-
- if (_weight->dimension(height_idx) < 5)
- {
- code += R"_(
- )
-)_";
- }
-
- if (_bias && _bias->has_valid_id())
- {
- code += R"_(
- TILE({{BIA_DATA_TYPE}}, 1, N0, {{bias}});
-
- T_LOAD({{BIA_DATA_TYPE}}, 1, N0, BUFFER, {{bias}}, g_ind_0, 0, 0, 0, {{bias}});
-
- T_ELTWISE_BROADCAST_ADD_X({{ACC_DATA_TYPE}}, M0, N0, {{dst}}, {{bias}}, {{dst}});
-)_";
- }
-
- code += R"_(
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
- g_dst_indirect_y[i].v += (int)(g_ind_2 % {{arg_dst}}_h) * (int)({{arg_dst}}_w);
- g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
- })
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateDepthwiseConv2d::declare_variables(GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- const GpuKernelArgumentInfo::Type input_type = _settings.export_input_to_cl_image()
- ? GpuKernelArgumentInfo::Type::Tensor_4D_t_Image
- : GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
-
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(input_type), "src");
-
- const GpuKernelArgumentInfo::Type weight_type = _settings.export_weights_to_cl_image()
- ? GpuKernelArgumentInfo::Type::Tensor_4D_t_Image
- : GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
-
- vtable.declare_variable(comp_group, _weight, GpuKernelArgumentInfo(weight_type), "weight");
-
- if (_bias != nullptr && _bias->has_valid_id()) // optional bias
- {
- vtable.declare_variable(comp_group, _bias, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Vector), "bias");
- }
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplateDepthwiseConv2d::get_tag_lut(const GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["weight"] = vtable.get_variable(_weight);
-
- if (_bias != nullptr && _bias->has_valid_id()) // optional bias
- {
- lut["bias"] = vtable.get_variable(_bias);
- lut["BIA_DATA_TYPE"] = get_cl_type_from_data_type(_bias->data_type());
- }
- lut["dst"] = vtable.get_variable(_dst);
-
- const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["arg_dst"] = dst_argument.uniq_name;
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["ACC_DATA_TYPE"] = _src->data_type();
- lut["SRC_DATA_TYPE"] = _src->data_type();
- lut["WEI_DATA_TYPE"] = _weight->data_type();
-
- switch (vtable.get_variable(_src).kernel_argument_info.type)
- {
- case GpuKernelArgumentInfo::Type::Image_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Image_3D_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Image:
- lut["SRC_TENSOR_TYPE"] = "IMAGE";
- break;
- default:
- lut["SRC_TENSOR_TYPE"] = "BUFFER";
- break;
- }
-
- switch (vtable.get_variable(_weight).kernel_argument_info.type)
- {
- case GpuKernelArgumentInfo::Type::Image_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Image_3D_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Image:
- lut["WEI_TENSOR_TYPE"] = "IMAGE";
- break;
- default:
- lut["WEI_TENSOR_TYPE"] = "BUFFER";
- break;
- }
-
- // Data Layout is NHWC
- constexpr int width_idx = 1;
- constexpr int height_idx = 2;
-
- lut["WEI_WIDTH"] = _weight->dimension(width_idx);
- lut["WEI_HEIGHT"] = _weight->dimension(height_idx);
-
- lut["STRIDE_X"] = _attributes.stride().x();
- lut["STRIDE_Y"] = _attributes.stride().y();
-
- lut["PAD_LEFT"] = _attributes.pad().left;
- lut["PAD_TOP"] = _attributes.pad().top;
-
- lut["DILATION_X"] = _attributes.dilation().x();
- lut["DILATION_Y"] = _attributes.dilation().y();
-
- lut["DEPTH_MULTIPLIER"] = _attributes.depth_multiplier();
-
- return lut;
-}
-
-CLBuildOptions ClTemplateDepthwiseConv2d::get_build_options(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- constexpr unsigned int width_idx = 1; // Data Layout is NHWC
-
- const unsigned int n0 = _settings.n0();
- const unsigned int m0 = _settings.m0();
- const unsigned int m0_a = _weight->dimension(width_idx) + m0 - 1;
- const unsigned int n0_a = _attributes.depth_multiplier() > 1 ? 1 : n0;
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
-
- CLBuildOptions build_opts{};
-
- if (_settings.fast_relaxed_math())
- {
- build_opts.add_option("-cl-fast-relaxed-math");
- }
- else
- {
- // -cl-fast-relaxed-math also sets -cl-finite-math-only and -cl-unsafe-math-optimizations
- // to disable -cl-finite-math-only, we only include -cl-unsafe-math-optimizations
- build_opts.add_option("-cl-unsafe-math-optimizations");
- }
-
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DN0_A=" + support::cpp11::to_string(n0_a));
- build_opts.add_option("-DM0_A=" + support::cpp11::to_string(m0_a));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateDepthwiseConv2d::get_config_id() const
-{
- std::string config_id{};
-
- config_id += support::cpp11::to_string(_src->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(1));
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(2));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(1));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(2));
- config_id += "_";
- config_id += string_from_data_type(_src->data_type());
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateDepthwiseConv2d::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateDepthwiseConv2d::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- Window win = calculate_max_window(*_dst, Steps(_settings.n0(), _settings.m0()));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.h
deleted file mode 100644
index 5d04c687c..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.h
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDEPTHWISECONV2D
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDEPTHWISECONV2D
-
-#include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDepthwiseConv2d.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateDepthwiseConv2d final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentDepthwiseConv2d::Attributes;
- using Settings = ClComponentDepthwiseConv2d::Settings;
- /** Constructor
- *
- * Similar to @ref ClComponentDepthwiseConv2d::validate()
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- * @param[in] settings Component settings
- */
- ClTemplateDepthwiseConv2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateDepthwiseConv2d(const ClTemplateDepthwiseConv2d &depthwise_conv2d) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateDepthwiseConv2d &operator=(const ClTemplateDepthwiseConv2d &depthwise_conv2d) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateDepthwiseConv2d(ClTemplateDepthwiseConv2d &&depthwise_conv2d) = default;
- /** Allow instances of this class to be moved */
- ClTemplateDepthwiseConv2d &operator=(ClTemplateDepthwiseConv2d &&depthwise_conv2d) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_weight;
- const ITensorInfo *_bias;
- const ITensorInfo *_dst;
- Attributes _attributes;
- Settings _settings;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDEPTHWISECONV2D */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp
deleted file mode 100644
index f6a7a58d1..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp
+++ /dev/null
@@ -1,393 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateDirectConv2d.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateDirectConv2d::ClTemplateDirectConv2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings)
- : IGpuTemplateComponentWriter{id, tensors},
- _src{},
- _weight{},
- _bias{},
- _dst{},
- _attributes{attributes},
- _settings{settings}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _weight = this->tensors().get_const_tensor(TensorType::ACL_SRC_1);
- if (this->tensors().get_const_tensor(TensorType::ACL_SRC_2))
- {
- _bias = this->tensors().get_const_tensor(TensorType::ACL_SRC_2);
- }
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _weight, _dst);
-}
-
-std::string ClTemplateDirectConv2d::get_name() const
-{
- return "direct_conv2d";
-}
-
-std::string ClTemplateDirectConv2d::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- const auto channel_idx = get_data_layout_dimension_index(_src->data_layout(), DataLayoutDimension::CHANNEL);
- const auto k0 = adjust_vec_size(_settings.direct_conv_descriptor().k0, _src->dimension(channel_idx));
- const bool leftover_loop = (_src->dimension(channel_idx) % k0) != 0;
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-// IN_0(src) {{src}}
-// IN_1(wei) {{weight}}
-)_";
- if (_bias && _bias->has_valid_id())
- {
- code += R"_(
-// IN_1(bia) {{bias}}
-)_";
- }
- code += R"_(
-// OUT(dst, accum) {{dst}}
-
-TILE(uint, M0, 1, g_dst_indirect_y);
-
-{
-#define _IWEI_WIDTH {{WEI_WIDTH}}
-#define _IWEI_HEIGHT {{WEI_HEIGHT}}
-#define _ISRC_WIDTH {{SRC_WIDTH}}
-#define _ISRC_HEIGHT {{SRC_HEIGHT}}
-#define _ISRC_CHANNELS {{SRC_CHANNELS}}
-#define _IDST_WIDTH {{DST_WIDTH}}
-#define _IDST_HEIGHT {{DST_HEIGHT}}
-#define _IDST_CHANNELS {{DST_CHANNELS}}
-#define _IY_MULTIPLIER (_IWEI_WIDTH * _IWEI_HEIGHT)
-
- TILE(int, M0, 1, xi);
- TILE(int, M0, 1, yi);
-
- // Convert the linear index to coordinate
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- xi[0].s[i] = ((g_ind_1 + i) % _IDST_WIDTH) * {{STRIDE_X}};
- yi[0].s[i] = ((g_ind_1 + i) / _IDST_WIDTH) * {{STRIDE_Y}};
- xi[0].s[i] -= {{PAD_LEFT}};
- yi[0].s[i] -= {{PAD_TOP}};
- })
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- {{dst}}[i].v = 0;
- })
-
- for(int i = 0; i < (_IWEI_WIDTH * _IWEI_HEIGHT); ++i)
- {
- int xk = i % _IWEI_WIDTH;
- int yk = i / _IWEI_WIDTH;
-
- TILE(int, 1, M0, my);
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- int x_s = xi[0].s[i] + xk;
- int y_s = yi[0].s[i] + yk;
- my[0].s[i] = x_s + y_s *_ISRC_WIDTH;
- my[0].s[i] = my[0].s[i] + g_ind_2 * (int)(_ISRC_WIDTH * _ISRC_HEIGHT);
- my[0].s[i] = select(-1, my[0].s[i], x_s >= 0);
- my[0].s[i] = select(-1, my[0].s[i], x_s < _ISRC_WIDTH);
- my[0].s[i] = select(-1, my[0].s[i], y_s >= 0);
- my[0].s[i] = select(-1, my[0].s[i], y_s < _ISRC_HEIGHT);
- })
-
- int ck = 0;
- for(; ck <= (_ISRC_CHANNELS - K0); ck += K0)
- {
- TILE({{SRC_DATA_TYPE}}, M0, K0, a);
- TILE({{WEI_DATA_TYPE}}, N0, K0, b);
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- a[i].v = {{ZERO_VALUE}};
- })
-
- LOOP_UNROLLING(int, i, 0, 1, N0,
- {
- b[i].v = {{ZERO_VALUE}};
- })
-
- T_LOAD2D_INDIRECT({{SRC_DATA_TYPE}}, M0, K0, {{SRC_TENSOR_TYPE}}, {{src}}, ck, {{src}}_stride_y, my, a);
-
- T_LOAD({{WEI_DATA_TYPE}}, N0, K0, {{WEI_TENSOR_TYPE}}, {{weight}}, ck, g_ind_0 * _IY_MULTIPLIER + i, _IY_MULTIPLIER, {{weight}}_stride_y, b);
-
- T_MMUL({{SRC_DATA_TYPE}}, {{WEI_DATA_TYPE}}, {{ACC_DATA_TYPE}}, M0, N0, K0, NT, T, a, b, {{dst}});
- }
-)_";
-
- if (leftover_loop)
- {
- code += R"_(
- for(; ck < _ISRC_CHANNELS; ++ck)
- {
- TILE({{SRC_DATA_TYPE}}, M0, 1, a);
- TILE({{WEI_DATA_TYPE}}, N0, 1, b);
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- a[i].v = {{ZERO_VALUE}};
- })
-
- LOOP_UNROLLING(int, i, 0, 1, N0,
- {
- b[i].v = {{ZERO_VALUE}};
- })
-
- T_LOAD2D_INDIRECT({{SRC_DATA_TYPE}}, M0, 1, {{SRC_TENSOR_TYPE}}, {{src}}, ck, {{src}}_stride_y, my, a);
-
- T_LOAD({{WEI_DATA_TYPE}}, N0, 1, BUFFER, {{weight}}, ck, g_ind_0 * _IY_MULTIPLIER + i, _IY_MULTIPLIER, {{weight}}_stride_y, b);
-
- T_MMUL({{SRC_DATA_TYPE}}, {{WEI_DATA_TYPE}}, {{ACC_DATA_TYPE}}, M0, N0, 1, NT, T, a, b, {{dst}});
- }
- )_";
- }
-
- code += R"_(
-#undef _I_WEI_WIDTH
-#undef _I_WEI_HEIGHT
-#undef _ISRC_WIDTH
-#undef _ISRC_HEIGHT
-#undef _ISRC_CHANNELS
-#undef _IDST_WIDTH
-#undef _IDST_HEIGHT
-#undef _IDST_CHANNELS
-#undef _IY_MULTIPLIER
-
- }
-)_";
-
- if (_bias && _bias->has_valid_id())
- {
- code += R"_(
- TILE({{BIA_DATA_TYPE}}, 1, N0, bias0);
-
- T_LOAD({{BIA_DATA_TYPE}}, 1, N0, BUFFER, {{bias}}, g_ind_0, 0, 1, 0, bias0);
-
- T_ELTWISE_BROADCAST_ADD_X({{ACC_DATA_TYPE}}, M0, N0, {{dst}}, bias0, {{dst}});
- )_";
- }
-
- code += R"_(
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{DST_WIDTH}} * {{DST_HEIGHT}}) - 1);
- g_dst_indirect_y[i].v += g_ind_2 * (int)({{DST_WIDTH}} * {{DST_HEIGHT}});
- })
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
- return code;
-}
-
-void ClTemplateDirectConv2d::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
-
- const GpuKernelArgumentInfo::Type weight_type = _settings.export_to_cl_image()
- ? GpuKernelArgumentInfo::Type::Tensor_4D_t_Image
- : GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
- vtable.declare_variable(comp_group, _weight, GpuKernelArgumentInfo(weight_type), "weight");
-
- if (_bias && _bias->has_valid_id()) // optional bias
- {
- vtable.declare_variable(comp_group, _bias, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Vector), "bias");
- }
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(common_tensor_type), "dst");
-}
-
-TagLUT ClTemplateDirectConv2d::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- TagLUT lut{};
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["weight"] = vtable.get_variable(_weight);
-
- if (_bias && _bias->has_valid_id()) // optional bias
- {
- lut["bias"] = vtable.get_variable(_bias);
- lut["BIA_DATA_TYPE"] = get_cl_type_from_data_type(_bias->data_type());
- }
- lut["dst"] = vtable.get_variable(_dst);
-
- const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["arg_dst"] = dst_argument.uniq_name;
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["ACC_DATA_TYPE"] = _src->data_type();
- lut["SRC_DATA_TYPE"] = _src->data_type();
- lut["WEI_DATA_TYPE"] = _weight->data_type();
-
- lut["SRC_TENSOR_TYPE"] = "BUFFER";
- switch (vtable.get_variable(_weight).kernel_argument_info.type)
- {
- case GpuKernelArgumentInfo::Type::Image_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Image_3D_Export_To_ClImage2D:
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Image:
- {
- lut["WEI_TENSOR_TYPE"] = "IMAGE";
- break;
- }
- default:
- {
- lut["WEI_TENSOR_TYPE"] = "BUFFER";
- break;
- }
- }
- const auto width_idx = 1;
- const auto height_idx = 2;
- const auto channel_idx = 0;
-
- lut["SRC_WIDTH"] = _src->dimension(width_idx);
- lut["SRC_HEIGHT"] = _src->dimension(height_idx);
- lut["SRC_CHANNELS"] = _src->dimension(channel_idx);
-
- lut["WEI_WIDTH"] = _weight->dimension(width_idx);
- lut["WEI_HEIGHT"] = _weight->dimension(height_idx);
-
- lut["DST_WIDTH"] = _dst->dimension(width_idx);
- lut["DST_HEIGHT"] = _dst->dimension(height_idx);
- lut["DST_CHANNELS"] = _dst->dimension(channel_idx);
-
- lut["STRIDE_X"] = _attributes.stride().x();
- lut["STRIDE_Y"] = _attributes.stride().y();
-
- lut["PAD_LEFT"] = _attributes.pad().left;
- lut["PAD_TOP"] = _attributes.pad().top;
-
- lut["ZERO_VALUE"] = 0;
-
- return lut;
-}
-
-CLBuildOptions ClTemplateDirectConv2d::get_build_options(const ComponentGroup &comp_group) const
-{
- const unsigned int channel_idx = get_data_layout_dimension_index(_src->data_layout(), DataLayoutDimension::CHANNEL);
-
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
- const unsigned int k0 = adjust_vec_size(_settings.direct_conv_descriptor().k0, _src->dimension(channel_idx));
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
-
- CLBuildOptions build_opts{};
- if (_settings.fast_relaxed_math())
- {
- build_opts.add_option("-cl-fast-relaxed-math");
- }
- else
- {
- // -cl-fast-relaxed-math also sets -cl-finite-math-only and -cl-unsafe-math-optimizations
- // to disable -cl-finite-math-only, we only include -cl-unsafe-math-optimizations
- build_opts.add_option("-cl-unsafe-math-optimizations");
- }
-
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DK0=" + support::cpp11::to_string(k0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateDirectConv2d::get_config_id() const
-{
- const DataType data_type = _src->data_type();
- const DataLayout data_layout = _src->data_layout();
-
- const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-
- const unsigned int kernel_size = _weight->dimension(width_idx);
-
- std::string config_id{};
- config_id += lower_string(string_from_data_type(data_type));
- config_id += "_";
- config_id += support::cpp11::to_string(kernel_size);
- config_id += "_";
- config_id += support::cpp11::to_string(_attributes.stride().x());
- config_id += "_";
- config_id += support::cpp11::to_string(_attributes.stride().y());
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(width_idx));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(height_idx));
- config_id += "_";
- config_id += lower_string(string_from_data_layout(data_layout));
- return config_id;
-}
-
-std::set<std::string> ClTemplateDirectConv2d::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateDirectConv2d::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- const auto output_shape = _dst->tensor_shape();
- const auto desc = _settings.direct_conv_descriptor();
-
- const unsigned int n0 = adjust_vec_size(desc.n0, output_shape[0]);
- const unsigned int m0 = adjust_vec_size(desc.m0, output_shape[1] * output_shape[2]);
-
- // Create and configure kernel window
- Window win = calculate_max_window(output_shape, Steps(n0, m0));
-
- const size_t dim_y_collapsed = ceil_to_multiple(output_shape[1] * output_shape[2], m0);
- win.set(Window::DimY, Window::Dimension(0, dim_y_collapsed, m0));
- win.set(Window::DimZ, Window::Dimension(0, output_shape.total_size_upper(3), 1));
-
- return win;
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.h
deleted file mode 100644
index 03c8cd2f1..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.h
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDIRECTCONV2D
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDIRECTCONV2D
-
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateDirectConv2d final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentDirectConv2d::Attributes;
- using Settings = ClComponentDirectConv2d::Settings;
- /** Constructor
- *
- * Similar to @ref ClComponentDirectConv2d::validate()
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- * @param[in] settings Component settings
- */
- ClTemplateDirectConv2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings);
- /** Destructor */
- ~ClTemplateDirectConv2d() override = default;
- /** Prevent instances of this class from being copy constructed */
- ClTemplateDirectConv2d(const ClTemplateDirectConv2d &direct_conv2d) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateDirectConv2d &operator=(const ClTemplateDirectConv2d &direct_conv2d) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateDirectConv2d(ClTemplateDirectConv2d &&direct_conv2d) = default;
- /** Allow instances of this class to be moved */
- ClTemplateDirectConv2d &operator=(ClTemplateDirectConv2d &&direct_conv2d) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_weight;
- const ITensorInfo *_bias;
- const ITensorInfo *_dst;
- Attributes _attributes;
- Settings _settings;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEDIRECTCONV2D */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
deleted file mode 100644
index 78bff3c3f..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
+++ /dev/null
@@ -1,274 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateElementwiseBinary.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-constexpr unsigned int vector_size_byte_opencl = 16;
-
-ClTemplateElementwiseBinary::ClTemplateElementwiseBinary(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _lhs{}, _rhs{}, _dst{}, _attributes{attributes}
-{
- _lhs = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _rhs = this->tensors().get_const_tensor(TensorType::ACL_SRC_1);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_lhs, _rhs, _dst);
-}
-
-std::string ClTemplateElementwiseBinary::get_name() const
-{
- return "elementwise_binary";
-}
-
-std::string ClTemplateElementwiseBinary::get_component_code(const ComponentGroup &comp_group) const
-{
- std::string code;
- const bool is_root = (comp_group.get_root_component()->id() == this->id());
- const bool is_lhs_input = comp_group.is_input_tensor(_lhs);
- const bool is_rhs_input = comp_group.is_input_tensor(_rhs);
-
- code =
- R"_(
- //------------------ START KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} ---------------------
-)_";
-
- if (is_root)
- {
- code +=
- R"_(
- TILE(uint, M0, 1, g_dst_indirect_y);
-)_";
- }
-
- if (is_lhs_input)
- {
- code +=
- R"_(
- TILE({{DATA_TYPE}}, {{lhs_m0}}, N0, {{lhs}});
-)_";
- }
-
- if (is_rhs_input)
- {
- code +=
- R"_(
- TILE({{DATA_TYPE}}, {{rhs_m0}}, N0, {{rhs}});
-)_";
- }
-
- code +=
- R"_(
- {
-)_";
-
- if (is_lhs_input)
- {
- code +=
- R"_(
- {{lhs}}_offset_first_element_in_bytes += g_ind_2 * {{lhs}}_stride_w;
- T_LOAD({{DATA_TYPE}}, {{lhs_m0}}, {{lhs_n0}}, BUFFER, {{lhs}}, {{lhs_start_ind_0}}, {{lhs_start_ind_1}}, 1, {{lhs}}_stride_y, {{lhs}});
-)_";
- }
-
- if (is_rhs_input)
- {
- code +=
- R"_(
- {{rhs}}_offset_first_element_in_bytes += g_ind_2 * {{rhs}}_stride_w;
- T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{rhs}}, {{rhs_start_ind_0}}, {{rhs_start_ind_1}}, 1, {{rhs}}_stride_y, {{rhs}});
-)_";
- }
-
- code +=
- R"_(
- T_ELTWISE_{{BROADCAST_OP}}{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}});
-)_";
-
- if (is_root)
- {
- // Calculate the destination indirect Y
- code +=
- R"_(
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{arg_dst}}_w * {{arg_dst}}_h) - 1);
- g_dst_indirect_y[i].v += g_ind_2 * (int)({{arg_dst}}_w * {{arg_dst}}_h);
- })
-)_";
- }
-
- code +=
- R"_(
- }
- //------------------ END KERNEL {{meta_kernel_id}} {{ELTWISE_OP}} ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateElementwiseBinary::declare_variables(GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _lhs, GpuKernelArgumentInfo(common_tensor_type), "lhs");
-
- vtable.declare_variable(comp_group, _rhs, GpuKernelArgumentInfo(common_tensor_type), "rhs");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(common_tensor_type), "dst");
-}
-
-TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- TagLUT lut{};
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["DATA_TYPE"] = get_cl_type_from_data_type(_lhs->data_type());
- // Arguments and global shared variables
-
- lut["lhs"] = vtable.get_variable(_lhs);
- lut["rhs"] = vtable.get_variable(_rhs);
- lut["dst"] = vtable.get_variable(_dst);
- lut["arg_dst"] = vtable.get_variable(comp_group.get_any_dst_tensor());
-
- switch (_attributes.operation())
- {
- case Attributes::ElementwiseOp::Add:
- lut["ELTWISE_OP"] = "ADD";
- break;
- case Attributes::ElementwiseOp::Sub:
- lut["ELTWISE_OP"] = "SUB";
- break;
- case Attributes::ElementwiseOp::Mul:
- lut["ELTWISE_OP"] = "MUL";
- break;
- default:
- ARM_COMPUTE_ERROR("Arithmetic Operation not supported");
- }
-
- ARM_COMPUTE_ERROR_ON(comp_group.is_intermediate_tensor(_lhs) &&
- detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0));
- ARM_COMPUTE_ERROR_ON(comp_group.is_intermediate_tensor(_rhs) &&
- detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0));
-
- // Set broadcast parameters
- // PRE: All tensors are broadcast-compatible
- const auto &lhs_dims = _lhs->tensor_shape();
- const auto &rhs_dims = _rhs->tensor_shape();
- const auto &dst_dims = _dst->tensor_shape();
-
- const auto lhs_broadcast_x = dst_dims[0] != 1 && lhs_dims[0] == 1;
- const auto rhs_broadcast_x = dst_dims[0] != 1 && rhs_dims[0] == 1;
- const auto lhs_broadcast_y = dst_dims[1] != 1 && lhs_dims[1] == 1;
- const auto rhs_broadcast_y = dst_dims[1] != 1 && rhs_dims[1] == 1;
- const auto lhs_broadcast_z = dst_dims[2] != 1 && lhs_dims[2] == 1;
- const auto rhs_broadcast_z = dst_dims[2] != 1 && rhs_dims[2] == 1;
-
- const auto lhs_broadcast_yz = lhs_broadcast_y && lhs_broadcast_z;
- const auto rhs_broadcast_yz = rhs_broadcast_y && rhs_broadcast_z;
-
- lut["lhs_n0"] = (lhs_broadcast_x) ? "1" : "N0";
- lut["lhs_start_ind_0"] = (lhs_broadcast_x) ? "0" : "g_ind_0";
- lut["rhs_n0"] = (rhs_broadcast_x) ? "1" : "N0";
- lut["rhs_start_ind_0"] = (rhs_broadcast_x) ? "0" : "g_ind_0";
-
- lut["lhs_m0"] = (lhs_broadcast_yz) ? "1" : "M0";
- lut["lhs_start_ind_1"] = (lhs_broadcast_yz) ? "0" : "g_ind_1";
- lut["rhs_m0"] = (rhs_broadcast_yz) ? "1" : "M0";
- lut["rhs_start_ind_1"] = (rhs_broadcast_yz) ? "0" : "g_ind_1";
-
- lut["BROADCAST_OP"] = (lhs_broadcast_yz) ? "BROADCAST_LHS_X_" : (rhs_broadcast_yz) ? "BROADCAST_RHS_X_" : "";
-
- return lut;
-}
-
-CLBuildOptions ClTemplateElementwiseBinary::get_build_options(const ComponentGroup &comp_group) const
-{
- CLBuildOptions build_opts{};
- /// NOTE: For now tile sizes (n0, m0) are set by the execution window. This may change in the future
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
-
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(_lhs->data_type()));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateElementwiseBinary::get_config_id() const
-{
- std::string config_id{};
- config_id += lower_string(string_from_data_type(_dst->data_type()));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(1));
- config_id += "_";
- config_id += lower_string(string_from_data_layout(_dst->data_layout()));
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateElementwiseBinary::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateElementwiseBinary::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- TensorShape output_shape = _dst->tensor_shape();
- // Collapse Dim 1 (W) and Dim 2 (H) together, leave Dim 0 (C) and upper dimensions unchanged
- // This is in line with the collapsing convention used by operators like Conv2d
- output_shape.collapse(2U, 1U);
- const unsigned int num_elems_processed_per_iteration =
- adjust_vec_size(vector_size_byte_opencl / _dst->element_size(), _dst->dimension(0));
- Window win = calculate_max_window(output_shape, Steps(num_elems_processed_per_iteration));
-
- return win;
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h
deleted file mode 100644
index 991c0eca4..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEELEMENTWISEBINARY
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEELEMENTWISEBINARY
-
-#include "arm_compute/core/experimental/Types.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateElementwiseBinary final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentElementwiseBinary::Attributes;
-
- /** Constructor
- *
- * Similar to @ref ClComponentElementwiseBinary::validate()
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateElementwiseBinary(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateElementwiseBinary(const ClTemplateElementwiseBinary &elementwise) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateElementwiseBinary &operator=(const ClTemplateElementwiseBinary &elementwise) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateElementwiseBinary(ClTemplateElementwiseBinary &&elementwise) = default;
- /** Allow instances of this class to be moved */
- ClTemplateElementwiseBinary &operator=(ClTemplateElementwiseBinary &&elementwise) = default;
-
- /** Generate kernel component name */
- std::string get_name() const override;
-
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
-
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
-
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
-
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
-
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_lhs;
- const ITensorInfo *_rhs;
- const ITensorInfo *_dst;
- Attributes _attributes;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEELEMENTWISEBINARY */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp
deleted file mode 100644
index 522c33a02..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp
+++ /dev/null
@@ -1,267 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-namespace
-{
-constexpr unsigned int serial_vector_size = 8;
-} // namespace
-ClTemplateLogits1DMaxShiftExpSum::ClTemplateLogits1DMaxShiftExpSum(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _sum{}, _dst{}, _attributes{attributes}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _sum = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_1);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_sum);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_dst);
-}
-
-std::string ClTemplateLogits1DMaxShiftExpSum::get_name() const
-{
- return "logits_1d_max_shift_exp_sum";
-}
-
-std::string ClTemplateLogits1DMaxShiftExpSum::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-#define VEC_TYPE VEC_DATA_TYPE({{DATA_TYPE}}, N0)
-#define SELECT_TYPE SELECT_VEC_DATA_TYPE({{DATA_TYPE}}, N0)
-{
- __global uchar *src_addr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes + g_ind_1 * {{src}}_stride_y + g_ind_2 * {{src}}_stride_z;
- __global uchar *dst_addr = {{dst}}_ptr + {{dst}}_offset_first_element_in_bytes + g_ind_1 * {{dst}}_stride_y + g_ind_2 * {{dst}}_stride_z;
- Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT({{sum}});
- VEC_TYPE max_val_vec = (VEC_TYPE)({{MINVAL}});
-)_";
-
- const bool beta_defined = (_attributes.beta() != 1.f);
-
- if (beta_defined)
- {
- code += R"_(
- VEC_TYPE beta = (VEC_TYPE){{BETA}};
-)_";
- }
-
- constexpr unsigned int _serial_vector_size = 8;
- const unsigned int reduction_dim_size = _src->dimension(0);
- const unsigned int vector_size = adjust_vec_size(_serial_vector_size, reduction_dim_size);
- const bool non_multiple_of_n0 = ((reduction_dim_size % vector_size) != 0);
-
- if (non_multiple_of_n0)
- {
- code += R"_(
- VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)src_addr);
- SELECT_TYPE widx = (SELECT_TYPE)PARTIAL_N0 > VEC_OFFS(SELECT_DATA_TYPE({{DATA_TYPE}}), N0);
- max_val_vec = max(max_val_vec, select((VEC_TYPE)({{MINVAL}}), data, widx));
-)_";
- }
-
- code += R"_(
- for(uint i = PARTIAL_N0; i < {{SRC_WIDTH}}; i += N0)
- {
- VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(src_addr + i * sizeof({{DATA_TYPE}})));
- max_val_vec = max(data, max_val_vec);
- }
-
- {{DATA_TYPE}} max_val = MAX_REDUCE(max_val_vec, N0);
- VEC_TYPE sum1D = 0;
-)_";
-
- if (non_multiple_of_n0)
- {
- code += R"_(
- data -= max_val;
-)_";
- if (beta_defined)
- {
- code += R"_(
- data *= beta;
-)_";
- }
-
- if (_attributes.is_log_softmax())
- {
- code += R"_(
- VSTORE_PARTIAL(N0, PARTIAL_N0)
- (data, 0, (__global {{DATA_TYPE}} *)dst_addr);
- data = exp(data);
- data = select(0, data, widx);
-)_";
- }
- else
- {
- code += R"_(
- data = exp(data);
- data = select(0, data, widx);
- VSTORE_PARTIAL(N0, PARTIAL_N0)
- (data, 0, (__global {{DATA_TYPE}} *)dst_addr);
-)_";
- }
-
- code += R"_(
- sum1D += data;
-)_";
- }
- code += R"_(
- for(uint i = PARTIAL_N0; i < {{SRC_WIDTH}}; i += N0)
- {
- VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(src_addr + i * sizeof({{DATA_TYPE}})));
- data -= max_val;
-)_";
-
- if (beta_defined)
- {
- code += R"_(
- data *= beta;
-)_";
- }
-
- if (_attributes.is_log_softmax())
- {
- code += R"_(
- VSTORE(N0)
- (data, 0, (__global {{DATA_TYPE}} *)(dst_addr + i * sizeof({{DATA_TYPE}})));
- data = exp(data);
-)_";
- }
- else
- {
- code += R"_(
- data = exp(data);
- VSTORE(N0)
- (data, 0, (__global {{DATA_TYPE}} *)(dst_addr + i * sizeof({{DATA_TYPE}})));
-)_";
- }
-
- code += R"_(
- sum1D += data;
- }
-)_";
-
- code += R"_(
- *((__global {{DATA_TYPE}} *)sum.ptr) = SUM_REDUCE(sum1D, N0);
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateLogits1DMaxShiftExpSum::declare_variables(GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "src");
-
- vtable.declare_variable(comp_group, _sum, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "sum");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "dst");
-}
-
-TagLUT ClTemplateLogits1DMaxShiftExpSum::get_tag_lut(const GpuKernelVariableTable &vtable,
- const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["sum"] = vtable.get_variable(_sum);
- lut["dst"] = vtable.get_variable(_dst);
-
- // Local build options
- lut["meta_kernel_id"] = id();
-
- const DataType data_type = _src->data_type();
-
- lut["DATA_TYPE"] = get_cl_type_from_data_type(data_type);
- lut["BETA"] = float_to_string_with_full_precision(_attributes.beta());
- lut["MINVAL"] = (data_type == DataType::F16) ? std::string("-HALF_MAX") : std::string("-FLT_MAX");
- lut["SRC_WIDTH"] = support::cpp11::to_string(_src->dimension(0));
-
- return lut;
-}
-
-CLBuildOptions ClTemplateLogits1DMaxShiftExpSum::get_build_options(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
- CLBuildOptions build_opts{};
-
- const unsigned int reduction_dim_size = _src->dimension(0);
- const unsigned int vector_size = adjust_vec_size(serial_vector_size, reduction_dim_size);
-
- build_opts.add_option("-DN0=" + support::cpp11::to_string(vector_size));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string((reduction_dim_size % vector_size)));
-
- return build_opts;
-}
-
-std::string ClTemplateLogits1DMaxShiftExpSum::get_config_id() const
-{
- std::string config_id = get_name();
-
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(0));
- config_id += "_";
- config_id += string_from_data_type(_src->data_type());
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateLogits1DMaxShiftExpSum::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateLogits1DMaxShiftExpSum::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- Window win = calculate_max_window(*_dst, Steps(_src->dimension(0)));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h
deleted file mode 100644
index ac9ddaa9d..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DMAXSHIFTEXPSUM
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DMAXSHIFTEXPSUM
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DMaxShiftExpSum.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateLogits1DMaxShiftExpSum final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentLogits1DMaxShiftExpSum::Attributes;
-
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateLogits1DMaxShiftExpSum(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateLogits1DMaxShiftExpSum(const ClTemplateLogits1DMaxShiftExpSum &) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateLogits1DMaxShiftExpSum &operator=(const ClTemplateLogits1DMaxShiftExpSum &) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateLogits1DMaxShiftExpSum(ClTemplateLogits1DMaxShiftExpSum &&) = default;
- /** Allow instances of this class to be moved */
- ClTemplateLogits1DMaxShiftExpSum &operator=(ClTemplateLogits1DMaxShiftExpSum &&) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src; // input
- const ITensorInfo *_sum; // exponentiated and summed input
- const ITensorInfo *_dst; // exponentiated input
- Attributes _attributes;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DMAXSHIFTEXPSUM */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp
deleted file mode 100644
index 7d7c3e667..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.cpp
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateLogits1DNorm::ClTemplateLogits1DNorm(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _sum{}, _dst{}, _attributes{attributes}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _sum = this->tensors().get_const_tensor(TensorType::ACL_SRC_1);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_sum);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_dst);
-}
-
-std::string ClTemplateLogits1DNorm::get_name() const
-{
- return "logits_1d_norm";
-}
-
-std::string ClTemplateLogits1DNorm::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-{
- const int x_offs = g_ind_0 * sizeof({{DATA_TYPE}});
- __global uchar *src_addr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes + x_offs + g_ind_1 * {{src}}_stride_y + g_ind_2 * {{src}}_stride_z;
- __global uchar *dst_addr = {{dst}}_ptr + {{dst}}_offset_first_element_in_bytes + x_offs + g_ind_1 * {{dst}}_stride_y + g_ind_2 * {{dst}}_stride_z;
- Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP({{sum}});
-)_";
- // Load max value of 1D logits vector (row)
- code += R"_(
- {{DATA_TYPE}} sum_val = *((__global {{DATA_TYPE}} *)offset(&sum, 0, g_ind_1));
- VEC_DATA_TYPE({{DATA_TYPE}}, N0)
- data0 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)src_addr);
-)_";
-
- if (_attributes.is_log_softmax())
- {
- code += R"_(
- sum_val = log(sum_val);
- data0 -= sum_val;
-)_";
- }
- else
- {
- code += R"_(
- data0 /= sum_val;
-)_";
- }
-
- code += R"_(
- STORE_VECTOR_SELECT(data, {{DATA_TYPE}}, dst_addr, N0, PARTIAL_N0, PARTIAL_N0 != 0 && g_ind_0 == 0);
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateLogits1DNorm::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "src");
-
- vtable.declare_variable(comp_group, _sum, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "sum");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D), "dst");
-}
-
-TagLUT ClTemplateLogits1DNorm::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["sum"] = vtable.get_variable(_sum);
- lut["dst"] = vtable.get_variable(_dst);
-
- // Local build options
- lut["meta_kernel_id"] = id();
-
- const DataType data_type = _src->data_type();
-
- lut["DATA_TYPE"] = get_cl_type_from_data_type(data_type);
-
- return lut;
-}
-
-CLBuildOptions ClTemplateLogits1DNorm::get_build_options(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
- CLBuildOptions build_opts{};
-
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string((_src->dimension(0) % n0)));
-
- return build_opts;
-}
-
-std::string ClTemplateLogits1DNorm::get_config_id() const
-{
- std::string config_id = get_name();
-
- config_id += "_";
- config_id += support::cpp11::to_string(_src->dimension(0));
- config_id += "_";
- config_id += string_from_data_type(_src->data_type());
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateLogits1DNorm::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateLogits1DNorm::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
- constexpr unsigned int serial_vector_size = 16;
- const unsigned int vector_size = adjust_vec_size(serial_vector_size, _src->dimension(0));
-
- Window win = calculate_max_window(*_src, Steps(vector_size));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h
deleted file mode 100644
index 5a74be584..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DNorm.h
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DNORM
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DNORM
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentLogits1DNorm.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateLogits1DNorm final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentLogits1DNorm::Attributes;
-
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateLogits1DNorm(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateLogits1DNorm(const ClTemplateLogits1DNorm &) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateLogits1DNorm &operator=(const ClTemplateLogits1DNorm &) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateLogits1DNorm(ClTemplateLogits1DNorm &&) = default;
- /** Allow instances of this class to be moved */
- ClTemplateLogits1DNorm &operator=(ClTemplateLogits1DNorm &&) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src; // exponentiated input
- const ITensorInfo *_sum; // exponentiated and summed input
- const ITensorInfo *_dst; // normalization of input with _sum
-
- Attributes _attributes;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATELOGITS1DNORM */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp
deleted file mode 100644
index 8936db6ab..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp
+++ /dev/null
@@ -1,470 +0,0 @@
-/*
- * Copyright (c) 2023-2024 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplatePool2d.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "support/StringSupport.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-namespace
-{
-// Shape indexes for NHWC Datalayout
-constexpr static int32_t height_idx = 2;
-constexpr static int32_t width_idx = 1;
-constexpr static int32_t channel_idx = 0;
-} // namespace
-ClTemplatePool2d::ClTemplatePool2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}, _attributes{attributes}, _settings{settings}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
-}
-
-std::string ClTemplatePool2d::get_name() const
-{
- return "pool2d";
-}
-
-std::string ClTemplatePool2d::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- // Condition to use 2x2 optimized kernel
- if (_attributes.pool_size() == Size2D(2, 2))
- {
- return get_2x2_kernel_code();
- }
- else
- {
- return get_MxN_kernel_code();
- }
-}
-
-std::string ClTemplatePool2d::get_MxN_kernel_code() const
-{
- const auto pool_type = _attributes.pool_type();
- const bool fp_mixed_precision = (_src->data_type() == DataType::F16) && pool_type != PoolingType::MAX;
-
- // Define pool op macro.
- std::string pool_op = (pool_type == PoolingType::AVG) ? R"_(#define POOL_OP(x,y) ((x) + (y)))_"
- : R"_(#define POOL_OP(x,y) (fmax((x), (y))) )_";
-
- // Kernel start
- // Note: If C is not multiple of N0, we shift back of PARTIAL_N0 elements to compute the leftover elements for get_global_id(0) == 0
- // Note: If C is less than N0, N0 should be SHRINKED to the closest smaller N0. This operation is performed on the host side
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-// IN_0(src) {{src}}
-// OUT(dst, accum) {{dst}}
-
-{
- const int idx_out_c = g_ind_0;
- const int idx_out_w = g_ind_1;
-)_";
-
- // Add macro for POOL_OP
- code += "\n" + pool_op + "\n";
-
- code += R"_(
- const int idx_out_h = g_ind_2 % {{DST_HEIGHT}};
- const int idx_out_n = g_ind_2 / {{DST_HEIGHT}};
-)_";
-
- // Define common variables.
- code += R"_(
- __global unsigned char *in_base_ptr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes + idx_out_c * sizeof({{DATA_TYPE}}) + idx_out_n * {{src}}_stride_w;
-
- __global unsigned char *out_base_ptr = {{dst}}_ptr + {{dst}}_offset_first_element_in_bytes + idx_out_c * sizeof({{DATA_TYPE}}) + idx_out_w * {{dst}}_stride_y + idx_out_h * {{dst}}_stride_z + idx_out_n * {{dst}}_stride_w;
-
- VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)
- res0 = {{INITIAL_VALUE}};
-
- const int idx_in_w = idx_out_w * {{STRIDE_X}} - {{PAD_X}};
- const int idx_in_h = idx_out_h * {{STRIDE_Y}} - {{PAD_Y}};
-
- const int pool_x_s = max((int)0, -idx_in_w);
- const int pool_x_e = min((int){{POOL_SIZE_X}}, (int){{SRC_WIDTH}} - idx_in_w);
- const int pool_y_s = max((int)0, -idx_in_h);
- const int pool_y_e = min((int){{POOL_SIZE_Y}}, (int){{SRC_HEIGHT}} - idx_in_h);
-)_";
-
- // Determine filter size depending on if padding is excluded or not
- if (_attributes.exclude_padding())
- {
- code += R"_(
- const int filter_size = (pool_y_e - pool_y_s) * (pool_x_e - pool_x_s);
-)_";
- }
- else
- {
- code += R"_(
- const int filter_size = {{POOL_SIZE_X}} * {{POOL_SIZE_Y}};
-)_";
- }
-
- // Loop through pool size
- // if global pooling
- if (_attributes.pool_size().x() == _src->dimension(width_idx) &&
- _attributes.pool_size().y() == _src->dimension(height_idx))
- {
- // Begin loop
- code += R"_(
- // Global pooling path
- for(int y = 0; y < {{POOL_SIZE_Y}}; ++y)
- {
- #pragma unroll 8
- for(int x = 0; x < {{POOL_SIZE_X}}; ++x)
- {
- VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)
- data0;
-)_";
- }
- else // if local pooling size
- {
- code += R"_(
- for(int y = pool_y_s; y < pool_y_e; ++y)
- {
- #pragma unroll 8
- for(int x = pool_x_s; x < pool_x_e; ++x)
- {
- VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)
- data0;
-)_";
- } // end else
-
- // if condition inside loop - use 32bit acc if mixed_precision.
- // End loop through pooling section.
- if (fp_mixed_precision)
- {
- // In case of FP_MIXED_PRECISION, ACC_DATA_TYPE is != DATA_TYPE
- code += R"_(
- data0 = CONVERT(VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + (x + idx_in_w) * {{src}}_stride_y + (y + idx_in_h) * {{src}}_stride_z)), VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0));
- res0 = POOL_OP(res0, data0);
- }
- }
-)_";
- }
- else // load data, compute result and end loop
- {
- code += R"_(
- data0 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + (x + idx_in_w) * {{src}}_stride_y + (y + idx_in_h) * {{src}}_stride_z));
- res0 = POOL_OP(res0, data0);
- }
- }
-)_";
- }
-
- // For Pool AVG ONLY, divide pool output by filter size
- if (pool_type == PoolingType::AVG)
- {
- code += R"_(
- res0 /= (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0))filter_size;
-)_";
- }
-
- // If mixed precision convert datatype before storing. Then end kernel.
- if (fp_mixed_precision)
- {
- code += R"_(
- VEC_DATA_TYPE({{DATA_TYPE}}, N0)
- res_converted0 = CONVERT(res0, VEC_DATA_TYPE({{DATA_TYPE}}, N0));
- STORE_VECTOR_SELECT(res_converted, {{DATA_TYPE}}, out_base_ptr, N0, PARTIAL_N0, (PARTIAL_N0 != 0) && g_ind_0 == 0);
-)_";
- }
- else
- {
- // Store data
- code += R"_(
- STORE_VECTOR_SELECT(res, {{DATA_TYPE}}, out_base_ptr, N0, PARTIAL_N0, (PARTIAL_N0 != 0) && g_ind_0 == 0);
-)_";
- }
-
- code += R"_(
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-}
-)_";
-
- return code;
-}
-
-std::string ClTemplatePool2d::get_2x2_kernel_code() const
-{
- const auto pool_type = _attributes.pool_type();
- const bool fp_mixed_precision = (_src->data_type() == DataType::F16) && pool_type != PoolingType::MAX;
- std::string pool_op = (pool_type == PoolingType::AVG) ? R"_(#define POOL_OP(x,y) ((x) + (y)))_"
- : R"_(#define POOL_OP(x,y) (fmax((x), (y))) )_";
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-// IN_0(src) {{src}}
-// OUT(dst, accum) {{dst}}
-
-#define SELECT_TYPE SELECT_VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)
-
-{
- const int idx_out_c = g_ind_0;
- const int idx_out_w = g_ind_1;
-)_";
-
- // Add pool op macro
- code += "\n" + pool_op + "\n";
-
- // If batch size != 1, the batch size dimension is collapsed over the height dimension
- code += R"_(
- const int idx_out_h = g_ind_2 % {{DST_HEIGHT}};
- const int idx_out_n = g_ind_2 / {{DST_HEIGHT}};
-)_";
-
- code += R"_(
- const int idx_in_w = idx_out_w * {{STRIDE_X}} - {{PAD_X}};
- const int idx_in_h = idx_out_h * {{STRIDE_Y}} - {{PAD_Y}};
-
- __global unsigned char *in_base_ptr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes + idx_out_c * sizeof({{DATA_TYPE}}) + idx_out_n * {{src}}_stride_w;
- __global unsigned char *out_base_ptr = {{dst}}_ptr + {{dst}}_offset_first_element_in_bytes + idx_out_c * sizeof({{DATA_TYPE}}) + idx_out_w * {{dst}}_stride_y + idx_out_h * {{dst}}_stride_z + idx_out_n *
- {{dst}}_stride_w;
- const int pool_x_s = max((int)0, -idx_in_w);
- const int pool_x_e = min((int)2, (int){{SRC_WIDTH}} - idx_in_w);
- const int pool_y_s = max((int)0, -idx_in_h);
- const int pool_y_e = min((int)2, (int){{SRC_HEIGHT}} - idx_in_h);
-
- const int filter_size = (pool_x_e - pool_x_s) * (pool_y_e - pool_y_s);
- const int x0 = pool_x_s + idx_in_w;
- const int y0 = pool_y_s + idx_in_h;
- const int x1 = pool_x_e - 1 + idx_in_w;
- const int y1 = pool_y_e - 1 + idx_in_h;
-
- REPEAT_VAR_INIT_TO_CONST(4, VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0), data, 0);
-)_";
-
- if (fp_mixed_precision)
- {
- // In case of FP_MIXED_PRECISION, ACC_DATA_TYPE is != DATA_TYPE
- code += R"_(
- data0 = CONVERT(VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x0 * {{src}}_stride_y + y0 * {{src}}_stride_z)), VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0));
- data1 = CONVERT(VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x1 * {{src}}_stride_y + y0 * {{src}}_stride_z)), VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0));
- data2 = CONVERT(VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x0 * {{src}}_stride_y + y1 * {{src}}_stride_z)), VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0));
- data3 = CONVERT(VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x1 * {{src}}_stride_y + y1 * {{src}}_stride_z)), VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0));
-)_";
- }
- else
- {
- code += R"_(
- data0 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x0 * {{src}}_stride_y + y0 * {{src}}_stride_z));
- data1 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x1 * {{src}}_stride_y + y0 * {{src}}_stride_z));
- data2 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x0 * {{src}}_stride_y + y1 * {{src}}_stride_z));
- data3 = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(in_base_ptr + x1 * {{src}}_stride_y + y1 * {{src}}_stride_z));
-)_";
- }
-
- if (pool_type != PoolingType::MAX)
- {
- // Make invalid the values loaded if the x or y coordinate was clamped (out-of-bound)
- code += R"_(
- if(filter_size != 4)
- {
- SELECT_TYPE cond_w_s = (SELECT_TYPE)idx_in_w < (SELECT_TYPE)0;
- SELECT_TYPE cond_w_e = (SELECT_TYPE)idx_in_w >= (SELECT_TYPE)({{SRC_WIDTH}} - 1);
- SELECT_TYPE cond_h_s = (SELECT_TYPE)idx_in_h < (SELECT_TYPE)0;
- SELECT_TYPE cond_h_e = (SELECT_TYPE)idx_in_h >= (SELECT_TYPE)({{SRC_HEIGHT}} - 1);
-
- data0 = select(data0, (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)){{INITIAL_VALUE}}, (SELECT_TYPE)(cond_w_s | cond_h_s));
- data1 = select(data1, (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)){{INITIAL_VALUE}}, (SELECT_TYPE)(cond_w_e | cond_h_s));
- data2 = select(data2, (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)){{INITIAL_VALUE}}, (SELECT_TYPE)(cond_w_s | cond_h_e));
- data3 = select(data3, (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)){{INITIAL_VALUE}}, (SELECT_TYPE)(cond_w_e | cond_h_e));
- }
-)_";
- }
-
- code += R"_(
- VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0)
- res0 = data0;
- res0 = POOL_OP(res0, data1);
- res0 = POOL_OP(res0, data2);
- res0 = POOL_OP(res0, data3);
-)_";
-
- if (pool_type == PoolingType::AVG)
- {
- // If avg pooling divide result accordingly.
- if (_attributes.exclude_padding())
- {
- code += R"_(
- res0 /= (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0))filter_size;
-)_";
- }
- else
- {
- code += R"_(
- res0 /= (VEC_DATA_TYPE({{ACC_DATA_TYPE}}, N0))4;
-)_";
- }
- }
-
- // Store result
- if (fp_mixed_precision)
- {
- code += R"_(
- VEC_DATA_TYPE({{DATA_TYPE}}, N0)
- res_converted0 = CONVERT(res0, VEC_DATA_TYPE({{DATA_TYPE}}, N0));
- STORE_VECTOR_SELECT(res_converted, {{DATA_TYPE}}, out_base_ptr, N0, PARTIAL_N0, (PARTIAL_N0 != 0) && g_ind_0 == 0);
-)_";
- }
- else
- {
- code += R"_(
- STORE_VECTOR_SELECT(res, {{DATA_TYPE}}, out_base_ptr, N0, PARTIAL_N0, (PARTIAL_N0 != 0) && g_ind_0 == 0);
-)_";
- }
-
- code += R"_(
- //------------------ END KERNEL {{meta_kernel_id}} ---------------------
-}
-#undef SELECT_TYPE
-)_";
-
- return code;
-}
-
-void ClTemplatePool2d::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplatePool2d::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- TagLUT lut{};
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
-
- // Local build options
- lut["meta_kernel_id"] = id();
-
- // Retrieve relevant data
- const auto padding = _attributes.pad();
- const auto stride = _attributes.stride();
- const auto pool_size = _attributes.pool_size();
- const auto data_type = _src->data_type();
- const auto use_fp_mixed_precision =
- (_src->data_type() == DataType::F16) && _attributes.pool_type() != PoolingType::MAX;
- const std::string max_initial_value =
- _settings.use_inf_as_limit() ? "(-INFINITY)"
- : float_to_string_with_full_precision(std::numeric_limits<float>::lowest());
-
- // pool specific
- lut["STRIDE_X"] = stride.x();
- lut["STRIDE_Y"] = stride.y();
- lut["PAD_X"] = padding.left;
- lut["PAD_Y"] = padding.top;
- lut["POOL_SIZE_X"] = pool_size.width;
- lut["POOL_SIZE_Y"] = pool_size.height;
-
- // Datatypes and variables
- lut["ACC_DATA_TYPE"] = get_cl_type_from_data_type(
- (use_fp_mixed_precision) ? (DataType::F32) : (data_type)); // Type of accumulators to use.
- lut["DATA_TYPE"] = get_cl_type_from_data_type(data_type);
- lut["SRC_WIDTH"] = _src->dimension(width_idx);
- lut["SRC_HEIGHT"] = _src->dimension(height_idx);
- lut["INITIAL_VALUE"] = (_attributes.pool_type() == PoolingType::MAX) ? max_initial_value : std::string("0");
-
- // Tensor specific data
- lut["DST_HEIGHT"] = _dst->dimension(height_idx);
-
- return lut;
-}
-
-CLBuildOptions ClTemplatePool2d::get_build_options(const ComponentGroup &comp_group) const
-{
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
-
- CLBuildOptions build_opts{};
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplatePool2d::get_config_id() const
-{
- const DataType data_type = _src->data_type();
- const DataLayout data_layout = _src->data_layout();
-
- std::string config_id{};
- config_id += "pooling_layer_2d_";
- config_id += lower_string(string_from_data_type(data_type));
- config_id += "_";
- config_id += lower_string(string_from_data_layout(data_layout));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(width_idx));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(height_idx));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(channel_idx));
-
- return config_id;
-}
-
-std::set<std::string> ClTemplatePool2d::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h", "repeat.h"};
-}
-
-Window ClTemplatePool2d::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
- const auto output_shape = _dst->tensor_shape();
- const unsigned int vec_size = adjust_vec_size(((_dst->data_type() == DataType::F32) ? 2 : 4), _dst->dimension(0));
-
- // Create and configure kernel window
- auto win = calculate_max_window(output_shape, Steps(vec_size));
- win = win.collapse_if_possible(win, Window::DimZ); // collapse window on batch size.
- return win;
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.h
deleted file mode 100644
index d1d3c0166..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.h
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEPOOL2D
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEPOOL2D
-
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/dynamic_fusion/sketch/attributes/Pool2dAttributes.h"
-#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentPool2d.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplatePool2d final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentPool2d::Attributes;
- using Settings = ClComponentPool2d::Settings;
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- * @param[in] settings Component settings
- */
- ClTemplatePool2d(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes,
- const Settings &settings);
-
- /** Prevent instances of this class from being copy constructed */
- ClTemplatePool2d(const ClTemplatePool2d &direct_conv2d) = delete;
-
- /** Prevent instances of this class from being copied */
- ClTemplatePool2d &operator=(const ClTemplatePool2d &direct_conv2d) = delete;
-
- /** Allow instances of this class to be move constructed */
- ClTemplatePool2d(ClTemplatePool2d &&direct_conv2d) = default;
-
- /** Allow instances of this class to be moved */
- ClTemplatePool2d &operator=(ClTemplatePool2d &&direct_conv2d) = default;
-
- /** Generate kernel component name */
- std::string get_name() const override;
-
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
-
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
-
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
-
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- /** Generate pooling kernel template code optimized for 2x2 pooling
- *
- * @return std::String Component code
- */
- std::string get_2x2_kernel_code() const;
-
- /** Generate generalised pooling kernel template code for MxN pooling
- *
- * @return std::String Component code
- */
- std::string get_MxN_kernel_code() const;
-
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
- Attributes _attributes;
- Settings _settings;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEPOOL2D */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp
deleted file mode 100644
index c882353fc..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.cpp
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateReshape.h"
-
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-constexpr unsigned int vector_size_byte_opencl = 16;
-
-ClTemplateReshape::ClTemplateReshape(ComponentId id, const ArgumentPack<ITensorInfo> &tensors)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
-}
-
-std::string ClTemplateReshape::get_name() const
-{
- return "reshape";
-}
-
-std::string ClTemplateReshape::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
- std::string code;
-
- code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-
-// IN(src) {{src}}
-// OUT(dst, accum) {{dst}}
-
-TILE(uint, M0, 1, g_dst_indirect_y);
-{
- __global uchar * base_src_ptr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes;
- const int tile_vertical_idx = g_ind_1 * {{arg_dst}}_c + g_ind_2 * {{arg_dst}}_c * {{arg_dst}}_w;
- LOOP_UNROLLING(int, _m0, 0, 1, M0,
- {
- const int row_idx = _m0 * {{arg_dst}}_c + tile_vertical_idx;
- const int tile_horizontal_idx = g_ind_0 + row_idx;
- LOOP_UNROLLING(int, _n0, 0, 1, N0,
- {
- {{src}}_ptr = base_src_ptr;
- const int linear_idx = tile_horizontal_idx + _n0;
- const int in_id_x = linear_idx % {{src}}_c;
- const int in_id_y = (linear_idx / {{src}}_c) % {{src}}_w;
- const int in_id_z = linear_idx / ({{src}}_c * {{src}}_w);
- {{src}}_ptr += in_id_x * sizeof({{DATA_TYPE}}) + in_id_y * {{src}}_stride_y + in_id_z * {{src}}_stride_z;
- {{dst}}[_m0].s[_n0] = *((__global {{DATA_TYPE}} *){{src}}_ptr);
- })
- })
-
- LOOP_UNROLLING(int, i, 0, 1, M0,
- {
- g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
- g_dst_indirect_y[i].v += (int)(g_ind_2 % {{arg_dst}}_h) * (int)({{arg_dst}}_w);
- g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
- })
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
- return code;
-}
-
-void ClTemplateReshape::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src,
- GpuKernelArgumentInfo(common_tensor_type), // GpuKernelArgumentInfo::Type::Image_3D
- "src");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(common_tensor_type), "dst");
-}
-
-TagLUT ClTemplateReshape::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
- lut["arg_dst"] = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["meta_kernel_id"] = id();
- lut["DATA_TYPE"] = get_cl_type_from_data_type(_dst->data_type());
-
- return lut;
-}
-
-CLBuildOptions ClTemplateReshape::get_build_options(const ComponentGroup &comp_group) const
-{
- CLBuildOptions build_opts{};
- const auto root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
- const unsigned int partial_store_n0 = _dst->dimension(0) % n0;
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateReshape::get_config_id() const
-{
- std::string config_id{};
- config_id += lower_string(string_from_data_type(_dst->data_type()));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(1));
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateReshape::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateReshape::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
- const unsigned int n0 = adjust_vec_size(vector_size_byte_opencl / _dst->element_size(), _dst->dimension(0));
- Window win = calculate_max_window(*_dst, Steps(n0));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.h
deleted file mode 100644
index 838a21db6..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateReshape.h
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESHAPE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESHAPE
-
-#include "arm_compute/core/experimental/Types.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentReshape.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateReshape final : public IGpuTemplateComponentWriter
-{
-public:
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- */
- ClTemplateReshape(ComponentId id, const ArgumentPack<ITensorInfo> &tensors);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateReshape(const ClTemplateReshape &reshape) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateReshape &operator=(const ClTemplateReshape &reshape) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateReshape(ClTemplateReshape &&reshape) = default;
- /** Allow instances of this class to be moved */
- ClTemplateReshape &operator=(ClTemplateReshape &&reshape) = default;
-
- /** Generate kernel component name */
- std::string get_name() const override;
-
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
-
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
-
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
-
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
-
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESHAPE */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp
deleted file mode 100644
index 846c712ce..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp
+++ /dev/null
@@ -1,279 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "ClTemplateResize.h"
-
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
-#include "arm_compute/core/utils/StringUtils.h"
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/core/utils/ScaleUtils.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateResize::ClTemplateResize(ComponentId id,
- const ArgumentPack<ITensorInfo> &tensors,
- const ClTemplateResize::Attributes &attributes)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}, _attributes{attributes}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
-
- ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
-}
-
-std::string ClTemplateResize::get_name() const
-{
- return _attributes.interpolation_policy() == InterpolationPolicy::BILINEAR ? "resize_bilinear" : "resize_nearest";
-}
-
-std::string ClTemplateResize::get_component_code(const IGpuTemplateComponentWriter::ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
-TILE(uint, 1, 1, g_dst_indirect_y);
-{
- const int yo = g_ind_2 % {{arg_dst}}_h;
- const int bout = g_ind_2 / {{arg_dst}}_h;
-)_";
-
- if (_attributes.interpolation_policy() == InterpolationPolicy::NEAREST_NEIGHBOR)
- {
- if (_attributes.sampling_policy() == SamplingPolicy::TOP_LEFT)
- {
- code += R"_(
- float xi_f = (g_ind_1 * {{SCALE_X}});
- float yi_f = (yo * {{SCALE_Y}});
-)_";
- }
- else
- {
- code += R"_(
- float xi_f = ((g_ind_1 + 0.5f) * {{SCALE_X}});
- float yi_f = ((yo + 0.5f) * {{SCALE_Y}});
-)_";
- }
-
- if (_attributes.align_corners())
- {
- code += R"_(
- xi_f = round(xi_f);
- yi_f = round(yi_f);
-)_";
- }
-
- code += R"_(
- const int xi0 = clamp((int)xi_f, 0, (int){{src}}_w - 1);
- const int yi0 = clamp((int)yi_f, 0, (int){{src}}_h - 1);
-
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, 1, N0, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi0, xi0, g_ind_0, {{src}}_w, {{src}}_h, 1, 1, false, {{dst}});
-)_";
- }
- else if (_attributes.interpolation_policy() == InterpolationPolicy::BILINEAR)
- {
- if (_attributes.sampling_policy() == SamplingPolicy::TOP_LEFT)
- {
- code += R"_(
- float xi_f = (g_ind_1 * {{SCALE_X}});
- float yi_f = (yo * {{SCALE_Y}});
-)_";
- }
- else
- {
- code += R"_(
- float xi_f = ((g_ind_1 + 0.5f) * {{SCALE_X}} - 0.5f);
- float yi_f = ((yo + 0.5f) * {{SCALE_Y}} - 0.5f);
-)_";
- }
-
- code += R"_(
- const int xi = (int)floor(xi_f);
- const int yi = (int)floor(yi_f);
-
- TILE({{SRC_DATA_TYPE}}, 1, N0, in00);
- TILE({{SRC_DATA_TYPE}}, 1, N0, in01);
- TILE({{SRC_DATA_TYPE}}, 1, N0, in10);
- TILE({{SRC_DATA_TYPE}}, 1, N0, in11);
-
- in00[0].v = {{CONSTANT_VALUE}};
- in01[0].v = {{CONSTANT_VALUE}};
- in10[0].v = {{CONSTANT_VALUE}};
- in11[0].v = {{CONSTANT_VALUE}};
-
- const int xi0 = clamp(xi, 0, (int){{src}}_w - 1);
- const int yi0 = clamp(yi, 0, (int){{src}}_h - 1);
- const int xi1 = clamp(xi + 1, 0, (int){{src}}_w - 1);
- const int yi1 = clamp(yi + 1, 0, (int){{src}}_h - 1);
-
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, 1, N0, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi0, xi0, g_ind_0, {{src}}_w, {{src}}_h, 1, 1, false, in00);
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, 1, N0, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi0, xi1, g_ind_0, {{src}}_w, {{src}}_h, 1, 1, false, in01);
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, 1, N0, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi1, xi0, g_ind_0, {{src}}_w, {{src}}_h, 1, 1, false, in10);
- T_LOAD_NHWC_WITH_DILATION({{SRC_DATA_TYPE}}, 1, 1, N0, {{SRC_TENSOR_TYPE}}, {{src}}, bout, yi1, xi1, g_ind_0, {{src}}_w, {{src}}_h, 1, 1, false, in11);
-)_";
-
- if (is_data_type_float(_src->data_type()))
- {
- code += R"_(
- const {{SRC_DATA_TYPE}} a = ({{SRC_DATA_TYPE}})(xi_f - (float)xi);
- const {{SRC_DATA_TYPE}} b = ({{SRC_DATA_TYPE}})(1.f - a);
- const {{SRC_DATA_TYPE}} a1 = ({{SRC_DATA_TYPE}})(yi_f - (float)yi);
- const {{SRC_DATA_TYPE}} b1 = ({{SRC_DATA_TYPE}})(1.f - a1);
-
- // Calculate the output
- {{dst}}[0].v = ((in00[0].v * b * b1) + (in01[0].v * a * b1) + (in10[0].v * b * a1) + (in11[0].v * a * a1));
-)_";
- }
- else
- {
- code += R"_(
- const float a = (xi_f - (float)xi);
- const float b = (1.f - a);
- const float a1 = (yi_f - (float)yi);
- const float b1 = (1.f - a1);
-
- {{dst}}[0].v = CONVERT_SAT(
- (CONVERT(in00[0].v, VEC_DATA_TYPE(float, N0)) * b * b1) +
- (CONVERT(in01[0].v, VEC_DATA_TYPE(float, N0)) * a * b1) +
- (CONVERT(in10[0].v, VEC_DATA_TYPE(float, N0)) * b * a1) +
- (CONVERT(in11[0].v, VEC_DATA_TYPE(float, N0)) * a * a1), VEC_DATA_TYPE({{DST_DATA_TYPE}}, N0));
-)_";
- }
- }
- else
- {
- ARM_COMPUTE_ERROR("Unsupported interpolation policy");
- }
-
- code += R"_(
- g_dst_indirect_y[0].v = g_ind_1 + (yo * (int)({{arg_dst}}_w)) + bout * (int)({{arg_dst}}_w * {{arg_dst}}_h);
-}
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
-)_";
-
- return code;
-}
-
-void ClTemplateResize::declare_variables(GpuKernelVariableTable &vtable,
- const IGpuTemplateComponentWriter::ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
-
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplateResize::get_tag_lut(const GpuKernelVariableTable &vtable,
- const IGpuTemplateComponentWriter::ComponentGroup &comp_group) const
-{
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
-
- const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
- lut["arg_dst"] = dst_argument.uniq_name;
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["SRC_DATA_TYPE"] = get_cl_type_from_data_type(_src->data_type());
- lut["SRC_TENSOR_TYPE"] = "BUFFER";
- lut["DST_DATA_TYPE"] = get_cl_type_from_data_type(_dst->data_type());
- lut["CONSTANT_VALUE"] = string_from_pixel_value(0, _src->data_type());
-
- const float scale_x =
- scale_utils::calculate_resize_ratio(_src->dimension(1), _dst->dimension(1), _attributes.align_corners());
- const float scale_y =
- scale_utils::calculate_resize_ratio(_src->dimension(2), _dst->dimension(2), _attributes.align_corners());
-
- lut["SCALE_X"] = float_to_string_with_full_precision(scale_x);
- lut["SCALE_Y"] = float_to_string_with_full_precision(scale_y);
-
- return lut;
-}
-
-CLBuildOptions ClTemplateResize::get_build_options(const IGpuTemplateComponentWriter::ComponentGroup &comp_group) const
-{
- const Window root_window = comp_group.get_root_component()->template_writer()->get_window();
- const unsigned int n0 = root_window.x().step();
- const unsigned int m0 = root_window.y().step();
- const unsigned int partial_n0 = _dst->dimension(0) % n0;
-
- CLBuildOptions build_opts;
-
- build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
- build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0));
-
- return build_opts;
-}
-
-std::string ClTemplateResize::get_config_id() const
-{
- std::string config_id{};
-
- config_id += "resize_";
- config_id +=
- (_attributes.interpolation_policy() == InterpolationPolicy::NEAREST_NEIGHBOR ? "NEAREST_NEIGHBOR" : "");
- config_id += (_attributes.interpolation_policy() == InterpolationPolicy::BILINEAR ? "BILINEAR" : "");
- config_id += "_";
- config_id += (_attributes.sampling_policy() == SamplingPolicy::CENTER ? "center" : "topleft");
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(0));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(1));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(2));
- config_id += "_";
- config_id += support::cpp11::to_string(_dst->dimension(3));
-
- return config_id;
-}
-
-std::set<std::string> ClTemplateResize::get_headers_list() const
-{
- return std::set<std::string>{"helpers.h", "tile_helpers.h"};
-}
-
-Window ClTemplateResize::get_window() const
-{
- ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
-
- const unsigned int n0 = adjust_vec_size(16 / _src->element_size(), _src->dimension(0));
- Window win = calculate_max_window(*_dst, Steps(n0));
- return win.collapse(win, Window::DimZ);
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.h
deleted file mode 100644
index 4c6900718..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.h
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESIZE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESIZE
-
-#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentResize.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateResize final : public IGpuTemplateComponentWriter
-{
-public:
- using Attributes = ClComponentResize::Attributes;
-
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- * @param[in] attributes Component attributes
- */
- ClTemplateResize(ComponentId id, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes);
-
- /** Destructor */
- ~ClTemplateResize() override = default;
-
- /** Prevent instances of this class from being copy constructed */
- ClTemplateResize(const ClTemplateResize &resize) = delete;
-
- /** Prevent instances of this class from being copied */
- ClTemplateResize &operator=(const ClTemplateResize &resize) = delete;
-
- /** Allow instances of this class to be move constructed */
- ClTemplateResize(ClTemplateResize &&resize) = default;
-
- /** Allow instances of this class to be moved */
- ClTemplateResize &operator=(ClTemplateResize &&resize) = default;
-
- /** Generate kernel component name */
- std::string get_name() const override;
-
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
-
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
- /** Generate the build options used in the component
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return CLBuildOptions Build options
- */
- CLBuildOptions get_build_options(const ComponentGroup &comp_group) const override;
-
- /** Generate the component config id string used for tuning */
- std::string get_config_id() const override;
-
- /** Generate the header list used in the component */
- std::set<std::string> get_headers_list() const override;
-
- /** Generate the execution window for the component */
- Window get_window() const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
- Attributes _attributes;
-};
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATERESIZE */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp
deleted file mode 100644
index d0ec91e0a..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateStore.h"
-
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-ClTemplateStore::ClTemplateStore(ComponentId id, const ArgumentPack<ITensorInfo> &tensors)
- : IGpuTemplateComponentWriter{id, tensors}, _src{}, _dst{}
-{
- _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
- _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
-}
-
-std::string ClTemplateStore::get_name() const
-{
- return "store";
-}
-
-std::string ClTemplateStore::get_component_code(const ComponentGroup &comp_group) const
-{
- ARM_COMPUTE_UNUSED(comp_group);
-
- return R"_(
-//------------------ START KERNEL {{meta_kernel_id}} STORE ---------------------
-{
- bool x_cond = PARTIAL_N0 != 0 && get_global_id(0) == 0;
-
- T_STORE_INDIRECT_WIDTH_SELECT({{DST_DATA_TYPE}}, M0, N0, PARTIAL_N0, {{DST_TENSOR_TYPE}}, {{dst}}, g_ind_0, {{dst}}_stride_y, x_cond, {{src}}, g_dst_indirect_y);
-//------------------ END KERNEL {{meta_kernel_id}} STORE ---------------------
-}
-
-)_";
-}
-
-void ClTemplateStore::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- vtable.declare_variable(comp_group, _src, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "src");
- vtable.declare_variable(comp_group, _dst, GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
- "dst");
-}
-
-TagLUT ClTemplateStore::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
-{
- TagLUT lut{};
-
- // Arguments and global shared variables
- lut["src"] = vtable.get_variable(_src);
- lut["dst"] = vtable.get_variable(_dst);
-
- // Local build options
- lut["meta_kernel_id"] = id();
- lut["DST_TENSOR_TYPE"] = "BUFFER";
- lut["DST_DATA_TYPE"] = _dst->data_type();
-
- ARM_COMPUTE_UNUSED(comp_group);
- return lut;
-}
-
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.h
deleted file mode 100644
index b8c82cead..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATESTORE
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATESTORE
-
-#include "arm_compute/core/experimental/Types.h"
-
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-class ClTemplateStore final : public IGpuTemplateComponentWriter
-{
-public:
- /** Constructor
- *
- * @param[in] id Component id
- * @param[in] tensors Tensor arguments to the components
- */
- ClTemplateStore(ComponentId id, const ArgumentPack<ITensorInfo> &tensors);
- /** Prevent instances of this class from being copy constructed */
- ClTemplateStore(const ClTemplateStore &store) = delete;
- /** Prevent instances of this class from being copied */
- ClTemplateStore &operator=(const ClTemplateStore &store) = delete;
- /** Allow instances of this class to be move constructed */
- ClTemplateStore(ClTemplateStore &&store) = default;
- /** Allow instances of this class to be moved */
- ClTemplateStore &operator=(ClTemplateStore &&store) = default;
- /** Generate kernel component name */
- std::string get_name() const override;
- /** Generate kernel component code template
- *
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return std::string Component code
- */
- std::string get_component_code(const ComponentGroup &comp_group) const override;
- /** Declare all variables used by the component in the @p vtable
- *
- * @param[out] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- */
- void declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
- /** Generate the tag look-up table used to instantiate the component code.
- *
- * @param[in] vtable Variable table
- * @param[in] comp_group Component group of which the component is a part of
- *
- * @return TagLUT Tag lookup table
- */
- TagLUT get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const override;
-
-private:
- const ITensorInfo *_src;
- const ITensorInfo *_dst;
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATESTORE */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
deleted file mode 100644
index d3d7c8db8..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
+++ /dev/null
@@ -1,325 +0,0 @@
-/*
- * Copyright (c) 2022-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "ClTemplateWriter.h"
-
-#include "arm_compute/core/CL/CLKernelLibrary.h"
-
-#include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/IGpuTemplateComponentWriter.h"
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-/// @note: some tags can be unused since they could be used only for the macros, or only for the component code
-std::string ClTemplateWriter::replace_tags(const std::string &code_template, const TagLUT &tags)
-{
- std::string replaced_code = "";
- bool scanning_pattern = false;
- std::string pattern_found = "";
- for (size_t i = 0; i < code_template.size() - 1; ++i)
- {
- if (!scanning_pattern)
- {
- if (code_template[i] == '{' && code_template[i + 1] == '{')
- {
- i += 1;
- scanning_pattern = true;
- pattern_found = "";
- }
- else
- {
- replaced_code += code_template[i];
- }
- }
- else
- {
- if (code_template[i] == '}' && code_template[i + 1] == '}')
- {
- i += 1;
- scanning_pattern = false;
- std::string err = "Pattern " + pattern_found + " not found in tags";
- ARM_COMPUTE_ERROR_ON_MSG(tags.find(pattern_found) == tags.end(), err.c_str());
- replaced_code += tags.find(pattern_found)->second.value;
- }
- else
- {
- pattern_found += code_template[i];
- }
- }
- }
-
- return replaced_code;
-}
-ClTemplateWriter::~ClTemplateWriter()
-{
-}
-ClTemplateWriter::ClTemplateWriter(const GpuKernelComponentGroup &components) : _components{components}
-{
-}
-std::string ClTemplateWriter::get_name()
-{
- return write_kernel_name();
-}
-std::string ClTemplateWriter::get_code()
-{
- return write_code();
-}
-std::string ClTemplateWriter::get_config_id()
-{
- std::string config_id = get_name();
- for (const auto &comp : _components)
- {
- config_id += "--" + comp->template_writer()->get_config_id() + "--";
- }
-
- return config_id;
-}
-
-CLBuildOptions ClTemplateWriter::get_build_options()
-{
- CLBuildOptions build_opts{};
-
- for (const auto &comp : _components)
- {
- build_opts.add_options(comp->template_writer()->get_build_options(_components).options());
- }
-
- return build_opts;
-}
-
-Window ClTemplateWriter::get_window() const
-{
- const auto root_comp = _components.get_root_component();
- ARM_COMPUTE_ERROR_ON_MSG(root_comp == nullptr, "No root component found");
- return root_comp->template_writer()->get_window();
-}
-
-std::map<ITensorInfo::Id, GpuKernelArgument> ClTemplateWriter::get_tensors()
-{
- // Assemble GpuKernelArguments
- std::map<ITensorInfo::Id, GpuKernelArgument> tensors;
- for (const auto t : _components.get_argument_tensors())
- {
- tensors.emplace(t->id(), GpuKernelArgument{*t, _vtable.get_variable(t).kernel_argument_info});
- }
- return tensors;
-}
-
-std::string ClTemplateWriter::write_code()
-{
- ARM_COMPUTE_ERROR_ON_MSG(_components.empty(), "No components found");
-
- // These data structures will hold the data from all the components in the blueprint
- std::set<std::string> headers_list{};
- std::set<std::string> additional_macros{};
- std::vector<std::string> component_codes{}; // vector because order matters
-
- // Pass 1: Declare all kernel variables
- for (auto &component : _components)
- {
- component->template_writer()->declare_variables(_vtable, _components);
- }
- // Pass 2: Generate component codes
- for (auto &component : _components)
- {
- const auto component_writer = component->template_writer();
- auto curr_headers_list = component_writer->get_headers_list();
- auto curr_additional_macros = component_writer->get_additional_macros();
- auto curr_component_code = component_writer->get_component_code(_components);
- const auto var_lut = component_writer->get_tag_lut(
- _vtable,
- _components); // Ideally can be merged with get_component_code once we have finer-grained code generation technique
- component_codes.push_back(replace_tags(curr_component_code, var_lut));
-
- headers_list.insert(curr_headers_list.begin(), curr_headers_list.end());
- if (!additional_macros.empty()) // Some components might not have any
- {
- additional_macros.insert(replace_tags(curr_additional_macros, var_lut));
- }
- }
-
- // Step 3: Assemble the data gathered by traversing the graph into the string "code"
- std::string code = "";
-
- for (auto &header : headers_list)
- {
-#if defined(EMBEDDED_KERNELS)
- code += CLKernelLibrary::get().get_program(header).first;
-#else // defined(EMBEDDED_KERNELS)
- code += "#include \"" + header + "\"\n";
-#endif // defined(EMBEDDED_KERNELS)
- }
-
- for (auto &macros : additional_macros)
- {
- code += macros;
- }
-
- auto arguments = _components.get_argument_tensors();
- std::sort(arguments.begin(), arguments.end(),
- [](const ITensorInfo *l, const ITensorInfo *r) { return l->id() < r->id(); });
- code += write_kernel_signature(_vtable.get_variable_list(arguments));
-
- code += "\n{\n\n";
-
- code += " //------------------ START KERNEL_BUILDER_COORDINATE ---------------------\n\n";
- code += write_global_section();
- code += " //------------------ END KERNEL_BUILDER_COORDINATE ---------------------\n";
-
- {
- const auto tiles = _components.get_tiles();
- std::stringstream tiles_ss;
-
- tiles_ss << " //------------------ START TILE DECLARATION ---------------------\n";
-
- for (auto tile : tiles)
- {
- const auto var = _vtable.get_variable(tile);
- const auto data_type = get_cl_type_from_data_type(tile->data_type());
- const auto var_name = var.uniq_name;
-
- tiles_ss << " TILE(" << data_type << ", M0, N0, " << var_name << ");\n";
- }
-
- tiles_ss << " //------------------ END TILE DECLARATION ---------------------\n";
-
- code += tiles_ss.str();
- }
-
- for (const auto &component_code : component_codes)
- {
- code += component_code;
- code += "\n";
- }
-
- code += "}\n";
-
- return code;
-}
-std::string ClTemplateWriter::write_global_section() const
-{
- const auto dst_info = _components.get_any_dst_tensor();
- const auto dst_w = dst_info->dimension(0);
- const auto tile_w = std::max(1, get_window().x().step());
- const auto tile_h = std::max(1, get_window().y().step());
- auto leftover_w = dst_w % tile_w;
-
- std::string code = "";
- code += std::string(" int g_ind_0 = GET_SPATIAL_IDX(0, ") + std::to_string(tile_w) + ", " +
- std::to_string(leftover_w) + ");\n";
- code += std::string(" int g_ind_1 = GET_SPATIAL_IDX(1, ") + std::to_string(tile_h) + ", " + "0);\n";
- code += std::string(" int g_ind_2 = GET_SPATIAL_IDX(2, 1, 0);\n\n");
-
- code += " const bool g_cond_x = (g_ind_0 == 0);\n";
- code += " const bool g_cond_y = (g_ind_1 == 0);\n";
-
- return code;
-}
-std::string ClTemplateWriter::write_argument_declaration(const GpuKernelVariableTable::TensorVariable &var) const
-{
- std::string code;
- switch (var.kernel_argument_info.type)
- {
- case GpuKernelArgumentInfo::Type::Vector:
- {
- code += "\n VECTOR_DECLARATION(" + var.uniq_name + ")";
- break;
- }
- case GpuKernelArgumentInfo::Type::Image:
- {
- code += "\n IMAGE_DECLARATION(" + var.uniq_name + ")";
- break;
- }
- case GpuKernelArgumentInfo::Type::Image_3D:
- {
- code += "\n IMAGE_DECLARATION(" + var.uniq_name + "),";
- code += "\n unsigned int " + var.uniq_name + "_stride_z";
- break;
- }
- case GpuKernelArgumentInfo::Type::Image_3D_Export_To_ClImage2D:
- {
- code += "\n __read_only image2d_t " + var.uniq_name + "_img,";
- code += "\n unsigned int " + var.uniq_name + "_stride_z";
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer:
- {
- code += "\n TENSOR4D_T(" + var.uniq_name + ", BUFFER)";
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_4D_t_Image:
- {
- code += "\n TENSOR4D_T(" + var.uniq_name + ", IMAGE)";
- break;
- }
- case GpuKernelArgumentInfo::Type::Tensor_3D:
- {
- code += "\n TENSOR3D_DECLARATION(" + var.uniq_name + ")";
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Unsupported declaration generation for GpuKernelArgumentInfo::Type");
- }
- }
- return code;
-}
-std::string ClTemplateWriter::write_kernel_signature(const GpuKernelVariableTable::VariableList &argument_list) const
-{
- std::string code = "\n__kernel void " + write_kernel_name() + "(";
-
- for (int i = 0; i < static_cast<int>(argument_list.size()) - 1; ++i)
- {
- code += write_argument_declaration(argument_list[i]) + ",";
- }
- if (static_cast<int>(argument_list.size()) - 1 >= 0)
- {
- code += write_argument_declaration(argument_list[argument_list.size() - 1]);
- }
-
- code += ')';
-
- return code;
-}
-std::string ClTemplateWriter::write_kernel_name() const
-{
- if (_components.empty())
- {
- return "empty_kernel";
- }
- std::string name = _components.empty() ? "" : _components[0]->template_writer()->get_name();
- for (size_t i = 1; i < _components.size(); ++i)
- {
- name += "___";
- name += _components[i]->template_writer()->get_name();
- }
-
- return name;
-}
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.h b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.h
deleted file mode 100644
index 83f617b6c..000000000
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.h
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEWRITER
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEWRITER
-
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
-#include "src/dynamic_fusion/sketch/gpu/IGpuKernelWriter.h"
-#include "src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h"
-
-#include <map>
-
-namespace arm_compute
-{
-namespace experimental
-{
-namespace dynamic_fusion
-{
-/** Use a templated-string-based method to write kernel code
- * It stitches the component code templates together based on the valid fusion configuration.
- * It then instantiates the actual kernel code from the template and the generated tag lookup table.
- */
-class ClTemplateWriter : public IGpuKernelWriter
-{
-public:
- /** Instantiates a kernel code string from the kernel code template
- * @note: some tags can be unused since they could be used only for the macros, or only for the component code
- *
- * @param[in] code_template Kernel code template
- * @param[in] tags Tag lookup table
- *
- * @return std::string Instantiated kernel string
- */
- static std::string replace_tags(const std::string &code_template, const TagLUT &tags);
- /** Default constructor */
- ClTemplateWriter() = default;
- /** Constructor
- *
- * @param[in] components Kernel component group from which the kernel will be generated
- */
- ClTemplateWriter(const GpuKernelComponentGroup &components);
- /** Destructor */
- ~ClTemplateWriter() override;
- /** Generate kernel name */
- std::string get_name() override;
- /** Generate kernel code */
- std::string get_code() override;
- /** Generate build options */
- CLBuildOptions get_build_options() override;
- /** Generate config id string of the entire kernel. This is used for tuning */
- std::string get_config_id() override;
- /** Generate execution window */
- Window get_window() const override;
- /** Get the kernel argument lists of the kernel*/
- std::map<ITensorInfo::Id, GpuKernelArgument> get_tensors() override;
-
-private:
- std::string write_kernel_name() const;
- std::string write_code();
- std::string write_global_section() const;
- std::string write_argument_declaration(const GpuKernelVariableTable::TensorVariable &var) const;
- std::string write_kernel_signature(const GpuKernelVariableTable::VariableList &argument_list) const;
-
-private:
- GpuKernelComponentGroup _components{};
- GpuKernelVariableTable _vtable{};
-};
-} // namespace dynamic_fusion
-} // namespace experimental
-} // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_TEMPLATE_WRITER_CL_CLTEMPLATEWRITER */
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
new file mode 100644
index 000000000..720164366
--- /dev/null
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/gpu/cl/kernels/ClScatterKernel.h"
+
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/ITensorPack.h"
+#include "arm_compute/core/TensorInfo.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+namespace kernels
+{
+ClScatterKernel::ClScatterKernel()
+{
+}
+
+Status ClScatterKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ return Status{};
+}
+void ClScatterKernel::configure(const ClCompileContext &compile_context,
+ const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(compile_context);
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+}
+
+void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
+{
+ ARM_COMPUTE_UNUSED(tensors);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_UNUSED(queue);
+}
+
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h
new file mode 100644
index 000000000..dda614ff3
--- /dev/null
+++ b/src/gpu/cl/kernels/ClScatterKernel.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H
+#define ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H
+
+#include "arm_compute/function_info/ScatterInfo.h"
+
+#include "src/core/common/Macros.h"
+#include "src/gpu/cl/ClCompileContext.h"
+#include "src/gpu/cl/IClKernel.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+namespace kernels
+{
+class ClScatterKernel : public IClKernel
+{
+public:
+ ClScatterKernel();
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel);
+ /** Initialise the kernel's input and output.
+ *
+ * @param[in] compile_context The compile context to be used.
+ * @param[in] src Input tensor info for the source matrix.
+ * @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[out] dst Output tensor info. Data type supported: same as @p src
+ * @param[in] info Attributes for Scatter Kernel
+ */
+ void configure(const ClCompileContext &compile_context,
+ const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info);
+ /** Static function to check if given info will lead to a valid configuration
+ *
+ * Similar to @ref ClScatterKernel::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info);
+
+ // Inherited methods overridden:
+ void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
+};
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
+
+#endif // ACL_SRC_GPU_CL_KERNELS_CLSCATTERKERNEL_H
diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp
new file mode 100644
index 000000000..af5fbb86f
--- /dev/null
+++ b/src/gpu/cl/operators/ClScatter.cpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/gpu/cl/operators/ClScatter.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+
+#include "src/common/utils/Log.h"
+#include "src/gpu/cl/kernels/ClFillKernel.h"
+#include "src/gpu/cl/kernels/ClScatterKernel.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+using namespace arm_compute::opencl::kernels;
+
+ClScatter::ClScatter()
+{
+}
+
+Status ClScatter::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32);
+
+ return kernels::ClScatterKernel::validate(src, updates, indices, dst, info);
+}
+
+void ClScatter::configure(const CLCompileContext &compile_context,
+ const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst);
+ ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info);
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info));
+ _fill_zero = info.zero_initialization;
+
+ // If necessary, create fill kernel to fill dst tensor.
+ if (_fill_zero)
+ {
+ _fill_kernel = std::make_unique<kernels::ClFillKernel>();
+ }
+
+ // Configure ClScatterKernel
+ auto k = std::make_unique<kernels::ClScatterKernel>();
+ k->set_target(CLScheduler::get().target());
+ k->configure(compile_context, src, updates, indices, dst, info);
+ _scatter_kernel = std::move(k);
+}
+
+void ClScatter::run(ITensorPack &tensors)
+{
+ ARM_COMPUTE_UNUSED(tensors);
+}
+
+} // namespace opencl
+} // namespace arm_compute
diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h
new file mode 100644
index 000000000..433f7ca3a
--- /dev/null
+++ b/src/gpu/cl/operators/ClScatter.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H
+#define ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H
+
+#include "arm_compute/function_info/ScatterInfo.h"
+
+#include "src/gpu/cl/IClKernel.h"
+#include "src/gpu/cl/IClOperator.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+namespace opencl
+{
+// Forward declaration
+class ClFillKernel;
+class ClScatterKernel;
+
+/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels:
+ *
+ * -# @ref kernels::ClScatterKernel
+ */
+class ClScatter : public IClOperator
+{
+public:
+ /** Constructor */
+ ClScatter();
+ /** Default destructor */
+ ~ClScatter() = default;
+ /** Initialise the kernel's inputs and output
+ *
+ * Valid data layouts:
+ * - All
+ *
+ * @note indices must always be U32
+ * @note src, updates and dst tensors must be same datatype.
+ *
+ * @param[in] compile_context The compile context to be used.
+ * @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
+ * @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
+ * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
+ * @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
+ * @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
+ */
+ void configure(const CLCompileContext &compile_context,
+ const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &Scatter_info);
+ /** Static function to check if given info will lead to a valid configuration
+ *
+ * Similar to @ref ClScatter::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &Scatter_info);
+ // Inherited methods overridden:
+ void run(ITensorPack &tensors) override;
+
+private:
+ std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr};
+ std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr};
+ bool _fill_zero{false};
+};
+} // namespace opencl
+} // namespace arm_compute
+#endif // ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H
diff --git a/src/runtime/CL/functions/CLScatter.cpp b/src/runtime/CL/functions/CLScatter.cpp
new file mode 100644
index 000000000..e16fcc4cc
--- /dev/null
+++ b/src/runtime/CL/functions/CLScatter.cpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/CL/functions/CLScatter.h"
+
+#include "arm_compute/function_info/ScatterInfo.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+
+#include "src/gpu/cl/operators/ClScatter.h"
+
+namespace arm_compute
+{
+using OperatorType = opencl::ClScatter;
+
+struct CLScatter::Impl
+{
+ std::unique_ptr<OperatorType> op{nullptr};
+ ITensorPack run_pack{};
+};
+
+CLScatter::CLScatter() : _impl(std::make_unique<Impl>())
+{
+}
+
+CLScatter::~CLScatter() = default;
+
+void CLScatter::configure(const ICLTensor *src,
+ const ICLTensor *updates,
+ const ICLTensor *indices,
+ ICLTensor *output,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ configure(CLKernelLibrary::get().get_compile_context(), src, updates, indices, output, info);
+}
+
+void CLScatter::configure(const CLCompileContext &compile_context,
+ const ICLTensor *src,
+ const ICLTensor *updates,
+ const ICLTensor *indices,
+ ICLTensor *output,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, output);
+
+ _impl->op = std::make_unique<OperatorType>();
+ if (src)
+ { // Src not nullptr.
+ _impl->op->configure(compile_context, src->info(), updates->info(), indices->info(), output->info(), info);
+ }
+ else
+ {
+ _impl->op->configure(compile_context, nullptr, updates->info(), indices->info(), output->info(), info);
+ }
+ _impl->run_pack = {{ACL_SRC_0, src}, {ACL_SRC_1, updates}, {ACL_SRC_2, indices}, {ACL_DST, output}};
+}
+
+Status CLScatter::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *output,
+ const ScatterInfo &info)
+{
+ return OperatorType::validate(src, updates, indices, output, info);
+}
+
+void CLScatter::run()
+{
+ _impl->op->run(_impl->run_pack);
+}
+
+} // namespace arm_compute
diff --git a/support/Bfloat16.h b/support/Bfloat16.h
index 17013294e..02772898a 100644
--- a/support/Bfloat16.h
+++ b/support/Bfloat16.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited.
+ * Copyright (c) 2020-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,12 +21,12 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_BFLOAT16_H
-#define ARM_COMPUTE_BFLOAT16_H
+#ifndef ACL_SUPPORT_BFLOAT16_H
+#define ACL_SUPPORT_BFLOAT16_H
#include <cstdint>
#include <cstring>
-
+#include <ostream>
namespace arm_compute
{
namespace
@@ -131,8 +131,16 @@ public:
return val;
}
+ bfloat16 &operator+=(float v)
+ {
+ value = float_to_bf16(bf16_to_float(value) + v);
+ return *this;
+ }
+
+ friend std::ostream &operator<<(std::ostream &os, const bfloat16 &arg);
+
private:
uint16_t value;
};
} // namespace arm_compute
-#endif /* ARM_COMPUTE_BFLOAT16_H */
+#endif // ACL_SUPPORT_BFLOAT16_H
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 3f2223596..20a010f38 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -100,6 +100,7 @@ target_sources(
validation/reference/Floor.cpp
validation/reference/PriorBoxLayer.cpp
validation/reference/Scale.cpp
+ validation/reference/ScatterLayer.cpp
validation/reference/ReorgLayer.cpp
validation/reference/Range.cpp
validation/reference/ArithmeticDivision.cpp
diff --git a/tests/SConscript b/tests/SConscript
index 305f1693d..0907c5713 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -1,7 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
-# Copyright (c) 2017-2023 Arm Limited.
+# Copyright (c) 2017-2023,2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -81,6 +81,9 @@ if 'macos' in test_env['os']:
load_whole_archive = '-Wl,-force_load'
noload_whole_archive = ''
+if (env['multi_isa']):
+ test_env.Append(CPPDEFINES=['ARM_COMPUTE_ENABLE_BF16'])
+
if env['os'] in ['android', 'macos', 'bare_metal'] or env['standalone']:
Import("arm_compute_a")
Import("arm_compute_graph_a")
diff --git a/tests/datasets/LargeConvolutionLayerDataset.h b/tests/datasets/LargeConvolutionLayerDataset.h
index 72f73ba6d..c299f2460 100644
--- a/tests/datasets/LargeConvolutionLayerDataset.h
+++ b/tests/datasets/LargeConvolutionLayerDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -294,6 +294,16 @@ public:
}
};
+class VeryLargeConvolutionLayerDataset final : public ConvolutionLayerDataset
+{
+public:
+ VeryLargeConvolutionLayerDataset()
+ {
+ // Tensor size > 1e7 bytes && weight dimensions > 7
+ add_config(TensorShape(336U, 336U, 32U), TensorShape(9U, 9U, 32U, 64U), TensorShape(64U), TensorShape(168U, 168U, 64U), PadStrideInfo(2, 2, 4, 4));
+ }
+};
+
class LargeGroupedConvolutionLayerDataset final : public ConvolutionLayerDataset
{
public:
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 6cdff7f55..e45319ef5 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
-#define ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
+#define ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -79,7 +79,20 @@ public:
add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
}
};
+
+class LargeAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ LargeAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_LARGE_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
new file mode 100644
index 000000000..d204d1785
--- /dev/null
+++ b/tests/datasets/ScatterDataset.h
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_TESTS_DATASETS_SCATTERDATASET_H
+#define ACL_TESTS_DATASETS_SCATTERDATASET_H
+
+#include "arm_compute/core/TensorShape.h"
+#include "utils/TypePrinter.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace datasets
+{
+
+class ScatterDataset
+{
+public:
+ using type = std::tuple<TensorShape, TensorShape, TensorShape, TensorShape>;
+
+ struct iterator
+ {
+ iterator(std::vector<TensorShape>::const_iterator src_it,
+ std::vector<TensorShape>::const_iterator updates_it,
+ std::vector<TensorShape>::const_iterator indices_it,
+ std::vector<TensorShape>::const_iterator dst_it)
+ : _src_it{ std::move(src_it) },
+ _updates_it{ std::move(updates_it) },
+ _indices_it{std::move(indices_it)},
+ _dst_it{ std::move(dst_it) }
+ {
+ }
+
+ std::string description() const
+ {
+ std::stringstream description;
+ description << "A=" << *_src_it << ":";
+ description << "B=" << *_updates_it << ":";
+ description << "C=" << *_indices_it << ":";
+ description << "Out=" << *_dst_it << ":";
+ return description.str();
+ }
+
+ ScatterDataset::type operator*() const
+ {
+ return std::make_tuple(*_src_it, *_updates_it, *_indices_it, *_dst_it);
+ }
+
+ iterator &operator++()
+ {
+ ++_src_it;
+ ++_updates_it;
+ ++_indices_it;
+ ++_dst_it;
+
+ return *this;
+ }
+
+ private:
+ std::vector<TensorShape>::const_iterator _src_it;
+ std::vector<TensorShape>::const_iterator _updates_it;
+ std::vector<TensorShape>::const_iterator _indices_it;
+ std::vector<TensorShape>::const_iterator _dst_it;
+ };
+
+ iterator begin() const
+ {
+ return iterator(_src_shapes.begin(), _update_shapes.begin(), _indices_shapes.begin(), _dst_shapes.begin());
+ }
+
+ int size() const
+ {
+ return std::min(_src_shapes.size(), std::min(_indices_shapes.size(), std::min(_update_shapes.size(), _dst_shapes.size())));
+ }
+
+ void add_config(TensorShape a, TensorShape b, TensorShape c, TensorShape dst)
+ {
+ _src_shapes.emplace_back(std::move(a));
+ _update_shapes.emplace_back(std::move(b));
+ _indices_shapes.emplace_back(std::move(c));
+ _dst_shapes.emplace_back(std::move(dst));
+ }
+
+protected:
+ ScatterDataset() = default;
+ ScatterDataset(ScatterDataset &&) = default;
+
+private:
+ std::vector<TensorShape> _src_shapes{};
+ std::vector<TensorShape> _update_shapes{};
+ std::vector<TensorShape> _indices_shapes{};
+ std::vector<TensorShape> _dst_shapes{};
+};
+
+class Small1DScatterDataset final : public ScatterDataset
+{
+public:
+ Small1DScatterDataset()
+ {
+ add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U));
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U));
+ }
+};
+} // namespace datasets
+} // namespace test
+} // namespace arm_compute
+#endif // ACL_TESTS_DATASETS_SCATTERDATASET_H
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index c12f57b26..99c7abbf6 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
-#define ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
+#define ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -97,7 +97,18 @@ public:
}
};
+class SmallAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ SmallAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(8U, 2U), TensorShape(16U, 8U), TensorShape(16U, 2U), TensorShape(16U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_SMALL_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
diff --git a/tests/validation/CL/DepthwiseConvolutionLayer.cpp b/tests/validation/CL/DepthwiseConvolutionLayer.cpp
index 04612a689..d4dbcec9d 100644
--- a/tests/validation/CL/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/CL/DepthwiseConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,7 +54,9 @@ constexpr float tolerance_num = 0.05f; /**<
const auto depth_multipliers = make("DepthMultiplier", { 1, 4 });
const auto large_depth_multipliers = make("DepthMultiplier", { 2, 5, 8 });
-//Activation Functions
+// Activation Functions
+const auto NoActivation = make("ActivationInfo", ActivationLayerInfo());
+
const auto ActivationFunctionsSmallDataset = make("ActivationInfo",
{
ActivationLayerInfo(),
@@ -77,11 +79,19 @@ const auto ActivationFunctionsDataset = make("ActivationInfo",
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::GELU)
});
+const auto ActivationFunctionsQuantizedSmallDataset = make("ActivationInfo",
+{
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 2.f, 0.f)
+});
+
const auto ActivationFunctionsQuantizedDataset = make("ActivationInfo",
{
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 2.3f, -1.5f),
});
+
+const auto IgnoredQuantizationInfo = make("IgnoredQuantizationInfo", QuantizationInfo());
+
} // namespace
TEST_SUITE(CL)
@@ -507,24 +517,35 @@ TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
TEST_SUITE(Generic)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 128), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(1.f, 128) })),
- make("DataLayout", { DataLayout::NHWC })), // NCHW is tested with int8
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }), // NCHW is tested with int8
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 128), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(1.f, 128) }),
+ make("DataLayout", { DataLayout::NHWC }), // NCHW is tested with int8
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.7f, 2) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
@@ -545,24 +566,35 @@ FIXTURE_DATA_TEST_CASE_NEW(RunActivations, CLDepthwiseConvolutionLayerQuantizedF
}
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.8, 1) })),
- make("DataLayout", { DataLayout::NHWC })), // NCHW is tested with int8
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }), // NCHW is tested with int8
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(0.8, 1) }),
+ make("DataLayout", { DataLayout::NHWC }), // NCHW is tested with int8
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(1.3f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.9f, 11) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
@@ -570,58 +602,80 @@ TEST_SUITE_END() // Dilation
TEST_SUITE_END() // Generic
TEST_SUITE(W3x3)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.3f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.3f, 10), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunMixedDataLayout, CLDepthwiseConvolutionLayerQuantizedMixedDataLayoutFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
- make("DepthMultiplier", { 2 })),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", { 2 }),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, CLDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
@@ -632,24 +686,35 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
TEST_SUITE(Generic)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.3f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) })),
- make("DataLayout", { DataLayout::NCHW })),
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.3f, 10), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) }),
+ make("DataLayout", { DataLayout::NCHW }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunMixedDataLayout, CLDepthwiseConvolutionLayerQuantizedMixedDataLayoutFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- make("DepthMultiplier", { 2 })),
- make("DataType", DataType::QASYMM8_SIGNED)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.3f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) })),
- make("DataLayout", { DataLayout::NCHW })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ make("DepthMultiplier", { 2 }),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
@@ -670,13 +735,24 @@ FIXTURE_DATA_TEST_CASE_NEW(RunActivations, CLDepthwiseConvolutionLayerQuantizedF
}
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, CLDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) })),
- make("DstQuantizationInfo", { QuantizationInfo(0.8, 1) })),
- make("DataLayout", { DataLayout::NCHW })),
- ActivationFunctionsSmallDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, CLDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ make("SrcQuantizationInfo", { QuantizationInfo(0.5f, 10), QuantizationInfo(2.2f, 10) }),
+ make("DstQuantizationInfo", { QuantizationInfo(0.8, 1) }),
+ make("DataLayout", { DataLayout::NCHW }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index 1ae9e9662..78d794a9b 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -71,7 +71,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyCoreFixture, framework:
}
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
@@ -84,7 +84,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
@@ -98,7 +98,7 @@ TEST_SUITE_END() // BatchedMatMul
TEST_SUITE(FusedOffsetOutput)
TEST_SUITE(QASYMM8)
-using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
+using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -110,7 +110,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUi
TEST_SUITE(Output3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -123,7 +123,7 @@ TEST_SUITE_END() // Output3D
TEST_SUITE(InputOutput3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -148,7 +148,8 @@ using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture =
GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInt8Dataset(),
- make("DataType", { DataType::QASYMM8_SIGNED })))
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_quant);
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
new file mode 100644
index 000000000..56338f489
--- /dev/null
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/CL/functions/CLScatter.h"
+#include "tests/validation/fixtures/ScatterLayerFixture.h"
+#include "tests/datasets/ScatterDataset.h"
+#include "tests/CL/CLAccessor.h"
+#include "arm_compute/function_info/ScatterInfo.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+
+template <typename T>
+using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
+
+using framework::dataset::make;
+
+TEST_SUITE(CL)
+TEST_SUITE(Scatter)
+DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
+ make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
+ TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
+ TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
+ }),
+ make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ }),
+ make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
+ TensorInfo(TensorShape(15U), 1, DataType::U32),
+ TensorInfo(TensorShape(2U), 1, DataType::U32),
+ TensorInfo(TensorShape(271U), 1, DataType::U32),
+ TensorInfo(TensorShape(271U), 1, DataType::U32),
+ TensorInfo(TensorShape(2U), 1 , DataType::S32)
+ }),
+ make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U), 1, DataType::F32),
+ TensorInfo(TensorShape(12U), 1, DataType::F32)
+ }),
+ make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
+ }),
+ make("Expected", { false, true, true, false, false, false })),
+ input_info, updates_info, indices_info, output_info, scatter_info, expected)
+{
+ // TODO: Enable validation tests.
+ ARM_COMPUTE_UNUSED(input_info);
+ ARM_COMPUTE_UNUSED(updates_info);
+ ARM_COMPUTE_UNUSED(indices_info);
+ ARM_COMPUTE_UNUSED(output_info);
+ ARM_COMPUTE_UNUSED(scatter_info);
+ ARM_COMPUTE_UNUSED(expected);
+}
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
+ make("ZeroInit", {false})))
+{
+ // TODO: Add validate() here.
+}
+
+// With this test, src should be passed as nullptr.
+FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Add}),
+ make("ZeroInit", {true})))
+{
+ // TODO: Add validate() here
+}
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // Scatter
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/CPP/DFT.cpp b/tests/validation/CPP/DFT.cpp
index e19e85058..84431399b 100644
--- a/tests/validation/CPP/DFT.cpp
+++ b/tests/validation/CPP/DFT.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -125,7 +125,7 @@ DATA_TEST_CASE(Real, framework::DatasetMode::ALL, shapes_2d_dft,
auto backward = reference::ridft_2d(forward, is_odd);
// Validate with input
- validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f), 0.f, AbsoluteTolerance<float>(0.001f));
}
DATA_TEST_CASE(Complex, framework::DatasetMode::ALL, shapes_2d_dft,
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 647adcdb6..e04462055 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2023,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/function_info/ActivationLayerInfo.h"
+
#include "support/Half.h"
#include "tests/Globals.h"
#include "tests/SimpleTensor.h"
@@ -52,6 +53,10 @@ template <>
struct is_floating_point<half> : public std::true_type
{
};
+template <>
+struct is_floating_point<bfloat16> : public std::true_type
+{
+};
/** Helper struct to store the hints for
* - destination quantization info
@@ -78,13 +83,13 @@ std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::Activation
{
std::pair<T, T> bounds;
- switch(data_type)
+ switch (data_type)
{
case DataType::F16:
{
using namespace half_float::literal;
- switch(activation)
+ switch (activation)
{
case ActivationLayerInfo::ActivationFunction::TANH:
case ActivationLayerInfo::ActivationFunction::SQUARE:
@@ -104,7 +109,7 @@ std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::Activation
break;
}
case DataType::F32:
- switch(activation)
+ switch (activation)
{
case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
// Reduce range as exponent overflows
@@ -227,7 +232,8 @@ std::pair<int, int> get_quantized_qasymm8_signed_bounds(const QuantizationInfo &
* @param[in] max Floating point maximum value to be quantized
* @param[in] channel_id Channel id for per channel quantization info.
*/
-std::pair<int, int> get_symm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id = 0);
+std::pair<int, int>
+get_symm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id = 0);
/** Add random padding along the X axis (between 1 and 16 columns per side) to all the input tensors.
* This is used in our validation suite in order to simulate implicit padding addition after configuring, but before allocating.
@@ -238,7 +244,9 @@ std::pair<int, int> get_symm_quantized_per_channel_bounds(const QuantizationInfo
*
* @note This function adds padding to the input tensors only if data_layout == DataLayout::NHWC
*/
-void add_padding_x(std::initializer_list<ITensor *> tensors, const DataLayout &data_layout = DataLayout::NHWC, bool only_right_pad = false);
+void add_padding_x(std::initializer_list<ITensor *> tensors,
+ const DataLayout &data_layout = DataLayout::NHWC,
+ bool only_right_pad = false);
/** For 2d convolution, given the Lhs/Rhs matrix quantization informations and the convolution dimension,
* calculate a suitable output quantization and suggested bias range for obtaining non-saturated outputs with high probability.
@@ -255,11 +263,11 @@ void add_padding_x(std::initializer_list<ITensor *> tensors, const DataLayout &d
*/
QuantizationHint suggest_conv_dst_q_info_and_bias(const QuantizationInfo &in_q_info,
const QuantizationInfo &weight_q_info,
- int32_t height,
- int32_t width,
- int32_t channels,
- DataType data_type,
- float bias_fraction);
+ int32_t height,
+ int32_t width,
+ int32_t channels,
+ DataType data_type,
+ float bias_fraction);
/** For a matrix multiplication, given the Lhs/Rhs matrix quantization informations and the matrix multiplication dimensions,
* calculate a suitable output quantization and suggested bias range for obtaining non-saturated outputs with high probability.
@@ -275,8 +283,12 @@ QuantizationHint suggest_conv_dst_q_info_and_bias(const QuantizationInfo &in_q_i
* @return QuantizationHint object containing the suggested output quantization info and min/max bias range
*/
QuantizationHint suggest_matmul_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info,
- const QuantizationInfo &rhs_q_info, int32_t m, int32_t n, int32_t k, DataType data_type,
- float bias_fraction);
+ const QuantizationInfo &rhs_q_info,
+ int32_t m,
+ int32_t n,
+ int32_t k,
+ DataType data_type,
+ float bias_fraction);
/** For a multiply-accumulate (mac), given the Lhs/Rhs vector quantization informations and the dot product dimensions,
* calculate a suitable output quantization and suggested bias range for obtaining non-saturated outputs with high probability.
@@ -291,8 +303,11 @@ QuantizationHint suggest_matmul_dst_q_info_and_bias(const QuantizationInfo &lhs_
* @return QuantizationHint object containing the suggested output quantization info and min/max bias range
*/
QuantizationHint suggest_mac_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info,
- const QuantizationInfo &rhs_q_info, int32_t k, DataType data_type, float bias_fraction,
- int num_sd = 2);
+ const QuantizationInfo &rhs_q_info,
+ int32_t k,
+ DataType data_type,
+ float bias_fraction,
+ int num_sd = 2);
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 62690c053..d739d4e1a 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -109,6 +109,11 @@ const auto ActivationFunctionsDataset = make("ActivationInfo",
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.5f)
});
+const auto NoActivation = make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+});
+
const auto ActivationFunctionsDatasetNightly = make("ActivationInfo",
{
ActivationLayerInfo(),
@@ -762,21 +767,33 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath
}
#if defined(ARM_COMPUTE_ENABLE_BF16)
-
+// These tests currently only works with SVE length 256
+// If other SVE length is used a kernel will fail to be found
+// This needs to be addressed in order to ensure it doesn't revert to FP32 kernels for systems with SVE length other than 256
FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+ }
}
FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+ }
}
#endif // ARM_COMPUTE_ENABLE_BF16
@@ -847,20 +864,36 @@ FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<c
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16()){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
}
FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16()){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
}
#endif // ARM_COMPUTE_ENABLE_BF16
@@ -1136,7 +1169,7 @@ TEST_SUITE(Float)
TEST_SUITE(BFLOAT16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::BFLOAT16)),
+ framework::dataset::make("DataType", Scheduler::get().cpu_info().has_bf16() ? DataType::BFLOAT16 : DataType::F32)),
framework::dataset::make("DataLayout", { DataLayout::NHWC })),
ActivationFunctionsDataset))
{
@@ -1201,6 +1234,20 @@ FIXTURE_DATA_TEST_CASE(RunPaddedWeights, NEGEMMConvolutionLayerPaddedWeightsFixt
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
}
+
+// This very large shape test is required to test heuristic paths where the tensor size is > 1e7 bytes
+// and weight dimensions larger than 7
+FIXTURE_DATA_TEST_CASE(RunVeryLarge, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::NIGHTLY,
+ combine(datasets::VeryLargeConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true }),
+ framework::dataset::make("DataType", DataType::F32),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
@@ -1310,6 +1357,27 @@ FIXTURE_DATA_TEST_CASE(RunSmallSigned, NEGEMMConvolutionLayerQuantizedPerChannel
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
+
+FIXTURE_DATA_TEST_CASE(MemoryStressLargeChannels, NEGEMMConvolutionLayerQuantizedPerChannelFixture<int8_t>,
+ framework::DatasetMode::ALL,
+ combine(
+ make("In", TensorShape(1U)),
+ make("Weights", TensorShape(1U, 1U, 1U, 17000U)),
+ make("Biases", TensorShape(17000U)),
+ make("Out", TensorShape(1U, 1U, 17000U)),
+ make("Info", PadStrideInfo(1, 1, 0, 0)),
+ make("Dilation", Size2D(1, 1)),
+ make("ReshapeWeights", { true }),
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("DataLayout", { DataLayout::NHWC }),
+ make("QuantizationInfo", QuantizationInfo(0.5f, 10)),
+ make("ActivationInfo", ActivationLayerInfo()),
+ make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+
TEST_SUITE_END() // QSYMM8_PER_CHANNEL
TEST_SUITE_END() // Quantized
diff --git a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
index a5d7a31cf..e9609b7b7 100644
--- a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,8 +47,9 @@ using namespace arm_compute::misc::shape_calculator;
namespace
{
-constexpr RelativeTolerance<float> tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8 */
+constexpr RelativeTolerance<float> tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8 */
+constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_signed(1); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8_SIGNED */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.02)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
constexpr float tolerance_num = 0.05f; /**< Tolerance number */
@@ -57,7 +58,9 @@ constexpr float tolerance_num = 0.05f; /**<
const auto depth_multipliers = make("DepthMultiplier", { 1, 2, 8 });
const auto large_depth_multipliers = make("DepthMultiplier", { 5, 32 });
-//Activation Functions
+// Activation Functions
+const auto NoActivation = make("ActivationInfo", ActivationLayerInfo());
+
const auto ActivationFunctionsDataset = make("ActivationInfo",
{
ActivationLayerInfo(),
@@ -83,17 +86,26 @@ const auto ActivationFunctionsDatasetNightly = make("ActivationInfo",
#endif // __aarch64__
});
+const auto ActivationFunctionsQuantizedSmallDataset = make("ActivationInfo",
+{
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+});
+
const auto ActivationFunctionsQuantizedDataset = make("ActivationInfo",
{
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.5f),
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.5f, -0.5f),
});
+// This is only used when there is fused activation
const auto input_qinfo_dataset = make("InputQInfo",
{
QuantizationInfo(0.3f, 10),
QuantizationInfo(2.2f, 10),
});
+
+const auto IgnoredQuantizationInfo = make("IgnoredQuantizationInfo", QuantizationInfo());
+
} // namespace
TEST_SUITE(NEON)
@@ -629,47 +641,69 @@ FIXTURE_DATA_TEST_CASE_NEW(RunActivations, NEDepthwiseConvolutionLayerQuantizedF
TEST_SUITE(Generic)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunMixedDataLayout, NEDepthwiseConvolutionLayerQuantizedMixedDataLayoutFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- make("DepthMultiplier", { 2 })),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ make("DepthMultiplier", { 2 }),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.8f, 1) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.8f, 1) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.9f, 11) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
@@ -677,47 +711,66 @@ TEST_SUITE_END() // Dilation
TEST_SUITE_END() // Generic
TEST_SUITE(W3x3)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
TEST_SUITE(Dilation)
-
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.7f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.7f, 10) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
@@ -726,48 +779,68 @@ TEST_SUITE_END() // W3x3
TEST_SUITE(Optimized)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall3x3, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
- make("DepthMultiplier", 1)),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmall3x3WithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunMixedDataLayout3x3, NEDepthwiseConvolutionLayerQuantizedMixedDataLayoutFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
- make("DepthMultiplier", 1)),
- make("DataType", DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", ActivationLayerInfo())))
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunSmall5x5, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
- make("DepthMultiplier", 1)),
- make("DataType",
- DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmall5x5WithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge3x3, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeOptimizedDepthwiseConvolutionLayerDataset3x3(),
- make("DepthMultiplier", 1)),
- make("DataType",
- DataType::QASYMM8)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NHWC })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NHWC }),
+ NoActivation))
{
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
@@ -794,127 +867,191 @@ FIXTURE_DATA_TEST_CASE_NEW(RunActivations, NEDepthwiseConvolutionLayerQuantizedF
TEST_SUITE(Generic)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 4) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
-
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
- depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.8f, 1) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.8f, 1) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.9f, 11) })),
- make("DataLayout", { DataLayout::NCHW })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
TEST_SUITE_END() // Dilation
TEST_SUITE_END() // Generic
TEST_SUITE(W3x3)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(), depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.7f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmallWithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+ depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.7f, 10) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
- large_depth_multipliers),
- make("DataType", DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
+ large_depth_multipliers,
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
TEST_SUITE_END() // Dilation
TEST_SUITE_END() // W3x3
TEST_SUITE(Optimized)
FIXTURE_DATA_TEST_CASE_NEW(RunSmall3x3, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
- make("DepthMultiplier", 1)),
- make("DataType",
- DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmall3x3WithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE_NEW(RunSmall5x5, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
- make("DepthMultiplier", 1)),
- make("DataType",
- DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+FIXTURE_DATA_TEST_CASE_NEW(RunSmall5x5WithActivation, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallOptimizedDepthwiseConvolutionLayerDataset5x5(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ input_qinfo_dataset,
+ make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) }),
+ make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }),
+ ActivationFunctionsQuantizedSmallDataset))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE_NEW(RunLarge3x3, NEDepthwiseConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(datasets::LargeOptimizedDepthwiseConvolutionLayerDataset3x3(),
- make("DepthMultiplier", 1)),
- make("DataType",
- DataType::QASYMM8_SIGNED)),
- input_qinfo_dataset),
- make("DstQuantizationInfo", { QuantizationInfo(0.5f, 10) })),
- make("DataLayout", { DataLayout::NCHW })),
- make("ActivationInfo", { ActivationLayerInfo() })))
+ combine(datasets::LargeOptimizedDepthwiseConvolutionLayerDataset3x3(),
+ make("DepthMultiplier", 1),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ IgnoredQuantizationInfo,
+ IgnoredQuantizationInfo,
+ make("DataLayout", { DataLayout::NCHW }),
+ NoActivation))
{
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
TEST_SUITE_END() // Optimized
TEST_SUITE_END() // QASYMM8_SIGNED
diff --git a/tests/validation/NEON/DilatedConvolutionLayer.cpp b/tests/validation/NEON/DilatedConvolutionLayer.cpp
index 2ede4fac4..fbfe8b8a7 100644
--- a/tests/validation/NEON/DilatedConvolutionLayer.cpp
+++ b/tests/validation/NEON/DilatedConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021, 2023 Arm Limited.
+ * Copyright (c) 2018-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,7 +50,7 @@ const AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
const RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.2f)); /**< Relative tolerance value for comparing reference's output against implementation's output for DataType::F16 */
constexpr float tolerance_num_f16 = 0.07f; /**< Tolerance number for FP16 */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-constexpr AbsoluteTolerance<float> tolerance_qasymm8(0.0); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
+constexpr AbsoluteTolerance<int32_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
diff --git a/tests/validation/NEON/ElementwiseDivision.cpp b/tests/validation/NEON/ElementwiseDivision.cpp
index 5f0224c91..95db4ad5f 100644
--- a/tests/validation/NEON/ElementwiseDivision.cpp
+++ b/tests/validation/NEON/ElementwiseDivision.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -43,7 +43,7 @@ namespace validation
namespace
{
RelativeTolerance<float> tolerance_fp32(0.000001f);
-AbsoluteTolerance<int> tolerance_zero_s32(1); // Tolerance for S32 division
+AbsoluteTolerance<int> tolerance_zero_s32(0); // Tolerance for S32 division
/** Input data sets **/
const auto ElementwiseDivisionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32),
@@ -177,7 +177,7 @@ TEST_SUITE_END() // S32
TEST_SUITE_END() // Integer
TEST_SUITE_END() // ElementwiseDivision
-TEST_SUITE_END() // Neon
+TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 31db8f0f8..ee7e56227 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -487,7 +487,7 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayoutWithActivation, NEFullyConnectedLayerQu
make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
combine(datasets::FullyConnectedLayerWithActivationDataset(),
@@ -529,7 +529,7 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEFullyConnectedLayerQuantizedMixedDa
NoActivationFunctionDataset))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_qasymm8);
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallFullyConnectedLayerDataset(),
make("DataType", DataType::QASYMM8_SIGNED),
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index f956cdfed..5f6a40220 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,8 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
namespace
{
constexpr AbsoluteTolerance<float> tolerance_f(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
@@ -60,7 +62,7 @@ const AbsoluteTolerance<float> abs_tolerance_f16(0.2f); /**< Absolute
constexpr float tolerance_num = 0.07f; /**< Tolerance number for FP16 data types */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
/** CNN data types */
-const auto CNNDataTypes = framework::dataset::make("DataType",
+const auto CNNDataTypes = make("DataType",
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
DataType::F16,
@@ -68,8 +70,8 @@ const auto CNNDataTypes = framework::dataset::make("DataType",
DataType::F32,
});
-const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12);
-const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
+const auto data_interleave = make("M", 8, 12) * make("N", 8, 12);
+const auto data_transpose = make("M", 8, 14) * make("N", 7, 14);
/** Zero padding test */
template <typename FunctionType>
@@ -204,16 +206,16 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
+ make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
TensorInfo(TensorShape(27U, 13U), 1, DataType::F32),
}),
- framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
+ make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 27U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
+ make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 13U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { false, true })),
+ make("Expected", { false, true })),
lhs_info, rhs_info, output_info, expected)
{
constexpr float alpha = 1.0;
@@ -226,8 +228,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
// *INDENT-ON*
TEST_SUITE(KERNEL_SELECTION)
DATA_TEST_CASE(KernelSelection_mul_and_add, framework::DatasetMode::ALL,
- combine(framework::dataset::make("CpuExt", std::string("NEON")),
- framework::dataset::make("DataType", { DataType::F32,
+ combine(make("CpuExt", std::string("NEON")),
+ make("DataType", { DataType::F32,
DataType::F16
})),
cpu_ext, data_type)
@@ -261,8 +263,8 @@ TEST_SUITE_END() // KERNEL_SELECTION
TEST_SUITE(TRANSPOSE_1XW)
using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmTranspose1xWKernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("N", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("N", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
n_value, k_value)
{
bool status = validate_zero_padding<CpuGemmTranspose1xW>(n_value, k_value);
@@ -271,7 +273,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -280,7 +282,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -289,7 +291,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -302,8 +304,8 @@ TEST_SUITE(INTERLEAVE_4X4)
using CpuGemmInterleave4x4 = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmInterleave4x4Kernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("M", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("M", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
m_value, k_value)
{
bool status = validate_zero_padding<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
@@ -312,7 +314,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -321,7 +323,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -330,7 +332,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::QASYMM8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::QASYMM8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -345,15 +347,18 @@ using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
template <typename T>
using NEBatchedMatMulFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true, false, false, false, false, true>;
+template <typename T>
+using NEGEMMAccumulateFixture = GEMMAccumulateValidationFixture<Tensor, Accessor, NEGEMM, T>;
+
TEST_SUITE(Float)
-DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
+DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(make("In0", { TensorShape(21U, 13U),
TensorShape(31U, 1U),
TensorShape(31U, 1U),
TensorShape(8U, 2U),
TensorShape(38U, 12U),
TensorShape(32U, 1U)
}),
- framework::dataset::make("In1", { TensorShape(33U, 21U),
+ make("In1", { TensorShape(33U, 21U),
TensorShape(23U, 31U),
TensorShape(23U, 31U),
TensorShape(16U, 8U),
@@ -366,75 +371,111 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::
ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
}
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(make("In0",{ TensorShape(21U, 13U) }),
+ make("In1", { TensorShape(33U, 21U) }),
+ make("Dst", { TensorShape(33U, 13U) })),
+ zip(
+ make("alpha", { 1.0, 100.0, 1.0, 1.0 }),
+ make("beta", { 0.0, 0.0, 1.0, 1.0 }),
+ make("is_c_null", { false, false, false, true }),
+ make("Expected", { true, false, false, true }))),
+ shape_a, shape_b, shape_dst, alpha, beta, is_c_null, expected)
+{
+ /* Accumulation test for GEMM kernels */
+ // Create tensors
+ TensorInfo in_a(shape_a, 1, DataType::F32);
+ TensorInfo in_b(shape_b, 1, DataType::F32);
+ TensorInfo in_c(shape_dst, 1, DataType::F32);
+ TensorInfo dst(shape_dst, 1, DataType::F32);
+
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ // Validate accumulation
+ cpu::CpuGemm gemm;
+ Status status = gemm.validate(&in_a, &in_b, (is_c_null ? nullptr : &in_c), &dst, alpha, beta, gemm_info);
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
- framework::dataset::make("DataType", DataType::F16)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-
-TEST_SUITE(BATCHED_MATMUL)
-
-FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F16)))
+TEST_SUITE(BATCHED_MATMUL)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
+
+TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
TEST_SUITE(BATCHED_MATMUL)
-
-TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
-TEST_SUITE_END()
+TEST_SUITE(ACCUMULATE)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMAccumulateFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMAccumulateFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END() // ACCUMULATE
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // FP32
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // GEMM
+TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 9c4d1741e..9b1da61ed 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -47,12 +47,24 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
+namespace
+{
+ constexpr AbsoluteTolerance<float> tolerance_batched(1);
+ constexpr AbsoluteTolerance<float> tolerance_quant(1);
+} // namespace
+
+
TEST_SUITE(NEON)
TEST_SUITE(GEMMLowp)
TEST_SUITE(MatrixMultiplyCore)
using NEGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
+using NEGEMMLowpMatrixMultiplyCoreAccumulateFixture = GEMMLowpMatrixMultiplyAccumulateValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
using NEGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, true>;
+using NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture = GEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
+using NEGEMMLowpDequantizedMatrixMultiplyValidationFixture = GEMMLowpDequantizedMatrixMultiplyValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
using framework::dataset::make;
@@ -80,6 +92,46 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::c
validate(b.info()->padding(), PaddingSize());
validate(c.info()->padding(), PaddingSize());
}
+// accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __aarch64__
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(
+ make("In0",{ TensorShape(21U, 1U) }),
+ make("In1", { TensorShape(1U, 21U) }),
+ make("Dst", { TensorShape(1U, 1U) }),
+ make("a_offset", { -2 }),
+ make("a_offset", { 13 })
+ ),
+ zip(
+ make("OutputDataType", { DataType::S32, DataType::QASYMM8, DataType::QASYMM8_SIGNED}),
+ make("Expected", { true, false, false })
+ )),
+ shape_a, shape_b, shape_dst, a_offset, b_offset, output_data_type, expected)
+{
+ DataType input_data_type = (output_data_type == DataType::S32 ? DataType::QASYMM8 : output_data_type);
+ // Accumulation test for GEMM kernels
+ TensorInfo a(shape_a, 1, input_data_type, QuantizationInfo(1.0f / 255, a_offset));
+ TensorInfo b(shape_b, 1, input_data_type, QuantizationInfo(1.0f / 255, b_offset));
+ TensorInfo dst(shape_dst, 1, output_data_type, QuantizationInfo());
+
+ // Create and configure function
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ if (is_data_type_quantized(output_data_type))
+ {
+ GEMMLowpOutputStageInfo gemmLowpOutputStageInfo = GEMMLowpOutputStageInfo();
+ gemmLowpOutputStageInfo.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+
+ gemm_info.set_gemmlowp_output_stage(gemmLowpOutputStageInfo);
+ }
+
+ cpu::CpuGemmLowpMatrixMultiplyCore gemmlowp_mm;
+ Status status = gemmlowp_mm.validate(&a, &b, nullptr, &dst, gemm_info);
+
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+#endif // __arch64__
// *INDENT-OFF*
// clang-format off
@@ -226,13 +278,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFixture, framework:
validate(Accessor(_target), _reference);
}
-constexpr AbsoluteTolerance<float> tolerance_batched(1);
-
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
-
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -242,9 +291,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
}
TEST_SUITE_END() // QASYMM8
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8_SIGNED }),
@@ -255,26 +304,76 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // BatchedMatMul
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
-constexpr AbsoluteTolerance<float> tolerance_quant(1);
-
TEST_SUITE(FusedOffsetOutput)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
-
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::NIGHTLY,
combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
TEST_SUITE_END() // FusedOffsetOutput
+
+// accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __aarch64__
+TEST_SUITE(ACCUMULATION)
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+TEST_SUITE_END() // ACCUMULATION
+#endif // __arch64__
+
+TEST_SUITE(DynamicQuantization)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // DynamicQuantization
+
+#ifdef __aarch64__
+// Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64
+TEST_SUITE(Dequant)
+constexpr AbsoluteTolerance<float> tolerance_dequantized(0.01f);
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_dequantized);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_dequantized);
+}
+TEST_SUITE_END() // Dequant
+#endif // __aarch64__
+
TEST_SUITE_END() // MatrixMultiplyCore
TEST_SUITE_END() // GEMMLowp
TEST_SUITE_END() // NEON
diff --git a/tests/validation/NEON/LSTMLayerQuantized.cpp b/tests/validation/NEON/LSTMLayerQuantized.cpp
index d391267e3..6b98ee2b6 100644
--- a/tests/validation/NEON/LSTMLayerQuantized.cpp
+++ b/tests/validation/NEON/LSTMLayerQuantized.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -64,11 +64,7 @@ inline void fill_tensor(SimpleTensor<T> &tensor, const std::vector<T> &v)
}
/** Tolerance for quantized asymmetric operations */
-#if defined(__aarch64__)
-constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(0);
-#else // defined(__aarch64__)
constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1);
-#endif // defined(__aarch64__)
} // namespace
diff --git a/tests/validation/NEON/MatMul.cpp b/tests/validation/NEON/MatMul.cpp
index 5577a9bb9..f22bd9e86 100644
--- a/tests/validation/NEON/MatMul.cpp
+++ b/tests/validation/NEON/MatMul.cpp
@@ -24,15 +24,14 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/functions/NEMatMul.h"
-#include "tests/NEON/Accessor.h"
-#include "tests/framework/Asserts.h"
-#include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
-
#include "tests/datasets/LargeMatMulDataset.h"
#include "tests/datasets/SmallMatMulDataset.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Macros.h"
+#include "tests/NEON/Accessor.h"
#include "tests/validation/fixtures/MatMulFixture.h"
+#include "tests/validation/Validation.h"
namespace arm_compute
{
@@ -45,11 +44,12 @@ using framework::dataset::make;
TEST_SUITE(NEON)
TEST_SUITE(MatMul)
-constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
-const AbsoluteTolerance<half> tolerance_fp16(half(0.1f));
+constexpr AbsoluteTolerance<float> tolerance_fp32(
+ 0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
+const AbsoluteTolerance<half> tolerance_fp16(half(0.1f));
#ifdef __aarch64__
-constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(0);
-constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8_signed(0);
+constexpr AbsoluteTolerance<int32_t> tolerance_qasymm8(1);
+constexpr AbsoluteTolerance<int32_t> tolerance_qasymm8_signed(1);
#endif // __aarch64__
// clang-format off
@@ -120,55 +120,79 @@ template <typename T>
using NEMatMulFastMathFixture = MatMulGenericValidationFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
template <typename T>
-using NEMatMulDynamicTensorsFixture = MatMulValidationWithDynamicTensorsFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
+using NEMatMulFixedFormatFixture = MatMulFixedFormatFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
+
+template <typename T>
+using NEMatMulDynamicTensorsFixture =
+ MatMulValidationWithDynamicTensorsFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
template <typename T>
using NEQuantizedMatMulFixture = QuantizedMatMulValidationFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ NEMatMulFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEMatMulFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::LargeMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ NEMatMulFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32);
}
-FIXTURE_DATA_TEST_CASE(RunHighDimensions, NEMatMulFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::HighDimensionalMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })))
+FIXTURE_DATA_TEST_CASE(RunHighDimensions,
+ NEMatMulFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::HighDimensionalMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32);
}
-FIXTURE_DATA_TEST_CASE(RunStressDynamicTensors, NEMatMulDynamicTensorsFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfRuns", 5)))
+FIXTURE_DATA_TEST_CASE(RunStressDynamicTensors,
+ NEMatMulDynamicTensorsFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfRuns", 5)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32);
@@ -179,37 +203,60 @@ TEST_SUITE_END() // FP32
/* Note : MatMul BF16 is enabled by specifying FP32 datatype and enabling the fast math setting */
constexpr AbsoluteTolerance<float> tolerance_bf16(0.02f);
TEST_SUITE(BF16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFastMathFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo() }),
- make("RunTimes", { 0 }),
- make("Settings", { CpuMatMulSettings().fast_math(true) }),
- make("LhsQInfo", { QuantizationInfo() }),
- make("RhsQInfo", { QuantizationInfo() }),
- make("OutQInfo", { QuantizationInfo() }))
-)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ NEMatMulFastMathFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo", {ActivationLayerInfo()}),
+ make("RunTimes", {0}),
+ make("Settings", {CpuMatMulSettings().fast_math(true)}),
+ make("LhsQInfo", {QuantizationInfo()}),
+ make("RhsQInfo", {QuantizationInfo()}),
+ make("OutQInfo", {QuantizationInfo()})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_bf16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEMatMulFastMathFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::LargeMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F32),
- make("ActivationInfo", { ActivationLayerInfo() }),
- make("RunTimes", { 0 }),
- make("Settings", { CpuMatMulSettings().fast_math(true) }),
- make("LhsQInfo", { QuantizationInfo() }),
- make("RhsQInfo", { QuantizationInfo() }),
- make("OutQInfo", { QuantizationInfo() }))
-)
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+FIXTURE_DATA_TEST_CASE(RunTinyFixedFormat,
+ NEMatMulFixedFormatFixture<bfloat16>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::TinyMatMulDataset(),
+ make("TransposeA", {false}),
+ make("TransposeB", {false}),
+ make("DataType", DataType::BFLOAT16),
+ make("ActivationInfo", {ActivationLayerInfo()}),
+ make("RunTimes", {0}),
+ make("Settings", {CpuMatMulSettings().fast_math(true).fixed_format(true)}),
+ make("LhsQInfo", {QuantizationInfo()}),
+ make("RhsQInfo", {QuantizationInfo()}),
+ make("OutQInfo", {QuantizationInfo()})))
+{
+ if (CPUInfo::get().has_bf16())
+ {
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_bf16);
+ }
+}
+#endif /* ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS */
+
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ NEMatMulFastMathFixture<float>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F32),
+ make("ActivationInfo", {ActivationLayerInfo()}),
+ make("RunTimes", {0}),
+ make("Settings", {CpuMatMulSettings().fast_math(true)}),
+ make("LhsQInfo", {QuantizationInfo()}),
+ make("RhsQInfo", {QuantizationInfo()}),
+ make("OutQInfo", {QuantizationInfo()})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_bf16, 0.01 /* tolerance_num */);
@@ -219,36 +266,51 @@ TEST_SUITE_END() // BF16
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFixture<half>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F16),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ NEMatMulFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F16),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEMatMulFixture<half>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::LargeMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F16),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ NEMatMulFixture<half>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F16),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp16);
}
-FIXTURE_DATA_TEST_CASE(RunStressDynamicTensors, NEMatMulDynamicTensorsFixture<half>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::F16),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfRuns", 5)))
+FIXTURE_DATA_TEST_CASE(RunStressDynamicTensors,
+ NEMatMulDynamicTensorsFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::F16),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfRuns", 5)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp16);
@@ -263,52 +325,64 @@ TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 50, 1) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 30, -1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 2) }))
-)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ NEQuantizedMatMulFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 50, 1)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 30, -1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 2)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::SmallerMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8),
- make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 50, 1) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 30, -1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 2) }))
-)
+FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation,
+ NEQuantizedMatMulFixture<uint8_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::SmallerMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 50, 1)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 30, -1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 2)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::LargeMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 100, 1) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 200, -1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 2) }))
-)
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ NEQuantizedMatMulFixture<uint8_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 100, 1)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 200, -1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 2)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -318,52 +392,64 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
- combine(
- datasets::SmallMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8_SIGNED),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 40, -2) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 50, 1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 1) }))
-)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ NEQuantizedMatMulFixture<int8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 40, -2)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 50, 1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 1)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
-FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::SmallerMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8_SIGNED),
- make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 40, -2) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 50, 1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 1) }))
-)
+FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation,
+ NEQuantizedMatMulFixture<int8_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::SmallerMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 40, -2)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 50, 1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 1)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::NIGHTLY,
- combine(
- datasets::LargeMatMulDataset(),
- make("TransposeA", { false, true }),
- make("TransposeB", { false, true }),
- make("DataType", DataType::QASYMM8_SIGNED),
- make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) }),
- make("NumberOfExtraRuns", { 0, 1 }),
- make("LhsQInfo", { QuantizationInfo(1.f / 150, -2) }),
- make("RhsQInfo", { QuantizationInfo(1.f / 250, 1) }),
- make("OutQInfo", { QuantizationInfo(1.f, 1) }))
-)
+FIXTURE_DATA_TEST_CASE(RunLarge,
+ NEQuantizedMatMulFixture<int8_t>,
+ framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeMatMulDataset(),
+ make("TransposeA", {false, true}),
+ make("TransposeB", {false, true}),
+ make("DataType", DataType::QASYMM8_SIGNED),
+ make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
+}),
+make("NumberOfExtraRuns", {0, 1}),
+make("LhsQInfo", {QuantizationInfo(1.f / 150, -2)}),
+make("RhsQInfo", {QuantizationInfo(1.f / 250, 1)}),
+make("OutQInfo", {QuantizationInfo(1.f, 1)})))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
@@ -372,7 +458,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizedMatMulFixture<int8_t>, framework::Da
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // Quantized
-#endif // __aarch64__
+#endif // __aarch64__
TEST_SUITE_END() // MatMul
TEST_SUITE_END() // NEON
diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp
index 3acd453ea..161fe627c 100644
--- a/tests/validation/NEON/PoolingLayer.cpp
+++ b/tests/validation/NEON/PoolingLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -105,7 +105,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::QASYMM8), // Invalid exclude_padding = false with quantized type, no actual padding and NHWC
TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::F32),
TensorInfo(TensorShape(1U, 16U, 1U), 1, DataType::F32),
- }),
+ TensorInfo(TensorShape(112, 112, 64,1), 1, DataType::F32, DataLayout::NHWC), // Mismatching number of channels
+ TensorInfo(TensorShape(112, 112, 64,1), 1, DataType::F32, DataLayout::NHWC), // Mismatching width
+ }),
framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F16),
TensorInfo(TensorShape(25U, 10U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32),
@@ -115,7 +117,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TensorInfo(TensorShape(12U, 12U, 5U), 1, DataType::QASYMM8),
TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(1U, 15U, 1U), 1, DataType::F32),
- })),
+ TensorInfo(TensorShape(56, 56, 64,1), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(56, 51, 64,1), 1, DataType::F32, DataLayout::NHWC),
+
+ })),
framework::dataset::make("PoolInfo", { PoolingLayerInfo(PoolingType::AVG, 3, DataLayout::NCHW, PadStrideInfo(1, 1, 0, 0)),
PoolingLayerInfo(PoolingType::AVG, 3, DataLayout::NCHW, PadStrideInfo(1, 1, 0, 0)),
PoolingLayerInfo(PoolingType::AVG, 2, DataLayout::NCHW, PadStrideInfo(1, 1, 2, 0)),
@@ -125,8 +130,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
PoolingLayerInfo(PoolingType::AVG, 2, DataLayout::NHWC, PadStrideInfo(), false),
PoolingLayerInfo(PoolingType::AVG, DataLayout::NCHW),
PoolingLayerInfo(PoolingType::MAX, 2, DataLayout::NHWC, PadStrideInfo(1, 1, 0, 0), false),
+ PoolingLayerInfo(PoolingType::MAX,3,DataLayout::NHWC,PadStrideInfo(2,2,1,1)),
+ PoolingLayerInfo(PoolingType::MAX,3,DataLayout::NHWC,PadStrideInfo(2,2,1,1)),
+
})),
- framework::dataset::make("Expected", { false, false, false, false, true, false, true, false, false})),
+ framework::dataset::make("Expected", { false, false, false, false, true, false, true, false, false, false, false})),
input_info, output_info, pool_info, expected)
{
bool is_valid = bool(NEPoolingLayer::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), pool_info));
diff --git a/tests/validation/NEON/QuantizationLayer.cpp b/tests/validation/NEON/QuantizationLayer.cpp
index aeee54c83..bab749076 100644
--- a/tests/validation/NEON/QuantizationLayer.cpp
+++ b/tests/validation/NEON/QuantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#include "tests/validation/Validation.h"
#include "tests/validation/fixtures/QuantizationLayerFixture.h"
+
namespace arm_compute
{
namespace test
@@ -182,7 +183,16 @@ FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8, NEQuantizationLayerQASYMM8GenFixture<uin
framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("DataTypeOut", { DataType::QASYMM8 })),
framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(0.5f, 10) })),
- framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, 15) })))
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, 15), QuantizationInfo(0.5f, 25) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_u8);
+}
+FIXTURE_DATA_TEST_CASE(ConvertUint8toInt8, NEQuantizationLayerQASYMM8GenFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes,
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", { DataType::QASYMM8_SIGNED })),
+ framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(2.0f, -1) })),
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, 127) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_u8);
@@ -191,7 +201,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8_SIGNED, NEQuantizationLayerQASYMM8_SIGNED
framework::dataset::make("DataTypeIn", DataType::QASYMM8)),
framework::dataset::make("DataTypeOut", { DataType::QASYMM8_SIGNED })),
framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 10), QuantizationInfo(2.0f, -25) })),
- framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 15) })))
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 15), QuantizationInfo(1.0f, 127) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_s8);
@@ -211,7 +221,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8_SIGNED, NEQuantizationLayerQASYMM8_SIGNED
framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
framework::dataset::make("DataTypeOut", { DataType::QASYMM8_SIGNED })),
framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 10) })),
- framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, -5) })))
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(2.0f, -5), QuantizationInfo(1.0f, 43) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_s8);
@@ -220,11 +230,21 @@ FIXTURE_DATA_TEST_CASE(RunSmallQASYMM8, NEQuantizationLayerQASYMM8GenFixture<int
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("DataTypeOut", { DataType::QASYMM8 })),
framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(2.0f, 10), QuantizationInfo(2.0f, -25) })),
- framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 30) })))
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, 30), QuantizationInfo(2.0f, -128) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_u8);
}
+FIXTURE_DATA_TEST_CASE(ConvertInt8toUint8, NEQuantizationLayerQASYMM8_SIGNEDGenFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(QuantizationSmallShapes,
+ framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeOut", { DataType::QASYMM8 })),
+ framework::dataset::make("QuantizationInfoOutput", { QuantizationInfo(1.0f, 0) })),
+ framework::dataset::make("QuantizationInfoInput", { QuantizationInfo(1.0f, -128) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_s8);
+}
+
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // Quantized
diff --git a/tests/validation/NEON/RNNLayer.cpp b/tests/validation/NEON/RNNLayer.cpp
index 14d9a5d14..979aa0f2c 100644
--- a/tests/validation/NEON/RNNLayer.cpp
+++ b/tests/validation/NEON/RNNLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -139,7 +139,7 @@ TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NERNNLayerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallRNNLayerDataset(), framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
+ validate(Accessor(_target), _reference, tolerance_f16, 0.02f, abs_tolerance_f16);
}
TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 2397d8154..8da5a0d95 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -122,40 +122,35 @@ template <typename T>
using NESoftmaxLayerFixture = SoftmaxValidationFixture<Tensor, Accessor, NESoftmaxLayer, T>;
DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL,
- concat(concat(
+ concat(
combine(
- make("CpuExt", std::string("NEON")),
+ make("CpuExt", std::string("neon")),
make("DataType", { DataType::F32,
DataType::F16,
DataType::QASYMM8,
DataType::QASYMM8_SIGNED})
),
combine(
- make("CpuExt", std::string("SVE")),
+ make("CpuExt", std::string("sme2")),
make("DataType", { DataType::F32,
DataType::F16}))
),
- combine(
- make("CpuExt", std::string("SVE2")),
- make("DataType", { DataType::QASYMM8,
- DataType::QASYMM8_SIGNED}))
- ),
cpu_ext, data_type)
{
using namespace cpu::kernels;
cpuinfo::CpuIsaInfo cpu_isa{};
- cpu_isa.neon = (cpu_ext == "NEON");
- cpu_isa.sve = (cpu_ext == "SVE");
- cpu_isa.sve2 = (cpu_ext == "SVE2");
+ cpu_isa.neon = (cpu_ext == "neon");
+ cpu_isa.sme2 = (cpu_ext == "sme2");
cpu_isa.fp16 = (data_type == DataType::F16);
const auto *selected_impl = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{ data_type, cpu_isa, false /* is_log */ }, cpu::KernelSelectionType::Preferred);
+ SoftmaxKernelDataTypeISASelectorData{ data_type, cpu_isa, false /* is_log */, 0 /* axis */},
+ cpu::KernelSelectionType::Preferred);
ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
- std::string expected = "neon_" + cpu_impl_dt(data_type) + "_softmax";
+ std::string expected = cpu_ext + "_" + cpu_impl_dt(data_type) + "_softmax";
std::string actual = selected_impl->name;
ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
@@ -164,9 +159,19 @@ DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL,
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall2D, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0, -1 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f16);
+}
FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
combine(
- datasets::Small4DShapes(),
+ datasets::SmallShapes(),
make("DataType", DataType::F16),
make("Beta", { 1.0f, 2.0f }),
make("Axis", { 0, 1 })))
@@ -178,7 +183,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerFixture<half>, framework::Datas
combine(
datasets::Small4DShapes(),
make("DataType", DataType::F16),
- make("Beta", { 1.0f, 2.0f }),
+ make("Beta", { 1.0f }),
make("Axis", { 0, 2, -1 })))
{
// Validate output
diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp
index 80dcaa8f9..453983c07 100644
--- a/tests/validation/dynamic_fusion/gpu/Integration.cpp
+++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp
@@ -63,7 +63,7 @@ namespace validation
TEST_SUITE(CL)
TEST_SUITE(INTEGRATION)
TEST_SUITE(DYNAMIC_FUSION)
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Conv2d is not ported to ckw yet. COMPMID-6259
+
TEST_CASE(Conv2d, framework::DatasetMode::ALL)
{
/* Computation:
@@ -156,7 +156,7 @@ TEST_CASE(Conv2d, framework::DatasetMode::ALL)
0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32);
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
+
TEST_CASE(Add_Output_Add_Output, framework::DatasetMode::ALL)
{
/* Computation:
@@ -368,8 +368,9 @@ TEST_CASE(Add_Output_Add_Cast_Cast_Output, framework::DatasetMode::ALL)
validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_cast_f32);
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Conv2d is not ported to ckw yet. COMPMID-6259
-TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL)
+/// TODO: COMPMID-6593 : This integration test fails with CKW backend.
+/// It was not enabled for CKW before, therefore went unnoticed.
+TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::DISABLED)
{
// (tensor0)
// |
@@ -580,7 +581,6 @@ TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL)
constexpr RelativeTolerance<float> tolerance(0.001f);
validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance);
}
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
TEST_SUITE(Invalid_Fusion_Should_Fail)
TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL)
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
index 40e1ea892..2f8c639ce 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
@@ -290,7 +290,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
{
validate(CLAccessor(_target), _reference, tolerance_f16);
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
+
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuDepthwiseConv2dFixture<half>,
@@ -313,7 +313,6 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
validate(CLAccessor(_target), _reference, tolerance_f16);
}
TEST_SUITE_END() // Dilation
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
TEST_SUITE_END() // W3x3
TEST_SUITE(Generic)
@@ -336,7 +335,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
{
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
+
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuDepthwiseConv2dFixture<half>,
@@ -359,7 +358,6 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
TEST_SUITE_END() // Dilation
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
TEST_SUITE_END() // Generic
TEST_SUITE_END() // FP16
@@ -385,7 +383,6 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE(RunSmall,
@@ -409,7 +406,6 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
validate(CLAccessor(_target), _reference, tolerance_f32);
}
TEST_SUITE_END() // Dilation
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
TEST_SUITE_END() // W3x3
TEST_SUITE(Generic)
@@ -445,7 +441,6 @@ FIXTURE_DATA_TEST_CASE(RunLargeKernelSize,
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
TEST_SUITE(Dilation)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuDepthwiseConv2dFixture<float>,
@@ -468,7 +463,6 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
validate(CLAccessor(_target), _reference, tolerance_f32);
}
TEST_SUITE_END() // Dilation
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
TEST_SUITE_END() // Generic
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
index dae550003..b84376478 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
@@ -177,7 +177,8 @@ template <typename T>
using DynamicFusionGpuDirectConv2dFixture = DynamicFusionDirectConv2dValidationFixture<CLTensor, CLAccessor, GpuConv2d, T>;
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<half>, framework::DatasetMode::PRECOMMIT,
+/// TODO: COMPMID-6877: Once the issue in Conv2d is resolved, re-enable these
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<half>, framework::DatasetMode::DISABLED,
combine(combine(combine(zip(zip(zip(zip(zip(
framework::dataset::make("InputShape", { TensorShape(27U, 13U, 23U),
TensorShape(19U, 5U, 16U, 4U),
@@ -213,7 +214,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDirectConv2dFixture<half>, fram
TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<float>, framework::DatasetMode::PRECOMMIT,
+/// TODO: COMPMID-6877: Once the issue in Conv2d is resolved, re-enable these
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDirectConv2dFixture<float>, framework::DatasetMode::DISABLED,
combine(combine(combine(zip(zip(zip(zip(zip(
framework::dataset::make("InputShape", { TensorShape(27U, 13U, 23U),
TensorShape(19U, 5U, 16U, 4U),
diff --git a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
index 96b79679c..82d66ca6c 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
+
#include "tests/AssetsLibrary.h"
#include "tests/CL/CLAccessor.h"
#include "tests/datasets/LargeMatMulDataset.h"
@@ -333,4 +333,3 @@ TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
index e537826c7..be816b32b 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifdef ACL_INTERNAL_TEST_CKW_IN_DF
+
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
#include "tests/CL/CLAccessor.h"
@@ -217,4 +217,3 @@ TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
index 43617fe1b..d46754ccc 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test if ACL_INTERNAL_TEST_CKW_IN_DF and the op has not been ported to ckw
+
#include "tests/CL/CLAccessor.h"
#include "tests/datasets/ReshapeLayerDataset.h"
#include "tests/framework/datasets/Datasets.h"
@@ -40,7 +40,7 @@ TEST_SUITE(DYNAMIC_FUSION)
TEST_SUITE(RESHAPE)
DATA_TEST_CASE(Validate,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
zip(zip(framework::dataset::make(
"InputInfo",
{
@@ -82,7 +82,7 @@ using DynamicFusionGpuReshapeLayerFixture =
TEST_SUITE(F32)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuReshapeLayerFixture<float>,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
combine(datasets::SmallReshapeLayerDataset(),
framework::dataset::make("DataType", DataType::F32)))
{
@@ -94,7 +94,7 @@ TEST_SUITE_END() // F32
TEST_SUITE(F16)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuReshapeLayerFixture<half>,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
combine(datasets::SmallReshapeLayerDataset(),
framework::dataset::make("DataType", DataType::F16)))
{
@@ -106,7 +106,7 @@ TEST_SUITE_END() // F16
TEST_SUITE(U8)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuReshapeLayerFixture<uint8_t>,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
combine(datasets::SmallReshapeLayerDataset(),
framework::dataset::make("DataType", DataType::U8)))
{
@@ -118,7 +118,7 @@ TEST_SUITE_END() // U8
TEST_SUITE(S8)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuReshapeLayerFixture<int8_t>,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
combine(datasets::SmallReshapeLayerDataset(),
framework::dataset::make("DataType", DataType::S8)))
{
@@ -130,7 +130,7 @@ TEST_SUITE_END() // S8
TEST_SUITE(S16)
FIXTURE_DATA_TEST_CASE(RunSmall,
DynamicFusionGpuReshapeLayerFixture<int16_t>,
- framework::DatasetMode::ALL,
+ framework::DatasetMode::DISABLED,
combine(datasets::SmallReshapeLayerDataset(),
framework::dataset::make("DataType", DataType::S16)))
{
@@ -145,5 +145,3 @@ TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
-
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
index b7cb6bace..8f5a1ed14 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test if ACL_INTERNAL_TEST_CKW_IN_DF and the op has not been ported to ckw
+
#include "arm_compute/core/Types.h"
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSoftmax.h"
@@ -46,62 +46,70 @@ namespace validation
RelativeTolerance<half> tolerance_f16(half(0.2));
RelativeTolerance<float> tolerance_f32(0.001f);
+using framework::dataset::make;
+
+/// TODO: COMPMID-6713
+/// Softmax is not implemented in CKW. Therefore, the tests are DISABLED.
+/// Enable the tests when Softmax is implemented in CKW.
+
TEST_SUITE(CL)
TEST_SUITE(DYNAMIC_FUSION)
TEST_SUITE(SOFTMAX)
// *INDENT-OFF*
// clang-format off
-DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
- framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching data types
- TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching shapes
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::S32), // Unsupported data type
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F16),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
-
- }),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U), 1, DataType::F16),
- TensorInfo(TensorShape(27U, 11U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM16), // Unsupported data type
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
- TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
-
- })),
- framework::dataset::make("beta", { 1.0,
- 2.0,
- 2.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- })),
- framework::dataset::make("axis", {
- 0,
- 0,
- 1, // Invalid as axis != 0
- 0,
- 0,
- 0,
- -3, // Invalid as axis != 0
- 2, // Invalid as axis != 0
- 1, // Invalid as axis != 0
- -1, // Invalid as axis != 0
- })),
- framework::dataset::make("Expected", { false, false, false, true, false, false, false, false, false, false})),
- input_info, output_info, beta, axis, expected)
+DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED,
+ zip(
+ make("InputInfo", {
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching data types
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::S32), // Unsupported data type
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ }),
+ make("OutputInfo",{
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F16),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM16), // Unsupported data type
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
+ }),
+ make("beta", {
+ 1.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ }),
+ make("axis", {
+ 0,
+ 0,
+ 1, // Invalid as axis != 0
+ 0,
+ 0,
+ 0,
+ -3, // Invalid as axis != 0
+ 2, // Invalid as axis != 0
+ 1, // Invalid as axis != 0
+ -1, // Invalid as axis != 0
+ }),
+ make("Expected", { false, false, false, true, false, false, false, false, false, false})),
+ input_info, output_info, beta, axis, expected)
{
// Create a new workload sketch
CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
@@ -122,33 +130,39 @@ using DynamicFusionSoftmaxLayerFixture = DynamicFusionSoftmaxValidationFixture<C
TEST_SUITE(FLOAT)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerLargeShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::SoftmaxLayer4DShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayer4DShapes(),
+ make("DataType", DataType::F32),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -156,33 +170,39 @@ FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<float>, framework
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayerLargeShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
}
-FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::SoftmaxLayer4DShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("Beta", { 1.0f, 2.0f })),
- framework::dataset::make("Axis", { 0 })),
- framework::dataset::make("is_log", {false, true})))
+FIXTURE_DATA_TEST_CASE(Run4D, DynamicFusionSoftmaxLayerFixture<half>, framework::DatasetMode::DISABLED,
+ combine(
+ datasets::SoftmaxLayer4DShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0 }),
+ make("is_log", {false, true})))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -197,5 +217,3 @@ TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
-
-#endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
index d291a7736..6e2e3a384 100644
--- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE
-#define ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DEPTHWISECONVOLUTIONLAYERFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DEPTHWISECONVOLUTIONLAYERFIXTURE_H
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
@@ -38,6 +38,7 @@
#include "utils/Utils.h"
+#include <cstdint>
#include <random>
namespace arm_compute
@@ -54,6 +55,35 @@ class DepthwiseConvolutionLayerValidationGenericFixture : public framework::Fixt
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+ void setup_quantization(TensorShape input_shape, TensorShape weights_shape, QuantizationInfo &input_q_info,
+ QuantizationInfo &weights_q_info, DataType data_type)
+ {
+ ARM_COMPUTE_UNUSED(input_shape);
+ const int32_t t_max = static_cast<int32_t>(std::numeric_limits<T>::max());
+ const int32_t t_min = static_cast<int32_t>(std::numeric_limits<T>::min());
+
+ std::mt19937 generator(library->seed() + _hash);
+ std::uniform_real_distribution<float> distribution_float(-5.0f, 3.0f);
+ std::uniform_int_distribution<int32_t> distribution_t(t_min, t_max);
+
+ const float scale_lhs = pow(2, distribution_float(generator)); // [2^-5, 2^3]
+ const float scale_rhs = pow(2, distribution_float(generator)); // [2^-5, 2^3]
+
+ const int32_t offset_lhs = distribution_t(generator);
+ const int32_t offset_rhs = distribution_t(generator);
+
+ _input_quantization_info = QuantizationInfo(scale_lhs, offset_lhs);
+ _weights_quantization_info = QuantizationInfo(scale_rhs, offset_rhs);
+
+ QuantizationHint q_hint = suggest_conv_dst_q_info_and_bias(input_q_info, weights_q_info,
+ weights_shape.y() /* heights */, weights_shape.x() /* width */, 1 /* channels */,
+ data_type, 0.5f /* bias_fraction */);
+
+ _output_quantization_info = q_hint.q_info;
+ _min_bias = q_hint.bias_min;
+ _max_bias = q_hint.bias_max;
+ }
+
public:
void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation,
unsigned int depth_multiplier, DataType input_data_type, DataType weights_data_type,
@@ -61,13 +91,18 @@ public:
DataLayout data_layout, ActivationLayerInfo act_info, bool mixed_layout = false, bool in_place = false, bool run_twice = false)
{
ARM_COMPUTE_ERROR_ON(mixed_layout && in_place);
+ // This hash is used by random generators. There may be hash collisions but
+ // this is intentional as it's a very easy way to make the the current
+ // random generation process almost different for many test configurations,
+ // which were using the same set of values before.
+ _hash = in_shape[0] + in_shape[1] + in_shape[2] + in_shape[3] +
+ kernel_size.width + kernel_size.height + dilation.x() +
+ dilation.y() + pad_stride_info.pad_bottom() + pad_stride_info.pad_left() + pad_stride_info.pad_right() + pad_stride_info.pad_top();
+
_mixed_layout = mixed_layout;
_input_shape = in_shape;
_input_data_type = input_data_type;
_weights_data_type = weights_data_type;
- _input_quantization_info = input_quantization_info;
- _weights_quantization_info = weights_quantization_info;
- _output_quantization_info = output_quantization_info;
_data_layout = data_layout;
_pad_stride_info = pad_stride_info;
_act_info = act_info;
@@ -87,6 +122,16 @@ public:
_weights_shape.set(2, _output_shape.z());
_biases_shape = TensorShape(_weights_shape[2]);
+
+ _input_quantization_info = input_quantization_info;
+ _weights_quantization_info = weights_quantization_info;
+ _output_quantization_info = output_quantization_info;
+
+ if(is_data_type_quantized(_input_data_type) && !is_data_type_quantized_symmetric(weights_data_type) && (!act_info.enabled() || act_info.activation() == ActivationFunction::IDENTITY))
+ {
+ setup_quantization(in_shape, _weights_shape, _input_quantization_info, _weights_quantization_info, _input_data_type);
+ _use_dynamic_output_quant = true;
+ }
}
void configure_target()
@@ -150,18 +195,18 @@ public:
}
// Fill tensors
- fill(AccessorType(_src), 0);
- fill(AccessorType(_weights), 1);
- fill(AccessorType(_biases), 2);
+ fill(AccessorType(_src), 0 + _hash);
+ fill(AccessorType(_weights), 1 + _hash);
+ fill(AccessorType(_biases), 2 + _hash);
// Run with variable input
if(_run_twice) {
_dwc.run();
// Fill tensors with a new seed
- fill(AccessorType(_src), 3);
- fill(AccessorType(_weights), 4);
- fill(AccessorType(_biases), 5);
+ fill(AccessorType(_src), 3 + _hash);
+ fill(AccessorType(_weights), 4 + _hash);
+ fill(AccessorType(_biases), 5 + _hash);
}
if(_mixed_layout)
@@ -181,18 +226,19 @@ public:
SimpleTensor<TW> weights{ _weights_shape, _weights_data_type, 1, _weights_quantization_info };
SimpleTensor<TBias> biases{ _biases_shape, _bias_data_type, 1, _input_quantization_info };
- fill(src, 0);
- fill(weights, 1);
- fill(biases, 2);
+ fill(src, 0 + _hash);
+ fill(weights, 1 + _hash);
+ fill(biases, 2 + _hash);
+
if(_run_twice) {
SimpleTensor<T> depth_out = reference::depthwise_convolution(src, weights, biases, _output_shape, _pad_stride_info, _depth_multiplier, _dilation, _output_quantization_info);
if(_act_info.enabled()) {
reference::activation_layer<T>(depth_out, _act_info);
}
- fill(src, 3);
- fill(weights, 4);
- fill(biases, 5);
+ fill(src, 3 + _hash);
+ fill(weights, 4 + _hash);
+ fill(biases, 5 + _hash);
}
SimpleTensor<T> depth_out = reference::depthwise_convolution(src, weights, biases, _output_shape, _pad_stride_info, _depth_multiplier, _dilation, _output_quantization_info);
@@ -222,32 +268,77 @@ protected:
{
case DataType::QASYMM8:
{
- std::uniform_int_distribution<uint32_t> distribution(0, 15);
- library->fill(tensor, distribution, i);
+ if(_use_dynamic_output_quant)
+ {
+ std::uniform_int_distribution<int32_t> distribution(0, 255);
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ // Legacy initialization in case the output quantization info can't be reliably estimated
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
+ library->fill(tensor, distribution, i);
+ }
break;
}
case DataType::QASYMM8_SIGNED:
+ {
+ if(_use_dynamic_output_quant)
+ {
+ std::uniform_int_distribution<int32_t> distribution(-128, 127);
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ // Legacy initialization in case the output quantization info can't be reliably estimated
+ std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
+ library->fill(tensor, distribution, i);
+ }
+ break;
+ }
case DataType::QSYMM8_PER_CHANNEL:
{
- std::uniform_int_distribution<int32_t> distribution(-10, 10);
+ int min_bound = 128;
+ int max_bound = -127;
+ for(size_t i = 0; i < _weights_quantization_info.scale().size(); i++)
+ {
+ std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i);
+ if(bounds.first < min_bound)
+ {
+ min_bound = bounds.first;
+ }
+ if(bounds.second > max_bound)
+ {
+ max_bound = bounds.second;
+ }
+ }
+ std::uniform_int_distribution<int32_t> distribution(min_bound, max_bound);
library->fill(tensor, distribution, i);
break;
}
- case DataType::F16:
+ case DataType::S32:
{
- arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+ std::uniform_int_distribution<int32_t> distribution(_min_bias, _max_bias);
library->fill(tensor, distribution, i);
break;
}
- case DataType::F32:
+ case DataType::BFLOAT16:
{
- std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+ arm_compute::utils::uniform_real_distribution_16bit<bfloat16> distribution{ -1.0f, 1.0f };
library->fill(tensor, distribution, i);
break;
}
- case DataType::S32:
+ case DataType::F16:
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::F32:
{
- std::uniform_int_distribution<int32_t> distribution(-100, 100);
+ std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, i);
break;
}
@@ -282,6 +373,18 @@ protected:
bool _mixed_layout{ false };
bool _in_place{ false };
bool _run_twice{ false };
+ bool _use_dynamic_output_quant{false};
+
+ int32_t _hash{0};
+ // Random initialization limits
+ // Default values are previously handcrafted limits
+ // that sould be used when we don't use dynamic quantization
+ int32_t _min_bias{-100};
+ int32_t _max_bias{100};
+ int32_t _min_u8{0};
+ int32_t _max_u8{50};
+ int32_t _min_s8{-25};
+ int32_t _max_s8{25};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false, bool in_place = false, bool run_twice = false>
@@ -671,4 +774,4 @@ public:
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DEPTHWISECONVOLUTIONLAYERFIXTURE_H
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 05f20ac12..344187868 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -414,7 +414,7 @@ private:
void validate_with_tolerance(TensorType &target, SimpleTensor<int8_t> &ref)
{
- constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8_signed(1);
+ constexpr AbsoluteTolerance<int32_t> tolerance_qasymm8_signed(1);
validate(AccessorType(target), ref, tolerance_qasymm8_signed);
}
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index afde3d806..94bedc83e 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,14 +46,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
-class GEMMValidationFixture : public framework::Fixture
+class GEMMGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type, bool accumulate=false)
{
ARM_COMPUTE_UNUSED(pretranspose);
- _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
- _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type);
+ _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type, accumulate);
+ _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type, accumulate);
}
protected:
@@ -80,7 +80,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
// Create tensors
TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
@@ -99,7 +99,7 @@ protected:
&dst,
alpha, beta,
GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, false, (reinterpret_input_as_3d
- || reinterpret_output_as_3d)));
+ || reinterpret_output_as_3d), arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED, false /* pretranspose_B */, accumulate));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
ARM_COMPUTE_ASSERT(c.info()->is_resizable());
@@ -121,11 +121,14 @@ protected:
// Fill tensors
fill(AccessorType(a), 0);
fill(AccessorType(b), 1);
+ if (accumulate)
+ {
+ fill(AccessorType(dst), 6);
+ }
if(!disable_c)
{
fill(AccessorType(c), 2);
}
-
// Run with variable inputs.
if(run_twice)
{
@@ -145,7 +148,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
@@ -158,6 +161,7 @@ protected:
SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
SimpleTensor<T> c{ output_shape, data_type, 1 };
+ SimpleTensor<T> dst{ output_shape, data_type, 1 };
// Fill reference
fill(a, 0);
@@ -211,17 +215,51 @@ protected:
fill(c, 5);
}
+ // Do in place summation
+ if (accumulate)
+ {
+ fill(dst, 6);
+ }
+
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
- return r;
+ if (accumulate)
+ {
+ reference::gemm_accumulate<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta, dst);
+ return dst;
+ }
+ else
+ {
+ return reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, false /*accumulate*/);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMAccumulateValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ bool accumulate = true;
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, accumulate);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename T, typename GEMMOperatorType>
class GEMMMatrixMultiplyValidationFixture : public framework::Fixture
{
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index a65a1e6bd..11a491faa 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -30,6 +30,8 @@
#include "tests/framework/Fixture.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/GEMMLowp.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
+#include "tests/validation/reference/QuantizationLayer.h"
#include <cstdint>
#include <vector>
@@ -42,20 +44,35 @@ namespace validation
{
namespace
{
-
template <typename U>
void fill(U &&tensor, int i)
{
+ library->fill_tensor_uniform(tensor, i);
+}
+
+template <typename U>
+void fill_quantized(U &&tensor, int i)
+{
ARM_COMPUTE_ASSERT(is_data_type_quantized(tensor.data_type()));
library->fill_tensor_uniform(tensor, i);
}
template <typename U>
-void fill_bias_s32(U &&tensor, int i, int32_t min, int32_t max)
+void fill(U &&tensor, int i, int32_t min, int32_t max)
{
- ARM_COMPUTE_ASSERT(tensor.data_type() == DataType::S32);
- std::uniform_int_distribution<int32_t> distribution(min, max);
- library->fill(tensor, distribution, i);
+ if (tensor.data_type() == DataType::S32) {
+ std::uniform_int_distribution<int32_t> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ }
+ else if(tensor.data_type() == DataType::F32)
+ {
+ std::uniform_real_distribution<float> distribution((float)min, (float)max);
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
}
/** Information about how to fill tensors */
@@ -64,6 +81,11 @@ struct TensorFillInfo
// Bias fill range. Default values are arbitrary
int32_t min_bias {-20000};
int32_t max_bias {20000};
+
+ // Output fill range. Default values are arbitrary
+ int32_t min_output {-20000};
+ int32_t max_output {20000};
+
// Optional extra hash to randomize tensor filling
int32_t hash {0};
};
@@ -71,29 +93,42 @@ struct TensorFillInfo
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo,
const QuantizationInfo& output_qinfo, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
- GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo() )
+ GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo(),
+ bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
- // Create tensors
- const DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
+ // If unknown, set to sensible defaults
+ if (data_type_output == DataType::UNKNOWN) {
+ data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
+ }
- TensorType a = create_tensor<TensorType>(shape_a, data_type_a, 1, a_qinfo);
- TensorType b = create_tensor<TensorType>(shape_b, data_type_b, 1, b_qinfo); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
+ // Create tensors
+ TensorType a = create_tensor<TensorType>(shape_a, data_type_a, 1, dynamic_qinfo ? QuantizationInfo(1.0,0,true) : a_qinfo);
+ TensorType b = create_tensor<TensorType>(shape_b, data_type_b, 1, dynamic_qinfo ? QuantizationInfo(1.0,0,true) : b_qinfo); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
TensorType output = create_tensor<TensorType>(shape_output, data_type_output, 1, output_qinfo /* output_qinfo will be ignored when output stage type is None */);
TensorType bias;
if(is_fused)
{
TensorShape bias_shape(shape_b[0]);
- bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
+ bias = create_tensor<TensorType>(bias_shape,data_type_output == DataType::F32 ? DataType::F32 : DataType::S32, 1);
}
// Create and configure function
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
FunctionType gemmlowp;
gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
- output_stage));
+ output_stage, false /*fp_mixed_precision*/, false /*fast_math*/, false /*broadcast_bias*/,
+ arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED,
+ false /* pretranspose_B */, accumulate));
+
+ // If the QuantizationInfo is dynamic, it needs to be settable after configure (note that we also force it to be dynamic)
+ if (dynamic_qinfo)
+ {
+ a.info()->set_quantization_info(QuantizationInfo(a_qinfo.scale(), a_qinfo.offset(), true));
+ b.info()->set_quantization_info(QuantizationInfo(b_qinfo.scale(), b_qinfo.offset(), true));
+ }
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
@@ -111,26 +146,32 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
// Fill tensors
- fill(AccessorType(a), 0 + finfo.hash);
- fill(AccessorType(b), 1 + finfo.hash);
+ fill_quantized(AccessorType(a), 0 + finfo.hash);
+ fill_quantized(AccessorType(b), 1 + finfo.hash);
+
+ if (accumulate)
+ {
+ ARM_COMPUTE_ASSERT(accumulate != run_twice);
+ fill(AccessorType(output), 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ }
if(is_fused)
{
ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
bias.allocator()->allocate();
ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
- fill_bias_s32(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
// Run with variable inputs.
if(run_twice)
{
gemmlowp.run();
- fill(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
- fill(AccessorType(b), 4 + finfo.hash);
+ fill_quantized(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
+ fill_quantized(AccessorType(b), 4 + finfo.hash);
if(is_fused)
{
- fill_bias_s32(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
}
@@ -168,8 +209,8 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, b_qinfo };
// Fill reference
- fill(a, 0 + finfo.hash);
- fill(b, 1 + finfo.hash);
+ fill_quantized(a, 0 + finfo.hash);
+ fill_quantized(b, 1 + finfo.hash);
// Transpose reference if required
/* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
@@ -189,11 +230,12 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
// Run with variable inputs.
const int32_t a_offset = a_qinfo.uniform().offset;
const int32_t b_offset = b_qinfo.uniform().offset;
+
if(run_twice)
{
reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
- fill((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
- fill((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
+ fill_quantized((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
+ fill_quantized((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
}
return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
@@ -201,35 +243,77 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
} // namespace
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate=false, bool dynamic_qinfo = false)
{
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(1.0f / 255, b_offset);
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ TensorFillInfo finfo;
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
{
const auto output_qinfo = QuantizationInfo(); // No output stage
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8, DataType::QASYMM8, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo);
}
- SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate)
{
- return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ SimpleTensor<int32_t> ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
+ DataType::QASYMM8, DataType::QASYMM8, finfo);
+
+ if (accumulate)
+ {
+ SimpleTensor<int32_t> output{ shape_output, DataType::S32, 1 };
+ fill(output, 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ reference::arithmetic_operation<int32_t>(reference::ArithmeticOperation::ADD, output, ref_output, output, ConvertPolicy::SATURATE);
+ return output;
+ }
+
+ return ref_output;
}
TensorType _target{};
SimpleTensor<int32_t> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, false /* accumulate */);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyAccumulateValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, true /* accumulate */);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, false /* accumulate */, true /* dynamic_qinfo */);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public framework::Fixture
{
public:
/** Dynamically initialize the quantization info with saturation awareness
@@ -363,16 +447,16 @@ protected:
TensorShape bias_shape(shape_b[0]);
SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
- (run_twice) ? fill_bias_s32(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill_bias_s32(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
+ (run_twice) ? fill(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
switch(output_stage.type)
{
case GEMMLowpOutputStageType::QUANTIZE_DOWN:
- return reference::gemmlowp_quantize_down_scale<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale<int32_t, TI>(output, bias,
output_stage.gemmlowp_offset, output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
- return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TI>(output, bias,
output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
default:
@@ -384,15 +468,75 @@ protected:
SimpleTensor<TI> _reference{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ // Accumulation is supported for Int8/UInt8 only in aarch64
+ bool accumulate = true;
+ // Accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __arm__
+ accumulate = false;
+#endif //__arm__
+ bool dynamic_qinfo = false;
+ const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
+ const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
+ TensorFillInfo finfo;
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
+ }
+
+protected:
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
+ {
+ const auto output_qinfo = QuantizationInfo();
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
+ }
+
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate)
+ {
+ SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
+ DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+
+ SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);
+ QuantizationInfo dst_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0);
+ f32_ref_output = reference::quantization_layer<int32_t, float>(s32_ref_output, DataType::F32, dst_quant_info);
+
+ if (accumulate)
+ {
+ SimpleTensor<float> output{ shape_output, DataType::F32, 1 };
+ fill(output, 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, output, f32_ref_output, output, ConvertPolicy::SATURATE);
+ return output;
+ }
+
+ return f32_ref_output;
+ }
+
+ TensorType _target{};
+ SimpleTensor<float> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b,
+ shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
{
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
- shape_output, output_stage_type, data_type, false /* reshape_b_only_on_first_run */);
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b, shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
}
};
diff --git a/tests/validation/fixtures/MatMulFixture.h b/tests/validation/fixtures/MatMulFixture.h
index 2e79612a3..ffd12e56d 100644
--- a/tests/validation/fixtures/MatMulFixture.h
+++ b/tests/validation/fixtures/MatMulFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,15 +27,17 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+
#include "src/core/utils/quantization/AsymmHelpers.h"
#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
#include "tests/framework/Fixture.h"
-#include "tests/validation/Validation.h"
#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/GEMM.h"
#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/ReshapeLayer.h"
+#include "tests/validation/Validation.h"
+
#include <limits>
#include <random>
#include <type_traits>
@@ -50,32 +52,50 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class MatMulGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs,
- Settings settings, QuantizationInfo a_qinfo = QuantizationInfo(), QuantizationInfo b_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs,
+ Settings settings,
+ QuantizationInfo a_qinfo = QuantizationInfo(),
+ QuantizationInfo b_qinfo = QuantizationInfo(),
+ QuantizationInfo o_qinfo = QuantizationInfo())
{
// For brevity, the input shapes are assumed to be not-transposed for both a and b matrices.
- if(transpose_a)
+ if (transpose_a)
{
permute(shape_a, PermutationVector(1U, 0U));
}
- if(transpose_b)
+ if (transpose_b)
{
permute(shape_b, PermutationVector(1U, 0U));
}
- _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings, a_qinfo, b_qinfo, o_qinfo);
- _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, a_qinfo, b_qinfo, o_qinfo);
+ _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info,
+ num_extra_runs, settings, a_qinfo, b_qinfo, o_qinfo);
+ _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info,
+ a_qinfo, b_qinfo, o_qinfo);
}
protected:
template <typename U>
void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
{
- switch(tensor.data_type())
+ switch (tensor.data_type())
{
+ case DataType::BFLOAT16:
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<bfloat16> distribution{float(lo), float(hi)};
+ library->fill(tensor, distribution, i);
+ break;
+ }
case DataType::F16:
{
- arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ float(lo), float(hi) };
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution{float(lo), float(hi)};
library->fill(tensor, distribution, i);
break;
}
@@ -98,8 +118,18 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
- ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ virtual TensorType compute_target(const TensorShape &shape_a,
+ const TensorShape &shape_b,
+ const TensorShape &output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs,
+ const Settings &settings,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo)
{
// 1. Create Classes and configure function
// ----------------------------------------------------
@@ -137,7 +167,7 @@ protected:
ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
// For multiple runs.
- for(int i = 0; i < num_extra_runs; i++)
+ for (int i = 0; i < num_extra_runs; i++)
{
// Stress dynamic tensors by running multiple times.
// --------------------------------------------------------
@@ -164,7 +194,12 @@ protected:
template <typename TT>
typename std::enable_if < !std::is_integral<TT>::value, SimpleTensor<TT >>::type
- compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
+ compute_reference_gemm(const SimpleTensor<TT> &a,
+ const SimpleTensor<TT> &b,
+ const SimpleTensor<TT> &c,
+ float alpha,
+ float beta,
+ const QuantizationInfo &o_qinfo)
{
ARM_COMPUTE_UNUSED(o_qinfo);
@@ -173,7 +208,12 @@ protected:
template <typename TT>
typename std::enable_if<std::is_integral<TT>::value, SimpleTensor<TT>>::type
- compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
+ compute_reference_gemm(const SimpleTensor<TT> &a,
+ const SimpleTensor<TT> &b,
+ const SimpleTensor<TT> &c,
+ float alpha,
+ float beta,
+ const QuantizationInfo &o_qinfo)
{
ARM_COMPUTE_UNUSED(alpha, beta);
@@ -186,23 +226,30 @@ protected:
int32_t output_multiplier = 0;
int32_t output_shift = 0;
quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
- std::vector<int32_t> output_multipliers{ output_multiplier };
- std::vector<int32_t> output_shifts{ output_shift };
+ std::vector<int32_t> output_multipliers{output_multiplier};
+ std::vector<int32_t> output_shifts{output_shift};
//The lhs and rhs offsets are negated here to keep the reference aligned with the function implementation where the lhs and rhs offsets are also negated.
- const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>(
- a, b, c.shape(), -aq.offset, -bq.offset);
+ const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>(a, b, c.shape(), -aq.offset, -bq.offset);
auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TT>(
- tmp, output_multipliers, output_shifts, oq.offset,
- std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
+ tmp, output_multipliers, output_shifts, oq.offset, std::numeric_limits<int32_t>::lowest(),
+ std::numeric_limits<int32_t>::max());
output.quantization_info(o_qinfo);
return output;
}
- SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
- ActivationLayerInfo act_info, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape,
+ const TensorShape &b_shape,
+ const TensorShape &output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo)
{
// We collapse dimensions > 2 onto dimension 2, i.e. 4D+ tensors will look like 3D
// This is necessary unless we choose to extend gemm reference for 4D+ tensors
@@ -211,9 +258,9 @@ protected:
TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimZ);
// Create reference
- SimpleTensor<T> a{ a_shape_collapsed, data_type, 1, a_qinfo };
- SimpleTensor<T> b{ b_shape_collapsed, data_type, 1, b_qinfo };
- SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
+ SimpleTensor<T> a{a_shape_collapsed, data_type, 1, a_qinfo};
+ SimpleTensor<T> b{b_shape_collapsed, data_type, 1, b_qinfo};
+ SimpleTensor<T> c{output_shape_collapsed, data_type, 1};
// Fill reference
fill(a, 2);
@@ -234,16 +281,16 @@ protected:
b_transposed_shape.set(1, b.shape().x());
// Define transposed tensors
- SimpleTensor<T> a_transposed{ a_transposed_shape, data_type };
- SimpleTensor<T> b_transposed{ b_transposed_shape, data_type };
+ SimpleTensor<T> a_transposed{a_transposed_shape, data_type};
+ SimpleTensor<T> b_transposed{b_transposed_shape, data_type};
// pretranspose a if necessary
- if(transpose_a)
+ if (transpose_a)
{
a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U));
}
// pretranspose b if necessary
- if(transpose_b)
+ if (transpose_b)
{
b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U));
}
@@ -251,12 +298,13 @@ protected:
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto result = compute_reference_gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f, o_qinfo);
+ auto result = compute_reference_gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c,
+ 1.0f, 0.f, o_qinfo);
result = reference::activation_layer<T>(result, act_info, o_qinfo);
// We reshape the gemm output back if the tensor is high dimensional
- if(output_shape_collapsed != output_shape)
+ if (output_shape_collapsed != output_shape)
{
result = reference::reshape_layer(result, output_shape);
}
@@ -268,72 +316,293 @@ protected:
SimpleTensor<T> _reference{};
};
+/// TODO: (ONCPUML-1451) The current state of this fixture is interim and a longer-term testing method will be implemented later.
+/// @note: Currently we support only a 2x2 test due to the lack of reorder ref. implementation.
+template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
+class MatMulFixedFormatFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+{
+public:
+ TensorType compute_target(const TensorShape &shape_a,
+ const TensorShape &shape_b,
+ const TensorShape &output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs,
+ const Settings &settings,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo) override
+ {
+ // 1. Create Classes and configure function
+ // ----------------------------------------------------
+ // Create tensors
+ // Configure relevant classes and matmul function
+ TensorType a = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
+ TensorType b = create_tensor<TensorType>(shape_b, data_type, 1, b_qinfo);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, o_qinfo);
+
+ const auto weight_tensor_info = TensorInfo(*b.info());
+ const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info);
+ TensorType weights_transformed = create_tensor<TensorType>(new_tensor_info);
+
+ // Configure MatMulInfo class
+ MatMulInfo mm_info;
+ mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b);
+
+ // Ensure values are dynamic
+ a.info()->set_are_values_constant(false);
+ b.info()->set_are_values_constant(false);
+ weights_transformed.info()->set_are_values_constant(false);
+
+ FunctionType matmul;
+
+ // Configure operator
+ matmul.configure(&a, &weights_transformed, &dst, mm_info, settings, act_info);
+
+ // Assertions
+ ARM_COMPUTE_ASSERT(a.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(b.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(weights_transformed.info()->is_resizable());
+
+ // Allocate tensors
+ a.allocator()->allocate();
+ b.allocator()->allocate();
+ dst.allocator()->allocate();
+ weights_transformed.allocator()->allocate();
+
+ ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!weights_transformed.info()->is_resizable());
+
+ // For multiple runs.
+ for (int i = 0; i < num_extra_runs; i++)
+ {
+ // Stress dynamic tensors by running multiple times.
+ // --------------------------------------------------------
+ // Fill tensors with new seed
+ // Run function
+ const int seed_offset = num_extra_runs * 100;
+ this->fill(AccessorType(a), seed_offset);
+ this->fill(AccessorType(b), seed_offset + 1);
+
+ matmul.run();
+ }
+
+ // 2. Final Run for reference comparison
+ // --------------------------------------------------------
+ // Re-fill tensors same seed as reference run
+ // Compute MatMul operation
+ this->fill(AccessorType(a), 2);
+ this->fill(AccessorType(b), 3);
+
+ rearrange_data(AccessorType(b), AccessorType(weights_transformed));
+
+ matmul.run();
+
+ return dst;
+ }
+
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs,
+ Settings settings,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo)
+ {
+ if (CPUInfo::get().has_bf16())
+ {
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings,
+ a_qinfo, b_qinfo, o_qinfo);
+ }
+ }
+
+private:
+ TensorInfo prepare_weights(const TensorInfo tensor_info)
+ {
+ const DataLayout data_layout = tensor_info.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NCHW, framework::LogLevel::ERRORS);
+ const DataType data_type = tensor_info.data_type();
+ const TensorShape tensor_shape = tensor_info.tensor_shape();
+ const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ ARM_COMPUTE_EXPECT(H <= 2 && W <= 2, framework::LogLevel::ERRORS);
+
+ arm_compute::Strides strides_in_bytes = tensor_info.strides_in_bytes();
+ strides_in_bytes.set(1, 32);
+ strides_in_bytes.set(2, 32);
+
+ const size_t offset_first_element_in_bytes = tensor_info.offset_first_element_in_bytes();
+ const size_t total_size_in_bytes = 32;
+
+ const TensorShape TS(H, W);
+
+ TensorInfo new_tensor_info = tensor_info;
+ new_tensor_info.init(TS, tensor_info.num_channels(), data_type, strides_in_bytes, offset_first_element_in_bytes,
+ total_size_in_bytes);
+
+ return new_tensor_info;
+ }
+
+ void rearrange_data(const AccessorType src, AccessorType dst)
+ {
+ const TensorShape src_tensor_shape = src.shape();
+ const DataLayout data_layout = src.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NCHW, framework::LogLevel::ERRORS);
+ const unsigned int O =
+ src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+ const unsigned int H =
+ src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const unsigned int W =
+ src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const unsigned int I =
+ src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+ ARM_COMPUTE_EXPECT(H <= 2 && W <= 2, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(I == 1 && O == 1, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS);
+
+ const T *src_ptr = reinterpret_cast<const T *>(src.data());
+ T *dst_ptr = reinterpret_cast<T *>(dst.data());
+
+ // rearrange indexes for 2x2 input and weight
+ int dst_idx[] = {0, 4, 1, 5};
+ for (int i = 0; i < 4; i++)
+ {
+ dst_ptr[dst_idx[i]] = src_ptr[i];
+ }
+ }
+};
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class MatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class MatMulValidationFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0,
- Settings());
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0, Settings());
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class MatMulValidationWithDynamicTensorsFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings());
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings());
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class QuantizedMatMulValidationFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs,
- QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info,
+ int num_extra_runs,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
- a_qinfo, b_qinfo, o_qinfo);
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
+ a_qinfo, b_qinfo, o_qinfo);
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class MatMulValidationWithActivationFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo act_info)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class MatMulValidationWithActivationAlphaBetaFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class MatMulValidationWithActivationAlphaBetaFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function,
- float alpha_beta)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo::ActivationFunction function,
+ float alpha_beta)
{
ActivationLayerInfo act_info(function, alpha_beta, alpha_beta);
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class QuantizedMatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class QuantizedMatMulValidationWithActivationFixture
+ : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function,
- float alpha_beta, int num_extra_runs,
- QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ void setup(TensorShape shape_a,
+ TensorShape shape_b,
+ TensorShape output_shape,
+ bool transpose_a,
+ bool transpose_b,
+ DataType data_type,
+ ActivationLayerInfo::ActivationFunction function,
+ float alpha_beta,
+ int num_extra_runs,
+ QuantizationInfo a_qinfo,
+ QuantizationInfo b_qinfo,
+ QuantizationInfo o_qinfo)
{
ActivationLayerInfo act_info(function, alpha_beta, alpha_beta);
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
- a_qinfo, b_qinfo, o_qinfo);
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(
+ shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
+ a_qinfo, b_qinfo, o_qinfo);
}
};
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
new file mode 100644
index 000000000..bda5532a5
--- /dev/null
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H
+
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/runtime/CL/CLTensorAllocator.h"
+#include "tests/Globals.h"
+#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
+#include "tests/framework/Fixture.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/reference/ScatterLayer.h"
+#include "tests/SimpleTensor.h"
+
+#include <random>
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ScatterGenericValidationFixture : public framework::Fixture
+{
+public:
+ void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ {
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+ _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<float> distribution(lo, hi);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported data type.");
+ }
+ }
+ }
+
+ // This is used to fill indices tensor with U32 datatype.
+ // Used to prevent ONLY having values that are out of bounds.
+ template <typename U>
+ void fill_indices(U &&tensor, int i, const TensorShape &shape)
+ {
+ // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values.
+ const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5;
+ library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(0), static_cast<uint32_t>(max));
+ }
+
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ {
+ // 1. Create relevant tensors using ScatterInfo data structure.
+ // ----------------------------------------------------
+ // In order - src, updates, indices, output.
+ TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
+ TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
+ TensorType indices = create_tensor<TensorType>(shape_c, DataType::U32, 1, QuantizationInfo());
+ TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
+
+ FunctionType scatter;
+
+ // Configure operator
+ // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+ if(info.zero_initialization)
+ {
+ scatter.configure(nullptr, &updates, &indices, &dst, info);
+ }
+ else
+ {
+ scatter.configure(&src, &updates, &indices, &dst, info);
+ }
+
+ // Assertions
+ ARM_COMPUTE_ASSERT(src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(updates.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+
+ // Allocate tensors
+ src.allocator()->allocate();
+ updates.allocator()->allocate();
+ indices.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+ // Fill update (a) and indices (b) tensors.
+ fill(AccessorType(src), 0);
+ fill(AccessorType(updates), 1);
+ fill_indices(AccessorType(indices), 2, out_shape);
+
+ scatter.run();
+
+ return dst;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
+ ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ {
+ // Output Quantization not currently in use - fixture should be extended to support this.
+ ARM_COMPUTE_UNUSED(o_qinfo);
+
+ // Create reference tensors
+ SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
+ SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
+ SimpleTensor<uint32_t> indices{ c_shape, DataType::U32, 1, QuantizationInfo() };
+
+ // Fill reference
+ fill(src, 0);
+ fill(updates, 1);
+ fill_indices(indices, 2, out_shape);
+
+ // Calculate individual reference.
+ auto result = reference::scatter_layer<T>(src, updates, indices, out_shape, info);
+
+ return result;
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+};
+
+// This fixture will use the same shape for updates as indices.
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+ {
+ ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+ }
+};
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+#endif // ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H
diff --git a/tests/validation/reference/ActivationLayer.cpp b/tests/validation/reference/ActivationLayer.cpp
index 664b96912..2172362bd 100644
--- a/tests/validation/reference/ActivationLayer.cpp
+++ b/tests/validation/reference/ActivationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "ActivationLayer.h"
#include "arm_compute/core/Types.h"
+
#include "tests/validation/Helpers.h"
namespace arm_compute
@@ -40,7 +41,7 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo
ARM_COMPUTE_UNUSED(oq_info);
// Create reference
- SimpleTensor<T> dst{ src.shape(), src.data_type(), 1 };
+ SimpleTensor<T> dst{src.shape(), src.data_type(), 1};
// Compute reference
const T a(info.a());
@@ -48,7 +49,7 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo
#if defined(_OPENMP)
#pragma omp parallel for
#endif /* _OPENMP */
- for(int i = 0; i < src.num_elements(); ++i)
+ for (int i = 0; i < src.num_elements(); ++i)
{
dst[i] = activate_float<T>(src[i], a, b, info.activation());
}
@@ -57,7 +58,8 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo
}
template <>
-SimpleTensor<uint8_t> activation_layer<uint8_t>(const SimpleTensor<uint8_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
+SimpleTensor<uint8_t>
+activation_layer<uint8_t>(const SimpleTensor<uint8_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
{
const QuantizationInfo dst_qinfo = oq_info.empty() ? src.quantization_info() : oq_info;
@@ -68,7 +70,8 @@ SimpleTensor<uint8_t> activation_layer<uint8_t>(const SimpleTensor<uint8_t> &src
}
template <>
-SimpleTensor<int8_t> activation_layer<int8_t>(const SimpleTensor<int8_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
+SimpleTensor<int8_t>
+activation_layer<int8_t>(const SimpleTensor<int8_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
{
const QuantizationInfo dst_qinfo = oq_info.empty() ? src.quantization_info() : oq_info;
@@ -79,7 +82,8 @@ SimpleTensor<int8_t> activation_layer<int8_t>(const SimpleTensor<int8_t> &src, A
}
template <>
-SimpleTensor<int16_t> activation_layer<int16_t>(const SimpleTensor<int16_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
+SimpleTensor<int16_t>
+activation_layer<int16_t>(const SimpleTensor<int16_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info)
{
const QuantizationInfo dst_qinfo = oq_info.empty() ? src.quantization_info() : oq_info;
@@ -88,9 +92,14 @@ SimpleTensor<int16_t> activation_layer<int16_t>(const SimpleTensor<int16_t> &src
SimpleTensor<int16_t> dst = convert_to_symmetric<int16_t>(dst_tmp, dst_qinfo);
return dst;
}
-template SimpleTensor<int32_t> activation_layer(const SimpleTensor<int32_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
-template SimpleTensor<float> activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
-template SimpleTensor<half> activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
+template SimpleTensor<int32_t>
+activation_layer(const SimpleTensor<int32_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
+template SimpleTensor<float>
+activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
+template SimpleTensor<half>
+activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
+template SimpleTensor<bfloat16>
+activation_layer(const SimpleTensor<bfloat16> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/ActivationLayer.h b/tests/validation/reference/ActivationLayer.h
index a813ba503..7f896bd69 100644
--- a/tests/validation/reference/ActivationLayer.h
+++ b/tests/validation/reference/ActivationLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022 Arm Limited.
+ * Copyright (c) 2017-2020,2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_ACTIVATION_LAYER_H
-#define ARM_COMPUTE_TEST_ACTIVATION_LAYER_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_ACTIVATIONLAYER_H
+#define ACL_TESTS_VALIDATION_REFERENCE_ACTIVATIONLAYER_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -40,7 +40,7 @@ inline T activate_float(T x, T a, T b, ActivationLayerInfo::ActivationFunction a
{
T ret;
- switch(activation)
+ switch (activation)
{
case ActivationLayerInfo::ActivationFunction::ABS:
ret = std::abs(x);
@@ -61,13 +61,13 @@ inline T activate_float(T x, T a, T b, ActivationLayerInfo::ActivationFunction a
ret = std::min<T>(a, std::max<T>(b, x));
break;
case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
- ret = (x > 0) ? x : a * x;
+ ret = x > static_cast<T>(0) ? x : static_cast<T>(a * x);
break;
case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
ret = std::log(static_cast<T>(1) + std::exp(static_cast<double>(x)));
break;
case ActivationLayerInfo::ActivationFunction::ELU:
- ret = (x > 0) ? x : a * (std::exp(x) - static_cast<T>(1));
+ ret = x > static_cast<T>(0) ? x : static_cast<T>(a * (std::exp(x) - static_cast<T>(1)));
break;
case ActivationLayerInfo::ActivationFunction::SQRT:
ret = std::sqrt(x);
@@ -82,10 +82,11 @@ inline T activate_float(T x, T a, T b, ActivationLayerInfo::ActivationFunction a
ret = x;
break;
case ActivationLayerInfo::ActivationFunction::HARD_SWISH:
- ret = x * ((std::min(std::max(static_cast<T>(x + 3), static_cast<T>(0.0f)), static_cast<T>(6.0f))) * 0.166666667f);
+ ret = x * ((std::min(std::max(static_cast<T>(x + 3), static_cast<T>(0.0f)), static_cast<T>(6.0f))) *
+ 0.166666667f);
break;
case ActivationLayerInfo::ActivationFunction::SWISH:
- ret = static_cast<T>(x) / (static_cast<T>(1) + std::exp(-a*x));
+ ret = static_cast<T>(x) / (static_cast<T>(1) + std::exp(-a * x));
break;
case ActivationLayerInfo::ActivationFunction::GELU:
ret = x * 0.5f * (1 + erf(x / std::sqrt(2.0f)));
@@ -99,9 +100,11 @@ inline T activate_float(T x, T a, T b, ActivationLayerInfo::ActivationFunction a
}
template <typename T>
-SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info = QuantizationInfo());
+SimpleTensor<T> activation_layer(const SimpleTensor<T> &src,
+ ActivationLayerInfo info,
+ const QuantizationInfo &oq_info = QuantizationInfo());
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_ACTIVATION_LAYER_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_ACTIVATIONLAYER_H
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp
index 1e4939129..3f88897f8 100644
--- a/tests/validation/reference/DepthConvertLayer.cpp
+++ b/tests/validation/reference/DepthConvertLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -165,7 +165,7 @@ template SimpleTensor<half> depth_convert(const SimpleTensor<int32_t> &src, Data
template SimpleTensor<float> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
// BFLOAT16
-template SimpleTensor<float> depth_convert(const SimpleTensor<bfloat16> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<bfloat16> depth_convert(const SimpleTensor<bfloat16> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
// F16
template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
diff --git a/tests/validation/reference/ElementwiseOperations.cpp b/tests/validation/reference/ElementwiseOperations.cpp
index f22c84e15..edbbab860 100644
--- a/tests/validation/reference/ElementwiseOperations.cpp
+++ b/tests/validation/reference/ElementwiseOperations.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -74,15 +74,6 @@ T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy
case ArithmeticOperation::DIV:
{
val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2));
- if(std::is_integral<T>::value)
- {
- // Implement flooring division
- val = (src2 == 0) ? 0 : val;
- if(static_cast<int32_t>(src1) % static_cast<int32_t>(src2) != 0 && ((src1 < 0) != (src2 < 0)))
- {
- --val;
- }
- }
break;
}
case ArithmeticOperation::POWER:
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index f7e97e47b..d51334379 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
namespace arm_compute
{
@@ -35,10 +36,11 @@ namespace validation
namespace reference
{
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
-SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
+SimpleTensor<T>
+gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
{
// Create reference
- SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
+ SimpleTensor<T> dst{c.shape(), c.data_type(), 1};
// Compute reference
const int M = a.shape().y();
@@ -50,15 +52,22 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
const int a_stride_z = K * M;
const int a_stride_w = K * M * D;
- const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
- int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+ const int b_stride_z =
+ b.shape().num_dimensions() > 2
+ ? N * K
+ : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+ int b_stride_w =
+ b.shape().num_dimensions() > 3
+ ? K * N * D
+ : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
// Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions:
// it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4
- const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
+ const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 &&
+ c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
// Batched-GEMM
- if(is_batched_gemm)
+ if (is_batched_gemm)
{
b_stride_w = b_stride_z;
}
@@ -69,21 +78,21 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
#if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
#pragma omp parallel for collapse(2)
#endif /* _OPENMP */
- for(int w = 0; w < W; ++w)
+ for (int w = 0; w < W; ++w)
{
- for(int depth = 0; depth < D; ++depth)
+ for (int depth = 0; depth < D; ++depth)
{
const int base_addr_a = depth * a_stride_z + w * a_stride_w;
const int base_addr_b = depth * b_stride_z + w * b_stride_w;
const int base_addr_c = depth * c_stride_z + w * c_stride_w;
- for(int row = 0; row < M; ++row)
+ for (int row = 0; row < M; ++row)
{
- for(int col = 0; col < N; ++col)
+ for (int col = 0; col < N; ++col)
{
T acc(0);
- for(int k = 0; k < K; ++k)
+ for (int k = 0; k < K; ++k)
{
acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N];
}
@@ -99,11 +108,12 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
}
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
-SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
+SimpleTensor<T> gemm_mixed_precision(
+ const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
{
// GEMM mixed-precision combines F32 accumulators with F16 multiplications
// Create reference
- SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
+ SimpleTensor<T> dst{c.shape(), c.data_type(), 1};
// Compute reference
const int M = a.shape().y();
@@ -115,15 +125,22 @@ SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTenso
const int a_stride_z = K * M;
const int a_stride_w = K * M * D;
- const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
- int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+ const int b_stride_z =
+ b.shape().num_dimensions() > 2
+ ? N * K
+ : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+ int b_stride_w =
+ b.shape().num_dimensions() > 3
+ ? K * N * D
+ : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
// Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions:
// it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4
- const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
+ const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 &&
+ c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
// Batched-GEMM
- if(is_batched_gemm)
+ if (is_batched_gemm)
{
b_stride_w = b_stride_z;
}
@@ -134,27 +151,28 @@ SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTenso
#if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
#pragma omp parallel for collapse(2)
#endif /* _OPENMP */
- for(int w = 0; w < W; ++w)
+ for (int w = 0; w < W; ++w)
{
- for(int depth = 0; depth < D; ++depth)
+ for (int depth = 0; depth < D; ++depth)
{
const int base_addr_a = depth * a_stride_z + w * a_stride_w;
const int base_addr_b = depth * b_stride_z + w * b_stride_w;
const int base_addr_c = depth * c_stride_z + w * c_stride_w;
- for(int row = 0; row < M; ++row)
+ for (int row = 0; row < M; ++row)
{
- for(int col = 0; col < N; ++col)
+ for (int col = 0; col < N; ++col)
{
float acc(0);
- for(int k = 0; k < K; ++k)
+ for (int k = 0; k < K; ++k)
{
acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]);
}
// Finalize the result: alpha * A * B + beta * C
- dst[base_addr_c + col + row * N] = static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]);
+ dst[base_addr_c + col + row * N] =
+ static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]);
}
}
}
@@ -163,8 +181,21 @@ SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTenso
return dst;
}
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst)
+{
+ // Compute reference
+ SimpleTensor<T> dst_gemm = gemm(a, b, c, alpha, beta);
+ reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
+template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, const SimpleTensor<bfloat16> &b, const SimpleTensor<bfloat16> &c, float alpha, float beta);
template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+
+template void gemm_accumulate(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta, SimpleTensor<float> &dst);
+template void gemm_accumulate(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta, SimpleTensor<half> &dst);
+
template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h
index 5feaeda58..1b9757012 100644
--- a/tests/validation/reference/GEMM.h
+++ b/tests/validation/reference/GEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMM_H
-#define ARM_COMPUTE_TEST_GEMM_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -41,8 +41,11 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst);
+
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMM_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 1615b51e7..30c577d85 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "GEMMLowp.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
#include "tests/validation/reference/UtilsQuantizedAsymm.h"
#include "support/ToolchainSupport.h"
@@ -230,6 +231,13 @@ SimpleTensor<T_out> gemmlowp_matrix_multiply_core(const SimpleTensor<T_in> &a, c
return c;
}
+template <typename T_out, typename T_in, typename T_in_1>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T_in> &a, const SimpleTensor<T_in_1> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T_out> &dst)
+{
+ SimpleTensor<T_out> dst_gemm = gemmlowp_matrix_multiply_core<T_out, T_in, T_in_1>(a, b, shape_c, a_offset, b_offset);
+ reference::arithmetic_operation<T_out>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
// used to validate assembly kernels which don't know anything about offsets
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c)
@@ -336,6 +344,8 @@ template SimpleTensor<int8_t> gemmlowp_quantize_down_scale(const SimpleTensor<in
std::vector<int32_t> result_shift, int32_t min, int32_t max);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
template SimpleTensor<int32_t> gemmlowp<int32_t, int8_t, int8_t>(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, uint8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, int8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
diff --git a/tests/validation/reference/GEMMLowp.h b/tests/validation/reference/GEMMLowp.h
index 99015d71f..6e471fdad 100644
--- a/tests/validation/reference/GEMMLowp.h
+++ b/tests/validation/reference/GEMMLowp.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMMLOWP_H
-#define ARM_COMPUTE_TEST_GEMMLOWP_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -38,6 +38,9 @@ namespace reference
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp_matrix_multiply_core(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template <typename T1, typename T2, typename T3>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T1> &dst_);
+
template <typename T1, typename T2, typename T3 = T2>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c);
@@ -71,4 +74,4 @@ SimpleTensor<TOut> gemmlowp_quantize_down_scale_by_float(const SimpleTensor<TIn>
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMMLOWP_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
diff --git a/tests/validation/reference/Permute.cpp b/tests/validation/reference/Permute.cpp
index 6f122b1bf..7aa3011d8 100644
--- a/tests/validation/reference/Permute.cpp
+++ b/tests/validation/reference/Permute.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "Permute.h"
#include "arm_compute/core/Types.h"
+
#include "tests/validation/Helpers.h"
namespace arm_compute
@@ -42,11 +43,11 @@ SimpleTensor<T> permute(const SimpleTensor<T> &src, PermutationVector perm)
permute(dst_shape, perm);
// Create reference
- SimpleTensor<T> dst{ dst_shape, src.data_type(), src.num_channels(), src.quantization_info() };
+ SimpleTensor<T> dst{dst_shape, src.data_type(), src.num_channels(), src.quantization_info()};
// Compute reference
const uint32_t num_elements = src.num_elements();
- for(uint32_t i = 0; i < num_elements; ++i)
+ for (uint32_t i = 0; i < num_elements; ++i)
{
const Coordinates src_coords = index2coord(src.shape(), i);
Coordinates dst_coords = src_coords;
@@ -58,13 +59,14 @@ SimpleTensor<T> permute(const SimpleTensor<T> &src, PermutationVector perm)
return dst;
}
-template SimpleTensor<int8_t> permute(const SimpleTensor<int8_t> &src, PermutationVector perm);
-template SimpleTensor<uint8_t> permute(const SimpleTensor<uint8_t> &src, PermutationVector perm);
-template SimpleTensor<int16_t> permute(const SimpleTensor<int16_t> &src, PermutationVector perm);
+template SimpleTensor<int8_t> permute(const SimpleTensor<int8_t> &src, PermutationVector perm);
+template SimpleTensor<uint8_t> permute(const SimpleTensor<uint8_t> &src, PermutationVector perm);
+template SimpleTensor<int16_t> permute(const SimpleTensor<int16_t> &src, PermutationVector perm);
template SimpleTensor<uint16_t> permute(const SimpleTensor<uint16_t> &src, PermutationVector perm);
template SimpleTensor<uint32_t> permute(const SimpleTensor<uint32_t> &src, PermutationVector perm);
-template SimpleTensor<float> permute(const SimpleTensor<float> &src, PermutationVector perm);
-template SimpleTensor<half> permute(const SimpleTensor<half> &src, PermutationVector perm);
+template SimpleTensor<float> permute(const SimpleTensor<float> &src, PermutationVector perm);
+template SimpleTensor<half> permute(const SimpleTensor<half> &src, PermutationVector perm);
+template SimpleTensor<bfloat16> permute(const SimpleTensor<bfloat16> &src, PermutationVector perm);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/QuantizationLayer.cpp b/tests/validation/reference/QuantizationLayer.cpp
index 27665375c..b76263bf9 100644
--- a/tests/validation/reference/QuantizationLayer.cpp
+++ b/tests/validation/reference/QuantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -80,6 +80,15 @@ SimpleTensor<Tout> quantization_layer(const SimpleTensor<Tin> &src, DataType out
dst[i] = quantize_qasymm16((src[i]), qinfo, rounding_policy);
}
break;
+ case DataType::F32:
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ dst[i] = dequantize_s32((src[i]), qinfo);
+ }
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported output data type");
}
@@ -127,6 +136,7 @@ template SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<half> &src,
template SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<float> &src, DataType output_data_type, const QuantizationInfo &quantization_info);
template SimpleTensor<uint16_t> quantization_layer(const SimpleTensor<half> &src, DataType output_data_type, const QuantizationInfo &quantization_info);
template SimpleTensor<uint16_t> quantization_layer(const SimpleTensor<float> &src, DataType output_data_type, const QuantizationInfo &quantization_info);
+template SimpleTensor<float> quantization_layer(const SimpleTensor<int32_t> &src, DataType output_data_type, const QuantizationInfo &quantization_info);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/ReshapeLayer.cpp b/tests/validation/reference/ReshapeLayer.cpp
index daea001be..30a58dd65 100644
--- a/tests/validation/reference/ReshapeLayer.cpp
+++ b/tests/validation/reference/ReshapeLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 Arm Limited.
+ * Copyright (c) 2017,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,14 +44,15 @@ SimpleTensor<T> reshape_layer(const SimpleTensor<T> &src, const TensorShape &out
return dst;
}
-template SimpleTensor<uint8_t> reshape_layer(const SimpleTensor<uint8_t> &src, const TensorShape &output_shape);
-template SimpleTensor<int8_t> reshape_layer(const SimpleTensor<int8_t> &src, const TensorShape &output_shape);
+template SimpleTensor<uint8_t> reshape_layer(const SimpleTensor<uint8_t> &src, const TensorShape &output_shape);
+template SimpleTensor<int8_t> reshape_layer(const SimpleTensor<int8_t> &src, const TensorShape &output_shape);
template SimpleTensor<uint16_t> reshape_layer(const SimpleTensor<uint16_t> &src, const TensorShape &output_shape);
-template SimpleTensor<int16_t> reshape_layer(const SimpleTensor<int16_t> &src, const TensorShape &output_shape);
+template SimpleTensor<int16_t> reshape_layer(const SimpleTensor<int16_t> &src, const TensorShape &output_shape);
template SimpleTensor<uint32_t> reshape_layer(const SimpleTensor<uint32_t> &src, const TensorShape &output_shape);
-template SimpleTensor<int32_t> reshape_layer(const SimpleTensor<int32_t> &src, const TensorShape &output_shape);
-template SimpleTensor<half> reshape_layer(const SimpleTensor<half> &src, const TensorShape &output_shape);
-template SimpleTensor<float> reshape_layer(const SimpleTensor<float> &src, const TensorShape &output_shape);
+template SimpleTensor<int32_t> reshape_layer(const SimpleTensor<int32_t> &src, const TensorShape &output_shape);
+template SimpleTensor<half> reshape_layer(const SimpleTensor<half> &src, const TensorShape &output_shape);
+template SimpleTensor<float> reshape_layer(const SimpleTensor<float> &src, const TensorShape &output_shape);
+template SimpleTensor<bfloat16> reshape_layer(const SimpleTensor<bfloat16> &src, const TensorShape &output_shape);
/** [ReshapeLayer] **/
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp
new file mode 100644
index 000000000..920f2b999
--- /dev/null
+++ b/tests/validation/reference/ScatterLayer.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "ScatterLayer.h"
+#include "tests/validation/Helpers.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace reference
+{
+namespace
+{
+
+template <typename T>
+T reduce_op(const T &current,const T &update,const ScatterFunction func)
+{
+ switch(func)
+ {
+ case ScatterFunction::Update:
+ return update;
+ break;
+ case ScatterFunction::Add:
+ return current + update;
+ break;
+ case ScatterFunction::Sub:
+ return current - update;
+ break;
+ case ScatterFunction::Max:
+ return std::max(current, update);
+ break;
+ case ScatterFunction::Min:
+ return std::min(current, update);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported Scatter function");
+ break;
+ }
+}
+
+template float reduce_op(const float &current,const float &update,const ScatterFunction func);
+}
+
+// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors.
+template <typename T>
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+{
+ SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
+
+ // 1. If zero initialization variable is true, fill dst with 0 values. Else copy src data to dst.
+ if(info.zero_initialization)
+ {
+ for (int i = 0; i < src.num_elements(); ++i)
+ {
+ dst[i] = static_cast<T>(0);
+ }
+ }
+ else
+ {
+ std::copy_n(src.data(), src.num_elements(), dst.data());
+ }
+
+ // 2. Get max index of output tensor, then iterate over index tensor.
+ const auto x_bound = dst.shape().x();
+
+
+ for(int i = 0; i < indices.num_elements(); ++i)
+ {
+ // 3. Check whether index is out of bounds for dst, if not then apply reduce op.
+ const auto index = indices[i];
+ if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned.
+ {
+ dst[index] = reduce_op(dst[index], updates[i], info.func);
+ }
+ }
+ return dst;
+}
+
+template <typename T>
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+{
+ return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
+}
+
+template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+
+} // namespace reference
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.cpp b/tests/validation/reference/ScatterLayer.h
index 9cecfc2ff..dc441a889 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuKernelArgument.cpp
+++ b/tests/validation/reference/ScatterLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,17 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "src/dynamic_fusion/sketch/gpu/GpuKernelArgument.h"
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H
+#define ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H
+
+#include "Utils.h"
+#include "arm_compute/function_info/ScatterInfo.h"
+#include "tests/SimpleTensor.h"
+
namespace arm_compute
{
-namespace experimental
+namespace test
{
-namespace dynamic_fusion
+namespace validation
{
-bool operator==(const GpuKernelArgumentInfo &info0, const GpuKernelArgumentInfo &info1)
+namespace reference
{
- return info0.type == info1.type;
-}
-} // namespace dynamic_fusion
-} // namespace experimental
+template <typename T>
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+
+template <typename T>
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+} // namespace reference
+} // namespace validation
+} // namespace test
} // namespace arm_compute
+#endif // ACL_TESTS_VALIDATION_REFERENCE_SCATTERLAYER_H
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp
index ca8e14abb..7e618c9de 100644
--- a/utils/GraphUtils.cpp
+++ b/utils/GraphUtils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -711,13 +711,13 @@ bool RandomAccessor::access_tensor(ITensor &tensor)
case DataType::QASYMM8:
case DataType::U8:
{
- std::uniform_int_distribution<uint8_t> distribution_u8(_lower.get<uint8_t>(), _upper.get<uint8_t>());
+ std::uniform_int_distribution<uint32_t> distribution_u8(_lower.get<uint8_t>(), _upper.get<uint8_t>());
fill<uint8_t>(tensor, distribution_u8);
break;
}
case DataType::S8:
{
- std::uniform_int_distribution<int8_t> distribution_s8(_lower.get<int8_t>(), _upper.get<int8_t>());
+ std::uniform_int_distribution<int32_t> distribution_s8(_lower.get<int8_t>(), _upper.get<int8_t>());
fill<int8_t>(tensor, distribution_s8);
break;
}
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 23e28d68a..2d106d849 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -49,6 +49,7 @@
#include "arm_compute/function_info/FullyConnectedLayerInfo.h"
#include "arm_compute/function_info/GEMMInfo.h"
#include "arm_compute/function_info/MatMulInfo.h"
+#include "arm_compute/function_info/ScatterInfo.h"
#include "arm_compute/runtime/CL/CLTunerTypes.h"
#include "arm_compute/runtime/CL/CLTypes.h"
#include "arm_compute/runtime/common/LSTMParams.h"
@@ -3601,7 +3602,7 @@ inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::CpuMatM
{
os << "CpuMatMulSettings="
<< "["
- << "fast_math=" << settings.fast_math() << "]";
+ << "fast_math=" << settings.fast_math() << ",fixed_format=" << settings.fixed_format() << "]";
return os;
}
@@ -3618,6 +3619,77 @@ inline std::string to_string(const arm_compute::CpuMatMulSettings &settings)
return str.str();
}
+/** Formatted output of the scatter function type.
+ *
+ * @param[out] os Output stream.
+ * @param[in] function arm_compute::ScatterFunction type to output.
+ *
+ * @return Modified output stream.
+ */
+inline ::std::ostream &operator<<(::std::ostream &os, const ScatterFunction &function)
+{
+ switch (function)
+ {
+ case ScatterFunction::Update:
+ os << "UPDATE";
+ break;
+ case ScatterFunction::Add:
+ os << "ADD";
+ break;
+ case ScatterFunction::Sub:
+ os << "SUB";
+ break;
+ case ScatterFunction::Max:
+ os << "MAX";
+ break;
+ case ScatterFunction::Min:
+ os << "MIN";
+ break;
+ default:
+ ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
+ }
+ return os;
+}
+/** Formatted output of the arm_compute::ScatterFunction type.
+ *
+ * @param[in] func arm_compute::ScatterFunction type to output.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const arm_compute::ScatterFunction &func)
+{
+ std::stringstream str;
+ str << func;
+ return str.str();
+}
+/** Formatted output of the arm_compute::ScatterInfo type.
+ *
+ * @param[out] os Output stream.
+ * @param[in] info arm_compute::ScatterInfo type to output.
+ *
+ * @return Modified output stream.
+ */
+inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::ScatterInfo &info)
+{
+ os << "ScatterInfo="
+ << "["
+ << "Function=" << info.func << ", "
+ << "InitialiseZero=" << info.zero_initialization << "] ";
+ return os;
+}
+/** Formatted output of the arm_compute::ScatterInfo type.
+ *
+ * @param[in] info arm_compute::ScatterInfo type to output.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const arm_compute::ScatterInfo &info)
+{
+ std::stringstream str;
+ str << info;
+ return str.str();
+}
+
} // namespace arm_compute
#endif // ACL_UTILS_TYPEPRINTER_H