MONAI's unsafe torch usage may lead to arbitrary code execution
Description
MONAI (Medical Open Network for AI) is an AI toolkit for health care imaging. In versions up to and including 1.5.0, in model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True) in monai/bundle/scripts.py , weights_only=True is loaded securely. However, insecure loading methods still exist elsewhere in the project, such as when loading checkpoints. This is a common practice when users want to reduce training time and costs by loading pre-trained models downloaded from other platforms. Loading a checkpoint containing malicious content can trigger a deserialization vulnerability, leading to code execution. As of time of publication, no known fixed versions are available.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
monaiPyPI | < 1.5.1 | 1.5.1 |
Affected products
1- Range: <= 1.5.0
Patches
1948fbb703adcTorch and Pickle Safe Load Fixes (#8566)
12 files changed · +95 −77
monai/apps/nnunet/nnunet_bundle.py+17 −9 modified@@ -133,7 +133,7 @@ def get_nnunet_trainer( cudnn.benchmark = True if pretrained_model is not None: - state_dict = torch.load(pretrained_model) + state_dict = torch.load(pretrained_model, weights_only=True) if "network_weights" in state_dict: nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) return nnunet_trainer @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name parameters = [] checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), + map_location=torch.device("cpu"), + weights_only=True, ) trainer_name = checkpoint["trainer_name"] configuration_name = checkpoint["init_args"]["configuration"] @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name else None ) if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True + ) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" ) - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + nnunet_checkpoint_final = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True + ) + nnunet_checkpoint_best = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True + ) nnunet_checkpoint = {} nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans( if model_ckpt is None: return network else: - state_dict = torch.load(model_ckpt) + state_dict = torch.load(model_ckpt, weights_only=True) network.load_state_dict(state_dict[model_key_in_ckpt]) return network @@ -534,7 +542,7 @@ def subfiles( Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) - nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True) latest_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True ) @@ -545,7 +553,7 @@ def subfiles( epochs.sort() final_epoch: int = epochs[-1] monai_last_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True ) best_checkpoints: list[str] = subfiles( @@ -558,7 +566,7 @@ def subfiles( key_metrics.sort() best_key_metric: str = key_metrics[-1] monai_best_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True ) nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
monai/data/dataset.py+34 −16 modified@@ -13,7 +13,6 @@ import collections.abc import math -import pickle import shutil import sys import tempfile @@ -22,9 +21,11 @@ import warnings from collections.abc import Callable, Sequence from copy import copy, deepcopy +from io import BytesIO from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool from pathlib import Path +from pickle import UnpicklingError from typing import IO, TYPE_CHECKING, Any, cast import numpy as np @@ -207,6 +208,11 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. + Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will + be converted to tensors, however any other object type returned by transforms will not be loadable since + `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. + Legacy cache files may not be loadable and may need to be recomputed. + Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to its documentation to familiarize yourself with the interaction between `PersistentDataset` and @@ -248,8 +254,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - return torch.load(hashfile, weights_only=False) + return torch.load(hashfile, weights_only=True) except PermissionError as e: if sys.platform != "win32": raise e - except RuntimeError as e: - if "Invalid magic number; corrupt file" in str(e): + except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed + if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError): warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.") hashfile.unlink() else: @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=_item_transformed, + obj=convert_to_tensor(_item_transformed, convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -455,8 +461,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -531,7 +537,7 @@ def __init__( hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", progress: bool = True, - pickle_protocol=pickle.HIGHEST_PROTOCOL, + pickle_protocol=DEFAULT_PROTOCOL, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, lmdb_kwargs: dict | None = None, @@ -551,8 +557,9 @@ def __init__( defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". progress: whether to display a progress bar. - pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. - https://docs.python.org/3/library/pickle.html#pickle-protocols + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. @@ -594,6 +601,15 @@ def set_data(self, data: Sequence): super().set_data(data=data) self._read_env = self._fill_cache_start_reader(show_progress=self.progress) + def _safe_serialize(self, val): + out = BytesIO() + torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol) + out.seek(0) + return out.read() + + def _safe_deserialize(self, val): + return torch.load(BytesIO(val), map_location="cpu", weights_only=True) + def _fill_cache_start_reader(self, show_progress=True): """ Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write. @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True): continue if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed - val = pickle.dumps(val, protocol=self.pickle_protocol) + # val = pickle.dumps(val, protocol=self.pickle_protocol) + val = self._safe_serialize(val) with env.begin(write=True) as txn: txn.put(key, val) done = True @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed): warnings.warn("LMDBDataset: cache key not found, running fallback caching.") return super()._cachecheck(item_transformed) try: - return pickle.loads(data) + # return pickle.loads(data) + return self._safe_deserialize(data) except Exception as err: raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): meta_hash_file = self.cache_dir / meta_hash_file_name temp_hash_file = Path(tmpdirname) / meta_hash_file_name torch.save( - obj=self._meta_cache[meta_hash_file_name], + obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name): if meta_hash_file_name in self._meta_cache: return self._meta_cache[meta_hash_file_name] else: - return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False) + return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
monai/data/__init__.py+0 −1 modified@@ -78,7 +78,6 @@ from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( - PICKLE_KEY_SUFFIX, affine_to_spacing, compute_importance_map, compute_shape_offset,
monai/data/meta_tensor.py+1 −1 modified@@ -611,4 +611,4 @@ def print_verbose(self) -> None: # needed in later versions of Pytorch to indicate the class is safe for serialisation if hasattr(torch.serialization, "add_safe_globals"): - torch.serialization.add_safe_globals([MetaTensor]) + torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
monai/data/utils.py+10 −36 modified@@ -30,7 +30,6 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai import config from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.data.meta_obj import MetaObj from monai.utils import ( @@ -93,7 +92,6 @@ "remove_keys", "remove_extra_metadata", "get_extra_metadata_keys", - "PICKLE_KEY_SUFFIX", "is_no_channel", ] @@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return -PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX - - -def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): - """ - Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate. - - Args: - data: a list or dictionary with substructures to be pickled/unpickled. - key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`). - is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False). - """ - if isinstance(data, Mapping): - data = dict(data) - for k in data: - if f"{k}".endswith(key): - if is_encode and not isinstance(data[k], bytes): - data[k] = pickle.dumps(data[k], 0) - if not is_encode and isinstance(data[k], bytes): - data[k] = pickle.loads(data[k]) - return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()} - elif isinstance(data, (list, tuple)): - return [pickle_operations(item, key=key, is_encode=is_encode) for item in data] - return data - - def collate_meta_tensor_fn(batch, *, collate_fn_map=None): """ Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` @@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence): key = None collate_fn = default_collate try: - if config.USE_META_DICT: - data = pickle_operations(data) # bc 0.9.0 + # if config.USE_META_DICT: + # data = pickle_operations(data) # bc 0.9.0 if isinstance(elem, Mapping): ret = {} for k in elem: @@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if isinstance(deco, Mapping): _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) ret = [dict(zip(deco, item)) for item in _gen] - if not config.USE_META_DICT: - return ret - return pickle_operations(ret, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret + # return pickle_operations(ret, is_encode=False) # bc 0.9.0 + return ret if isinstance(deco, Iterable): _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) ret_list = [list(item) for item in _gen] - if not config.USE_META_DICT: - return ret_list - return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret_list + # return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + return ret_list raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")
monai/handlers/checkpoint_loader.py+1 −1 modified@@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False) + checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True) k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict
monai/utils/state_cacher.py+3 −3 modified@@ -64,8 +64,8 @@ def __init__( pickle_module: module used for pickling metadata and objects, default to `pickle`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. """ @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any: fn = self.cached[key]["obj"] # pytype: disable=attribute-error if not os.path.exists(fn): # pytype: disable=wrong-arg-types raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False) + data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True) # copy back to device if necessary if "device" in self.cached[key]: data_obj = data_obj.to(self.cached[key]["device"])
monai/utils/type_conversion.py+16 −4 modified@@ -117,6 +117,7 @@ def convert_to_tensor( wrap_sequence: bool = False, track_meta: bool = False, safe: bool = False, + convert_numeric: bool = True, ) -> Any: """ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`, @@ -136,6 +137,7 @@ def convert_to_tensor( safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`. If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`. + convert_numeric: if `True`, convert numeric Python values to tensors. """ @@ -156,6 +158,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if safe: data = safe_dtype_range(data, dtype) dtype = get_equivalent_dtype(dtype, torch.Tensor) + if isinstance(data, torch.Tensor): return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format) if isinstance(data, np.ndarray): @@ -167,16 +170,25 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if data.ndim > 0: data = np.ascontiguousarray(data) return _convert_tensor(data, dtype=dtype, device=device) - elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))): return _convert_tensor(data, dtype=dtype, device=device) elif isinstance(data, list): - list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data] + list_ret = [ + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ] return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret elif isinstance(data, tuple): - tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data) + tuple_ret = tuple( + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ) return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()} + return { + k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for k, v in data.items() + } return data
tests/data/meta_tensor/test_meta_tensor.py+1 −1 modified@@ -245,7 +245,7 @@ def test_pickling(self): with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) - m2 = torch.load(fname, weights_only=False) + m2 = torch.load(fname, weights_only=True) self.check(m2, m, ids=False) @skip_if_no_cuda
tests/data/test_gdsdataset.py+2 −2 modified@@ -12,7 +12,6 @@ from __future__ import annotations import os -import pickle import tempfile import unittest @@ -86,7 +85,8 @@ def test_cache(self): cache_dir=tempdir, device=0, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
tests/data/test_persistentdataset.py+3 −2 modified@@ -12,12 +12,12 @@ from __future__ import annotations import os -import pickle import tempfile import unittest import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import PersistentDataset, json_hashing @@ -66,7 +66,8 @@ def test_cache(self): transform=_InplaceXform(), cache_dir=tempdir, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)
tests/utils/test_state_cacher.py+7 −1 modified@@ -27,7 +27,13 @@ TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [ torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL}, + { + "in_memory": False, + "cache_dir": gettempdir(), + "pickle_module": None, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + "pickle_protocol": torch.serialization.DEFAULT_PROTOCOL, + }, ] TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
5- github.com/advisories/GHSA-6vm5-6jv9-rjpjghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2025-58756ghsaADVISORY
- github.com/Project-MONAI/MONAI/commit/948fbb703adcb87cd04ebd83d20dcd8d73bf6259ghsaWEB
- github.com/Project-MONAI/MONAI/pull/8566ghsaWEB
- github.com/Project-MONAI/MONAI/security/advisories/GHSA-6vm5-6jv9-rjpjghsax_refsource_CONFIRMWEB
News mentions
0No linked articles in our index yet.