diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2022-07-19 09:32:39 -0700 |
---|---|---|
committer | TensorFlow Release Automation <jenkins@tensorflow.org> | 2022-08-20 00:32:34 +0000 |
commit | 89104e563d186a5ef836bbebcc332892ba42c9c7 (patch) | |
tree | dff2096f526979e7eeeba9f21a3d3808b07c6e68 | |
parent | 29d40476cc6db058a37b6c3963f7325d78af5a5a (diff) | |
download | tensorflow-upstream-r2.7-49b3824d83a.tar.gz |
Add IsScalar (rank == 0) check to min/max input tensors for QuantizedAdd/Relu/Relu6 op.upstream-r2.7-49b3824d83a
PiperOrigin-RevId: 461902847
4 files changed, 112 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/quantized_activation_ops.cc b/tensorflow/core/kernels/quantized_activation_ops.cc index 2896c3d45a7..36d321a8e17 100644 --- a/tensorflow/core/kernels/quantized_activation_ops.cc +++ b/tensorflow/core/kernels/quantized_activation_ops.cc @@ -32,8 +32,21 @@ class QuantizedReluOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - const float min_input = context->input(1).flat<float>()(0); - const float max_input = context->input(2).flat<float>()(0); + const Tensor& min_input_tensor = context->input(1); + const Tensor& max_input_tensor = context->input(2); + + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(min_input_tensor.shape()), + errors::InvalidArgument("`min_input` must be rank 0 but is rank ", + min_input_tensor.dims())); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_input_tensor.shape()), + errors::InvalidArgument("`max_input` must be rank 0 but is rank ", + max_input_tensor.dims())); + + const float min_input = min_input_tensor.scalar<float>()(); + const float max_input = max_input_tensor.scalar<float>()(); + Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); @@ -65,8 +78,21 @@ class QuantizedRelu6Op : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - const float min_input = context->input(1).flat<float>()(0); - const float max_input = context->input(2).flat<float>()(0); + const Tensor& min_input_tensor = context->input(1); + const Tensor& max_input_tensor = context->input(2); + + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(min_input_tensor.shape()), + errors::InvalidArgument("`min_input` must be rank 0 but is rank ", + min_input_tensor.dims())); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_input_tensor.shape()), + errors::InvalidArgument("`max_input` must be rank 0 but is rank ", + max_input_tensor.dims())); + + const float min_input = min_input_tensor.scalar<float>()(); + const float max_input = max_input_tensor.scalar<float>()(); + Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); diff --git a/tensorflow/core/kernels/quantized_activation_ops_test.cc b/tensorflow/core/kernels/quantized_activation_ops_test.cc index b3b7cb58b9a..34c5130f475 100644 --- a/tensorflow/core/kernels/quantized_activation_ops_test.cc +++ b/tensorflow/core/kernels/quantized_activation_ops_test.cc @@ -55,8 +55,8 @@ TEST_F(QuantizedActivationsTest, TestRelu) { AddInputFromArray<quint8>(input_quantized.shape(), input_quantized.flat<quint8>()); - AddInputFromArray<float>(TensorShape({1}), {input_min}); - AddInputFromArray<float>(TensorShape({1}), {input_max}); + AddInputFromArray<float>(TensorShape({}), {input_min}); + AddInputFromArray<float>(TensorShape({}), {input_max}); TF_ASSERT_OK(RunOpKernel()); const Tensor& output_quantized = *GetOutput(0); const float output_min = GetOutput(1)->flat<float>()(0); @@ -86,8 +86,8 @@ TEST_F(QuantizedActivationsTest, TestRelu6) { AddInputFromArray<quint8>(input_quantized.shape(), input_quantized.flat<quint8>()); - AddInputFromArray<float>(TensorShape({1}), {input_min}); - AddInputFromArray<float>(TensorShape({1}), {input_max}); + AddInputFromArray<float>(TensorShape({}), {input_min}); + AddInputFromArray<float>(TensorShape({}), {input_max}); TF_ASSERT_OK(RunOpKernel()); const Tensor& output_quantized = *GetOutput(0); const float output_min = GetOutput(1)->flat<float>()(0); diff --git a/tensorflow/core/kernels/quantized_add_op.cc b/tensorflow/core/kernels/quantized_add_op.cc index 21e07671ed8..4ddb6e04223 100644 --- a/tensorflow/core/kernels/quantized_add_op.cc +++ b/tensorflow/core/kernels/quantized_add_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -457,10 +458,28 @@ class QuantizedAddOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& x = context->input(0); const Tensor& y = context->input(1); - const float min_x = context->input(2).flat<float>()(0); - const float max_x = context->input(3).flat<float>()(0); - const float min_y = context->input(4).flat<float>()(0); - const float max_y = context->input(5).flat<float>()(0); + const Tensor& min_x_tensor = context->input(2); + const Tensor& max_x_tensor = context->input(3); + const Tensor& min_y_tensor = context->input(4); + const Tensor& max_y_tensor = context->input(5); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_x_tensor.shape()), + errors::InvalidArgument("`min_x` must be rank 0 but is rank ", + min_x_tensor.dims())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_x_tensor.shape()), + errors::InvalidArgument("`max_x` must be rank 0 but is rank ", + max_x_tensor.dims())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_y_tensor.shape()), + errors::InvalidArgument("`min_y` must be rank 0 but is rank ", + min_y_tensor.dims())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_y_tensor.shape()), + errors::InvalidArgument("`max_y` must be rank 0 but is rank ", + max_y_tensor.dims())); + + const float min_x = min_x_tensor.scalar<float>()(); + const float max_x = max_x_tensor.scalar<float>()(); + const float min_y = min_y_tensor.scalar<float>()(); + const float max_y = max_y_tensor.scalar<float>()(); BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape())); if (!bcast.IsValid()) { diff --git a/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py b/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py index af7cecf2ffb..8be78a0e0b2 100644 --- a/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py +++ b/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py @@ -206,5 +206,60 @@ class RequantizeOpTest(test_util.TensorFlowTestCase): out_type=dtypes.qint8)) +class QuantizedAddOpTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def test_invalid_inputs(self): + x = constant_op.constant( + np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8) + y = constant_op.constant(np.int8(0), shape=[3], dtype=dtypes.quint8) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "must be rank 0"): + self.evaluate( + math_ops.quantized_add( + x=x, + y=y, + min_x=[], + max_x=1.0, + min_y=0.0, + max_y=1.0, + Toutput=dtypes.qint32)) + + +class QuantizedReluOpTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def test_invalid_inputs(self): + inputs = constant_op.constant( + np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "must be rank 0"): + self.evaluate( + nn_ops.quantized_relu( + features=inputs, + min_features=[], + max_features=127.0, + out_type=dtypes.quint8)) + + +class QuantizedRelu6OpTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def test_invalid_inputs(self): + inputs = constant_op.constant( + np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "must be rank 0"): + self.evaluate( + nn_ops.quantized_relu6( + features=inputs, + min_features=[], + max_features=127.0, + out_type=dtypes.quint8)) + + if __name__ == "__main__": googletest.main() |