Heap buffer overflow in Tensorflow
Description
In Tensorflow before version 2.3.1, the SparseCountSparseOutput implementation does not validate that the input arguments form a valid sparse tensor. In particular, there is no validation that the indices tensor has the same shape as the values one. The values in these tensors are always accessed in parallel. Thus, a shape mismatch can result in accesses outside the bounds of heap allocated buffers. The issue is patched in commit 3cbb917b4714766030b28eba9fb41bb97ce9ee02 and is released in TensorFlow version 2.3.1.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
tensorflowPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
tensorflow-cpuPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
tensorflow-gpuPyPI | >= 2.3.0, < 2.3.1 | 2.3.1 |
Affected products
1- Range: >= 2.3.0, < 2.3.1
Patches
13cbb917b4714Fix multiple vulnerabilities in `tf.raw_ops.*CountSparseOutput`.
2 files changed · +159 −0
tensorflow/core/kernels/count_ops.cc+41 −0 modified@@ -178,10 +178,30 @@ class SparseCount : public OpKernel { const Tensor& weights = context->input(3); bool use_weights = weights.NumElements() > 0; + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()), + errors::InvalidArgument( + "Input indices must be a 2-dimensional tensor. Got: ", + indices.shape().DebugString())); + + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + bool is_1d = shape.NumElements() == 1; int num_batches = is_1d ? 1 : shape.flat<int64>()(0); int num_values = values.NumElements(); + OP_REQUIRES(context, num_values == indices.shape().dim_size(0), + errors::InvalidArgument( + "Number of values must match first dimension of indices.", + "Got ", num_values, + " values, indices shape: ", indices.shape().DebugString())); + const auto indices_values = indices.matrix<int64>(); const auto values_values = values.flat<T>(); const auto weight_values = weights.flat<W>(); @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel { bool use_weights = weights.NumElements() > 0; bool is_1d = false; + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + const auto splits_values = splits.flat<int64>(); const auto values_values = values.flat<T>(); const auto weight_values = weights.flat<W>(); int num_batches = splits.NumElements() - 1; int num_values = values.NumElements(); + OP_REQUIRES( + context, num_batches > 0, + errors::InvalidArgument( + "Must provide at least 2 elements for the splits argument")); + OP_REQUIRES(context, splits_values(0) == 0, + errors::InvalidArgument("Splits must start with 0, not with ", + splits_values(0))); + OP_REQUIRES(context, splits_values(num_batches) == num_values, + errors::InvalidArgument( + "Splits must end with the number of values, got ", + splits_values(num_batches), " instead of ", num_values)); + auto per_batch_counts = BatchedMap<W>(num_batches); T max_value = 0; int batch_idx = 0;
tensorflow/python/ops/bincount_ops_test.py+118 −0 modified@@ -25,7 +25,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import bincount_ops +from tensorflow.python.ops import gen_count_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -834,5 +836,121 @@ def test_ragged_input_different_shape_fails(self): self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) +@test_util.run_all_in_graph_and_eager_modes +@test_util.disable_tfrt +class RawOpsTest(test.TestCase, parameterized.TestCase): + + def testSparseCountSparseOutputBadIndicesShape(self): + indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Input indices must be a 2-dimensional tensor"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadWeightsShape(self): + indices = [[0, 0], [0, 1], [1, 0], [1, 2]] + values = [1, 1, 1, 10] + weights = [1, 2, 4] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadNumberOfValues(self): + indices = [[0, 0], [0, 1], [1, 0]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Number of values must match first dimension of indices"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutput(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + output_indices, output_values, output_shape = self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, values=values, weights=weights, binary_output=False)) + self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]], + output_indices) + self.assertAllEqual([7, 3, 5, 7, 6], output_values) + self.assertAllEqual([2, 11], output_shape) + + def testRaggedCountSparseOutputBadWeightsShape(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputEmptySplits(self): + splits = [] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Must provide at least 2 elements for the splits argument"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsStart(self): + splits = [1, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must start with 0"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsEnd(self): + splits = [0, 5] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must end with the number of values"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + if __name__ == "__main__": test.main()
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
8- github.com/advisories/GHSA-jc87-6vpp-7ff3ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2020-15198ghsaADVISORY
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow-cpu/PYSEC-2020-278.yamlghsaWEB
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow-gpu/PYSEC-2020-313.yamlghsaWEB
- github.com/pypa/advisory-database/tree/main/vulns/tensorflow/PYSEC-2020-121.yamlghsaWEB
- github.com/tensorflow/tensorflow/commit/3cbb917b4714766030b28eba9fb41bb97ce9ee02ghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/releases/tag/v2.3.1ghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/security/advisories/GHSA-jc87-6vpp-7ff3ghsax_refsource_CONFIRMWEB
News mentions
0No linked articles in our index yet.