aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NELSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NELSTMLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NELSTMLayer.cpp102
1 files changed, 51 insertions, 51 deletions
diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp
index f9d445fe7..dca274acd 100644
--- a/src/runtime/NEON/functions/NELSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NELSTMLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 ARM Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -347,7 +347,7 @@ void NELSTMLayer::configure(const ITensor *input,
_copy_output.configure(output_state_out, output);
// Vector for holding the tensors to store in scratch buffer
- std::vector<ITensor *> scratch_inputs;
+ std::vector<const ITensor *> scratch_inputs;
if(!lstm_params.has_cifg_opt())
{
scratch_inputs.emplace_back(input_gate_out);
@@ -464,17 +464,17 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&forget_gate));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Validate input gate
if(!lstm_params.has_cifg_opt())
@@ -498,21 +498,21 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&input_gate));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtractionKernel::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
}
// Validate cell state
@@ -522,18 +522,18 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
if(cell_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
- cell_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
+ cell_threshold)));
}
// Validate output gate tmp
@@ -548,29 +548,29 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Validate output state
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
if(lstm_params.has_projection())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
if(projection_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(output_state_out, output_state_out,
- ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, output_state_out,
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
}
}
@@ -579,7 +579,7 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(NECopyKernel::validate(output_state_out, output));
// Validate scratch concatenation
- std::vector<ITensorInfo *> inputs_vector_info_raw;
+ std::vector<const ITensorInfo *> inputs_vector_info_raw;
if(!lstm_params.has_cifg_opt())
{
inputs_vector_info_raw.push_back(&input_gate);
@@ -603,16 +603,16 @@ void NELSTMLayer::run()
if(_run_peephole_opt)
{
- NEScheduler::get().schedule(&_pixelwise_mul_forget_gate, Window::DimY);
+ _pixelwise_mul_forget_gate.run();
_accum_forget_gate1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_forget_gate.run();
- NEScheduler::get().schedule(&_pixelwise_mul_forget_gate_coeff, Window::DimY);
- NEScheduler::get().schedule(&_accum_forget_gate_bias, Window::DimY);
+ _pixelwise_mul_forget_gate_coeff.run();
+ _accum_forget_gate_bias.run();
}
- NEScheduler::get().schedule(&_activation_forget_gate, Window::DimY);
+ _activation_forget_gate.run();
if(_run_cifg_opt)
{
@@ -624,7 +624,7 @@ void NELSTMLayer::run()
{
std::fill_n(reinterpret_cast<float *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
}
- NEScheduler::get().schedule(&_subtract_input_gate, Window::DimY);
+ _subtract_input_gate.run();
}
else
{
@@ -632,62 +632,62 @@ void NELSTMLayer::run()
if(_run_peephole_opt)
{
- NEScheduler::get().schedule(&_pixelwise_mul_input_gate, Window::DimY);
+ _pixelwise_mul_input_gate.run();
_accum_input_gate1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_input_gate.run();
- NEScheduler::get().schedule(&_pixelwise_mul_input_gate_coeff, Window::DimY);
- NEScheduler::get().schedule(&_accum_input_gate_bias, Window::DimY);
+ _pixelwise_mul_input_gate_coeff.run();
+ _accum_input_gate_bias.run();
}
- NEScheduler::get().schedule(&_activation_input_gate, Window::DimY);
+ _activation_input_gate.run();
}
_fully_connected_cell_state.run();
NEScheduler::get().schedule(&_transpose_cell_state, Window::DimY);
_gemm_cell_state1.run();
- NEScheduler::get().schedule(&_accum_cell_state1, Window::DimY);
+ _accum_cell_state1.run();
if(_is_layer_norm_lstm)
{
_mean_std_norm_cell_gate.run();
- NEScheduler::get().schedule(&_pixelwise_mul_cell_gate_coeff, Window::DimY);
- NEScheduler::get().schedule(&_accum_cell_gate_bias, Window::DimY);
+ _pixelwise_mul_cell_gate_coeff.run();
+ _accum_cell_gate_bias.run();
}
- NEScheduler::get().schedule(&_activation_cell_state, Window::DimY);
- NEScheduler::get().schedule(&_pixelwise_mul_cell_state1, Window::DimY);
- NEScheduler::get().schedule(&_pixelwise_mul_cell_state2, Window::DimY);
- NEScheduler::get().schedule(&_accum_cell_state2, Window::DimY);
+ _activation_cell_state.run();
+ _pixelwise_mul_cell_state1.run();
+ _pixelwise_mul_cell_state2.run();
+ _accum_cell_state2.run();
if(_perform_cell_clipping)
{
- NEScheduler::get().schedule(&_cell_clip, Window::DimY);
+ _cell_clip.run();
}
_fully_connected_output.run();
if(_run_peephole_opt)
{
- NEScheduler::get().schedule(&_pixelwise_mul_output_state1, Window::DimY);
+ _pixelwise_mul_output_state1.run();
_accum_output1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_output_gate.run();
- NEScheduler::get().schedule(&_pixelwise_mul_output_gate_coeff, Window::DimY);
- NEScheduler::get().schedule(&_accum_output_gate_bias, Window::DimY);
+ _pixelwise_mul_output_gate_coeff.run();
+ _accum_output_gate_bias.run();
}
- NEScheduler::get().schedule(&_activation_output, Window::DimY);
+ _activation_output.run();
- NEScheduler::get().schedule(&_activation_output_state, Window::DimY);
- NEScheduler::get().schedule(&_pixelwise_mul_output_state2, Window::DimY);
+ _activation_output_state.run();
+ _pixelwise_mul_output_state2.run();
if(_has_projection_weights)
{
_fully_connected_output_state.run();
if(_perform_projection_clipping)
{
- NEScheduler::get().schedule(&_projection_clip, Window::DimY);
+ _projection_clip.run();
}
}