aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlearning-to-play <66660475+learning-to-play@users.noreply.github.com>2022-10-11 13:18:54 -0700
committerGitHub <noreply@github.com>2022-10-11 13:18:54 -0700
commitf0de9759fd48d8020f4992f7894ce575d7d10cab (patch)
tree68bc9427137de66fddd549dd93f36ecbeda17ce1
parent1c1b08c1dcc50ab1fa3abf631e2394d9f90d0c91 (diff)
parentbaf015c98d77f1b1be373ff1bae7f453b5ca4343 (diff)
downloadtensorflow-f0de9759fd48d8020f4992f7894ce575d7d10cab.tar.gz
Merge pull request #58062 from vinila21/cherrypick-b389f5c944cadfdfe599b3f1e4026e036f30d2d4-on-r2.9
Add true_classes input validation for candidate sampler ops.
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc8
-rw-r--r--tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py22
2 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 872e805873f..94eb7f2738e 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -73,6 +73,14 @@ class BaseCandidateSamplerOp : public OpKernel {
gtl::ArraySlice<int64_t> true_candidate(
true_classes.matrix<int64_t>().data(), batch_size * num_true_);
+
+ for (const auto& candidate : true_candidate) {
+ OP_REQUIRES(context, candidate >= 0 && candidate < sampler_->range(),
+ errors::InvalidArgument("`true_candidate` out of range [", 0,
+ ", ", sampler_->range(),
+ "), received ", candidate));
+ }
+
gtl::MutableArraySlice<int64_t> sampled_candidate(
out_sampled_candidates->vec<int64_t>().data(), num_sampled_);
gtl::MutableArraySlice<float> true_expected_count(
diff --git a/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py
index b70a30f4606..396843ace3a 100644
--- a/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/candidate_sampler_ops_test.py
@@ -18,6 +18,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
@@ -127,6 +128,27 @@ class RangeSamplerOpsTest(test.TestCase):
# twice very rarely.
self.assertLessEqual(num_same, 2)
+ def testCandidateOutOfRange(self):
+ with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
+ "out of range"):
+ self.evaluate(
+ candidate_sampling_ops.log_uniform_candidate_sampler(
+ true_classes=[[0, 10]],
+ num_true=2,
+ num_sampled=1000,
+ unique=False,
+ range_max=2))
+
+ with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
+ "out of range"):
+ self.evaluate(
+ candidate_sampling_ops.log_uniform_candidate_sampler(
+ true_classes=[[0, -10]],
+ num_true=2,
+ num_sampled=1000,
+ unique=False,
+ range_max=2))
+
if __name__ == "__main__":
test.main()