Segfault in Tensorflow
Description
In Tensorflow before version 2.3.1, the RaggedCountSparseOutput implementation does not validate that the input arguments form a valid ragged tensor. In particular, there is no validation that the values in the splits tensor generate a valid partitioning of the values tensor. Thus, the code sets up conditions to cause a heap buffer overflow. A BatchedMap is equivalent to a vector where each element is a hashmap. However, if the first element of splits_values is not 0, batch_idx will never be 1, hence there will be no hashmap at index 0 in per_batch_counts. Trying to access that in the user code results in a segmentation fault. The issue is patched in commit 3cbb917b4714766030b28eba9fb41bb97ce9ee02 and is released in TensorFlow version 2.3.1.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
tensorflowPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
tensorflow-cpuPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
tensorflow-gpuPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
Affected products
1- Range: = 2.3.0
Patches
13cbb917b4714Fix multiple vulnerabilities in `tf.raw_ops.*CountSparseOutput`.
2 files changed · +159 −0
tensorflow/core/kernels/count_ops.cc+41 −0 modified@@ -178,10 +178,30 @@ class SparseCount : public OpKernel { const Tensor& weights = context->input(3); bool use_weights = weights.NumElements() > 0; + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()), + errors::InvalidArgument( + "Input indices must be a 2-dimensional tensor. Got: ", + indices.shape().DebugString())); + + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + bool is_1d = shape.NumElements() == 1; int num_batches = is_1d ? 1 : shape.flat<int64>()(0); int num_values = values.NumElements(); + OP_REQUIRES(context, num_values == indices.shape().dim_size(0), + errors::InvalidArgument( + "Number of values must match first dimension of indices.", + "Got ", num_values, + " values, indices shape: ", indices.shape().DebugString())); + const auto indices_values = indices.matrix<int64>(); const auto values_values = values.flat<T>(); const auto weight_values = weights.flat<W>(); @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel { bool use_weights = weights.NumElements() > 0; bool is_1d = false; + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + const auto splits_values = splits.flat<int64>(); const auto values_values = values.flat<T>(); const auto weight_values = weights.flat<W>(); int num_batches = splits.NumElements() - 1; int num_values = values.NumElements(); + OP_REQUIRES( + context, num_batches > 0, + errors::InvalidArgument( + "Must provide at least 2 elements for the splits argument")); + OP_REQUIRES(context, splits_values(0) == 0, + errors::InvalidArgument("Splits must start with 0, not with ", + splits_values(0))); + OP_REQUIRES(context, splits_values(num_batches) == num_values, + errors::InvalidArgument( + "Splits must end with the number of values, got ", + splits_values(num_batches), " instead of ", num_values)); + auto per_batch_counts = BatchedMap<W>(num_batches); T max_value = 0; int batch_idx = 0;
tensorflow/python/ops/bincount_ops_test.py+118 −0 modified@@ -25,7 +25,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import bincount_ops +from tensorflow.python.ops import gen_count_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -834,5 +836,121 @@ def test_ragged_input_different_shape_fails(self): self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) +@test_util.run_all_in_graph_and_eager_modes +@test_util.disable_tfrt +class RawOpsTest(test.TestCase, parameterized.TestCase): + + def testSparseCountSparseOutputBadIndicesShape(self): + indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Input indices must be a 2-dimensional tensor"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadWeightsShape(self): + indices = [[0, 0], [0, 1], [1, 0], [1, 2]] + values = [1, 1, 1, 10] + weights = [1, 2, 4] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadNumberOfValues(self): + indices = [[0, 0], [0, 1], [1, 0]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Number of values must match first dimension of indices"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutput(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + output_indices, output_values, output_shape = self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, values=values, weights=weights, binary_output=False)) + self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]], + output_indices) + self.assertAllEqual([7, 3, 5, 7, 6], output_values) + self.assertAllEqual([2, 11], output_shape) + + def testRaggedCountSparseOutputBadWeightsShape(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputEmptySplits(self): + splits = [] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Must provide at least 2 elements for the splits argument"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsStart(self): + splits = [1, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must start with 0"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsEnd(self): + splits = [0, 5] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must end with the number of values"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + if __name__ == "__main__": test.main()
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
8- github.com/advisories/GHSA-x7rp-74x2-mjf3ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2020-15200ghsaADVISORY
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow-cpu/PYSEC-2020-280.yamlghsaWEB
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow-gpu/PYSEC-2020-315.yamlghsaWEB
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow/PYSEC-2020-123.yamlghsaWEB
- github.com/tensorflow/tensorflow/commit/3cbb917b4714766030b28eba9fb41bb97ce9ee02ghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/releases/tag/v2.3.1ghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/security/advisories/GHSA-x7rp-74x2-mjf3ghsax_refsource_CONFIRMWEB
News mentions
0No linked articles in our index yet.