VYPR
High severityOSV Advisory· Published Jan 15, 2026· Updated Jan 15, 2026

Denial of Service in Keras via Excessive Memory Allocation in HDF5 Metadata

CVE-2026-0897

Description

Allocation of Resources Without Limits or Throttling in the HDF5 weight loading component in Google Keras 3.0.0 through 3.13.0 on all platforms allows a remote attacker to cause a Denial of Service (DoS) through memory exhaustion and a crash of the Python interpreter via a crafted .keras archive containing a valid model.weights.h5 file whose dataset declares an extremely large shape.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
kerasPyPI
>= 3.0.0, < 3.12.13.12.1
kerasPyPI
>= 3.13.0, < 3.13.23.13.2

Affected products

1

Patches

2
f704c887bf45

3.12.1 cherry pick changes for patch release. (#22081)

https://github.com/keras-team/kerasSachin PrasadJan 30, 2026via ghsa
6 files changed · +200 75
  • keras/src/export/tfsm_layer.py+34 0 modified
    @@ -2,6 +2,7 @@
     from keras.src import layers
     from keras.src.api_export import keras_export
     from keras.src.export.saved_model import _list_variables_used_by_fns
    +from keras.src.saving import serialization_lib
     from keras.src.utils.module_utils import tensorflow as tf
     
     
    @@ -146,3 +147,36 @@ def get_config(self):
                 "call_training_endpoint": self.call_training_endpoint,
             }
             return {**base_config, **config}
    +
    +    @classmethod
    +    def from_config(cls, config, custom_objects=None, safe_mode=None):
    +        """Creates a TFSMLayer from its config.
    +        Args:
    +            config: A Python dictionary, typically the output of `get_config`.
    +            custom_objects: Optional dictionary mapping names to custom objects.
    +            safe_mode: Boolean, whether to disallow loading TFSMLayer.
    +                When `safe_mode=True`, loading is disallowed because TFSMLayer
    +                loads external SavedModels that may contain attacker-controlled
    +                executable graph code. Defaults to `True`.
    +        Returns:
    +            A TFSMLayer instance.
    +        """
    +        # Follow the same pattern as Lambda layer for safe_mode handling
    +        effective_safe_mode = (
    +            safe_mode
    +            if safe_mode is not None
    +            else serialization_lib.in_safe_mode()
    +        )
    +
    +        if effective_safe_mode is not False:
    +            raise ValueError(
    +                "Requested the deserialization of a `TFSMLayer`, which "
    +                "loads an external SavedModel. This carries a potential risk "
    +                "of arbitrary code execution and thus it is disallowed by "
    +                "default. If you trust the source of the artifact, you can "
    +                "override this error by passing `safe_mode=False` to the "
    +                "loading function, or calling "
    +                "`keras.config.enable_unsafe_deserialization()."
    +            )
    +
    +        return cls(**config)
    
  • keras/src/export/tfsm_layer_test.py+32 3 modified
    @@ -114,19 +114,48 @@ def test_serialization(self):
     
             # Test reinstantiation from config
             config = reloaded_layer.get_config()
    -        rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config)
    +        rereloaded_layer = tfsm_layer.TFSMLayer.from_config(
    +            config, safe_mode=False
    +        )
             self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7)
     
             # Test whole model saving with reloaded layer inside
             model = models.Sequential([reloaded_layer])
             temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras")
             model.save(temp_model_filepath, save_format="keras_v3")
             reloaded_model = saving_lib.load_model(
    -            temp_model_filepath,
    -            custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer},
    +            temp_model_filepath, safe_mode=False
             )
             self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7)
     
    +    def test_safe_mode_blocks_model_loading(self):
    +        temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
    +
    +        # Create and export a model
    +        model = get_model()
    +        model(tf.random.normal((1, 10)))
    +        saved_model.export_saved_model(model, temp_filepath)
    +
    +        # Wrap SavedModel in TFSMLayer and save as .keras
    +        reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)
    +        wrapper_model = models.Sequential([reloaded_layer])
    +
    +        model_path = os.path.join(self.get_temp_dir(), "tfsm_model.keras")
    +        wrapper_model.save(model_path)
    +
    +        # Default safe_mode=True should block loading
    +        with self.assertRaisesRegex(
    +            ValueError,
    +            "arbitrary code execution",
    +        ):
    +            saving_lib.load_model(model_path)
    +
    +        # Explicit opt-out should allow loading
    +        loaded_model = saving_lib.load_model(model_path, safe_mode=False)
    +
    +        x = tf.random.normal((2, 10))
    +        self.assertAllClose(loaded_model(x), wrapper_model(x))
    +
         def test_errors(self):
             # Test missing call endpoint
             temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
    
  • keras/src/saving/file_editor.py+87 6 modified
    @@ -455,33 +455,114 @@ def resave_weights(self, filepath):
         def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
             metadata = metadata or {}
     
    +        # ------------------------------------------------------
    +        # Collect metadata for this HDF5 group
    +        # ------------------------------------------------------
             object_metadata = {}
             for k, v in data.attrs.items():
                 object_metadata[k] = v
             if object_metadata:
                 metadata[inner_path] = object_metadata
     
             result = collections.OrderedDict()
    +
    +        # ------------------------------------------------------
    +        # Iterate over all keys in this HDF5 group
    +        # ------------------------------------------------------
             for key in data.keys():
    -            inner_path = f"{inner_path}/{key}"
    +            # IMPORTANT:
    +            # Never mutate inner_path; use local variable.
    +            current_inner_path = f"{inner_path}/{key}"
                 value = data[key]
    +
    +            # ------------------------------------------------------
    +            # CASE 1 — HDF5 GROUP → RECURSE
    +            # ------------------------------------------------------
                 if isinstance(value, h5py.Group):
    +                # Skip empty groups
                     if len(value) == 0:
                         continue
    +
    +                # Skip empty "vars" groups
                     if "vars" in value.keys() and len(value["vars"]) == 0:
                         continue
     
    -            if hasattr(value, "keys"):
    +                # Recurse into "vars" subgroup when present
                     if "vars" in value.keys():
                         result[key], metadata = self._extract_weights_from_store(
    -                        value["vars"], metadata=metadata, inner_path=inner_path
    +                        value["vars"],
    +                        metadata=metadata,
    +                        inner_path=current_inner_path,
                         )
                     else:
    +                    # Recurse normally
                         result[key], metadata = self._extract_weights_from_store(
    -                        value, metadata=metadata, inner_path=inner_path
    +                        value,
    +                        metadata=metadata,
    +                        inner_path=current_inner_path,
                         )
    -            else:
    -                result[key] = value[()]
    +
    +                continue  # finished processing this key
    +
    +            # ------------------------------------------------------
    +            # CASE 2 — HDF5 DATASET → SAFE LOADING
    +            # ------------------------------------------------------
    +
    +            # Skip any objects that are not proper datasets
    +            if not isinstance(value, h5py.Dataset):
    +                continue
    +
    +            if value.external:
    +                raise ValueError(
    +                    "Not allowed: H5 file Dataset with external links: "
    +                    f"{value.external}"
    +                )
    +
    +            shape = value.shape
    +            dtype = value.dtype
    +
    +            # ------------------------------------------------------
    +            # Validate SHAPE (avoid malformed / malicious metadata)
    +            # ------------------------------------------------------
    +
    +            # No negative dimensions
    +            if any(dim < 0 for dim in shape):
    +                raise ValueError(
    +                    "Malformed HDF5 dataset shape encountered in .keras file; "
    +                    "negative dimension detected."
    +                )
    +
    +            # Prevent absurdly high-rank tensors
    +            if len(shape) > 64:
    +                raise ValueError(
    +                    "Malformed HDF5 dataset shape encountered in .keras file; "
    +                    "tensor rank exceeds safety limit."
    +                )
    +
    +            # Safe product computation (Python int is unbounded)
    +            num_elems = int(np.prod(shape))
    +
    +            # ------------------------------------------------------
    +            # Validate TOTAL memory size
    +            # ------------------------------------------------------
    +            MAX_BYTES = 1 << 32  # 4 GiB
    +
    +            size_bytes = num_elems * dtype.itemsize
    +
    +            if size_bytes > MAX_BYTES:
    +                raise ValueError(
    +                    f"HDF5 dataset too large to load safely "
    +                    f"({size_bytes} bytes; limit is {MAX_BYTES})."
    +                )
    +
    +            # ------------------------------------------------------
    +            # SAFE — load dataset (guaranteed ≤ 4 GiB)
    +            # ------------------------------------------------------
    +            result[key] = value[()]
    +
    +        # ------------------------------------------------------
    +        # Return final tree and metadata
    +        # ------------------------------------------------------
             return result, metadata
     
         def _generate_filepath_info(self, rich_style=False):
    
  • keras/src/saving/saving_lib.py+46 47 modified
    @@ -796,7 +796,8 @@ def _load_state(
                 try:
                     saveable.load_own_variables(weights_store.get(inner_path))
                 except Exception as e:
    -                failed_saveables.add(id(saveable))
    +                if failed_saveables is not None:
    +                    failed_saveables.add(id(saveable))
                     error_msgs[id(saveable)] = saveable, e
                     failure = True
             else:
    @@ -807,7 +808,8 @@ def _load_state(
                 try:
                     saveable.load_assets(assets_store.get(inner_path))
                 except Exception as e:
    -                failed_saveables.add(id(saveable))
    +                if failed_saveables is not None:
    +                    failed_saveables.add(id(saveable))
                     error_msgs[id(saveable)] = saveable, e
                     failure = True
             else:
    @@ -855,7 +857,7 @@ def _load_state(
         if not failure:
             if visited_saveables is not None and newly_failed <= 0:
                 visited_saveables.add(id(saveable))
    -        if id(saveable) in failed_saveables:
    +        if failed_saveables is not None and id(saveable) in failed_saveables:
                 failed_saveables.remove(id(saveable))
                 error_msgs.pop(id(saveable))
     
    @@ -1035,6 +1037,25 @@ def __bool__(self):
             # will mistakenly using `__len__` to determine the value.
             return self.h5_file.__bool__()
     
    +    def _verify_group(self, group):
    +        if not isinstance(group, h5py.Group):
    +            raise ValueError(
    +                f"Invalid H5 file, expected Group but received {type(group)}"
    +            )
    +        return group
    +
    +    def _verify_dataset(self, dataset):
    +        if not isinstance(dataset, h5py.Dataset):
    +            raise ValueError(
    +                f"Invalid H5 file, expected Dataset, received {type(dataset)}"
    +            )
    +        if dataset.external:
    +            raise ValueError(
    +                "Not allowed: H5 file Dataset with external links: "
    +                f"{dataset.external}"
    +            )
    +        return dataset
    +
         def _get_h5_file(self, path_or_io, mode=None):
             mode = mode or self.mode
             if mode not in ("r", "w", "a"):
    @@ -1094,15 +1115,19 @@ def get(self, path):
             self._h5_entry_group = {}  # Defaults to an empty dict if not found.
             if not path:
                 if "vars" in self.h5_file:
    -                self._h5_entry_group = self.h5_file["vars"]
    +                self._h5_entry_group = self._verify_group(self.h5_file["vars"])
             elif path in self.h5_file and "vars" in self.h5_file[path]:
    -            self._h5_entry_group = self.h5_file[path]["vars"]
    +            self._h5_entry_group = self._verify_group(
    +                self._verify_group(self.h5_file[path])["vars"]
    +            )
             else:
                 # No hit. Fix for 2.13 compatibility.
                 if "_layer_checkpoint_dependencies" in self.h5_file:
                     path = path.replace("layers", "_layer_checkpoint_dependencies")
                     if path in self.h5_file and "vars" in self.h5_file[path]:
    -                    self._h5_entry_group = self.h5_file[path]["vars"]
    +                    self._h5_entry_group = self._verify_group(
    +                        self._verify_group(self.h5_file[path])["vars"]
    +                    )
             self._h5_entry_initialized = True
             return self
     
    @@ -1134,25 +1159,15 @@ def __len__(self):
         def keys(self):
             return self._h5_entry_group.keys()
     
    -    def items(self):
    -        return self._h5_entry_group.items()
    -
    -    def values(self):
    -        return self._h5_entry_group.values()
    -
         def __getitem__(self, key):
    -        value = self._h5_entry_group[key]
    +        value = self._verify_dataset(self._h5_entry_group[key])
             if (
                 hasattr(value, "attrs")
                 and "dtype" in value.attrs
                 and value.attrs["dtype"] == "bfloat16"
             ):
                 value = np.array(value, dtype=ml_dtypes.bfloat16)
    -        elif (
    -            hasattr(value, "shape")
    -            and hasattr(value, "dtype")
    -            and not isinstance(value, np.ndarray)
    -        ):
    +        elif not isinstance(value, np.ndarray):
                 value = np.array(value)
             return value
     
    @@ -1355,25 +1370,25 @@ def _switch_h5_file(self, filename, mode):
             self._get_h5_group(self._h5_entry_path)
     
         def _restore_h5_file(self):
    -        """Ensure the current shard is the last one created.
    -
    -        We use mode="a" to avoid truncating the file during the switching.
    -        """
    +        """Ensure the current shard is the last one created."""
             if (
                 pathlib.Path(self.h5_file.filename).name
                 != self.current_shard_path.name
             ):
    -            self._switch_h5_file(self.current_shard_path.name, mode="a")
    +            mode = "a" if self.mode == "w" else "r"
    +            self._switch_h5_file(self.current_shard_path.name, mode=mode)
     
         # H5 entry level methods.
     
         def _get_h5_group(self, path):
             """Get the H5 entry group. If it doesn't exist, return an empty dict."""
             try:
                 if not path:
    -                self._h5_entry_group = self.h5_file["vars"]
    +                self._h5_entry_group = self._verify_group(self.h5_file["vars"])
                 else:
    -                self._h5_entry_group = self.h5_file[path]["vars"]
    +                self._h5_entry_group = self._verify_group(
    +                    self._verify_group(self.h5_file[path])["vars"]
    +                )
                 self._h5_entry_initialized = True
             except KeyError:
                 self._h5_entry_group = {}
    @@ -1392,33 +1407,17 @@ def __len__(self):
             return total_len
     
         def keys(self):
    -        keys = set(self._h5_entry_group.keys())
    +        keys = []
    +        current_shard_keys = list(self._h5_entry_group.keys())
             for filename in self.current_shard_filenames:
                 if filename == self.current_shard_path.name:
    -                continue
    -            self._switch_h5_file(filename, mode="r")
    -            keys.update(self._h5_entry_group.keys())
    +                keys += current_shard_keys
    +            else:
    +                self._switch_h5_file(filename, mode="r")
    +                keys += list(self._h5_entry_group.keys())
             self._restore_h5_file()
             return keys
     
    -    def items(self):
    -        yield from self._h5_entry_group.items()
    -        for filename in self.current_shard_filenames:
    -            if filename == self.current_shard_path.name:
    -                continue
    -            self._switch_h5_file(filename, mode="r")
    -            yield from self._h5_entry_group.items()
    -        self._restore_h5_file()
    -
    -    def values(self):
    -        yield from self._h5_entry_group.values()
    -        for filename in self.current_shard_filenames:
    -            if filename == self.current_shard_path.name:
    -                continue
    -            self._switch_h5_file(filename, mode="r")
    -            yield from self._h5_entry_group.values()
    -        self._restore_h5_file()
    -
         def __getitem__(self, key):
             if key in self._h5_entry_group:
                 return super().__getitem__(key)
    
  • keras/src/saving/saving_lib_test.py+0 18 modified
    @@ -1319,24 +1319,6 @@ def test_sharded_h5_io_store_basics(self):
             for key in ["a", "b"]:
                 self.assertIn(key, vars_store.keys())
     
    -        # Items.
    -        for key, value in vars_store.items():
    -            if key == "a":
    -                self.assertAllClose(value, a)
    -            elif key == "b":
    -                self.assertAllClose(value, b)
    -            else:
    -                raise ValueError(f"Unexpected key: {key}")
    -
    -        # Values.
    -        for value in vars_store.values():
    -            if backend.standardize_dtype(value.dtype) == "float32":
    -                self.assertAllClose(value, a)
    -            elif backend.standardize_dtype(value.dtype) == "int32":
    -                self.assertAllClose(value, b)
    -            else:
    -                raise ValueError(f"Unexpected value: {value}")
    -
         def test_sharded_h5_io_store_exception_raised(self):
             temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5"))
     
    
  • keras/src/version.py+1 1 modified
    @@ -1,7 +1,7 @@
     from keras.src.api_export import keras_export
     
     # Unique source of truth for the version number.
    -__version__ = "3.12.0"
    +__version__ = "3.12.1"
     
     
     @keras_export("keras.version")
    
7360d4f0d764

Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (#21880)

https://github.com/keras-team/kerassarvesh patilDec 29, 2025via ghsa
1 file changed · +81 6
  • keras/src/saving/file_editor.py+81 6 modified
    @@ -455,33 +455,108 @@ def resave_weights(self, filepath):
         def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
             metadata = metadata or {}
     
    +        # ------------------------------------------------------
    +        # Collect metadata for this HDF5 group
    +        # ------------------------------------------------------
             object_metadata = {}
             for k, v in data.attrs.items():
                 object_metadata[k] = v
             if object_metadata:
                 metadata[inner_path] = object_metadata
     
             result = collections.OrderedDict()
    +
    +        # ------------------------------------------------------
    +        # Iterate over all keys in this HDF5 group
    +        # ------------------------------------------------------
             for key in data.keys():
    -            inner_path = f"{inner_path}/{key}"
    +            # IMPORTANT:
    +            # Never mutate inner_path; use local variable.
    +            current_inner_path = f"{inner_path}/{key}"
                 value = data[key]
    +
    +            # ------------------------------------------------------
    +            # CASE 1 — HDF5 GROUP → RECURSE
    +            # ------------------------------------------------------
                 if isinstance(value, h5py.Group):
    +                # Skip empty groups
                     if len(value) == 0:
                         continue
    +
    +                # Skip empty "vars" groups
                     if "vars" in value.keys() and len(value["vars"]) == 0:
                         continue
     
    -            if hasattr(value, "keys"):
    +                # Recurse into "vars" subgroup when present
                     if "vars" in value.keys():
                         result[key], metadata = self._extract_weights_from_store(
    -                        value["vars"], metadata=metadata, inner_path=inner_path
    +                        value["vars"],
    +                        metadata=metadata,
    +                        inner_path=current_inner_path,
                         )
                     else:
    +                    # Recurse normally
                         result[key], metadata = self._extract_weights_from_store(
    -                        value, metadata=metadata, inner_path=inner_path
    +                        value,
    +                        metadata=metadata,
    +                        inner_path=current_inner_path,
                         )
    -            else:
    -                result[key] = value[()]
    +
    +                continue  # finished processing this key
    +
    +            # ------------------------------------------------------
    +            # CASE 2 — HDF5 DATASET → SAFE LOADING
    +            # ------------------------------------------------------
    +
    +            # Skip any objects that are not proper datasets
    +            if not hasattr(value, "shape") or not hasattr(value, "dtype"):
    +                continue
    +
    +            shape = value.shape
    +            dtype = value.dtype
    +
    +            # ------------------------------------------------------
    +            # Validate SHAPE (avoid malformed / malicious metadata)
    +            # ------------------------------------------------------
    +
    +            # No negative dimensions
    +            if any(dim < 0 for dim in shape):
    +                raise ValueError(
    +                    "Malformed HDF5 dataset shape encountered in .keras file; "
    +                    "negative dimension detected."
    +                )
    +
    +            # Prevent absurdly high-rank tensors
    +            if len(shape) > 64:
    +                raise ValueError(
    +                    "Malformed HDF5 dataset shape encountered in .keras file; "
    +                    "tensor rank exceeds safety limit."
    +                )
    +
    +            # Safe product computation (Python int is unbounded)
    +            num_elems = int(np.prod(shape))
    +
    +            # ------------------------------------------------------
    +            # Validate TOTAL memory size
    +            # ------------------------------------------------------
    +            MAX_BYTES = 1 << 32  # 4 GiB
    +
    +            size_bytes = num_elems * dtype.itemsize
    +
    +            if size_bytes > MAX_BYTES:
    +                raise ValueError(
    +                    f"HDF5 dataset too large to load safely "
    +                    f"({size_bytes} bytes; limit is {MAX_BYTES})."
    +                )
    +
    +            # ------------------------------------------------------
    +            # SAFE — load dataset (guaranteed ≤ 4 GiB)
    +            # ------------------------------------------------------
    +            result[key] = value[()]
    +
    +        # ------------------------------------------------------
    +        # Return final tree and metadata
    +        # ------------------------------------------------------
             return result, metadata
     
         def _generate_filepath_info(self, rich_style=False):
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

7

News mentions

0

No linked articles in our index yet.