`CHECK` failures in `UnbatchGradOp` in TensorFlow
Description
TensorFlow is an open source platform for machine learning. The UnbatchGradOp function takes an argument id that is assumed to be a scalar. A nonscalar id can trigger a CHECK failure and crash the program. It also requires its argument batch_index to contain three times the number of elements as indicated in its batch_index.dim_size(0). An incorrect batch_index can trigger a CHECK failure and crash the program. We have patched the issue in GitHub commit 5f945fc6409a3c1e90d6970c9292f805f6e6ddf2. The fix will be included in TensorFlow 2.10.0. We will also cherrypick this commit on TensorFlow 2.9.1, TensorFlow 2.8.1, and TensorFlow 2.7.2, as these are also affected and still in supported range. There are no known workarounds for this issue.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
tensorflowPyPI | < 2.7.2 | 2.7.2 |
tensorflowPyPI | >= 2.8.0, < 2.8.1 | 2.8.1 |
tensorflowPyPI | >= 2.9.0, < 2.9.1 | 2.9.1 |
tensorflow-cpuPyPI | < 2.7.2 | 2.7.2 |
tensorflow-cpuPyPI | >= 2.8.0, < 2.8.1 | 2.8.1 |
tensorflow-cpuPyPI | >= 2.9.0, < 2.9.1 | 2.9.1 |
tensorflow-gpuPyPI | < 2.7.2 | 2.7.2 |
tensorflow-gpuPyPI | >= 2.8.0, < 2.8.1 | 2.8.1 |
tensorflow-gpuPyPI | >= 2.9.0, < 2.9.1 | 2.9.1 |
Affected products
1- Range: < 2.7.2
Patches
15f945fc6409aFix security vulnerability with UnbatchGradKernel
2 files changed · +63 −0
tensorflow/core/kernels/batch_kernels.cc+10 −0 modified@@ -885,8 +885,13 @@ class UnbatchGradResource : public ResourceBase { const Tensor& data_t = context->input(0); const Tensor& batch_index_t = context->input(1); const Tensor& grad_t = context->input(2); + const Tensor& batch_key_t = context->input(3); mutex_lock ml(mu_); + if (batch_key_t.NumElements() != 1) { + return errors::InvalidArgument("Expected `id` to be scalar. Received ", + batch_key_t.DebugString()); + } const int64_t batch_key = context->input(3).scalar<int64_t>()(); // Mark our tensor as available. @@ -902,6 +907,11 @@ class UnbatchGradResource : public ResourceBase { "batch_index is empty while the tensor isn't."); } std::unordered_set<int64_t> missing_tensors; + if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) { + return errors::InvalidArgument( + "batch_index should contain ", batch_index_t.dim_size(0) * 3, + " elements. Received ", batch_index_t.NumElements()); + } const auto batch_index = batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3}); for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
tensorflow/python/ops/batch_ops_test.py+53 −0 modified@@ -20,7 +20,9 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -30,6 +32,7 @@ from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables @@ -557,6 +560,56 @@ def worker(): # The thread's call should hit the timeout, and thus get 0 results. self.assertEqual(len(thread_results), 0) + def testUnbatchGradInvalidId(self): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=constant_op.constant([1]), + batch_index=constant_op.constant([ + [0, 0, 0], + ], dtype=dtypes.int64), + grad=constant_op.constant([ + 1, + ]), + id=constant_op.constant([ + 1, + 1, + ], dtype=dtypes.int64))) + + def testUnbatchGradInvalidBatchId(self): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=constant_op.constant([1]), + batch_index=constant_op.constant([ + [0, 0], + ], dtype=dtypes.int64), + grad=constant_op.constant([ + 1, + ]), + id=constant_op.constant([ + 1, + ], dtype=dtypes.int64))) + + def testUnbatchGradInvalidArgs(self): + original_input = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.float64, maxval=None) + batch_index = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.int64, maxval=65536) + grad = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.float64, maxval=None) + batch_id = random_ops.random_uniform( + shape=(3, 1), dtype=dtypes.int64, maxval=65536) + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate( + gen_batch_ops.unbatch_grad( + original_input=original_input, + batch_index=batch_index, + grad=grad, + id=batch_id, + container="", + shared_name="", + name="")) if __name__ == "__main__": test.main()
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
6- github.com/advisories/GHSA-h5vq-gw2c-pq47ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2022-35952ghsaADVISORY
- github.com/tensorflow/tensorflow/blob/769eddaf479c8debead9a59a72617d6ed6f0fe10/tensorflow/core/kernels/batch_kernels.ccghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/commit/5f945fc6409a3c1e90d6970c9292f805f6e6ddf2ghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/releases/tag/v2.10.0ghsaWEB
- github.com/tensorflow/tensorflow/security/advisories/GHSA-h5vq-gw2c-pq47ghsax_refsource_CONFIRMWEB
News mentions
0No linked articles in our index yet.