VYPR
Moderate severityNVD Advisory· Published Sep 25, 2020· Updated Aug 4, 2024

Heap buffer overflow in Tensorflow

CVE-2020-15201

Description

In Tensorflow before version 2.3.1, the RaggedCountSparseOutput implementation does not validate that the input arguments form a valid ragged tensor. In particular, there is no validation that the values in the splits tensor generate a valid partitioning of the values tensor. Hence, the code is prone to heap buffer overflow. If split_values does not end with a value at least num_values then the while loop condition will trigger a read outside of the bounds of split_values once batch_idx grows too large. The issue is patched in commit 3cbb917b4714766030b28eba9fb41bb97ce9ee02 and is released in TensorFlow version 2.3.1.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
tensorflowPyPI
>= 2.3.0, < 2.3.12.3.1
tensorflow-cpuPyPI
>= 2.3.0, < 2.3.12.3.1
tensorflow-gpuPyPI
>= 2.3.0, < 2.3.12.3.1

Affected products

1

Patches

1
3cbb917b4714

Fix multiple vulnerabilities in `tf.raw_ops.*CountSparseOutput`.

https://github.com/tensorflow/tensorflowMihai MaruseacSep 19, 2020via ghsa
2 files changed · +159 0
  • tensorflow/core/kernels/count_ops.cc+41 0 modified
    @@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
         const Tensor& weights = context->input(3);
         bool use_weights = weights.NumElements() > 0;
     
    +    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
    +                errors::InvalidArgument(
    +                    "Input indices must be a 2-dimensional tensor. Got: ",
    +                    indices.shape().DebugString()));
    +
    +    if (use_weights) {
    +      OP_REQUIRES(
    +          context, weights.shape() == values.shape(),
    +          errors::InvalidArgument(
    +              "Weights and values must have the same shape. Weight shape: ",
    +              weights.shape().DebugString(),
    +              "; values shape: ", values.shape().DebugString()));
    +    }
    +
         bool is_1d = shape.NumElements() == 1;
         int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
         int num_values = values.NumElements();
     
    +    OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
    +                errors::InvalidArgument(
    +                    "Number of values must match first dimension of indices.",
    +                    "Got ", num_values,
    +                    " values, indices shape: ", indices.shape().DebugString()));
    +
         const auto indices_values = indices.matrix<int64>();
         const auto values_values = values.flat<T>();
         const auto weight_values = weights.flat<W>();
    @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
         bool use_weights = weights.NumElements() > 0;
         bool is_1d = false;
     
    +    if (use_weights) {
    +      OP_REQUIRES(
    +          context, weights.shape() == values.shape(),
    +          errors::InvalidArgument(
    +              "Weights and values must have the same shape. Weight shape: ",
    +              weights.shape().DebugString(),
    +              "; values shape: ", values.shape().DebugString()));
    +    }
    +
         const auto splits_values = splits.flat<int64>();
         const auto values_values = values.flat<T>();
         const auto weight_values = weights.flat<W>();
         int num_batches = splits.NumElements() - 1;
         int num_values = values.NumElements();
     
    +    OP_REQUIRES(
    +        context, num_batches > 0,
    +        errors::InvalidArgument(
    +            "Must provide at least 2 elements for the splits argument"));
    +    OP_REQUIRES(context, splits_values(0) == 0,
    +                errors::InvalidArgument("Splits must start with 0, not with ",
    +                                        splits_values(0)));
    +    OP_REQUIRES(context, splits_values(num_batches) == num_values,
    +                errors::InvalidArgument(
    +                    "Splits must end with the number of values, got ",
    +                    splits_values(num_batches), " instead of ", num_values));
    +
         auto per_batch_counts = BatchedMap<W>(num_batches);
         T max_value = 0;
         int batch_idx = 0;
    
  • tensorflow/python/ops/bincount_ops_test.py+118 0 modified
    @@ -25,7 +25,9 @@
     from tensorflow.python.framework import errors
     from tensorflow.python.framework import ops
     from tensorflow.python.framework import sparse_tensor
    +from tensorflow.python.framework import test_util
     from tensorflow.python.ops import bincount_ops
    +from tensorflow.python.ops import gen_count_ops
     from tensorflow.python.ops import sparse_ops
     from tensorflow.python.ops.ragged import ragged_factory_ops
     from tensorflow.python.ops.ragged import ragged_tensor
    @@ -834,5 +836,121 @@ def test_ragged_input_different_shape_fails(self):
           self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
     
     
    +@test_util.run_all_in_graph_and_eager_modes
    +@test_util.disable_tfrt
    +class RawOpsTest(test.TestCase, parameterized.TestCase):
    +
    +  def testSparseCountSparseOutputBadIndicesShape(self):
    +    indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4, 6]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Input indices must be a 2-dimensional tensor"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testSparseCountSparseOutputBadWeightsShape(self):
    +    indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Weights and values must have the same shape"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testSparseCountSparseOutputBadNumberOfValues(self):
    +    indices = [[0, 0], [0, 1], [1, 0]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4, 6]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(
    +        errors.InvalidArgumentError,
    +        "Number of values must match first dimension of indices"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutput(self):
    +    splits = [0, 4, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    output_indices, output_values, output_shape = self.evaluate(
    +        gen_count_ops.RaggedCountSparseOutput(
    +            splits=splits, values=values, weights=weights, binary_output=False))
    +    self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
    +                        output_indices)
    +    self.assertAllEqual([7, 3, 5, 7, 6], output_values)
    +    self.assertAllEqual([2, 11], output_shape)
    +
    +  def testRaggedCountSparseOutputBadWeightsShape(self):
    +    splits = [0, 4, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Weights and values must have the same shape"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputEmptySplits(self):
    +    splits = []
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(
    +        errors.InvalidArgumentError,
    +        "Must provide at least 2 elements for the splits argument"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputBadSplitsStart(self):
    +    splits = [1, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Splits must start with 0"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputBadSplitsEnd(self):
    +    splits = [0, 5]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Splits must end with the number of values"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +
     if __name__ == "__main__":
       test.main()
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

8

News mentions

0

No linked articles in our index yet.