VYPR
Moderate severityNVD Advisory· Published Nov 5, 2021· Updated Aug 4, 2024

Deadlock in mutually recursive `tf.function` objects

CVE-2021-41213

Description

TensorFlow is an open source platform for machine learning. In affected versions the code behind tf.function API can be made to deadlock when two tf.function decorated Python functions are mutually recursive. This occurs due to using a non-reentrant Lock Python object. Loading any model which contains mutually recursive functions is vulnerable. An attacker can cause denial of service by causing users to load such models and calling a recursive tf.function, although this is not a frequent scenario. The fix will be included in TensorFlow 2.7.0. We will also cherrypick this commit on TensorFlow 2.6.1, TensorFlow 2.5.2, and TensorFlow 2.4.4, as these are also affected and still in supported range.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
tensorflowPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflowPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflowPyPI
< 2.4.42.4.4
tensorflow-cpuPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflow-cpuPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflow-cpuPyPI
< 2.4.42.4.4
tensorflow-gpuPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflow-gpuPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflow-gpuPyPI
< 2.4.42.4.4

Affected products

1

Patches

1
afac8158d436

Fix the deadlock issue of recursive tf.function.

3 files changed · +116 3
  • tensorflow/python/eager/def_function.py+2 2 modified
    @@ -572,7 +572,7 @@ def __init__(self,
           ValueError: if `input_signature` is not None and the `python_function`'s
             argspec has keyword arguments.
         """
    -    self._lock = threading.Lock()
    +    self._lock = threading.RLock()
         self._python_function = python_function
         self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
             python_function,
    @@ -613,7 +613,7 @@ def __getstate__(self):
       def __setstate__(self, state):
         """Restore from pickled state."""
         self.__dict__ = state
    -    self._lock = threading.Lock()
    +    self._lock = threading.RLock()
         self._descriptor_cache = weakref.WeakKeyDictionary()
         self._key_for_call_stats = self._get_key_for_call_stats()
     
    
  • tensorflow/python/eager/def_function_test.py+113 0 modified
    @@ -25,6 +25,7 @@
     from six.moves import range
     
     from tensorflow.python.autograph.core import converter
    +from tensorflow.python.eager import backprop
     from tensorflow.python.eager import def_function
     from tensorflow.python.eager import lift_to_graph
     from tensorflow.python.framework import constant_op
    @@ -36,6 +37,7 @@
     from tensorflow.python.framework import test_util
     from tensorflow.python.module import module
     from tensorflow.python.ops import array_ops
    +from tensorflow.python.ops import cond_v2
     from tensorflow.python.ops import control_flow_ops
     from tensorflow.python.ops import math_ops
     from tensorflow.python.ops import random_ops
    @@ -1261,6 +1263,117 @@ def testDouble(self, a):
         self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
         self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
     
    +  def test_recursive_tf_function(self):
    +
    +    @def_function.function
    +    def recursive_fn(n):
    +      if n > 0:
    +        return recursive_fn(n - 1)
    +      return 1
    +
    +    self.assertEqual(recursive_fn(5).numpy(), 1)
    +
    +  def test_recursive_tf_function_with_gradients(self):
    +
    +    @def_function.function
    +    def recursive_fn(n, x):
    +      if n > 0:
    +        return n * recursive_fn(n - 1, x)
    +      else:
    +        return x
    +
    +    x = variables.Variable(1.0)
    +    with backprop.GradientTape() as tape:
    +      g = recursive_fn(5, x)
    +
    +    dg_dx = tape.gradient(g, x)
    +    self.assertEqual(dg_dx.numpy(), 120)
    +
    +  def test_recursive_python_function(self):
    +
    +    def recursive_py_fn(n):
    +      if n > 0:
    +        return recursive_py_fn(n - 1)
    +      return 1
    +
    +    @def_function.function
    +    def recursive_fn(n):
    +      return recursive_py_fn(n)
    +
    +    self.assertEqual(recursive_fn(5).numpy(), 1)
    +
    +  def test_recursive_python_function_with_gradients(self):
    +
    +    def recursive_py_fn(n, x):
    +      if n > 0:
    +        return n * recursive_py_fn(n - 1, x)
    +      return x
    +
    +    @def_function.function
    +    def recursive_fn(n, x):
    +      return recursive_py_fn(n, x)
    +
    +    x = variables.Variable(1.0)
    +    with backprop.GradientTape() as tape:
    +      g = recursive_fn(5, x)
    +
    +    dg_dx = tape.gradient(g, x)
    +    self.assertEqual(dg_dx.numpy(), 120)
    +
    +  def test_recursive_tf_function_call_each_other(self):
    +
    +    @def_function.function
    +    def recursive_fn1(n):
    +      if n <= 1:
    +        return 1
    +      return recursive_fn2(n - 1)
    +
    +    @def_function.function
    +    def recursive_fn2(n):
    +      if n <= 1:
    +        return 2
    +      return recursive_fn1(n - 1)
    +
    +    self.assertEqual(recursive_fn1(5).numpy(), 1)
    +    self.assertEqual(recursive_fn1(6).numpy(), 2)
    +    self.assertEqual(recursive_fn2(5).numpy(), 2)
    +    self.assertEqual(recursive_fn2(6).numpy(), 1)
    +
    +  def test_recursive_tf_function_call_each_other_with_gradients(self):
    +
    +    @def_function.function
    +    def recursive_fn1(n, x):
    +      if n <= 1:
    +        return x
    +      return n * recursive_fn2(n - 1, x)
    +
    +    @def_function.function
    +    def recursive_fn2(n, x):
    +      if n <= 1:
    +        return 2 * x
    +      return n * recursive_fn1(n - 1, x)
    +
    +    x = variables.Variable(1.0)
    +    with backprop.GradientTape() as tape:
    +      g1 = recursive_fn1(5, x)
    +
    +    dg1_dx = tape.gradient(g1, x)
    +    self.assertEqual(dg1_dx.numpy(), 120)
    +
    +    with backprop.GradientTape() as tape:
    +      g2 = recursive_fn2(5, x)
    +
    +    dg2_dx = tape.gradient(g2, x)
    +    self.assertEqual(dg2_dx.numpy(), 240)
    +
    +  def test_recursive_tf_function_with_cond(self):
    +    @def_function.function(autograph=False)
    +    def recursive_fn(n):
    +      return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)
    +
    +    with self.assertRaises(RecursionError):
    +      recursive_fn(constant_op.constant(5))
    +
     
     if __name__ == '__main__':
       ops.enable_eager_execution()
    
  • tensorflow/python/eager/function.py+1 1 modified
    @@ -3037,7 +3037,7 @@ def __init__(self,
         if self.input_signature is not None:
           self._hashable_input_signature = hash(self.flat_input_signature)
     
    -    self._lock = threading.Lock()
    +    self._lock = threading.RLock()
         # _descriptor_cache is a of instance of a class to an instance-specific
         # `Function`, used to make sure defun-decorated methods create different
         # functions for each instance.
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

7

News mentions

0

No linked articles in our index yet.