High severityOSV Advisory· Published Jan 15, 2026· Updated Jan 15, 2026
Denial of Service in Keras via Excessive Memory Allocation in HDF5 Metadata
CVE-2026-0897
Description
Allocation of Resources Without Limits or Throttling in the HDF5 weight loading component in Google Keras 3.0.0 through 3.13.0 on all platforms allows a remote attacker to cause a Denial of Service (DoS) through memory exhaustion and a crash of the Python interpreter via a crafted .keras archive containing a valid model.weights.h5 file whose dataset declares an extremely large shape.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
kerasPyPI | >= 3.0.0, < 3.12.1 | 3.12.1 |
kerasPyPI | >= 3.13.0, < 3.13.2 | 3.13.2 |
Affected products
1- Range: v3.0.0, v3.0.1, v3.0.2, …
Patches
2f704c887bf453.12.1 cherry pick changes for patch release. (#22081)
6 files changed · +200 −75
keras/src/export/tfsm_layer.py+34 −0 modified@@ -2,6 +2,7 @@ from keras.src import layers from keras.src.api_export import keras_export from keras.src.export.saved_model import _list_variables_used_by_fns +from keras.src.saving import serialization_lib from keras.src.utils.module_utils import tensorflow as tf @@ -146,3 +147,36 @@ def get_config(self): "call_training_endpoint": self.call_training_endpoint, } return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None, safe_mode=None): + """Creates a TFSMLayer from its config. + Args: + config: A Python dictionary, typically the output of `get_config`. + custom_objects: Optional dictionary mapping names to custom objects. + safe_mode: Boolean, whether to disallow loading TFSMLayer. + When `safe_mode=True`, loading is disallowed because TFSMLayer + loads external SavedModels that may contain attacker-controlled + executable graph code. Defaults to `True`. + Returns: + A TFSMLayer instance. + """ + # Follow the same pattern as Lambda layer for safe_mode handling + effective_safe_mode = ( + safe_mode + if safe_mode is not None + else serialization_lib.in_safe_mode() + ) + + if effective_safe_mode is not False: + raise ValueError( + "Requested the deserialization of a `TFSMLayer`, which " + "loads an external SavedModel. This carries a potential risk " + "of arbitrary code execution and thus it is disallowed by " + "default. If you trust the source of the artifact, you can " + "override this error by passing `safe_mode=False` to the " + "loading function, or calling " + "`keras.config.enable_unsafe_deserialization()." + ) + + return cls(**config)
keras/src/export/tfsm_layer_test.py+32 −3 modified@@ -114,19 +114,48 @@ def test_serialization(self): # Test reinstantiation from config config = reloaded_layer.get_config() - rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config) + rereloaded_layer = tfsm_layer.TFSMLayer.from_config( + config, safe_mode=False + ) self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) # Test whole model saving with reloaded layer inside model = models.Sequential([reloaded_layer]) temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") model.save(temp_model_filepath, save_format="keras_v3") reloaded_model = saving_lib.load_model( - temp_model_filepath, - custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer}, + temp_model_filepath, safe_mode=False ) self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) + def test_safe_mode_blocks_model_loading(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + # Create and export a model + model = get_model() + model(tf.random.normal((1, 10))) + saved_model.export_saved_model(model, temp_filepath) + + # Wrap SavedModel in TFSMLayer and save as .keras + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + wrapper_model = models.Sequential([reloaded_layer]) + + model_path = os.path.join(self.get_temp_dir(), "tfsm_model.keras") + wrapper_model.save(model_path) + + # Default safe_mode=True should block loading + with self.assertRaisesRegex( + ValueError, + "arbitrary code execution", + ): + saving_lib.load_model(model_path) + + # Explicit opt-out should allow loading + loaded_model = saving_lib.load_model(model_path, safe_mode=False) + + x = tf.random.normal((2, 10)) + self.assertAllClose(loaded_model(x), wrapper_model(x)) + def test_errors(self): # Test missing call endpoint temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
keras/src/saving/file_editor.py+87 −6 modified@@ -455,33 +455,114 @@ def resave_weights(self, filepath): def _extract_weights_from_store(self, data, metadata=None, inner_path=""): metadata = metadata or {} + # ------------------------------------------------------ + # Collect metadata for this HDF5 group + # ------------------------------------------------------ object_metadata = {} for k, v in data.attrs.items(): object_metadata[k] = v if object_metadata: metadata[inner_path] = object_metadata result = collections.OrderedDict() + + # ------------------------------------------------------ + # Iterate over all keys in this HDF5 group + # ------------------------------------------------------ for key in data.keys(): - inner_path = f"{inner_path}/{key}" + # IMPORTANT: + # Never mutate inner_path; use local variable. + current_inner_path = f"{inner_path}/{key}" value = data[key] + + # ------------------------------------------------------ + # CASE 1 — HDF5 GROUP → RECURSE + # ------------------------------------------------------ if isinstance(value, h5py.Group): + # Skip empty groups if len(value) == 0: continue + + # Skip empty "vars" groups if "vars" in value.keys() and len(value["vars"]) == 0: continue - if hasattr(value, "keys"): + # Recurse into "vars" subgroup when present if "vars" in value.keys(): result[key], metadata = self._extract_weights_from_store( - value["vars"], metadata=metadata, inner_path=inner_path + value["vars"], + metadata=metadata, + inner_path=current_inner_path, ) else: + # Recurse normally result[key], metadata = self._extract_weights_from_store( - value, metadata=metadata, inner_path=inner_path + value, + metadata=metadata, + inner_path=current_inner_path, ) - else: - result[key] = value[()] + + continue # finished processing this key + + # ------------------------------------------------------ + # CASE 2 — HDF5 DATASET → SAFE LOADING + # ------------------------------------------------------ + + # Skip any objects that are not proper datasets + if not isinstance(value, h5py.Dataset): + continue + + if value.external: + raise ValueError( + "Not allowed: H5 file Dataset with external links: " + f"{value.external}" + ) + + shape = value.shape + dtype = value.dtype + + # ------------------------------------------------------ + # Validate SHAPE (avoid malformed / malicious metadata) + # ------------------------------------------------------ + + # No negative dimensions + if any(dim < 0 for dim in shape): + raise ValueError( + "Malformed HDF5 dataset shape encountered in .keras file; " + "negative dimension detected." + ) + + # Prevent absurdly high-rank tensors + if len(shape) > 64: + raise ValueError( + "Malformed HDF5 dataset shape encountered in .keras file; " + "tensor rank exceeds safety limit." + ) + + # Safe product computation (Python int is unbounded) + num_elems = int(np.prod(shape)) + + # ------------------------------------------------------ + # Validate TOTAL memory size + # ------------------------------------------------------ + MAX_BYTES = 1 << 32 # 4 GiB + + size_bytes = num_elems * dtype.itemsize + + if size_bytes > MAX_BYTES: + raise ValueError( + f"HDF5 dataset too large to load safely " + f"({size_bytes} bytes; limit is {MAX_BYTES})." + ) + + # ------------------------------------------------------ + # SAFE — load dataset (guaranteed ≤ 4 GiB) + # ------------------------------------------------------ + result[key] = value[()] + + # ------------------------------------------------------ + # Return final tree and metadata + # ------------------------------------------------------ return result, metadata def _generate_filepath_info(self, rich_style=False):
keras/src/saving/saving_lib.py+46 −47 modified@@ -796,7 +796,8 @@ def _load_state( try: saveable.load_own_variables(weights_store.get(inner_path)) except Exception as e: - failed_saveables.add(id(saveable)) + if failed_saveables is not None: + failed_saveables.add(id(saveable)) error_msgs[id(saveable)] = saveable, e failure = True else: @@ -807,7 +808,8 @@ def _load_state( try: saveable.load_assets(assets_store.get(inner_path)) except Exception as e: - failed_saveables.add(id(saveable)) + if failed_saveables is not None: + failed_saveables.add(id(saveable)) error_msgs[id(saveable)] = saveable, e failure = True else: @@ -855,7 +857,7 @@ def _load_state( if not failure: if visited_saveables is not None and newly_failed <= 0: visited_saveables.add(id(saveable)) - if id(saveable) in failed_saveables: + if failed_saveables is not None and id(saveable) in failed_saveables: failed_saveables.remove(id(saveable)) error_msgs.pop(id(saveable)) @@ -1035,6 +1037,25 @@ def __bool__(self): # will mistakenly using `__len__` to determine the value. return self.h5_file.__bool__() + def _verify_group(self, group): + if not isinstance(group, h5py.Group): + raise ValueError( + f"Invalid H5 file, expected Group but received {type(group)}" + ) + return group + + def _verify_dataset(self, dataset): + if not isinstance(dataset, h5py.Dataset): + raise ValueError( + f"Invalid H5 file, expected Dataset, received {type(dataset)}" + ) + if dataset.external: + raise ValueError( + "Not allowed: H5 file Dataset with external links: " + f"{dataset.external}" + ) + return dataset + def _get_h5_file(self, path_or_io, mode=None): mode = mode or self.mode if mode not in ("r", "w", "a"): @@ -1094,15 +1115,19 @@ def get(self, path): self._h5_entry_group = {} # Defaults to an empty dict if not found. if not path: if "vars" in self.h5_file: - self._h5_entry_group = self.h5_file["vars"] + self._h5_entry_group = self._verify_group(self.h5_file["vars"]) elif path in self.h5_file and "vars" in self.h5_file[path]: - self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_group = self._verify_group( + self._verify_group(self.h5_file[path])["vars"] + ) else: # No hit. Fix for 2.13 compatibility. if "_layer_checkpoint_dependencies" in self.h5_file: path = path.replace("layers", "_layer_checkpoint_dependencies") if path in self.h5_file and "vars" in self.h5_file[path]: - self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_group = self._verify_group( + self._verify_group(self.h5_file[path])["vars"] + ) self._h5_entry_initialized = True return self @@ -1134,25 +1159,15 @@ def __len__(self): def keys(self): return self._h5_entry_group.keys() - def items(self): - return self._h5_entry_group.items() - - def values(self): - return self._h5_entry_group.values() - def __getitem__(self, key): - value = self._h5_entry_group[key] + value = self._verify_dataset(self._h5_entry_group[key]) if ( hasattr(value, "attrs") and "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16" ): value = np.array(value, dtype=ml_dtypes.bfloat16) - elif ( - hasattr(value, "shape") - and hasattr(value, "dtype") - and not isinstance(value, np.ndarray) - ): + elif not isinstance(value, np.ndarray): value = np.array(value) return value @@ -1355,25 +1370,25 @@ def _switch_h5_file(self, filename, mode): self._get_h5_group(self._h5_entry_path) def _restore_h5_file(self): - """Ensure the current shard is the last one created. - - We use mode="a" to avoid truncating the file during the switching. - """ + """Ensure the current shard is the last one created.""" if ( pathlib.Path(self.h5_file.filename).name != self.current_shard_path.name ): - self._switch_h5_file(self.current_shard_path.name, mode="a") + mode = "a" if self.mode == "w" else "r" + self._switch_h5_file(self.current_shard_path.name, mode=mode) # H5 entry level methods. def _get_h5_group(self, path): """Get the H5 entry group. If it doesn't exist, return an empty dict.""" try: if not path: - self._h5_entry_group = self.h5_file["vars"] + self._h5_entry_group = self._verify_group(self.h5_file["vars"]) else: - self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_group = self._verify_group( + self._verify_group(self.h5_file[path])["vars"] + ) self._h5_entry_initialized = True except KeyError: self._h5_entry_group = {} @@ -1392,33 +1407,17 @@ def __len__(self): return total_len def keys(self): - keys = set(self._h5_entry_group.keys()) + keys = [] + current_shard_keys = list(self._h5_entry_group.keys()) for filename in self.current_shard_filenames: if filename == self.current_shard_path.name: - continue - self._switch_h5_file(filename, mode="r") - keys.update(self._h5_entry_group.keys()) + keys += current_shard_keys + else: + self._switch_h5_file(filename, mode="r") + keys += list(self._h5_entry_group.keys()) self._restore_h5_file() return keys - def items(self): - yield from self._h5_entry_group.items() - for filename in self.current_shard_filenames: - if filename == self.current_shard_path.name: - continue - self._switch_h5_file(filename, mode="r") - yield from self._h5_entry_group.items() - self._restore_h5_file() - - def values(self): - yield from self._h5_entry_group.values() - for filename in self.current_shard_filenames: - if filename == self.current_shard_path.name: - continue - self._switch_h5_file(filename, mode="r") - yield from self._h5_entry_group.values() - self._restore_h5_file() - def __getitem__(self, key): if key in self._h5_entry_group: return super().__getitem__(key)
keras/src/saving/saving_lib_test.py+0 −18 modified@@ -1319,24 +1319,6 @@ def test_sharded_h5_io_store_basics(self): for key in ["a", "b"]: self.assertIn(key, vars_store.keys()) - # Items. - for key, value in vars_store.items(): - if key == "a": - self.assertAllClose(value, a) - elif key == "b": - self.assertAllClose(value, b) - else: - raise ValueError(f"Unexpected key: {key}") - - # Values. - for value in vars_store.values(): - if backend.standardize_dtype(value.dtype) == "float32": - self.assertAllClose(value, a) - elif backend.standardize_dtype(value.dtype) == "int32": - self.assertAllClose(value, b) - else: - raise ValueError(f"Unexpected value: {value}") - def test_sharded_h5_io_store_exception_raised(self): temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5"))
keras/src/version.py+1 −1 modified@@ -1,7 +1,7 @@ from keras.src.api_export import keras_export # Unique source of truth for the version number. -__version__ = "3.12.0" +__version__ = "3.12.1" @keras_export("keras.version")
7360d4f0d764Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (#21880)
1 file changed · +81 −6
keras/src/saving/file_editor.py+81 −6 modified@@ -455,33 +455,108 @@ def resave_weights(self, filepath): def _extract_weights_from_store(self, data, metadata=None, inner_path=""): metadata = metadata or {} + # ------------------------------------------------------ + # Collect metadata for this HDF5 group + # ------------------------------------------------------ object_metadata = {} for k, v in data.attrs.items(): object_metadata[k] = v if object_metadata: metadata[inner_path] = object_metadata result = collections.OrderedDict() + + # ------------------------------------------------------ + # Iterate over all keys in this HDF5 group + # ------------------------------------------------------ for key in data.keys(): - inner_path = f"{inner_path}/{key}" + # IMPORTANT: + # Never mutate inner_path; use local variable. + current_inner_path = f"{inner_path}/{key}" value = data[key] + + # ------------------------------------------------------ + # CASE 1 — HDF5 GROUP → RECURSE + # ------------------------------------------------------ if isinstance(value, h5py.Group): + # Skip empty groups if len(value) == 0: continue + + # Skip empty "vars" groups if "vars" in value.keys() and len(value["vars"]) == 0: continue - if hasattr(value, "keys"): + # Recurse into "vars" subgroup when present if "vars" in value.keys(): result[key], metadata = self._extract_weights_from_store( - value["vars"], metadata=metadata, inner_path=inner_path + value["vars"], + metadata=metadata, + inner_path=current_inner_path, ) else: + # Recurse normally result[key], metadata = self._extract_weights_from_store( - value, metadata=metadata, inner_path=inner_path + value, + metadata=metadata, + inner_path=current_inner_path, ) - else: - result[key] = value[()] + + continue # finished processing this key + + # ------------------------------------------------------ + # CASE 2 — HDF5 DATASET → SAFE LOADING + # ------------------------------------------------------ + + # Skip any objects that are not proper datasets + if not hasattr(value, "shape") or not hasattr(value, "dtype"): + continue + + shape = value.shape + dtype = value.dtype + + # ------------------------------------------------------ + # Validate SHAPE (avoid malformed / malicious metadata) + # ------------------------------------------------------ + + # No negative dimensions + if any(dim < 0 for dim in shape): + raise ValueError( + "Malformed HDF5 dataset shape encountered in .keras file; " + "negative dimension detected." + ) + + # Prevent absurdly high-rank tensors + if len(shape) > 64: + raise ValueError( + "Malformed HDF5 dataset shape encountered in .keras file; " + "tensor rank exceeds safety limit." + ) + + # Safe product computation (Python int is unbounded) + num_elems = int(np.prod(shape)) + + # ------------------------------------------------------ + # Validate TOTAL memory size + # ------------------------------------------------------ + MAX_BYTES = 1 << 32 # 4 GiB + + size_bytes = num_elems * dtype.itemsize + + if size_bytes > MAX_BYTES: + raise ValueError( + f"HDF5 dataset too large to load safely " + f"({size_bytes} bytes; limit is {MAX_BYTES})." + ) + + # ------------------------------------------------------ + # SAFE — load dataset (guaranteed ≤ 4 GiB) + # ------------------------------------------------------ + result[key] = value[()] + + # ------------------------------------------------------ + # Return final tree and metadata + # ------------------------------------------------------ return result, metadata def _generate_filepath_info(self, rich_style=False):
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
7- github.com/advisories/GHSA-mgx6-5cf9-rr43ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2026-0897ghsaADVISORY
- github.com/keras-team/keras/commit/7360d4f0d764fbb1fa9c6408fe53da41974dd4f6ghsaWEB
- github.com/keras-team/keras/commit/f704c887bf459b42769bfc8a9182f838009afddbghsaWEB
- github.com/keras-team/keras/pull/21880ghsaWEB
- github.com/keras-team/keras/pull/22081ghsaWEB
- github.com/keras-team/keras/security/advisories/GHSA-mgx6-5cf9-rr43ghsaWEB
News mentions
0No linked articles in our index yet.