VYPR
High severityNVD Advisory· Published Feb 3, 2022· Updated May 5, 2025

`CHECK`-failures in Tensorflow

CVE-2022-21734

Description

TensorFlow's MapStage implementation crashes via a CHECK-fail when the key tensor is not a scalar, allowing denial of service.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

TensorFlow's MapStage implementation crashes via a CHECK-fail when the key tensor is not a scalar, allowing denial of service.

Vulnerability

The implementation of MapStage in TensorFlow is vulnerable to a CHECK-fail assertion when the key tensor is not a scalar [1]. Affected versions include TensorFlow 2.5.x up to 2.5.2, 2.6.x up to 2.6.2, 2.7.0, and all versions prior to the fix in 2.8.0 [1]. The function fails to validate that the key input is a scalar tensor, which triggers the assertion failure during graph execution.

Exploitation

An attacker can provide a non-scalar tensor as the key argument to MapStage during model training or inference. No special privileges are required beyond the ability to supply crafted input data to a TensorFlow model that uses the MapStage operation. The attack does not require authentication or network access if the attacker can control input tensors locally.

Impact

Successful exploitation causes a denial-of-service (DoS) condition due to a process termination caused by the CHECK-fail assertion. This can interrupt training pipelines or production inference services. The impact is limited to availability; no confidentiality or integrity compromise is reported.

Mitigation

The fix is included in TensorFlow 2.8.0 [1]. Patches have been cherry-picked to TensorFlow 2.7.1, 2.6.3, and 2.5.3 [1]. Users should upgrade to one of these patched versions or later. No workaround is available for unpatched versions. The affected versions are still in the supported range [1].

AI Insight generated on May 21, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
tensorflowPyPI
< 2.5.32.5.3
tensorflowPyPI
>= 2.6.0, < 2.6.32.6.3
tensorflowPyPI
>= 2.7.0, < 2.7.12.7.1
tensorflow-cpuPyPI
< 2.5.32.5.3
tensorflow-cpuPyPI
>= 2.6.0, < 2.6.32.6.3
tensorflow-cpuPyPI
>= 2.7.0, < 2.7.12.7.1
tensorflow-gpuPyPI
< 2.5.32.5.3
tensorflow-gpuPyPI
>= 2.6.0, < 2.6.32.6.3
tensorflow-gpuPyPI
>= 2.7.0, < 2.7.12.7.1

Affected products

5

Patches

1
f57315566d70

Add a check for Key being scalar tensor for MapStage and OrderedMapStage ops.

https://github.com/tensorflow/tensorflowIsha ArkatkarDec 3, 2021via ghsa
2 files changed · +167 124
  • tensorflow/core/kernels/map_stage_op.cc+5 0 modified
    @@ -536,6 +536,11 @@ class MapStageOp : public OpKernel {
         OP_REQUIRES(ctx, key_tensor->NumElements() > 0,
                     errors::InvalidArgument("key must not be empty"));
     
    +    OP_REQUIRES(ctx, key_tensor->NumElements() == 1,
    +                errors::InvalidArgument(
    +                    "key must be an int64 scalar, got tensor with shape: ",
    +                    key_tensor->shape()));
    +
         // Create copy for insertion into Staging Area
         Tensor key(*key_tensor);
     
    
  • tensorflow/python/kernel_tests/data_structures/map_stage_op_test.py+162 124 modified
    @@ -12,8 +12,11 @@
     # See the License for the specific language governing permissions and
     # limitations under the License.
     # ==============================================================================
    -from tensorflow.python.framework import errors
    +import numpy as np
    +
    +from tensorflow.python.framework import constant_op
     from tensorflow.python.framework import dtypes
    +from tensorflow.python.framework import errors
     from tensorflow.python.framework import ops
     from tensorflow.python.framework import test_util
     from tensorflow.python.ops import array_ops
    @@ -28,7 +31,7 @@ class MapStageTest(test.TestCase):
     
       @test_util.run_deprecated_v1
       def testSimple(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             pi = array_ops.placeholder(dtypes.int64)
    @@ -40,17 +43,17 @@ def testSimple(self):
             k, y = stager.get(gi)
             y = math_ops.reduce_max(math_ops.matmul(y, y))
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           sess.run(stage, feed_dict={x: -1, pi: 0})
           for i in range(10):
             _, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
             self.assertAllClose(4 * (i - 1) * (i - 1) * 128, yval, rtol=1e-4)
     
       @test_util.run_deprecated_v1
       def testMultiple(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             pi = array_ops.placeholder(dtypes.int64)
    @@ -62,9 +65,9 @@ def testMultiple(self):
             k, (z, y) = stager.get(gi)
             y = math_ops.reduce_max(z * math_ops.matmul(y, y))
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           sess.run(stage, feed_dict={x: -1, pi: 0})
           for i in range(10):
             _, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
    @@ -73,26 +76,25 @@ def testMultiple(self):
     
       @test_util.run_deprecated_v1
       def testDictionary(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             pi = array_ops.placeholder(dtypes.int64)
             gi = array_ops.placeholder(dtypes.int64)
             v = 2. * (array_ops.zeros([128, 128]) + x)
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [dtypes.float32, dtypes.float32],
    -            shapes=[[], [128, 128]],
    -            names=['x', 'v'])
    +        stager = data_flow_ops.MapStagingArea([dtypes.float32, dtypes.float32],
    +                                              shapes=[[], [128, 128]],
    +                                              names=['x', 'v'])
             stage = stager.put(pi, {'x': x, 'v': v})
             key, ret = stager.get(gi)
             z = ret['x']
             y = ret['v']
             y = math_ops.reduce_max(z * math_ops.matmul(y, y))
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           sess.run(stage, feed_dict={x: -1, pi: 0})
           for i in range(10):
             _, yval = sess.run([stage, y], feed_dict={x: i, pi: i + 1, gi: i})
    @@ -102,7 +104,7 @@ def testDictionary(self):
       def testColocation(self):
         gpu_dev = test.gpu_device_name()
     
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             v = 2. * (array_ops.zeros([128, 128]) + x)
    @@ -119,58 +121,56 @@ def testColocation(self):
             self.assertEqual(y.device, '/device:CPU:0')
             self.assertEqual(z[0].device, '/device:CPU:0')
     
    -    G.finalize()
    +    g.finalize()
     
       @test_util.run_deprecated_v1
       def testPeek(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.int32, name='x')
             pi = array_ops.placeholder(dtypes.int64)
             gi = array_ops.placeholder(dtypes.int64)
             p = array_ops.placeholder(dtypes.int32, name='p')
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [
    -                dtypes.int32,
    -            ], shapes=[[]])
    +        stager = data_flow_ops.MapStagingArea([
    +            dtypes.int32,
    +        ], shapes=[[]])
             stage = stager.put(pi, [x], [0])
             peek = stager.peek(gi)
             size = stager.size()
     
    -    G.finalize()
    +    g.finalize()
     
         n = 10
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           for i in range(n):
             sess.run(stage, feed_dict={x: i, pi: i})
     
           for i in range(n):
    -        self.assertTrue(sess.run(peek, feed_dict={gi: i})[0] == i)
    +        self.assertEqual(sess.run(peek, feed_dict={gi: i})[0], i)
     
    -      self.assertTrue(sess.run(size) == 10)
    +      self.assertEqual(sess.run(size), 10)
     
       @test_util.run_deprecated_v1
       def testSizeAndClear(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32, name='x')
             pi = array_ops.placeholder(dtypes.int64)
             gi = array_ops.placeholder(dtypes.int64)
             v = 2. * (array_ops.zeros([128, 128]) + x)
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [dtypes.float32, dtypes.float32],
    -            shapes=[[], [128, 128]],
    -            names=['x', 'v'])
    +        stager = data_flow_ops.MapStagingArea([dtypes.float32, dtypes.float32],
    +                                              shapes=[[], [128, 128]],
    +                                              names=['x', 'v'])
             stage = stager.put(pi, {'x': x, 'v': v})
             size = stager.size()
             clear = stager.clear()
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           sess.run(stage, feed_dict={x: -1, pi: 3})
           self.assertEqual(sess.run(size), 1)
           sess.run(stage, feed_dict={x: -1, pi: 1})
    @@ -182,30 +182,31 @@ def testSizeAndClear(self):
       def testCapacity(self):
         capacity = 3
     
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.int32, name='x')
             pi = array_ops.placeholder(dtypes.int64, name='pi')
             gi = array_ops.placeholder(dtypes.int64, name='gi')
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [
    -                dtypes.int32,
    -            ], capacity=capacity, shapes=[[]])
    +        stager = data_flow_ops.MapStagingArea([
    +            dtypes.int32,
    +        ],
    +                                              capacity=capacity,
    +                                              shapes=[[]])
     
           stage = stager.put(pi, [x], [0])
           get = stager.get()
           size = stager.size()
     
    -    G.finalize()
    +    g.finalize()
     
         from six.moves import queue as Queue
         import threading
     
         queue = Queue.Queue()
         n = 8
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # Stage data in a separate thread which will block
           # when it hits the staging area's capacity and thus
           # not fill the queue with n tokens
    @@ -234,42 +235,42 @@ def thread_run():
                                                  capacity))
     
           # Should have capacity elements in the staging area
    -      self.assertTrue(sess.run(size) == capacity)
    +      self.assertEqual(sess.run(size), capacity)
     
           # Clear the staging area completely
           for i in range(n):
             sess.run(get)
     
    -      self.assertTrue(sess.run(size) == 0)
    +      self.assertEqual(sess.run(size), 0)
     
       @test_util.run_deprecated_v1
       def testMemoryLimit(self):
         memory_limit = 512 * 1024  # 512K
         chunk = 200 * 1024  # 256K
         capacity = memory_limit // chunk
     
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.uint8, name='x')
             pi = array_ops.placeholder(dtypes.int64, name='pi')
             gi = array_ops.placeholder(dtypes.int64, name='gi')
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [dtypes.uint8], memory_limit=memory_limit, shapes=[[]])
    +        stager = data_flow_ops.MapStagingArea([dtypes.uint8],
    +                                              memory_limit=memory_limit,
    +                                              shapes=[[]])
             stage = stager.put(pi, [x], [0])
             get = stager.get()
             size = stager.size()
     
    -    G.finalize()
    +    g.finalize()
     
         from six.moves import queue as Queue
         import threading
    -    import numpy as np
     
         queue = Queue.Queue()
         n = 8
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # Stage data in a separate thread which will block
           # when it hits the staging area's capacity and thus
           # not fill the queue with n tokens
    @@ -299,56 +300,57 @@ def thread_run():
                                                  capacity))
     
           # Should have capacity elements in the staging area
    -      self.assertTrue(sess.run(size) == capacity)
    +      self.assertEqual(sess.run(size), capacity)
     
           # Clear the staging area completely
           for i in range(n):
             sess.run(get)
     
    -      self.assertTrue(sess.run(size) == 0)
    +      self.assertEqual(sess.run(size), 0)
     
       @test_util.run_deprecated_v1
       def testOrdering(self):
         import six
         import random
     
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.int32, name='x')
             pi = array_ops.placeholder(dtypes.int64, name='pi')
             gi = array_ops.placeholder(dtypes.int64, name='gi')
           with ops.device(test.gpu_device_name()):
    -        stager = data_flow_ops.MapStagingArea(
    -            [
    -                dtypes.int32,
    -            ], shapes=[[]], ordered=True)
    +        stager = data_flow_ops.MapStagingArea([
    +            dtypes.int32,
    +        ],
    +                                              shapes=[[]],
    +                                              ordered=True)
             stage = stager.put(pi, [x], [0])
             get = stager.get()
             size = stager.size()
     
    -    G.finalize()
    +    g.finalize()
     
         n = 10
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # Keys n-1..0
           keys = list(reversed(six.moves.range(n)))
     
           for i in keys:
             sess.run(stage, feed_dict={pi: i, x: i})
     
    -      self.assertTrue(sess.run(size) == n)
    +      self.assertEqual(sess.run(size), n)
     
           # Check that key, values come out in ascending order
           for i, k in enumerate(reversed(keys)):
             get_key, values = sess.run(get)
             self.assertTrue(i == k == get_key == values)
     
    -      self.assertTrue(sess.run(size) == 0)
    +      self.assertEqual(sess.run(size), 0)
     
       @test_util.run_deprecated_v1
       def testPartialDictInsert(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             f = array_ops.placeholder(dtypes.float32)
    @@ -366,49 +368,47 @@ def testPartialDictInsert(self):
             size = stager.size()
             isize = stager.incomplete_size()
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # 0 complete and incomplete entries
    -      self.assertTrue(sess.run([size, isize]) == [0, 0])
    +      self.assertEqual(sess.run([size, isize]), [0, 0])
           # Stage key 0, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
           # Stage key 1, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 2])
    +      self.assertEqual(sess.run([size, isize]), [0, 2])
     
           # Now complete key 0 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 0, v: 1})
           # 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 1])
    +      self.assertEqual(sess.run([size, isize]), [1, 1])
           # We can now obtain tuple associated with key 0
    -      self.assertTrue(
    -          sess.run([key, ret], feed_dict={
    -              gi: 0
    -          }) == [0, {
    +      self.assertEqual(
    +          sess.run([key, ret], feed_dict={gi: 0}),
    +          [0, {
                   'x': 1,
                   'f': 2,
                   'v': 1
               }])
     
           # 0 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
           # Now complete key 1 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 1, v: 3})
           # We can now obtain tuple associated with key 1
    -      self.assertTrue(
    -          sess.run([key, ret], feed_dict={
    -              gi: 1
    -          }) == [1, {
    +      self.assertEqual(
    +          sess.run([key, ret], feed_dict={gi: 1}),
    +          [1, {
                   'x': 1,
                   'f': 2,
                   'v': 3
               }])
     
       @test_util.run_deprecated_v1
       def testPartialIndexInsert(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             f = array_ops.placeholder(dtypes.float32)
    @@ -424,35 +424,35 @@ def testPartialIndexInsert(self):
             size = stager.size()
             isize = stager.incomplete_size()
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # 0 complete and incomplete entries
    -      self.assertTrue(sess.run([size, isize]) == [0, 0])
    +      self.assertEqual(sess.run([size, isize]), [0, 0])
           # Stage key 0, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
           # Stage key 1, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 2])
    +      self.assertEqual(sess.run([size, isize]), [0, 2])
     
           # Now complete key 0 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 0, v: 1})
           # 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 1])
    +      self.assertEqual(sess.run([size, isize]), [1, 1])
           # We can now obtain tuple associated with key 0
    -      self.assertTrue(sess.run([key, ret], feed_dict={gi: 0}) == [0, [1, 1, 2]])
    +      self.assertEqual(sess.run([key, ret], feed_dict={gi: 0}), [0, [1, 1, 2]])
     
           # 0 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
           # Now complete key 1 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 1, v: 3})
           # We can now obtain tuple associated with key 1
    -      self.assertTrue(sess.run([key, ret], feed_dict={gi: 1}) == [1, [1, 3, 2]])
    +      self.assertEqual(sess.run([key, ret], feed_dict={gi: 1}), [1, [1, 3, 2]])
     
       @test_util.run_deprecated_v1
       def testPartialDictGetsAndPeeks(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             f = array_ops.placeholder(dtypes.float32)
    @@ -476,81 +476,75 @@ def testPartialDictGetsAndPeeks(self):
             size = stager.size()
             isize = stager.incomplete_size()
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # 0 complete and incomplete entries
    -      self.assertTrue(sess.run([size, isize]) == [0, 0])
    +      self.assertEqual(sess.run([size, isize]), [0, 0])
           # Stage key 0, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
           # Stage key 1, x and f tuple entries
           sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
    -      self.assertTrue(sess.run([size, isize]) == [0, 2])
    +      self.assertEqual(sess.run([size, isize]), [0, 2])
     
           # Now complete key 0 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 0, v: 1})
           # 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 1])
    +      self.assertEqual(sess.run([size, isize]), [1, 1])
     
           # We can now peek at 'x' and 'f' values associated with key 0
    -      self.assertTrue(sess.run(peek_xf, feed_dict={pei: 0}) == {'x': 1, 'f': 2})
    +      self.assertEqual(sess.run(peek_xf, feed_dict={pei: 0}), {'x': 1, 'f': 2})
           # Peek at 'v' value associated with key 0
    -      self.assertTrue(sess.run(peek_v, feed_dict={pei: 0}) == {'v': 1})
    +      self.assertEqual(sess.run(peek_v, feed_dict={pei: 0}), {'v': 1})
           # 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 1])
    +      self.assertEqual(sess.run([size, isize]), [1, 1])
     
           # We can now obtain 'x' and 'f' values associated with key 0
    -      self.assertTrue(
    -          sess.run([key_xf, get_xf], feed_dict={
    -              gi: 0
    -          }) == [0, {
    +      self.assertEqual(
    +          sess.run([key_xf, get_xf], feed_dict={gi: 0}), [0, {
                   'x': 1,
                   'f': 2
               }])
           # Still have 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 1])
    +      self.assertEqual(sess.run([size, isize]), [1, 1])
     
           # We can no longer get 'x' and 'f' from key 0
           with self.assertRaises(errors.InvalidArgumentError) as cm:
             sess.run([key_xf, get_xf], feed_dict={gi: 0})
     
           exc_str = ("Tensor at index '0' for key '0' " 'has already been removed.')
     
    -      self.assertTrue(exc_str in cm.exception.message)
    +      self.assertIn(exc_str, cm.exception.message)
     
           # Obtain 'v' value associated with key 0
    -      self.assertTrue(
    -          sess.run([key_v, get_v], feed_dict={
    -              gi: 0
    -          }) == [0, {
    +      self.assertEqual(
    +          sess.run([key_v, get_v], feed_dict={gi: 0}), [0, {
                   'v': 1
               }])
           # 0 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [0, 1])
    +      self.assertEqual(sess.run([size, isize]), [0, 1])
     
           # Now complete key 1 with tuple entry v
           sess.run(stage_v, feed_dict={pi: 1, v: 1})
           # 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 0])
    +      self.assertEqual(sess.run([size, isize]), [1, 0])
     
           # Pop without key to obtain 'x' and 'f' values associated with key 1
    -      self.assertTrue(sess.run([pop_key_xf, pop_xf]) == [1, {'x': 1, 'f': 2}])
    +      self.assertEqual(sess.run([pop_key_xf, pop_xf]), [1, {'x': 1, 'f': 2}])
           # still 1 complete and 1 incomplete entry
    -      self.assertTrue(sess.run([size, isize]) == [1, 0])
    +      self.assertEqual(sess.run([size, isize]), [1, 0])
           # We can now obtain 'x' and 'f' values associated with key 1
    -      self.assertTrue(
    -          sess.run([pop_key_v, pop_v], feed_dict={
    -              pi: 1
    -          }) == [1, {
    +      self.assertEqual(
    +          sess.run([pop_key_v, pop_v], feed_dict={pi: 1}), [1, {
                   'v': 1
               }])
           # Nothing is left
    -      self.assertTrue(sess.run([size, isize]) == [0, 0])
    +      self.assertEqual(sess.run([size, isize]), [0, 0])
     
       @test_util.run_deprecated_v1
       def testPartialIndexGets(self):
    -    with ops.Graph().as_default() as G:
    +    with ops.Graph().as_default() as g:
           with ops.device('/cpu:0'):
             x = array_ops.placeholder(dtypes.float32)
             f = array_ops.placeholder(dtypes.float32)
    @@ -568,28 +562,72 @@ def testPartialIndexGets(self):
             size = stager.size()
             isize = stager.incomplete_size()
     
    -    G.finalize()
    +    g.finalize()
     
    -    with self.session(graph=G) as sess:
    +    with self.session(graph=g) as sess:
           # Stage complete tuple
           sess.run(stage_xvf, feed_dict={pi: 0, x: 1, f: 2, v: 3})
     
    -      self.assertTrue(sess.run([size, isize]) == [1, 0])
    +      self.assertEqual(sess.run([size, isize]), [1, 0])
     
           # Partial get using indices
    -      self.assertTrue(
    -          sess.run([key_xf, get_xf], feed_dict={
    -              gi: 0
    -          }) == [0, [1, 2]])
    +      self.assertEqual(
    +          sess.run([key_xf, get_xf], feed_dict={gi: 0}), [0, [1, 2]])
     
           # Still some of key 0 left
    -      self.assertTrue(sess.run([size, isize]) == [1, 0])
    +      self.assertEqual(sess.run([size, isize]), [1, 0])
     
           # Partial get of remaining index
    -      self.assertTrue(sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, [3]])
    +      self.assertEqual(sess.run([key_v, get_v], feed_dict={gi: 0}), [0, [3]])
     
           # All gone
    -      self.assertTrue(sess.run([size, isize]) == [0, 0])
    +      self.assertEqual(sess.run([size, isize]), [0, 0])
    +
    +  @test_util.run_deprecated_v1
    +  def testNonScalarKeyOrderedMap(self):
    +    with ops.Graph().as_default() as g:
    +      x = array_ops.placeholder(dtypes.float32)
    +      v = 2. * (array_ops.zeros([128, 128]) + x)
    +      t = data_flow_ops.gen_data_flow_ops.ordered_map_stage(
    +          key=constant_op.constant(value=[1], shape=(1, 3), dtype=dtypes.int64),
    +          indices=np.array([[6]]),
    +          values=[x, v],
    +          dtypes=[dtypes.int64],
    +          capacity=0,
    +          memory_limit=0,
    +          container='container1',
    +          shared_name='',
    +          name=None)
    +
    +    g.finalize()
    +
    +    with self.session(graph=g) as sess:
    +      with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                  'key must be an int64 scalar'):
    +        sess.run(t, feed_dict={x: 1})
    +
    +  @test_util.run_deprecated_v1
    +  def testNonScalarKeyUnorderedMap(self):
    +    with ops.Graph().as_default() as g:
    +      x = array_ops.placeholder(dtypes.float32)
    +      v = 2. * (array_ops.zeros([128, 128]) + x)
    +      t = data_flow_ops.gen_data_flow_ops.map_stage(
    +          key=constant_op.constant(value=[1], shape=(1, 3), dtype=dtypes.int64),
    +          indices=np.array([[6]]),
    +          values=[x, v],
    +          dtypes=[dtypes.int64],
    +          capacity=0,
    +          memory_limit=0,
    +          container='container1',
    +          shared_name='',
    +          name=None)
    +
    +    g.finalize()
    +
    +    with self.session(graph=g) as sess:
    +      with self.assertRaisesRegex(errors.InvalidArgumentError,
    +                                  'key must be an int64 scalar'):
    +        sess.run(t, feed_dict={x: 1})
     
     
     if __name__ == '__main__':
    

Vulnerability mechanics

Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

7

News mentions

0

No linked articles in our index yet.