VYPR
High severityNVD Advisory· Published Sep 25, 2020· Updated Aug 4, 2024

Segfault in Tensorflow

CVE-2020-15200

Description

In Tensorflow before version 2.3.1, the RaggedCountSparseOutput implementation does not validate that the input arguments form a valid ragged tensor. In particular, there is no validation that the values in the splits tensor generate a valid partitioning of the values tensor. Thus, the code sets up conditions to cause a heap buffer overflow. A BatchedMap is equivalent to a vector where each element is a hashmap. However, if the first element of splits_values is not 0, batch_idx will never be 1, hence there will be no hashmap at index 0 in per_batch_counts. Trying to access that in the user code results in a segmentation fault. The issue is patched in commit 3cbb917b4714766030b28eba9fb41bb97ce9ee02 and is released in TensorFlow version 2.3.1.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
tensorflowPyPI
>= 2.3.0, < 2.3.12.3.1
tensorflow-cpuPyPI
>= 2.3.0, < 2.3.12.3.1
tensorflow-gpuPyPI
>= 2.3.0, < 2.3.12.3.1

Affected products

1

Patches

1
3cbb917b4714

Fix multiple vulnerabilities in `tf.raw_ops.*CountSparseOutput`.

https://github.com/tensorflow/tensorflowMihai MaruseacSep 19, 2020via ghsa
2 files changed · +159 0
  • tensorflow/core/kernels/count_ops.cc+41 0 modified
    @@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
         const Tensor& weights = context->input(3);
         bool use_weights = weights.NumElements() > 0;
     
    +    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
    +                errors::InvalidArgument(
    +                    "Input indices must be a 2-dimensional tensor. Got: ",
    +                    indices.shape().DebugString()));
    +
    +    if (use_weights) {
    +      OP_REQUIRES(
    +          context, weights.shape() == values.shape(),
    +          errors::InvalidArgument(
    +              "Weights and values must have the same shape. Weight shape: ",
    +              weights.shape().DebugString(),
    +              "; values shape: ", values.shape().DebugString()));
    +    }
    +
         bool is_1d = shape.NumElements() == 1;
         int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
         int num_values = values.NumElements();
     
    +    OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
    +                errors::InvalidArgument(
    +                    "Number of values must match first dimension of indices.",
    +                    "Got ", num_values,
    +                    " values, indices shape: ", indices.shape().DebugString()));
    +
         const auto indices_values = indices.matrix<int64>();
         const auto values_values = values.flat<T>();
         const auto weight_values = weights.flat<W>();
    @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
         bool use_weights = weights.NumElements() > 0;
         bool is_1d = false;
     
    +    if (use_weights) {
    +      OP_REQUIRES(
    +          context, weights.shape() == values.shape(),
    +          errors::InvalidArgument(
    +              "Weights and values must have the same shape. Weight shape: ",
    +              weights.shape().DebugString(),
    +              "; values shape: ", values.shape().DebugString()));
    +    }
    +
         const auto splits_values = splits.flat<int64>();
         const auto values_values = values.flat<T>();
         const auto weight_values = weights.flat<W>();
         int num_batches = splits.NumElements() - 1;
         int num_values = values.NumElements();
     
    +    OP_REQUIRES(
    +        context, num_batches > 0,
    +        errors::InvalidArgument(
    +            "Must provide at least 2 elements for the splits argument"));
    +    OP_REQUIRES(context, splits_values(0) == 0,
    +                errors::InvalidArgument("Splits must start with 0, not with ",
    +                                        splits_values(0)));
    +    OP_REQUIRES(context, splits_values(num_batches) == num_values,
    +                errors::InvalidArgument(
    +                    "Splits must end with the number of values, got ",
    +                    splits_values(num_batches), " instead of ", num_values));
    +
         auto per_batch_counts = BatchedMap<W>(num_batches);
         T max_value = 0;
         int batch_idx = 0;
    
  • tensorflow/python/ops/bincount_ops_test.py+118 0 modified
    @@ -25,7 +25,9 @@
     from tensorflow.python.framework import errors
     from tensorflow.python.framework import ops
     from tensorflow.python.framework import sparse_tensor
    +from tensorflow.python.framework import test_util
     from tensorflow.python.ops import bincount_ops
    +from tensorflow.python.ops import gen_count_ops
     from tensorflow.python.ops import sparse_ops
     from tensorflow.python.ops.ragged import ragged_factory_ops
     from tensorflow.python.ops.ragged import ragged_tensor
    @@ -834,5 +836,121 @@ def test_ragged_input_different_shape_fails(self):
           self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
     
     
    +@test_util.run_all_in_graph_and_eager_modes
    +@test_util.disable_tfrt
    +class RawOpsTest(test.TestCase, parameterized.TestCase):
    +
    +  def testSparseCountSparseOutputBadIndicesShape(self):
    +    indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4, 6]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Input indices must be a 2-dimensional tensor"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testSparseCountSparseOutputBadWeightsShape(self):
    +    indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Weights and values must have the same shape"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testSparseCountSparseOutputBadNumberOfValues(self):
    +    indices = [[0, 0], [0, 1], [1, 0]]
    +    values = [1, 1, 1, 10]
    +    weights = [1, 2, 4, 6]
    +    dense_shape = [2, 3]
    +    with self.assertRaisesRegex(
    +        errors.InvalidArgumentError,
    +        "Number of values must match first dimension of indices"):
    +      self.evaluate(
    +          gen_count_ops.SparseCountSparseOutput(
    +              indices=indices,
    +              values=values,
    +              dense_shape=dense_shape,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutput(self):
    +    splits = [0, 4, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    output_indices, output_values, output_shape = self.evaluate(
    +        gen_count_ops.RaggedCountSparseOutput(
    +            splits=splits, values=values, weights=weights, binary_output=False))
    +    self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
    +                        output_indices)
    +    self.assertAllEqual([7, 3, 5, 7, 6], output_values)
    +    self.assertAllEqual([2, 11], output_shape)
    +
    +  def testRaggedCountSparseOutputBadWeightsShape(self):
    +    splits = [0, 4, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Weights and values must have the same shape"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputEmptySplits(self):
    +    splits = []
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(
    +        errors.InvalidArgumentError,
    +        "Must provide at least 2 elements for the splits argument"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputBadSplitsStart(self):
    +    splits = [1, 7]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Splits must start with 0"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +  def testRaggedCountSparseOutputBadSplitsEnd(self):
    +    splits = [0, 5]
    +    values = [1, 1, 2, 1, 2, 10, 5]
    +    weights = [1, 2, 3, 4, 5, 6, 7]
    +    with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                "Splits must end with the number of values"):
    +      self.evaluate(
    +          gen_count_ops.RaggedCountSparseOutput(
    +              splits=splits,
    +              values=values,
    +              weights=weights,
    +              binary_output=False))
    +
    +
     if __name__ == "__main__":
       test.main()
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

8

News mentions

0

No linked articles in our index yet.