diff options
author | Swachhand Lokhande <swachhand@google.com> | 2022-09-29 15:59:27 -0700 |
---|---|---|
committer | TensorFlow Release Automation <jenkins@tensorflow.org> | 2022-10-21 18:21:57 +0000 |
commit | 188ecda393d62df628e1254b416e5b3ba400f427 (patch) | |
tree | e53ce2be91d29c8c248b8d2b1012159907b07cae | |
parent | ee897ca3146c2d41848225f16df6d1cd0948151b (diff) | |
download | tensorflow-upstream-r2.10-39ec7eaf142.tar.gz |
Make MfccMelFilterbank fail initialization if num_channels is > max int value.upstream-r2.10-39ec7eaf142
Also initialize MfccDct only if MfccMelFilterbank initialization was successful.
PiperOrigin-RevId: 477844246
-rw-r--r-- | tensorflow/core/kernels/mfcc.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/mfcc_mel_filterbank.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/kernels/mfcc_mel_filterbank_test.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/kernels/mfcc_op.cc | 12 |
4 files changed, 58 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/mfcc.cc b/tensorflow/core/kernels/mfcc.cc index 8c755e0df87..cb4416f7bd3 100644 --- a/tensorflow/core/kernels/mfcc.cc +++ b/tensorflow/core/kernels/mfcc.cc @@ -38,8 +38,10 @@ bool Mfcc::Initialize(int input_length, double input_sample_rate) { bool initialized = mel_filterbank_.Initialize( input_length, input_sample_rate, filterbank_channel_count_, lower_frequency_limit_, upper_frequency_limit_); - initialized &= - dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_); + if (initialized) { + initialized = + dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_); + } initialized_ = initialized; return initialized; } diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc index 8eb2d9d8309..c5c2d29d37b 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc @@ -32,6 +32,8 @@ limitations under the License. #include <math.h> +#include <limits> + #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -74,7 +76,17 @@ bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate, // An extra center frequency is computed at the top to get the upper // limit on the high side of the final triangular filter. - center_frequencies_.resize(num_channels_ + 1); + std::size_t center_frequencies_size = std::size_t(num_channels_) + 1; + if (center_frequencies_size >= std::numeric_limits<int>::max() || + center_frequencies_size > center_frequencies_.max_size()) { + LOG(ERROR) << "Number of filterbank channels must be less than " + << std::numeric_limits<int>::max() + << " and less than or equal to " + << center_frequencies_.max_size(); + return false; + } + center_frequencies_.resize(center_frequencies_size); + const double mel_low = FreqToMel(lower_frequency_limit); const double mel_hi = FreqToMel(upper_frequency_limit); const double mel_span = mel_hi - mel_low; diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc index 54f31e1699e..26b5afed135 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/mfcc_mel_filterbank.h" +#include <limits> #include <vector> #include "tensorflow/core/platform/test.h" @@ -85,4 +86,37 @@ TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) { } } +TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxIntValue) { + // Test for bug where vector throws a length_error when it suspects the size + // to be more than it's max_size. For now, we fail initialization when the + // number of requested channels is >= the maximum value int can take (since + // num_channels_ is an int). + MfccMelFilterbank filterbank; + + const int kSampleCount = 513; + std::size_t num_channels = std::numeric_limits<int>::max(); + bool initialized = filterbank.Initialize( + kSampleCount, 2 /* sample rate */, num_channels /* channels */, + 1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */); + + EXPECT_FALSE(initialized); +} + +TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxSize) { + // Test for bug where vector throws a length_error when it suspects the size + // to be more than it's max_size. For now, we fail initialization when the + // number of requested channels is > than std::vector<double>::max_size(). + MfccMelFilterbank filterbank; + + const int kSampleCount = 513; + // Set num_channels to exceed the max_size a double vector can + // theoretically take. + std::size_t num_channels = std::vector<double>().max_size() + 1; + bool initialized = filterbank.Initialize( + kSampleCount, 2 /* sample rate */, num_channels /* channels */, + 1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */); + + EXPECT_FALSE(initialized); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/mfcc_op.cc b/tensorflow/core/kernels/mfcc_op.cc index 358a420c160..2c5f9560aaa 100644 --- a/tensorflow/core/kernels/mfcc_op.cc +++ b/tensorflow/core/kernels/mfcc_op.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { -// Create a speech fingerpring from spectrogram data. +// Create a speech fingerprint from spectrogram data. class MfccOp : public OpKernel { public: explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) { @@ -60,10 +60,12 @@ class MfccOp : public OpKernel { mfcc.set_lower_frequency_limit(lower_frequency_limit_); mfcc.set_filterbank_channel_count(filterbank_channel_count_); mfcc.set_dct_coefficient_count(dct_coefficient_count_); - OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate), - errors::InvalidArgument( - "Mfcc initialization failed for channel count ", - spectrogram_channels, " and sample rate ", sample_rate)); + OP_REQUIRES( + context, mfcc.Initialize(spectrogram_channels, sample_rate), + errors::InvalidArgument("Mfcc initialization failed for channel count ", + spectrogram_channels, ", sample rate ", + sample_rate, " and filterbank_channel_count ", + filterbank_channel_count_)); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, |