diff options
Diffstat (limited to 'src/runtime/NEON/functions/NELSTMLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NELSTMLayer.cpp | 102 |
1 files changed, 51 insertions, 51 deletions
diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp index f9d445fe7..dca274acd 100644 --- a/src/runtime/NEON/functions/NELSTMLayer.cpp +++ b/src/runtime/NEON/functions/NELSTMLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -347,7 +347,7 @@ void NELSTMLayer::configure(const ITensor *input, _copy_output.configure(output_state_out, output); // Vector for holding the tensors to store in scratch buffer - std::vector<ITensor *> scratch_inputs; + std::vector<const ITensor *> scratch_inputs; if(!lstm_params.has_cifg_opt()) { scratch_inputs.emplace_back(input_gate_out); @@ -464,17 +464,17 @@ Status NELSTMLayer::validate(const ITensorInfo *input, if(lstm_params.has_peephole_opt()) { - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE)); } if(lstm_params.use_layer_norm()) { ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&forget_gate)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, - RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate input gate if(!lstm_params.has_cifg_opt()) @@ -498,21 +498,21 @@ Status NELSTMLayer::validate(const ITensorInfo *input, { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights()); ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE)); } if(lstm_params.use_layer_norm()) { ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&input_gate)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); } else { - ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtractionKernel::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE)); + ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE)); } // Validate cell state @@ -522,18 +522,18 @@ Status NELSTMLayer::validate(const ITensorInfo *input, if(lstm_params.use_layer_norm()) { ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&cell_state_tmp)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE, - RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE)); if(cell_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, - cell_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, + cell_threshold))); } // Validate output gate tmp @@ -548,29 +548,29 @@ Status NELSTMLayer::validate(const ITensorInfo *input, if(lstm_params.has_peephole_opt()) { - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE, - RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE)); } if(lstm_params.use_layer_norm()) { ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&output_gate_tmp)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE, - RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE, + RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate output state - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); - ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); if(lstm_params.has_projection()) { ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out)); if(projection_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(output_state_out, output_state_out, - ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, output_state_out, + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); } } @@ -579,7 +579,7 @@ Status NELSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NECopyKernel::validate(output_state_out, output)); // Validate scratch concatenation - std::vector<ITensorInfo *> inputs_vector_info_raw; + std::vector<const ITensorInfo *> inputs_vector_info_raw; if(!lstm_params.has_cifg_opt()) { inputs_vector_info_raw.push_back(&input_gate); @@ -603,16 +603,16 @@ void NELSTMLayer::run() if(_run_peephole_opt) { - NEScheduler::get().schedule(&_pixelwise_mul_forget_gate, Window::DimY); + _pixelwise_mul_forget_gate.run(); _accum_forget_gate1.run(); } if(_is_layer_norm_lstm) { _mean_std_norm_forget_gate.run(); - NEScheduler::get().schedule(&_pixelwise_mul_forget_gate_coeff, Window::DimY); - NEScheduler::get().schedule(&_accum_forget_gate_bias, Window::DimY); + _pixelwise_mul_forget_gate_coeff.run(); + _accum_forget_gate_bias.run(); } - NEScheduler::get().schedule(&_activation_forget_gate, Window::DimY); + _activation_forget_gate.run(); if(_run_cifg_opt) { @@ -624,7 +624,7 @@ void NELSTMLayer::run() { std::fill_n(reinterpret_cast<float *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1); } - NEScheduler::get().schedule(&_subtract_input_gate, Window::DimY); + _subtract_input_gate.run(); } else { @@ -632,62 +632,62 @@ void NELSTMLayer::run() if(_run_peephole_opt) { - NEScheduler::get().schedule(&_pixelwise_mul_input_gate, Window::DimY); + _pixelwise_mul_input_gate.run(); _accum_input_gate1.run(); } if(_is_layer_norm_lstm) { _mean_std_norm_input_gate.run(); - NEScheduler::get().schedule(&_pixelwise_mul_input_gate_coeff, Window::DimY); - NEScheduler::get().schedule(&_accum_input_gate_bias, Window::DimY); + _pixelwise_mul_input_gate_coeff.run(); + _accum_input_gate_bias.run(); } - NEScheduler::get().schedule(&_activation_input_gate, Window::DimY); + _activation_input_gate.run(); } _fully_connected_cell_state.run(); NEScheduler::get().schedule(&_transpose_cell_state, Window::DimY); _gemm_cell_state1.run(); - NEScheduler::get().schedule(&_accum_cell_state1, Window::DimY); + _accum_cell_state1.run(); if(_is_layer_norm_lstm) { _mean_std_norm_cell_gate.run(); - NEScheduler::get().schedule(&_pixelwise_mul_cell_gate_coeff, Window::DimY); - NEScheduler::get().schedule(&_accum_cell_gate_bias, Window::DimY); + _pixelwise_mul_cell_gate_coeff.run(); + _accum_cell_gate_bias.run(); } - NEScheduler::get().schedule(&_activation_cell_state, Window::DimY); - NEScheduler::get().schedule(&_pixelwise_mul_cell_state1, Window::DimY); - NEScheduler::get().schedule(&_pixelwise_mul_cell_state2, Window::DimY); - NEScheduler::get().schedule(&_accum_cell_state2, Window::DimY); + _activation_cell_state.run(); + _pixelwise_mul_cell_state1.run(); + _pixelwise_mul_cell_state2.run(); + _accum_cell_state2.run(); if(_perform_cell_clipping) { - NEScheduler::get().schedule(&_cell_clip, Window::DimY); + _cell_clip.run(); } _fully_connected_output.run(); if(_run_peephole_opt) { - NEScheduler::get().schedule(&_pixelwise_mul_output_state1, Window::DimY); + _pixelwise_mul_output_state1.run(); _accum_output1.run(); } if(_is_layer_norm_lstm) { _mean_std_norm_output_gate.run(); - NEScheduler::get().schedule(&_pixelwise_mul_output_gate_coeff, Window::DimY); - NEScheduler::get().schedule(&_accum_output_gate_bias, Window::DimY); + _pixelwise_mul_output_gate_coeff.run(); + _accum_output_gate_bias.run(); } - NEScheduler::get().schedule(&_activation_output, Window::DimY); + _activation_output.run(); - NEScheduler::get().schedule(&_activation_output_state, Window::DimY); - NEScheduler::get().schedule(&_pixelwise_mul_output_state2, Window::DimY); + _activation_output_state.run(); + _pixelwise_mul_output_state2.run(); if(_has_projection_weights) { _fully_connected_output_state.run(); if(_perform_projection_clipping) { - NEScheduler::get().schedule(&_projection_clip, Window::DimY); + _projection_clip.run(); } } |