diff options
18 files changed, 583 insertions, 9 deletions
diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 505cbdbcd12..1a3ec6b359e 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -173,6 +173,7 @@ typedef enum { kTfLiteBuiltinReadVariable = 143, kTfLiteBuiltinAssignVariable = 144, kTfLiteBuiltinBroadcastArgs = 145, + kTfLiteBuiltinGelu = 150, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index ed5ac004cbd..8e49c91ad5f 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -502,6 +502,10 @@ typedef struct { const char* shared_name; } TfLiteVarHandleParams; +typedef struct { + bool approximate; +} TfLiteGeluParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index da714794a12..2897728123e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -795,6 +795,21 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + case BuiltinOperator_GELU: { + auto params = safe_allocator.Allocate<TfLiteGeluParams>(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* gelu_params = op->builtin_options_as_GeluOptions()) { + params->approximate = gelu_params->approximate(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + // Unsupported builtins. + case BuiltinOperator_RANDOM_STANDARD_NORMAL: + case BuiltinOperator_BUCKETIZE: + case BuiltinOperator_RANDOM_UNIFORM: + case BuiltinOperator_MULTINOMIAL: + return kTfLiteError; // Below are the ops with no builtin_data structure. // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are // ok for now, since there is no call implementation either. diff --git a/tensorflow/lite/core/shims/builtin_ops_list.inc b/tensorflow/lite/core/shims/builtin_ops_list.inc index b96e60afa6e..bb5110a3fc3 100644 --- a/tensorflow/lite/core/shims/builtin_ops_list.inc +++ b/tensorflow/lite/core/shims/builtin_ops_list.inc @@ -158,3 +158,4 @@ TFLITE_OP(Register_VAR_HANDLE) TFLITE_OP(Register_READ_VARIABLE) TFLITE_OP(Register_ASSIGN_VARIABLE) TFLITE_OP(Register_BROADCAST_ARGS) +TFLITE_OP(Register_GELU) diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index b13a2dac992..f4ded17d8f6 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/binary_function.h" +#include "tensorflow/lite/kernels/internal/reference/gelu.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h" @@ -1501,6 +1502,53 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus GeluPrepare(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); + OpData* data = reinterpret_cast<OpData*>(node->user_data); + auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data); + + if (input->type == kTfLiteInt8) { + PopulateLookupTable<int8_t>( + data, input, output, reference_ops::GeluTransform(params->approximate)); + } else if (input->type == kTfLiteUInt8) { + PopulateLookupTable<uint8_t>( + data, input, output, reference_ops::GeluTransform(params->approximate)); + } + return GenericPrepare(context, node); +} + +TfLiteStatus GeluEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); + + switch (input->type) { + case kTfLiteFloat32: { + reference_ops::Gelu(GetTensorShape(input), GetTensorData<float>(input), + params->approximate, GetTensorShape(output), + GetTensorData<float>(output)); + return kTfLiteOk; + } + case kTfLiteInt8: + case kTfLiteUInt8: { + OpData* data = reinterpret_cast<OpData*>(node->user_data); + EvalUsingLookupTable(data, input, output); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG( + context, "Only float32, int8 and uint8 supported currently, got %s.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace activations TfLiteRegistration* Register_ELU() { @@ -1661,6 +1709,13 @@ TfLiteRegistration* Register_HARD_SWISH_REF() { return &r; } +TfLiteRegistration* Register_GELU() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::GeluPrepare, + activations::GeluEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index f629775dd4a..3bf96fadff3 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -2564,6 +2564,171 @@ TEST(FloatActivationsOpTest, LeakyRelu) { })); } + +class GeluOpModel : public SingleOpModel { + public: + GeluOpModel(const TensorData& input, bool approximate) { + input_ = AddInput(input); + output_ = AddOutput(input); + SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions, + CreateGeluOptions(builder_, approximate).Union()); + BuildInterpreter({GetShape(input_)}); + } + void SetInput(std::initializer_list<float> data) { + PopulateTensor(input_, data); + } + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + + protected: + int input_; + int output_; +}; + +class BaseGeluOpModel : public SingleOpModel { + public: + BaseGeluOpModel(const TensorData& input, bool approximate) { + input_ = AddInput(input); + approximate_ = approximate; + output_ = AddOutput({input.type, input.shape, input.min, input.max}); + SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions, + CreateGeluOptions(builder_, approximate).Union()); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + + bool approximate_; + int output_; +}; + +// The FloatGeluOpModel class handles float input and output. +class FloatGeluOpModel : public BaseGeluOpModel { + public: + using BaseGeluOpModel::BaseGeluOpModel; + + void SetInput(std::initializer_list<float> data) { + PopulateTensor(input_, data); + } + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } +}; + +// The QuantizedGeluOpModel class handles quantized input and output. +class QuantizedGeluOpModel : public BaseGeluOpModel { + public: + using BaseGeluOpModel::BaseGeluOpModel; + + template <typename T> + void SetInput(std::initializer_list<float> data) { + QuantizeAndPopulate<T>(input_, data); + } + template <typename T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_); + } + template <typename T> + std::vector<float> GetDequantizedOutput() { + return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST(FloatActivationsOpTest, Gelu) { + FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/false); + + m.SetInput({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.0f, 0.841345f, 2.99595f, // Row 1 + 0.841345f, -0.158655f, -0.0455003f, // Row 2 + }))); +} + +TEST(FloatActivationsOpTest, GeluApproximate) { + FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/true); + + m.SetInput({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.0f, 0.841192f, 2.99636f, // Row 1 + 0.841192f, -0.158808f, -0.0454023f, // Row 2 + }))); +} + +TEST(QuantizedGeluOpTest, GeluInt8) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax}, + /*approximate=*/false); + m.SetInput<int8_t>({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear({ + 0.f, 0.84375f, 2.97656f, // Row 1 + 0.84375f, -0.164062f, -0.046875f // Row 2 + }))); +} + +TEST(QuantizedGeluOpTest, GeluInt8Approximate) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax}, + /*approximate=*/true); + m.SetInput<int8_t>({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear({ + 0.f, 0.84375f, 2.97656f, // Row 1 + 0.84375f, -0.164062f, -0.046875f // Row 2 + }))); +} +TEST(QuantizedGeluOpTest, GeluUInt8) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax}, + /*approximate=*/false); + m.SetInput<uint8_t>({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear({ + 0.f, 0.84375f, 2.97656f, // Row 1 + 0.84375f, -0.164062f, -0.046875f // Row 2 + }))); +} + +TEST(QuantizedGeluOpTest, GeluUInt8Approximate) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax}, + /*approximate=*/true); + m.SetInput<uint8_t>({ + 0.0f, 1.0f, 3.0f, // Row 1 + 1.0f, -1.0f, -2.0f, // Row 2 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear({ + 0.f, 0.84375f, 2.97656f, // Row 1 + 0.84375f, -0.164062f, -0.046875f // Row 2 + }))); +} + + INSTANTIATE_TEST_SUITE_P( TanhOpTest, TanhOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kTanhKernelMap))); diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index 85cc9b92a0d..834045621d2 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -74,6 +74,7 @@ TfLiteRegistration* Register_FLOOR_MOD(); TfLiteRegistration* Register_FULLY_CONNECTED(); TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_GATHER_ND(); +TfLiteRegistration* Register_GELU(); TfLiteRegistration* Register_GREATER(); TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_HARD_SWISH(); diff --git a/tensorflow/lite/kernels/internal/constants.h b/tensorflow/lite/kernels/internal/constants.h new file mode 100644 index 00000000000..aa8bd5f0860 --- /dev/null +++ b/tensorflow/lite/kernels/internal/constants.h @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_ + +// Maths constants. +// The following macros are not always available on all platforms. +// E.g. MSVC requires additional compile flag to export those. +#ifndef M_E +#define M_E 2.7182818284590452354 /* e */ +#endif +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074 /* log_2 e */ +#endif +#ifndef M_LOG10E +#define M_LOG10E 0.43429448190325182765 /* log_10 e */ +#endif +#ifndef M_LN2 +#define M_LN2 0.69314718055994530942 /* log_e 2 */ +#endif +#ifndef M_LN10 +#define M_LN10 2.30258509299404568402 /* log_e 10 */ +#endif +#ifndef M_PI +#define M_PI 3.14159265358979323846 /* pi */ +#endif +#ifndef M_PI_2 +#define M_PI_2 1.57079632679489661923 /* pi/2 */ +#endif +#ifndef M_PI_4 +#define M_PI_4 0.78539816339744830962 /* pi/4 */ +#endif +#ifndef M_1_PI +#define M_1_PI 0.31830988618379067154 /* 1/pi */ +#endif +#ifndef M_2_PI +#define M_2_PI 0.63661977236758134308 /* 2/pi */ +#endif +#ifndef M_2_SQRTPI +#define M_2_SQRTPI 1.12837916709551257390 /* 2/sqrt(pi) */ +#endif +#ifndef M_SQRT2 +#define M_SQRT2 1.41421356237309504880 /* sqrt(2) */ +#endif +#ifndef M_SQRT1_2 +#define M_SQRT1_2 0.70710678118654752440 /* 1/sqrt(2) */ +#endif + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_
\ No newline at end of file diff --git a/tensorflow/lite/kernels/internal/reference/gelu.h b/tensorflow/lite/kernels/internal/reference/gelu.h new file mode 100644 index 00000000000..08e5a33241d --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/gelu.h @@ -0,0 +1,82 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_ + +#include <cmath> +#include <functional> + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/constants.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +namespace gelu_internal { + +constexpr float kSqrt2dPi = M_2_SQRTPI * M_SQRT1_2; // sqrt( 2 / pi ) + +} // namespace gelu_internal + +// Plain implementation for GELU. Used for populating lookup table. +inline std::function<float(float)> GeluTransform(bool approximate) { + if (approximate) { + return [](float in) { + // 0.5 * x * ( 1 + tanh( sqrt( 2 / pi ) * ( x + 0.044715 * x^3 ) ) ) + return 0.5f * in * + (1.f + std::tanh(gelu_internal::kSqrt2dPi * + // Note: Avoid std::pow for integer exponents + // as it leads to much slower performance. + (in + 0.044715f * in * in * in))); + }; + } else { + return [](float in) { + // 0.5 * x * ( 1 + erf( x / sqrt( 2 ) ) ) + return 0.5f * in * (1.f + std::erf(in * M_SQRT1_2)); + }; + } +} + +template <typename T> +inline void Gelu(const RuntimeShape& input_shape, const T* input_data, + bool approximate, const RuntimeShape& output_shape, + T* output_data) { + auto matching_size = MatchingFlatSize(input_shape, output_shape); + + for (int i = 0; i < matching_size; i++) { + const T in = input_data[i]; + if (approximate) { + // 0.5 * x * ( 1 + tanh( sqrt( 2 / pi ) * ( x + 0.044715 * x^3 ) ) ) + output_data[i] = + static_cast<T>(0.5) * in * + (static_cast<T>(1) + + std::tanh(static_cast<T>(gelu_internal::kSqrt2dPi) * + // Note: Avoid std::pow for integer exponents + // as it leads to much slower performance. + (in + static_cast<T>(0.044715) * in * in * in))); + } else { + // 0.5 * x * ( 1 + erf( x / sqrt( 2 ) ) ) + output_data[i] = + static_cast<T>(0.5) * in * + (static_cast<T>(1) + std::erf(in * static_cast<T>(M_SQRT1_2))); + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_
\ No newline at end of file diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 4f5fc7faf78..8e26f1d4849 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -329,6 +329,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_VAR_HANDLE, Register_VAR_HANDLE()); AddBuiltin(BuiltinOperator_READ_VARIABLE, Register_READ_VARIABLE()); AddBuiltin(BuiltinOperator_ASSIGN_VARIABLE, Register_ASSIGN_VARIABLE()); + AddBuiltin(BuiltinOperator_GELU, Register_GELU(), + /* min_version = */ 1, + /* max_version = */ 2); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 889e003e404..d59eb9bdb0d 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -164,6 +164,7 @@ TfLiteRegistration* Register_REAL(); TfLiteRegistration* Register_COMPLEX_ABS(); TfLiteRegistration* Register_CONV_3D_TRANSPOSE_REF(); TfLiteRegistration* Register_BROADCAST_ARGS(); +TfLiteRegistration* Register_GELU(); namespace { @@ -478,6 +479,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_CONV_3D_TRANSPOSE, Register_CONV_3D_TRANSPOSE_REF()); AddBuiltin(BuiltinOperator_BROADCAST_ARGS, Register_BROADCAST_ARGS()); + AddBuiltin(BuiltinOperator_GELU, Register_GELU(), + /* min_version = */ 1, + /* max_version = */ 2); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY_REF()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index abd8db0012d..d2262de438f 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -379,6 +379,7 @@ enum BuiltinOperator : int32 { READ_VARIABLE = 143, ASSIGN_VARIABLE = 144, BROADCAST_ARGS = 145, + GELU = 150, } // LINT.ThenChange(nnapi_linter/linter.proto) @@ -497,6 +498,7 @@ union BuiltinOptions { VarHandleOptions, ReadVariableOptions, AssignVariableOptions, + GeluOptions, } enum Padding : byte { SAME, VALID } @@ -1082,6 +1084,10 @@ table ReadVariableOptions { table AssignVariableOptions { } +table GeluOptions { + approximate: bool; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 77253a4e667..6f4e93aae64 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -385,6 +385,9 @@ struct ReadVariableOptionsT; struct AssignVariableOptions; struct AssignVariableOptionsT; +struct GeluOptions; +struct GeluOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -854,11 +857,16 @@ enum BuiltinOperator { BuiltinOperator_READ_VARIABLE = 143, BuiltinOperator_ASSIGN_VARIABLE = 144, BuiltinOperator_BROADCAST_ARGS = 145, + BuiltinOperator_RANDOM_STANDARD_NORMAL = 146, + BuiltinOperator_BUCKETIZE = 147, + BuiltinOperator_RANDOM_UNIFORM = 148, + BuiltinOperator_MULTINOMIAL = 149, + BuiltinOperator_GELU = 150, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BROADCAST_ARGS + BuiltinOperator_MAX = BuiltinOperator_GELU }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[146] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[151] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -1005,13 +1013,18 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[146] { BuiltinOperator_VAR_HANDLE, BuiltinOperator_READ_VARIABLE, BuiltinOperator_ASSIGN_VARIABLE, - BuiltinOperator_BROADCAST_ARGS + BuiltinOperator_BROADCAST_ARGS, + BuiltinOperator_RANDOM_STANDARD_NORMAL, + BuiltinOperator_BUCKETIZE, + BuiltinOperator_RANDOM_UNIFORM, + BuiltinOperator_MULTINOMIAL, + BuiltinOperator_GELU }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[147] = { + static const char * const names[152] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1158,13 +1171,18 @@ inline const char * const *EnumNamesBuiltinOperator() { "READ_VARIABLE", "ASSIGN_VARIABLE", "BROADCAST_ARGS", + "RANDOM_STANDARD_NORMAL", + "BUCKETIZE", + "RANDOM_UNIFORM", + "MULTINOMIAL", + "GELU", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_ARGS)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_GELU)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOperator()[index]; } @@ -1284,11 +1302,14 @@ enum BuiltinOptions { BuiltinOptions_VarHandleOptions = 111, BuiltinOptions_ReadVariableOptions = 112, BuiltinOptions_AssignVariableOptions = 113, + BuiltinOptions_RandomOptions = 114, + BuiltinOptions_BucketizeOptions = 115, + BuiltinOptions_GeluOptions = 116, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_AssignVariableOptions + BuiltinOptions_MAX = BuiltinOptions_GeluOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[114] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[117] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1403,13 +1424,16 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[114] { BuiltinOptions_HashtableSizeOptions, BuiltinOptions_VarHandleOptions, BuiltinOptions_ReadVariableOptions, - BuiltinOptions_AssignVariableOptions + BuiltinOptions_AssignVariableOptions, + BuiltinOptions_RandomOptions, + BuiltinOptions_BucketizeOptions, + BuiltinOptions_GeluOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[115] = { + static const char * const names[118] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1524,6 +1548,9 @@ inline const char * const *EnumNamesBuiltinOptions() { "VarHandleOptions", "ReadVariableOptions", "AssignVariableOptions", + "RandomOptions", + "BucketizeOptions", + "GeluOptions", nullptr }; return names; @@ -1991,6 +2018,10 @@ template<> struct BuiltinOptionsTraits<tflite::AssignVariableOptions> { static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions; }; +template<> struct BuiltinOptionsTraits<tflite::GeluOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2927,6 +2958,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_AssignVariableOptions ? reinterpret_cast<const tflite::AssignVariableOptionsT *>(value) : nullptr; } + tflite::GeluOptionsT *AsGeluOptions() { + return type == BuiltinOptions_GeluOptions ? + reinterpret_cast<tflite::GeluOptionsT *>(value) : nullptr; + } + const tflite::GeluOptionsT *AsGeluOptions() const { + return type == BuiltinOptions_GeluOptions ? + reinterpret_cast<const tflite::GeluOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -10343,6 +10382,60 @@ inline flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions( flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions(flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct GeluOptionsT : public flatbuffers::NativeTable { + typedef GeluOptions TableType; + bool approximate; + GeluOptionsT() + : approximate(false) { + } +}; + +struct GeluOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GeluOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_APPROXIMATE = 4 + }; + bool approximate() const { + return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_APPROXIMATE) && + verifier.EndTable(); + } + GeluOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GeluOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<GeluOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GeluOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_approximate(bool approximate) { + fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0); + } + explicit GeluOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GeluOptionsBuilder &operator=(const GeluOptionsBuilder &); + flatbuffers::Offset<GeluOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GeluOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GeluOptions> CreateGeluOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool approximate = false) { + GeluOptionsBuilder builder_(_fbb); + builder_.add_approximate(approximate); + return builder_.Finish(); +} + +flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; int8_t deprecated_builtin_code; @@ -10832,6 +10925,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::AssignVariableOptions *builtin_options_as_AssignVariableOptions() const { return builtin_options_type() == tflite::BuiltinOptions_AssignVariableOptions ? static_cast<const tflite::AssignVariableOptions *>(builtin_options()) : nullptr; } + const tflite::GeluOptions *builtin_options_as_GeluOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GeluOptions ? static_cast<const tflite::GeluOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -11320,6 +11416,10 @@ template<> inline const tflite::AssignVariableOptions *Operator::builtin_options return builtin_options_as_AssignVariableOptions(); } +template<> inline const tflite::GeluOptions *Operator::builtin_options_as<tflite::GeluOptions>() const { + return builtin_options_as_GeluOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -15325,6 +15425,32 @@ inline flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions(fl _fbb); } +inline GeluOptionsT *GeluOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GeluOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GeluOptions::UnPackTo(GeluOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = approximate(); _o->approximate = _e; } +} + +inline flatbuffers::Offset<GeluOptions> GeluOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGeluOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GeluOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _approximate = _o->approximate; + return tflite::CreateGeluOptions( + _fbb, + _approximate); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -16252,6 +16378,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const tflite::AssignVariableOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast<const tflite::GeluOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -16722,6 +16852,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const tflite::AssignVariableOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast<const tflite::GeluOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -17180,6 +17314,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const tflite::AssignVariableOptionsT *>(value); return CreateAssignVariableOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast<const tflite::GeluOptionsT *>(value); + return CreateGeluOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -17638,6 +17776,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::AssignVariableOptionsT(*reinterpret_cast<tflite::AssignVariableOptionsT *>(u.value)); break; } + case BuiltinOptions_GeluOptions: { + value = new tflite::GeluOptionsT(*reinterpret_cast<tflite::GeluOptionsT *>(u.value)); + break; + } default: break; } @@ -18210,6 +18352,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast<tflite::GeluOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 1bcd0f27997..b18dcf78c7b 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -1041,6 +1041,11 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) { property.outputs = {{0, {}}}; property.version = 1; break; + case BuiltinOperator_GELU: + property.inputs = {{0, {}}}; + property.outputs = {{0, {}}}; + property.version = 2; + break; default: // No quantized implementation exists for this operation. property.quantizable = false; diff --git a/tensorflow/lite/tools/serialization/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc index 8875e287609..f798cd710b8 100644 --- a/tensorflow/lite/tools/serialization/option_writer_generator.cc +++ b/tensorflow/lite/tools/serialization/option_writer_generator.cc @@ -41,6 +41,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteFakeQuantParams", "TfLiteFullyConnectedParams", "TfLiteGatherParams", + "TfLiteGeluParams", "TfLiteIfParams", "TfLiteL2NormParams", "TfLiteLeakyReluParams", @@ -205,6 +206,7 @@ class OpOptionData { op_to_option_["IMAG"] = ""; op_to_option_["COMPLEX_ABS"] = ""; op_to_option_["BROADCAST_ARGS"] = ""; + op_to_option_["GELU"] = ""; // TODO(aselle): These are undesirable hacks. Consider changing C structs option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 21891754ccd..96f2261ea63 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -785,6 +785,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 3; } return 2; + case BuiltinOperator_GELU: + if (op_sig.inputs.at(0).type == kTfLiteInt8 || + op_sig.inputs.at(0).type == kTfLiteUInt8) { + return 2; + } + return 1; default: return 1; } diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index f605b976be8..47a92aa057d 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -1065,4 +1065,18 @@ TEST(OpVersionTest, VersioningBroadcastToTest) { }; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); } +TEST(OpVersionTest, VersioningGeluTest) { + OpSignature fake_op_sig; + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.op = BuiltinOperator_GELU; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} } // namespace tflite diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index b82f0a2a748..c602a9d9119 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -360,6 +360,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_READ_VARIABLE, 1}, "2.6.0"}, {{BuiltinOperator_ASSIGN_VARIABLE, 1}, "2.6.0"}, {{BuiltinOperator_BROADCAST_ARGS, 1}, "2.6.0"}, + {{BuiltinOperator_GELU, 1}, "2.9.0"}, + {{BuiltinOperator_GELU, 2}, "2.9.0"}, }); std::pair<BuiltinOperator, int> version_key = {op_code, op_version}; |