diff options
author | learning-to-play <66660475+learning-to-play@users.noreply.github.com> | 2022-10-11 13:18:54 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-11 13:18:54 -0700 |
commit | f0de9759fd48d8020f4992f7894ce575d7d10cab (patch) | |
tree | 68bc9427137de66fddd549dd93f36ecbeda17ce1 | |
parent | 1c1b08c1dcc50ab1fa3abf631e2394d9f90d0c91 (diff) | |
parent | baf015c98d77f1b1be373ff1bae7f453b5ca4343 (diff) | |
download | tensorflow-f0de9759fd48d8020f4992f7894ce575d7d10cab.tar.gz |
Merge pull request #58062 from vinila21/cherrypick-b389f5c944cadfdfe599b3f1e4026e036f30d2d4-on-r2.9
Add true_classes input validation for candidate sampler ops.
-rw-r--r-- | tensorflow/core/kernels/candidate_sampler_ops.cc | 8 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py | 22 |
2 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc index 872e805873f..94eb7f2738e 100644 --- a/tensorflow/core/kernels/candidate_sampler_ops.cc +++ b/tensorflow/core/kernels/candidate_sampler_ops.cc @@ -73,6 +73,14 @@ class BaseCandidateSamplerOp : public OpKernel { gtl::ArraySlice<int64_t> true_candidate( true_classes.matrix<int64_t>().data(), batch_size * num_true_); + + for (const auto& candidate : true_candidate) { + OP_REQUIRES(context, candidate >= 0 && candidate < sampler_->range(), + errors::InvalidArgument("`true_candidate` out of range [", 0, + ", ", sampler_->range(), + "), received ", candidate)); + } + gtl::MutableArraySlice<int64_t> sampled_candidate( out_sampled_candidates->vec<int64_t>().data(), num_sampled_); gtl::MutableArraySlice<float> true_expected_count( diff --git a/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py index b70a30f4606..396843ace3a 100644 --- a/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py +++ b/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py @@ -18,6 +18,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops @@ -127,6 +128,27 @@ class RangeSamplerOpsTest(test.TestCase): # twice very rarely. self.assertLessEqual(num_same, 2) + def testCandidateOutOfRange(self): + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "out of range"): + self.evaluate( + candidate_sampling_ops.log_uniform_candidate_sampler( + true_classes=[[0, 10]], + num_true=2, + num_sampled=1000, + unique=False, + range_max=2)) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "out of range"): + self.evaluate( + candidate_sampling_ops.log_uniform_candidate_sampler( + true_classes=[[0, -10]], + num_true=2, + num_sampled=1000, + unique=False, + range_max=2)) + if __name__ == "__main__": test.main() |