aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSwachhand Lokhande <swachhand@google.com>2022-09-29 15:59:27 -0700
committerTensorFlow Release Automation <jenkins@tensorflow.org>2022-10-21 18:21:57 +0000
commit188ecda393d62df628e1254b416e5b3ba400f427 (patch)
treee53ce2be91d29c8c248b8d2b1012159907b07cae
parentee897ca3146c2d41848225f16df6d1cd0948151b (diff)
downloadtensorflow-upstream-r2.10-39ec7eaf142.tar.gz
Make MfccMelFilterbank fail initialization if num_channels is > max int value.upstream-r2.10-39ec7eaf142
Also initialize MfccDct only if MfccMelFilterbank initialization was successful. PiperOrigin-RevId: 477844246
-rw-r--r--tensorflow/core/kernels/mfcc.cc6
-rw-r--r--tensorflow/core/kernels/mfcc_mel_filterbank.cc14
-rw-r--r--tensorflow/core/kernels/mfcc_mel_filterbank_test.cc34
-rw-r--r--tensorflow/core/kernels/mfcc_op.cc12
4 files changed, 58 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/mfcc.cc b/tensorflow/core/kernels/mfcc.cc
index 8c755e0df87..cb4416f7bd3 100644
--- a/tensorflow/core/kernels/mfcc.cc
+++ b/tensorflow/core/kernels/mfcc.cc
@@ -38,8 +38,10 @@ bool Mfcc::Initialize(int input_length, double input_sample_rate) {
bool initialized = mel_filterbank_.Initialize(
input_length, input_sample_rate, filterbank_channel_count_,
lower_frequency_limit_, upper_frequency_limit_);
- initialized &=
- dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_);
+ if (initialized) {
+ initialized =
+ dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_);
+ }
initialized_ = initialized;
return initialized;
}
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
index 8eb2d9d8309..c5c2d29d37b 100644
--- a/tensorflow/core/kernels/mfcc_mel_filterbank.cc
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
@@ -32,6 +32,8 @@ limitations under the License.
#include <math.h>
+#include <limits>
+
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -74,7 +76,17 @@ bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate,
// An extra center frequency is computed at the top to get the upper
// limit on the high side of the final triangular filter.
- center_frequencies_.resize(num_channels_ + 1);
+ std::size_t center_frequencies_size = std::size_t(num_channels_) + 1;
+ if (center_frequencies_size >= std::numeric_limits<int>::max() ||
+ center_frequencies_size > center_frequencies_.max_size()) {
+ LOG(ERROR) << "Number of filterbank channels must be less than "
+ << std::numeric_limits<int>::max()
+ << " and less than or equal to "
+ << center_frequencies_.max_size();
+ return false;
+ }
+ center_frequencies_.resize(center_frequencies_size);
+
const double mel_low = FreqToMel(lower_frequency_limit);
const double mel_hi = FreqToMel(upper_frequency_limit);
const double mel_span = mel_hi - mel_low;
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
index 54f31e1699e..26b5afed135 100644
--- a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
+#include <limits>
#include <vector>
#include "tensorflow/core/platform/test.h"
@@ -85,4 +86,37 @@ TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) {
}
}
+TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxIntValue) {
+ // Test for bug where vector throws a length_error when it suspects the size
+ // to be more than it's max_size. For now, we fail initialization when the
+ // number of requested channels is >= the maximum value int can take (since
+ // num_channels_ is an int).
+ MfccMelFilterbank filterbank;
+
+ const int kSampleCount = 513;
+ std::size_t num_channels = std::numeric_limits<int>::max();
+ bool initialized = filterbank.Initialize(
+ kSampleCount, 2 /* sample rate */, num_channels /* channels */,
+ 1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */);
+
+ EXPECT_FALSE(initialized);
+}
+
+TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxSize) {
+ // Test for bug where vector throws a length_error when it suspects the size
+ // to be more than it's max_size. For now, we fail initialization when the
+ // number of requested channels is > than std::vector<double>::max_size().
+ MfccMelFilterbank filterbank;
+
+ const int kSampleCount = 513;
+ // Set num_channels to exceed the max_size a double vector can
+ // theoretically take.
+ std::size_t num_channels = std::vector<double>().max_size() + 1;
+ bool initialized = filterbank.Initialize(
+ kSampleCount, 2 /* sample rate */, num_channels /* channels */,
+ 1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */);
+
+ EXPECT_FALSE(initialized);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_op.cc b/tensorflow/core/kernels/mfcc_op.cc
index 358a420c160..2c5f9560aaa 100644
--- a/tensorflow/core/kernels/mfcc_op.cc
+++ b/tensorflow/core/kernels/mfcc_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
-// Create a speech fingerpring from spectrogram data.
+// Create a speech fingerprint from spectrogram data.
class MfccOp : public OpKernel {
public:
explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -60,10 +60,12 @@ class MfccOp : public OpKernel {
mfcc.set_lower_frequency_limit(lower_frequency_limit_);
mfcc.set_filterbank_channel_count(filterbank_channel_count_);
mfcc.set_dct_coefficient_count(dct_coefficient_count_);
- OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate),
- errors::InvalidArgument(
- "Mfcc initialization failed for channel count ",
- spectrogram_channels, " and sample rate ", sample_rate));
+ OP_REQUIRES(
+ context, mfcc.Initialize(spectrogram_channels, sample_rate),
+ errors::InvalidArgument("Mfcc initialization failed for channel count ",
+ spectrogram_channels, ", sample rate ",
+ sample_rate, " and filterbank_channel_count ",
+ filterbank_channel_count_));
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,