VYPR
High severityNVD Advisory· Published Sep 8, 2025· Updated Sep 9, 2025

MONAI's unsafe use of Pickle deserialization may lead to RCE

CVE-2025-58757

Description

MONAI (Medical Open Network for AI) is an AI toolkit for health care imaging. In versions up to and including 1.5.0, the pickle_operations function in monai/data/utils.py automatically handles dictionary key-value pairs ending with a specific suffix and deserializes them using pickle.loads() . This function also lacks any security measures. The deserialization may lead to code execution. As of time of publication, no known fixed versions are available.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
monaiPyPI
< 1.5.11.5.1

Affected products

1

Patches

1
948fbb703adc

Torch and Pickle Safe Load Fixes (#8566)

https://github.com/Project-MONAI/MONAIEric KerfootSep 17, 2025via ghsa
12 files changed · +95 77
  • monai/apps/nnunet/nnunet_bundle.py+17 9 modified
    @@ -133,7 +133,7 @@ def get_nnunet_trainer(
             cudnn.benchmark = True
     
         if pretrained_model is not None:
    -        state_dict = torch.load(pretrained_model)
    +        state_dict = torch.load(pretrained_model, weights_only=True)
             if "network_weights" in state_dict:
                 nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
         return nnunet_trainer
    @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
             parameters = []
     
             checkpoint = torch.load(
    -            join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
    +            join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
    +            map_location=torch.device("cpu"),
    +            weights_only=True,
             )
             trainer_name = checkpoint["trainer_name"]
             configuration_name = checkpoint["init_args"]["configuration"]
    @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
                 else None
             )
             if Path(model_training_output_dir).joinpath(model_name).is_file():
    -            monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
    +            monai_checkpoint = torch.load(
    +                join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
    +            )
                 if "network_weights" in monai_checkpoint.keys():
                     parameters.append(monai_checkpoint["network_weights"])
                 else:
    @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
             dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
         )
     
    -    nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
    -    nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
    +    nnunet_checkpoint_final = torch.load(
    +        Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
    +    )
    +    nnunet_checkpoint_best = torch.load(
    +        Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
    +    )
     
         nnunet_checkpoint = {}
         nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
    @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
         if model_ckpt is None:
             return network
         else:
    -        state_dict = torch.load(model_ckpt)
    +        state_dict = torch.load(model_ckpt, weights_only=True)
             network.load_state_dict(state_dict[model_key_in_ckpt])
             return network
     
    @@ -534,7 +542,7 @@ def subfiles(
     
         Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
     
    -    nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
    +    nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
         latest_checkpoints: list[str] = subfiles(
             Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
         )
    @@ -545,7 +553,7 @@ def subfiles(
         epochs.sort()
         final_epoch: int = epochs[-1]
         monai_last_checkpoint: dict = torch.load(
    -        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
    +        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
         )
     
         best_checkpoints: list[str] = subfiles(
    @@ -558,7 +566,7 @@ def subfiles(
         key_metrics.sort()
         best_key_metric: str = key_metrics[-1]
         monai_best_checkpoint: dict = torch.load(
    -        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
    +        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
         )
     
         nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
    
  • monai/data/dataset.py+34 16 modified
    @@ -13,7 +13,6 @@
     
     import collections.abc
     import math
    -import pickle
     import shutil
     import sys
     import tempfile
    @@ -22,9 +21,11 @@
     import warnings
     from collections.abc import Callable, Sequence
     from copy import copy, deepcopy
    +from io import BytesIO
     from multiprocessing.managers import ListProxy
     from multiprocessing.pool import ThreadPool
     from pathlib import Path
    +from pickle import UnpicklingError
     from typing import IO, TYPE_CHECKING, Any, cast
     
     import numpy as np
    @@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
             not guaranteed, so caution should be used when modifying transforms to avoid unexpected
             errors. If in doubt, it is advisable to clear the cache directory.
     
    +        Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
    +        be converted to tensors, however any other object type returned by transforms will not be loadable since
    +        `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
    +        Legacy cache files may not be loadable and may need to be recomputed.
    +
         Lazy Resampling:
             If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
             its documentation to familiarize yourself with the interaction between `PersistentDataset` and
    @@ -248,8 +254,8 @@ def __init__(
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                     and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
    @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
     
             if hashfile is not None and hashfile.is_file():  # cache hit
                 try:
    -                return torch.load(hashfile, weights_only=False)
    +                return torch.load(hashfile, weights_only=True)
                 except PermissionError as e:
                     if sys.platform != "win32":
                         raise e
    -            except RuntimeError as e:
    -                if "Invalid magic number; corrupt file" in str(e):
    +            except (UnpicklingError, RuntimeError) as e:  # corrupt or unloadable cached files are recomputed
    +                if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError):
                         warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
                         hashfile.unlink()
                     else:
    @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
                 with tempfile.TemporaryDirectory() as tmpdirname:
                     temp_hash_file = Path(tmpdirname) / hashfile.name
                     torch.save(
    -                    obj=_item_transformed,
    +                    obj=convert_to_tensor(_item_transformed, convert_numeric=False),
                         f=temp_hash_file,
                         pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
                         pickle_protocol=self.pickle_protocol,
    @@ -455,8 +461,8 @@ def __init__(
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                     and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
    @@ -531,7 +537,7 @@ def __init__(
             hash_func: Callable[..., bytes] = pickle_hashing,
             db_name: str = "monai_cache",
             progress: bool = True,
    -        pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +        pickle_protocol=DEFAULT_PROTOCOL,
             hash_transform: Callable[..., bytes] | None = None,
             reset_ops_id: bool = True,
             lmdb_kwargs: dict | None = None,
    @@ -551,8 +557,9 @@ def __init__(
                     defaults to `monai.data.utils.pickle_hashing`.
                 db_name: lmdb database file name. Defaults to "monai_cache".
                 progress: whether to display a progress bar.
    -            pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
    -                https://docs.python.org/3/library/pickle.html#pickle-protocols
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
    +                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                     Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
    @@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
             super().set_data(data=data)
             self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
     
    +    def _safe_serialize(self, val):
    +        out = BytesIO()
    +        torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)
    +        out.seek(0)
    +        return out.read()
    +
    +    def _safe_deserialize(self, val):
    +        return torch.load(BytesIO(val), map_location="cpu", weights_only=True)
    +
         def _fill_cache_start_reader(self, show_progress=True):
             """
             Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
    @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
                                 continue
                             if val is None:
                                 val = self._pre_transform(deepcopy(item))  # keep the original hashed
    -                            val = pickle.dumps(val, protocol=self.pickle_protocol)
    +                            # val = pickle.dumps(val, protocol=self.pickle_protocol)
    +                            val = self._safe_serialize(val)
                             with env.begin(write=True) as txn:
                                 txn.put(key, val)
                             done = True
    @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
                 warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
                 return super()._cachecheck(item_transformed)
             try:
    -            return pickle.loads(data)
    +            # return pickle.loads(data)
    +            return self._safe_deserialize(data)
             except Exception as err:
                 raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
     
    @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
                     meta_hash_file = self.cache_dir / meta_hash_file_name
                     temp_hash_file = Path(tmpdirname) / meta_hash_file_name
                     torch.save(
    -                    obj=self._meta_cache[meta_hash_file_name],
    +                    obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
                         f=temp_hash_file,
                         pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
                         pickle_protocol=self.pickle_protocol,
    @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
             if meta_hash_file_name in self._meta_cache:
                 return self._meta_cache[meta_hash_file_name]
             else:
    -            return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
    +            return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
    
  • monai/data/__init__.py+0 1 modified
    @@ -78,7 +78,6 @@
     from .thread_buffer import ThreadBuffer, ThreadDataLoader
     from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
     from .utils import (
    -    PICKLE_KEY_SUFFIX,
         affine_to_spacing,
         compute_importance_map,
         compute_shape_offset,
    
  • monai/data/meta_tensor.py+1 1 modified
    @@ -611,4 +611,4 @@ def print_verbose(self) -> None:
     
     # needed in later versions of Pytorch to indicate the class is safe for serialisation
     if hasattr(torch.serialization, "add_safe_globals"):
    -    torch.serialization.add_safe_globals([MetaTensor])
    +    torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
    
  • monai/data/utils.py+10 36 modified
    @@ -30,7 +30,6 @@
     import torch
     from torch.utils.data._utils.collate import default_collate
     
    -from monai import config
     from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
     from monai.data.meta_obj import MetaObj
     from monai.utils import (
    @@ -93,7 +92,6 @@
         "remove_keys",
         "remove_extra_metadata",
         "get_extra_metadata_keys",
    -    "PICKLE_KEY_SUFFIX",
         "is_no_channel",
     ]
     
    @@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
         return
     
     
    -PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX
    -
    -
    -def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
    -    """
    -    Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.
    -
    -    Args:
    -        data: a list or dictionary with substructures to be pickled/unpickled.
    -        key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
    -        is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
    -    """
    -    if isinstance(data, Mapping):
    -        data = dict(data)
    -        for k in data:
    -            if f"{k}".endswith(key):
    -                if is_encode and not isinstance(data[k], bytes):
    -                    data[k] = pickle.dumps(data[k], 0)
    -                if not is_encode and isinstance(data[k], bytes):
    -                    data[k] = pickle.loads(data[k])
    -        return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
    -    elif isinstance(data, (list, tuple)):
    -        return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
    -    return data
    -
    -
     def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
         """
         Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
    @@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence):
         key = None
         collate_fn = default_collate
         try:
    -        if config.USE_META_DICT:
    -            data = pickle_operations(data)  # bc 0.9.0
    +        # if config.USE_META_DICT:
    +        # data = pickle_operations(data)  # bc 0.9.0
             if isinstance(elem, Mapping):
                 ret = {}
                 for k in elem:
    @@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
         if isinstance(deco, Mapping):
             _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
             ret = [dict(zip(deco, item)) for item in _gen]
    -        if not config.USE_META_DICT:
    -            return ret
    -        return pickle_operations(ret, is_encode=False)  # bc 0.9.0
    +        # if not config.USE_META_DICT:
    +        # return ret
    +        # return pickle_operations(ret, is_encode=False)  # bc 0.9.0
    +        return ret
         if isinstance(deco, Iterable):
             _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
             ret_list = [list(item) for item in _gen]
    -        if not config.USE_META_DICT:
    -            return ret_list
    -        return pickle_operations(ret_list, is_encode=False)  # bc 0.9.0
    +        # if not config.USE_META_DICT:
    +        # return ret_list
    +        # return pickle_operations(ret_list, is_encode=False)  # bc 0.9.0
    +        return ret_list
         raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")
     
     
    
  • monai/handlers/checkpoint_loader.py+1 1 modified
    @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
             Args:
                 engine: Ignite Engine, it can be a trainer, validator or evaluator.
             """
    -        checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
    +        checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)
     
             k, _ = list(self.load_dict.items())[0]
             # single object and checkpoint is directly a state_dict
    
  • monai/utils/state_cacher.py+3 3 modified
    @@ -64,8 +64,8 @@ def __init__(
                 pickle_module: module used for pickling metadata and objects, default to `pickle`.
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
     
             """
    @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
             fn = self.cached[key]["obj"]  # pytype: disable=attribute-error
             if not os.path.exists(fn):  # pytype: disable=wrong-arg-types
                 raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
    -        data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
    +        data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
             # copy back to device if necessary
             if "device" in self.cached[key]:
                 data_obj = data_obj.to(self.cached[key]["device"])
    
  • monai/utils/type_conversion.py+16 4 modified
    @@ -117,6 +117,7 @@ def convert_to_tensor(
         wrap_sequence: bool = False,
         track_meta: bool = False,
         safe: bool = False,
    +    convert_numeric: bool = True,
     ) -> Any:
         """
         Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
    @@ -136,6 +137,7 @@ def convert_to_tensor(
             safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
                 E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.
                 If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.
    +        convert_numeric: if `True`, convert numeric Python values to tensors.
     
         """
     
    @@ -156,6 +158,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
         if safe:
             data = safe_dtype_range(data, dtype)
         dtype = get_equivalent_dtype(dtype, torch.Tensor)
    +
         if isinstance(data, torch.Tensor):
             return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
         if isinstance(data, np.ndarray):
    @@ -167,16 +170,25 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
                 if data.ndim > 0:
                     data = np.ascontiguousarray(data)
                 return _convert_tensor(data, dtype=dtype, device=device)
    -    elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
    +    elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):
             return _convert_tensor(data, dtype=dtype, device=device)
         elif isinstance(data, list):
    -        list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data]
    +        list_ret = [
    +            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for i in data
    +        ]
             return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
         elif isinstance(data, tuple):
    -        tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data)
    +        tuple_ret = tuple(
    +            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for i in data
    +        )
             return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret
         elif isinstance(data, dict):
    -        return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()}
    +        return {
    +            k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for k, v in data.items()
    +        }
     
         return data
     
    
  • tests/data/meta_tensor/test_meta_tensor.py+1 1 modified
    @@ -245,7 +245,7 @@ def test_pickling(self):
             with tempfile.TemporaryDirectory() as tmp_dir:
                 fname = os.path.join(tmp_dir, "im.pt")
                 torch.save(m, fname)
    -            m2 = torch.load(fname, weights_only=False)
    +            m2 = torch.load(fname, weights_only=True)
             self.check(m2, m, ids=False)
     
         @skip_if_no_cuda
    
  • tests/data/test_gdsdataset.py+2 2 modified
    @@ -12,7 +12,6 @@
     from __future__ import annotations
     
     import os
    -import pickle
     import tempfile
     import unittest
     
    @@ -86,7 +85,8 @@ def test_cache(self):
                         cache_dir=tempdir,
                         device=0,
                         pickle_module="pickle",
    -                    pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +                    # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +                    pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,
                     )
                     assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))
                     ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
    
  • tests/data/test_persistentdataset.py+3 2 modified
    @@ -12,12 +12,12 @@
     from __future__ import annotations
     
     import os
    -import pickle
     import tempfile
     import unittest
     
     import nibabel as nib
     import numpy as np
    +import torch
     from parameterized import parameterized
     
     from monai.data import PersistentDataset, json_hashing
    @@ -66,7 +66,8 @@ def test_cache(self):
                     transform=_InplaceXform(),
                     cache_dir=tempdir,
                     pickle_module="pickle",
    -                pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +                # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +                pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,
                 )
                 self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
                 ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)
    
  • tests/utils/test_state_cacher.py+7 1 modified
    @@ -27,7 +27,13 @@
     TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}]
     TEST_CASE_1 = [
         torch.Tensor([1]).to(DEVICE),
    -    {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL},
    +    {
    +        "in_memory": False,
    +        "cache_dir": gettempdir(),
    +        "pickle_module": None,
    +        # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +        "pickle_protocol": torch.serialization.DEFAULT_PROTOCOL,
    +    },
     ]
     TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}]
     TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

5

News mentions

0

No linked articles in our index yet.