VYPR
High severityNVD Advisory· Published Sep 8, 2025· Updated Sep 9, 2025

MONAI's unsafe torch usage may lead to arbitrary code execution

CVE-2025-58756

Description

MONAI (Medical Open Network for AI) is an AI toolkit for health care imaging. In versions up to and including 1.5.0, in model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True) in monai/bundle/scripts.py , weights_only=True is loaded securely. However, insecure loading methods still exist elsewhere in the project, such as when loading checkpoints. This is a common practice when users want to reduce training time and costs by loading pre-trained models downloaded from other platforms. Loading a checkpoint containing malicious content can trigger a deserialization vulnerability, leading to code execution. As of time of publication, no known fixed versions are available.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
monaiPyPI
< 1.5.11.5.1

Affected products

1

Patches

1
948fbb703adc

Torch and Pickle Safe Load Fixes (#8566)

https://github.com/Project-MONAI/MONAIEric KerfootSep 17, 2025via ghsa
12 files changed · +95 77
  • monai/apps/nnunet/nnunet_bundle.py+17 9 modified
    @@ -133,7 +133,7 @@ def get_nnunet_trainer(
             cudnn.benchmark = True
     
         if pretrained_model is not None:
    -        state_dict = torch.load(pretrained_model)
    +        state_dict = torch.load(pretrained_model, weights_only=True)
             if "network_weights" in state_dict:
                 nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
         return nnunet_trainer
    @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
             parameters = []
     
             checkpoint = torch.load(
    -            join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
    +            join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
    +            map_location=torch.device("cpu"),
    +            weights_only=True,
             )
             trainer_name = checkpoint["trainer_name"]
             configuration_name = checkpoint["init_args"]["configuration"]
    @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
                 else None
             )
             if Path(model_training_output_dir).joinpath(model_name).is_file():
    -            monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
    +            monai_checkpoint = torch.load(
    +                join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
    +            )
                 if "network_weights" in monai_checkpoint.keys():
                     parameters.append(monai_checkpoint["network_weights"])
                 else:
    @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
             dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
         )
     
    -    nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
    -    nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
    +    nnunet_checkpoint_final = torch.load(
    +        Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
    +    )
    +    nnunet_checkpoint_best = torch.load(
    +        Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
    +    )
     
         nnunet_checkpoint = {}
         nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
    @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
         if model_ckpt is None:
             return network
         else:
    -        state_dict = torch.load(model_ckpt)
    +        state_dict = torch.load(model_ckpt, weights_only=True)
             network.load_state_dict(state_dict[model_key_in_ckpt])
             return network
     
    @@ -534,7 +542,7 @@ def subfiles(
     
         Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
     
    -    nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
    +    nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
         latest_checkpoints: list[str] = subfiles(
             Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
         )
    @@ -545,7 +553,7 @@ def subfiles(
         epochs.sort()
         final_epoch: int = epochs[-1]
         monai_last_checkpoint: dict = torch.load(
    -        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
    +        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
         )
     
         best_checkpoints: list[str] = subfiles(
    @@ -558,7 +566,7 @@ def subfiles(
         key_metrics.sort()
         best_key_metric: str = key_metrics[-1]
         monai_best_checkpoint: dict = torch.load(
    -        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
    +        f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
         )
     
         nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
    
  • monai/data/dataset.py+34 16 modified
    @@ -13,7 +13,6 @@
     
     import collections.abc
     import math
    -import pickle
     import shutil
     import sys
     import tempfile
    @@ -22,9 +21,11 @@
     import warnings
     from collections.abc import Callable, Sequence
     from copy import copy, deepcopy
    +from io import BytesIO
     from multiprocessing.managers import ListProxy
     from multiprocessing.pool import ThreadPool
     from pathlib import Path
    +from pickle import UnpicklingError
     from typing import IO, TYPE_CHECKING, Any, cast
     
     import numpy as np
    @@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
             not guaranteed, so caution should be used when modifying transforms to avoid unexpected
             errors. If in doubt, it is advisable to clear the cache directory.
     
    +        Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
    +        be converted to tensors, however any other object type returned by transforms will not be loadable since
    +        `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
    +        Legacy cache files may not be loadable and may need to be recomputed.
    +
         Lazy Resampling:
             If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
             its documentation to familiarize yourself with the interaction between `PersistentDataset` and
    @@ -248,8 +254,8 @@ def __init__(
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                     and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
    @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
     
             if hashfile is not None and hashfile.is_file():  # cache hit
                 try:
    -                return torch.load(hashfile, weights_only=False)
    +                return torch.load(hashfile, weights_only=True)
                 except PermissionError as e:
                     if sys.platform != "win32":
                         raise e
    -            except RuntimeError as e:
    -                if "Invalid magic number; corrupt file" in str(e):
    +            except (UnpicklingError, RuntimeError) as e:  # corrupt or unloadable cached files are recomputed
    +                if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError):
                         warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
                         hashfile.unlink()
                     else:
    @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
                 with tempfile.TemporaryDirectory() as tmpdirname:
                     temp_hash_file = Path(tmpdirname) / hashfile.name
                     torch.save(
    -                    obj=_item_transformed,
    +                    obj=convert_to_tensor(_item_transformed, convert_numeric=False),
                         f=temp_hash_file,
                         pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
                         pickle_protocol=self.pickle_protocol,
    @@ -455,8 +461,8 @@ def __init__(
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                     and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
    @@ -531,7 +537,7 @@ def __init__(
             hash_func: Callable[..., bytes] = pickle_hashing,
             db_name: str = "monai_cache",
             progress: bool = True,
    -        pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +        pickle_protocol=DEFAULT_PROTOCOL,
             hash_transform: Callable[..., bytes] | None = None,
             reset_ops_id: bool = True,
             lmdb_kwargs: dict | None = None,
    @@ -551,8 +557,9 @@ def __init__(
                     defaults to `monai.data.utils.pickle_hashing`.
                 db_name: lmdb database file name. Defaults to "monai_cache".
                 progress: whether to display a progress bar.
    -            pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
    -                https://docs.python.org/3/library/pickle.html#pickle-protocols
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
    +                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
                 hash_transform: a callable to compute hash from the transform information when caching.
                     This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                     Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
    @@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
             super().set_data(data=data)
             self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
     
    +    def _safe_serialize(self, val):
    +        out = BytesIO()
    +        torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)
    +        out.seek(0)
    +        return out.read()
    +
    +    def _safe_deserialize(self, val):
    +        return torch.load(BytesIO(val), map_location="cpu", weights_only=True)
    +
         def _fill_cache_start_reader(self, show_progress=True):
             """
             Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
    @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
                                 continue
                             if val is None:
                                 val = self._pre_transform(deepcopy(item))  # keep the original hashed
    -                            val = pickle.dumps(val, protocol=self.pickle_protocol)
    +                            # val = pickle.dumps(val, protocol=self.pickle_protocol)
    +                            val = self._safe_serialize(val)
                             with env.begin(write=True) as txn:
                                 txn.put(key, val)
                             done = True
    @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
                 warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
                 return super()._cachecheck(item_transformed)
             try:
    -            return pickle.loads(data)
    +            # return pickle.loads(data)
    +            return self._safe_deserialize(data)
             except Exception as err:
                 raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
     
    @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
                     meta_hash_file = self.cache_dir / meta_hash_file_name
                     temp_hash_file = Path(tmpdirname) / meta_hash_file_name
                     torch.save(
    -                    obj=self._meta_cache[meta_hash_file_name],
    +                    obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
                         f=temp_hash_file,
                         pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
                         pickle_protocol=self.pickle_protocol,
    @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
             if meta_hash_file_name in self._meta_cache:
                 return self._meta_cache[meta_hash_file_name]
             else:
    -            return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
    +            return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
    
  • monai/data/__init__.py+0 1 modified
    @@ -78,7 +78,6 @@
     from .thread_buffer import ThreadBuffer, ThreadDataLoader
     from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
     from .utils import (
    -    PICKLE_KEY_SUFFIX,
         affine_to_spacing,
         compute_importance_map,
         compute_shape_offset,
    
  • monai/data/meta_tensor.py+1 1 modified
    @@ -611,4 +611,4 @@ def print_verbose(self) -> None:
     
     # needed in later versions of Pytorch to indicate the class is safe for serialisation
     if hasattr(torch.serialization, "add_safe_globals"):
    -    torch.serialization.add_safe_globals([MetaTensor])
    +    torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
    
  • monai/data/utils.py+10 36 modified
    @@ -30,7 +30,6 @@
     import torch
     from torch.utils.data._utils.collate import default_collate
     
    -from monai import config
     from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
     from monai.data.meta_obj import MetaObj
     from monai.utils import (
    @@ -93,7 +92,6 @@
         "remove_keys",
         "remove_extra_metadata",
         "get_extra_metadata_keys",
    -    "PICKLE_KEY_SUFFIX",
         "is_no_channel",
     ]
     
    @@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
         return
     
     
    -PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX
    -
    -
    -def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
    -    """
    -    Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.
    -
    -    Args:
    -        data: a list or dictionary with substructures to be pickled/unpickled.
    -        key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
    -        is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
    -    """
    -    if isinstance(data, Mapping):
    -        data = dict(data)
    -        for k in data:
    -            if f"{k}".endswith(key):
    -                if is_encode and not isinstance(data[k], bytes):
    -                    data[k] = pickle.dumps(data[k], 0)
    -                if not is_encode and isinstance(data[k], bytes):
    -                    data[k] = pickle.loads(data[k])
    -        return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
    -    elif isinstance(data, (list, tuple)):
    -        return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
    -    return data
    -
    -
     def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
         """
         Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
    @@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence):
         key = None
         collate_fn = default_collate
         try:
    -        if config.USE_META_DICT:
    -            data = pickle_operations(data)  # bc 0.9.0
    +        # if config.USE_META_DICT:
    +        # data = pickle_operations(data)  # bc 0.9.0
             if isinstance(elem, Mapping):
                 ret = {}
                 for k in elem:
    @@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
         if isinstance(deco, Mapping):
             _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
             ret = [dict(zip(deco, item)) for item in _gen]
    -        if not config.USE_META_DICT:
    -            return ret
    -        return pickle_operations(ret, is_encode=False)  # bc 0.9.0
    +        # if not config.USE_META_DICT:
    +        # return ret
    +        # return pickle_operations(ret, is_encode=False)  # bc 0.9.0
    +        return ret
         if isinstance(deco, Iterable):
             _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
             ret_list = [list(item) for item in _gen]
    -        if not config.USE_META_DICT:
    -            return ret_list
    -        return pickle_operations(ret_list, is_encode=False)  # bc 0.9.0
    +        # if not config.USE_META_DICT:
    +        # return ret_list
    +        # return pickle_operations(ret_list, is_encode=False)  # bc 0.9.0
    +        return ret_list
         raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")
     
     
    
  • monai/handlers/checkpoint_loader.py+1 1 modified
    @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
             Args:
                 engine: Ignite Engine, it can be a trainer, validator or evaluator.
             """
    -        checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
    +        checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)
     
             k, _ = list(self.load_dict.items())[0]
             # single object and checkpoint is directly a state_dict
    
  • monai/utils/state_cacher.py+3 3 modified
    @@ -64,8 +64,8 @@ def __init__(
                 pickle_module: module used for pickling metadata and objects, default to `pickle`.
                     this arg is used by `torch.save`, for more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
    -            pickle_protocol: can be specified to override the default protocol, default to `2`.
    -                this arg is used by `torch.save`, for more details, please check:
    +            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
    +                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                     https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
     
             """
    @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
             fn = self.cached[key]["obj"]  # pytype: disable=attribute-error
             if not os.path.exists(fn):  # pytype: disable=wrong-arg-types
                 raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
    -        data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
    +        data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
             # copy back to device if necessary
             if "device" in self.cached[key]:
                 data_obj = data_obj.to(self.cached[key]["device"])
    
  • monai/utils/type_conversion.py+16 4 modified
    @@ -117,6 +117,7 @@ def convert_to_tensor(
         wrap_sequence: bool = False,
         track_meta: bool = False,
         safe: bool = False,
    +    convert_numeric: bool = True,
     ) -> Any:
         """
         Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
    @@ -136,6 +137,7 @@ def convert_to_tensor(
             safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
                 E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.
                 If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.
    +        convert_numeric: if `True`, convert numeric Python values to tensors.
     
         """
     
    @@ -156,6 +158,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
         if safe:
             data = safe_dtype_range(data, dtype)
         dtype = get_equivalent_dtype(dtype, torch.Tensor)
    +
         if isinstance(data, torch.Tensor):
             return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
         if isinstance(data, np.ndarray):
    @@ -167,16 +170,25 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
                 if data.ndim > 0:
                     data = np.ascontiguousarray(data)
                 return _convert_tensor(data, dtype=dtype, device=device)
    -    elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
    +    elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):
             return _convert_tensor(data, dtype=dtype, device=device)
         elif isinstance(data, list):
    -        list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data]
    +        list_ret = [
    +            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for i in data
    +        ]
             return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
         elif isinstance(data, tuple):
    -        tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data)
    +        tuple_ret = tuple(
    +            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for i in data
    +        )
             return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret
         elif isinstance(data, dict):
    -        return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()}
    +        return {
    +            k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
    +            for k, v in data.items()
    +        }
     
         return data
     
    
  • tests/data/meta_tensor/test_meta_tensor.py+1 1 modified
    @@ -245,7 +245,7 @@ def test_pickling(self):
             with tempfile.TemporaryDirectory() as tmp_dir:
                 fname = os.path.join(tmp_dir, "im.pt")
                 torch.save(m, fname)
    -            m2 = torch.load(fname, weights_only=False)
    +            m2 = torch.load(fname, weights_only=True)
             self.check(m2, m, ids=False)
     
         @skip_if_no_cuda
    
  • tests/data/test_gdsdataset.py+2 2 modified
    @@ -12,7 +12,6 @@
     from __future__ import annotations
     
     import os
    -import pickle
     import tempfile
     import unittest
     
    @@ -86,7 +85,8 @@ def test_cache(self):
                         cache_dir=tempdir,
                         device=0,
                         pickle_module="pickle",
    -                    pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +                    # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +                    pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,
                     )
                     assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))
                     ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
    
  • tests/data/test_persistentdataset.py+3 2 modified
    @@ -12,12 +12,12 @@
     from __future__ import annotations
     
     import os
    -import pickle
     import tempfile
     import unittest
     
     import nibabel as nib
     import numpy as np
    +import torch
     from parameterized import parameterized
     
     from monai.data import PersistentDataset, json_hashing
    @@ -66,7 +66,8 @@ def test_cache(self):
                     transform=_InplaceXform(),
                     cache_dir=tempdir,
                     pickle_module="pickle",
    -                pickle_protocol=pickle.HIGHEST_PROTOCOL,
    +                # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +                pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,
                 )
                 self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
                 ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)
    
  • tests/utils/test_state_cacher.py+7 1 modified
    @@ -27,7 +27,13 @@
     TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}]
     TEST_CASE_1 = [
         torch.Tensor([1]).to(DEVICE),
    -    {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL},
    +    {
    +        "in_memory": False,
    +        "cache_dir": gettempdir(),
    +        "pickle_module": None,
    +        # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
    +        "pickle_protocol": torch.serialization.DEFAULT_PROTOCOL,
    +    },
     ]
     TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}]
     TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

5

News mentions

0

No linked articles in our index yet.