VYPR
Unrated severityNVD Advisory· Published Jun 1, 2026· Updated Jun 1, 2026

CVE-2026-38950

CVE-2026-38950

Description

ESA AnomalyMatch versions before 1.3.1 are vulnerable to arbitrary code execution due to unsafe deserialization of model checkpoint files using torch.load().

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

ESA AnomalyMatch versions before 1.3.1 are vulnerable to arbitrary code execution due to unsafe deserialization of model checkpoint files using torch.load().

Vulnerability

ESA AnomalyMatch versions prior to 1.3.1 contain an unsafe deserialization vulnerability within the model checkpoint loading mechanism. The application utilizes torch.load() to process model files from session directories, which, by default, relies on the pickle module. This allows for the execution of arbitrary code when a maliciously crafted checkpoint file is processed by the system [1], [2], [3].

Exploitation

An attacker must be able to place a crafted model checkpoint file into a directory that the AnomalyMatch application accesses for loading. Once the application attempts to load this file through its standard workflow, the underlying pickle deserialization process triggers the execution of arbitrary code [2], [3].

Impact

Successful exploitation allows an attacker to execute arbitrary code on the host system with the privileges of the user running the AnomalyMatch application. This can lead to a full compromise of the application environment and potential lateral movement depending on the host's configuration [2].

Mitigation

The vulnerability is addressed in version 1.3.1, released on May 11, 2026. The fix involves replacing the pickle-based torch.load() and torch.save() functions with the safetensors library, which stores tensors in a secure binary format and metadata as JSON, thereby eliminating the risk of arbitrary code execution during deserialization [2], [3].

AI Insight generated on Jun 1, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected products

1

Patches

1
d63b1543208b

fix(security): replace pickle-based torch.save/load with safetensors (#9)

https://github.com/esa/AnomalyMatchPablo GómezMar 27, 2026via body-scan
15 files changed · +696 298
  • anomaly_match/data_io/checkpoint_io.py+326 0 added
    @@ -0,0 +1,326 @@
    +#   Copyright (c) European Space Agency, 2025.
    +#
    +#   This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
    +#   is part of this source code package. No part of the package, including
    +#   this file, may be copied, modified, propagated, or distributed except according to
    +#   the terms contained in the file 'LICENCE.txt'.
    +
    +"""Checkpoint I/O using safetensors for secure model serialization.
    +
    +Replaces pickle-based ``torch.save`` / ``torch.load`` with safetensors to
    +prevent arbitrary code execution when loading untrusted model files.
    +
    +Checkpoint layout inside a single ``.safetensors`` file:
    +
    +* **Binary section** — all ``torch.Tensor`` values (model weights, optimizer
    +  momentum buffers, …) stored under namespaced keys
    +  (``train_model.<name>``, ``optimizer.state.<idx>.<buf>``, …).
    +* **Metadata header** — every non-tensor value is JSON-encoded into the
    +  ``Dict[str, str]`` metadata that safetensors carries in its header.
    +"""
    +
    +from __future__ import annotations
    +
    +import json
    +from enum import Enum
    +from pathlib import Path
    +from typing import Any
    +
    +import numpy as np
    +import torch
    +from loguru import logger
    +
    +# ---------------------------------------------------------------------------
    +# JSON helpers for types that appear in checkpoint metadata
    +# ---------------------------------------------------------------------------
    +
    +
    +def _nullify_empty_dicts(obj: Any) -> Any:
    +    """Recursively replace empty dicts with ``None``.
    +
    +    DotMap auto-creates empty child maps when accessing missing keys.  After
    +    ``toDict()`` these become ``{}``, which breaks fitsbolt's
    +    ``validate_config`` on reload (e.g. ``channel_combination`` is expected to
    +    be ``None`` or ``np.ndarray``, not ``{}``).
    +    """
    +    if isinstance(obj, dict):
    +        if len(obj) == 0:
    +            return None
    +        return {k: _nullify_empty_dicts(v) for k, v in obj.items()}
    +    if isinstance(obj, (list, tuple)):
    +        return [_nullify_empty_dicts(v) for v in obj]
    +    return obj
    +
    +
    +def _prepare_for_json(obj: Any) -> Any:
    +    """Recursively convert non-JSON-native types to tagged representations.
    +
    +    This is needed because ``IntEnum`` (which ``NormalisationMethod`` inherits
    +    from) is a subclass of ``int`` — the standard JSON encoder serializes it
    +    as a plain integer and never calls ``default()``.  By walking the
    +    structure up-front we ensure *all* special types are tagged.
    +
    +    """
    +    # Enum check MUST come before int/float because IntEnum is also an int
    +    if isinstance(obj, Enum):
    +        return {"__enum__": type(obj).__name__, "name": obj.name}
    +    if isinstance(obj, np.dtype):
    +        return {"__numpy_dtype__": str(obj)}
    +    if isinstance(obj, type) and issubclass(obj, np.generic):
    +        return {"__numpy_dtype_type__": np.dtype(obj).str}
    +    if isinstance(obj, np.ndarray):
    +        return {"__numpy_array__": obj.tolist(), "dtype": str(obj.dtype)}
    +    if isinstance(obj, np.integer):
    +        return int(obj)
    +    if isinstance(obj, np.floating):
    +        return float(obj)
    +    if isinstance(obj, np.bool_):
    +        return bool(obj)
    +    if isinstance(obj, dict):
    +        return {k: _prepare_for_json(v) for k, v in obj.items()}
    +    if isinstance(obj, (list, tuple)):
    +        return [_prepare_for_json(v) for v in obj]
    +    return obj
    +
    +
    +class _CheckpointEncoder(json.JSONEncoder):
    +    """JSON encoder that handles checkpoint-specific types.
    +
    +    Note: ``IntEnum`` values bypass ``default()`` because they *are* ints.
    +    Use :func:`_prepare_for_json` on the data **before** calling
    +    ``json.dumps`` to ensure those types are correctly tagged.
    +    """
    +
    +    def default(self, obj: Any) -> Any:
    +        if isinstance(obj, Enum):
    +            return {"__enum__": type(obj).__name__, "name": obj.name}
    +        if isinstance(obj, np.dtype):
    +            return {"__numpy_dtype__": str(obj)}
    +        if isinstance(obj, type) and issubclass(obj, np.generic):
    +            return {"__numpy_dtype_type__": np.dtype(obj).str}
    +        if isinstance(obj, np.ndarray):
    +            return {"__numpy_array__": obj.tolist(), "dtype": str(obj.dtype)}
    +        if isinstance(obj, np.integer):
    +            return int(obj)
    +        if isinstance(obj, np.floating):
    +            return float(obj)
    +        if isinstance(obj, np.bool_):
    +            return bool(obj)
    +        return super().default(obj)
    +
    +
    +def _checkpoint_object_hook(obj: dict) -> Any:
    +    """JSON object-hook that restores checkpoint-specific types."""
    +    if "__enum__" in obj:
    +        enum_name = obj["__enum__"]
    +        if enum_name == "NormalisationMethod":
    +            from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
    +
    +            return NormalisationMethod[obj["name"]]
    +        return f"{enum_name}.{obj['name']}"
    +    if "__numpy_dtype__" in obj:
    +        return np.dtype(obj["__numpy_dtype__"])
    +    if "__numpy_dtype_type__" in obj:
    +        return np.dtype(obj["__numpy_dtype_type__"]).type
    +    if "__numpy_array__" in obj:
    +        return np.array(obj["__numpy_array__"], dtype=obj["dtype"])
    +    return obj
    +
    +
    +# ---------------------------------------------------------------------------
    +# Public API
    +# ---------------------------------------------------------------------------
    +
    +
    +def save_checkpoint(save_state: dict[str, Any], path: str | Path) -> Path:
    +    """Save a model checkpoint in safetensors format.
    +
    +    Tensors are stored in the safetensors binary section; everything else is
    +    JSON-encoded into the safetensors metadata header.
    +
    +    Args:
    +        save_state: Checkpoint dictionary (same keys as previously passed to
    +            ``torch.save``).
    +        path: Destination file path. The extension is forced to
    +            ``.safetensors``.
    +
    +    Returns:
    +        The actual path written (with ``.safetensors`` extension).
    +    """
    +    from safetensors.torch import save_file
    +
    +    path = Path(path).with_suffix(".safetensors")
    +
    +    tensors: dict[str, torch.Tensor] = {}
    +    metadata: dict[str, str] = {}
    +
    +    # ---- model state-dicts ------------------------------------------------
    +    for model_key in ("train_model", "eval_model"):
    +        state_dict = save_state.get(model_key)
    +        if state_dict is None:
    +            continue
    +        for param_name, tensor in state_dict.items():
    +            tensors[f"{model_key}.{param_name}"] = tensor.detach().clone().contiguous()
    +
    +    # ---- optimizer state --------------------------------------------------
    +    opt_state = save_state.get("optimizer")
    +    if opt_state is not None:
    +        opt_skeleton: dict[str, Any] = {
    +            "state": {},
    +            "param_groups": opt_state.get("param_groups", []),
    +        }
    +        for param_idx, state in opt_state.get("state", {}).items():
    +            idx_str = str(param_idx)
    +            opt_skeleton["state"][idx_str] = {}
    +            for key, val in state.items():
    +                if isinstance(val, torch.Tensor):
    +                    tensors[f"optimizer.state.{param_idx}.{key}"] = (
    +                        val.detach().clone().contiguous()
    +                    )
    +                    opt_skeleton["state"][idx_str][key] = "__tensor__"
    +                else:
    +                    opt_skeleton["state"][idx_str][key] = val
    +        metadata["optimizer"] = json.dumps(_prepare_for_json(opt_skeleton), cls=_CheckpointEncoder)
    +    else:
    +        metadata["optimizer"] = "null"
    +
    +    # ---- scheduler state --------------------------------------------------
    +    sched_state = save_state.get("scheduler")
    +    metadata["scheduler"] = (
    +        json.dumps(_prepare_for_json(sched_state), cls=_CheckpointEncoder)
    +        if sched_state is not None
    +        else "null"
    +    )
    +
    +    # ---- scalar / enum metadata -------------------------------------------
    +    for key in (
    +        "it",
    +        "total_it",
    +        "best_eval_acc",
    +        "best_it",
    +        "num_channels",
    +        "net",
    +        "normalisation_method",
    +        "last_normalisation_method",
    +    ):
    +        metadata[key] = json.dumps(_prepare_for_json(save_state.get(key)), cls=_CheckpointEncoder)
    +
    +    # ---- fitsbolt config (DotMap → dict → JSON) ---------------------------
    +    fb_cfg = save_state.get("fitsbolt_cfg")
    +    if fb_cfg is not None:
    +        cfg_dict = fb_cfg.toDict() if hasattr(fb_cfg, "toDict") else fb_cfg
    +        # DotMap auto-creates empty child maps on missing-key access (e.g.
    +        # channel_combination).  After toDict() these become empty dicts {},
    +        # which break fitsbolt's validate_config on reload.  Normalize
    +        # leaf-level empty dicts to None.
    +        cfg_dict = _nullify_empty_dicts(cfg_dict)
    +        metadata["fitsbolt_cfg"] = json.dumps(_prepare_for_json(cfg_dict), cls=_CheckpointEncoder)
    +    else:
    +        metadata["fitsbolt_cfg"] = "null"
    +
    +    # ---- labeled-data CSV -------------------------------------------------
    +    csv_str = save_state.get("labeled_data_csv")
    +    if csv_str is not None:
    +        metadata["labeled_data_csv"] = csv_str
    +
    +    # safetensors requires at least one tensor
    +    if not tensors:
    +        tensors["__placeholder__"] = torch.zeros(1)
    +
    +    save_file(tensors, str(path), metadata=metadata)
    +    logger.debug(f"Saved checkpoint in safetensors format: {path}")
    +    return path
    +
    +
    +def load_checkpoint(path: str | Path, device: str = "cpu") -> dict[str, Any]:
    +    """Load a model checkpoint from a ``.safetensors`` file.
    +
    +    Args:
    +        path: Path to the ``.safetensors`` checkpoint file.
    +        device: Device to map tensors to (default ``"cpu"``).
    +
    +    Returns:
    +        Checkpoint dictionary with the same structure as originally saved.
    +
    +    Raises:
    +        FileNotFoundError: If *path* does not exist.
    +    """
    +    from safetensors import safe_open
    +    from safetensors.torch import load_file
    +
    +    path = Path(path)
    +    if not path.exists():
    +        raise FileNotFoundError(f"Checkpoint not found: {path}")
    +
    +    all_tensors = load_file(str(path), device=device)
    +
    +    with safe_open(str(path), framework="pt", device=device) as f:
    +        raw_metadata = f.metadata() or {}
    +
    +    checkpoint: dict[str, Any] = {}
    +
    +    # ---- model state-dicts ------------------------------------------------
    +    for model_key in ("train_model", "eval_model"):
    +        prefix = f"{model_key}."
    +        state_dict = {k[len(prefix) :]: v for k, v in all_tensors.items() if k.startswith(prefix)}
    +        if state_dict:
    +            checkpoint[model_key] = state_dict
    +
    +    # ---- optimizer state --------------------------------------------------
    +    opt_skeleton = json.loads(
    +        raw_metadata.get("optimizer", "null"), object_hook=_checkpoint_object_hook
    +    )
    +    if opt_skeleton is not None:
    +        new_state: dict[int, dict] = {}
    +        for idx_str, state in opt_skeleton.get("state", {}).items():
    +            restored: dict[str, Any] = {}
    +            for key, val in state.items():
    +                if val == "__tensor__":
    +                    restored[key] = all_tensors[f"optimizer.state.{idx_str}.{key}"]
    +                else:
    +                    restored[key] = val
    +            new_state[int(idx_str)] = restored
    +        opt_skeleton["state"] = new_state
    +        checkpoint["optimizer"] = opt_skeleton
    +    else:
    +        checkpoint["optimizer"] = None
    +
    +    # ---- scheduler state --------------------------------------------------
    +    checkpoint["scheduler"] = json.loads(
    +        raw_metadata.get("scheduler", "null"), object_hook=_checkpoint_object_hook
    +    )
    +
    +    # ---- scalar / enum metadata -------------------------------------------
    +    for key in (
    +        "it",
    +        "total_it",
    +        "best_eval_acc",
    +        "best_it",
    +        "num_channels",
    +        "net",
    +        "normalisation_method",
    +        "last_normalisation_method",
    +    ):
    +        checkpoint[key] = json.loads(
    +            raw_metadata.get(key, "null"), object_hook=_checkpoint_object_hook
    +        )
    +
    +    # ---- fitsbolt config --------------------------------------------------
    +    fb_data = json.loads(
    +        raw_metadata.get("fitsbolt_cfg", "null"), object_hook=_checkpoint_object_hook
    +    )
    +    if fb_data is not None:
    +        from dotmap import DotMap
    +
    +        # _dynamic=False prevents DotMap from auto-creating empty child maps
    +        # on missing-key access, which would break fitsbolt's validate_config
    +        # (e.g. channel_combination should stay absent, not become DotMap()).
    +        checkpoint["fitsbolt_cfg"] = DotMap(fb_data, _dynamic=False)
    +    else:
    +        checkpoint["fitsbolt_cfg"] = None
    +
    +    # ---- labeled-data CSV -------------------------------------------------
    +    if "labeled_data_csv" in raw_metadata:
    +        checkpoint["labeled_data_csv"] = raw_metadata["labeled_data_csv"]
    +
    +    return checkpoint
    
  • anomaly_match/data_io/SessionIOHandler.py+11 77 modified
    @@ -7,14 +7,13 @@
     
     import json
     import os
    -import pickle
     from pathlib import Path
     from typing import Any, Dict, List, Optional
     
     import pandas as pd
    -import torch
     from loguru import logger
     
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint
     from anomaly_match.data_io.save_config import save_config_toml
     from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker
     
    @@ -185,43 +184,6 @@ def save_iteration_scores(
                 except Exception as e:
                     logger.warning(f"Failed to save test scores: {e}")
     
    -    def save_model_checkpoint(
    -        self,
    -        model_state: Dict[str, Any],
    -        session_tracker: SessionTracker,
    -        checkpoint_name: str = None,
    -    ) -> str:
    -        """
    -        Save a model checkpoint within the session directory.
    -
    -        Args:
    -            model_state: Model state dictionary to save.
    -            session_tracker: Associated session tracker.
    -            checkpoint_name: Optional custom checkpoint name.
    -
    -        Returns:
    -            Path to saved checkpoint.
    -        """
    -        save_path = self.get_session_save_path(session_tracker)
    -        save_path.mkdir(parents=True, exist_ok=True)
    -
    -        checkpoints_dir = save_path / "checkpoints"
    -        checkpoints_dir.mkdir(exist_ok=True)
    -
    -        if checkpoint_name is None:
    -            checkpoint_name = f"model_iter_{session_tracker.total_model_iterations}.pkl"
    -
    -        checkpoint_path = checkpoints_dir / checkpoint_name
    -
    -        with open(checkpoint_path, "wb") as f:
    -            pickle.dump(model_state, f)
    -
    -        # Update the session tracker with the checkpoint path
    -        session_tracker.update_model_state_path(str(checkpoint_path))
    -
    -        logger.debug(f"Saved model checkpoint to: {checkpoint_path}")
    -        return str(checkpoint_path)
    -
         def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
             """
             Save the model to the session directory if session_tracker is available,
    @@ -246,7 +208,7 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
                     if session_tracker.session_iterations
                     else 0
                 )
    -            model_filename = f"model_iteration_{iteration_num}.pth"
    +            model_filename = f"model_iteration_{iteration_num}.safetensors"
                 model_path = save_path / model_filename
             else:
                 if cfg.model_path is None:
    @@ -287,8 +249,9 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
                 "fitsbolt_cfg": fitsbolt_cfg,
             }
     
    -        # Save model
    -        torch.save(save_state, model_path)
    +        # Save model (save_checkpoint forces .safetensors extension)
    +        save_checkpoint(save_state, model_path)
    +        model_path = Path(model_path).with_suffix(".safetensors")
     
             if session_tracker is not None:
                 # Ensure there's an active session iteration
    @@ -331,7 +294,7 @@ def load_model(self, model, cfg, model_path: str = None) -> bool:
     
             try:
                 # Load checkpoint
    -            checkpoint = torch.load(load_path, weights_only=False)
    +            checkpoint = load_checkpoint(load_path)
     
                 # Handle distributed training case
                 train_model = (
    @@ -426,37 +389,6 @@ def load_model(self, model, cfg, model_path: str = None) -> bool:
                 logger.error(f"Failed to load model from {load_path}: {e}")
                 return False
     
    -    def load_model_checkpoint(self, checkpoint_path: str) -> Optional[Dict[str, Any]]:
    -        """
    -        Load a model checkpoint from the specified path.
    -
    -        Args:
    -            checkpoint_path: Path to the checkpoint file
    -
    -        Returns:
    -            Dictionary containing the checkpoint data, or None if loading failed
    -        """
    -        try:
    -            if not os.path.exists(checkpoint_path):
    -                logger.error(f"Checkpoint path does not exist: {checkpoint_path}")
    -                return None
    -
    -            # Try loading as pickle first (new format), then as torch (legacy)
    -            try:
    -                with open(checkpoint_path, "rb") as f:
    -                    checkpoint = pickle.load(f)
    -                logger.debug(f"Loaded checkpoint from pickle format: {checkpoint_path}")
    -            except (pickle.UnpicklingError, EOFError):
    -                # Fall back to torch format
    -                checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu")
    -                logger.debug(f"Loaded checkpoint from torch format: {checkpoint_path}")
    -
    -            return checkpoint
    -
    -        except Exception as e:
    -            logger.error(f"Failed to load checkpoint from {checkpoint_path}: {e}")
    -            return None
    -
         def load_session(self, session_path: Path) -> SessionTracker:
             """
             Load a session from disk.
    @@ -611,7 +543,9 @@ def save_run(
                 "fitsbolt_cfg": fitsbolt_cfg,
             }
     
    -        torch.save(save_state, save_filename)
    +        save_checkpoint(save_state, save_filename)
    +        # save_checkpoint forces .safetensors extension; update save_filename to match
    +        save_filename = str(Path(save_filename).with_suffix(".safetensors"))
     
             # Update session tracker if provided
             if session_tracker is not None:
    @@ -706,7 +640,7 @@ def update_config_paths_for_session(self, cfg, session_tracker: SessionTracker)
     
             # Update model path to session directory only if not already set by user
             if cfg.model_path is None:
    -            cfg.model_path = str(session_path / "model.pth")
    +            cfg.model_path = str(session_path / "model.safetensors")
     
             # Update output directory to session directory
             cfg.output_dir = str(session_path)
    @@ -805,7 +739,7 @@ def print_session(filepath: str) -> None:
     
             checkpoints_dir = session_path / "checkpoints"
             if checkpoints_dir.exists():
    -            checkpoints = list(checkpoints_dir.glob("*.pkl"))
    +            checkpoints = list(checkpoints_dir.glob("*.safetensors"))
                 print(f"✓ {len(checkpoints)} model checkpoint(s)")
     
             print("=" * 60)
    
  • anomaly_match/utils/get_default_cfg.py+1 1 modified
    @@ -32,7 +32,7 @@ def get_default_cfg():
         cfg.metadata_file = None  # Path to the metadata CSV file
         cfg.prediction_search_dir = None
         cfg.save_path = os.path.join(cfg.save_dir)
    -    cfg.save_file = create_model_string(cfg) + ".pth"
    +    cfg.save_file = create_model_string(cfg) + ".safetensors"
         cfg.model_path = None  # Will be set by SessionIOHandler when session is active
         cfg.N_batch_prediction = None  # User specified batch size for evaluating a directory, if None: determined automatically
         cfg.subprocess_buffer_size = (
    
  • environment_CI.yml+1 0 modified
    @@ -35,6 +35,7 @@ dependencies:
       - pip:
           - opencv-python-headless
           - albumentations
    +      - safetensors
           - timm
           - fitsbolt>=0.2
           - cutana>=0.2.1
    
  • environment.yml+1 0 modified
    @@ -38,4 +38,5 @@ dependencies:
           - cutana>=0.2.1
           - fitsbolt>=0.2
           - opencv-python-headless
    +      - safetensors
           - timm
    
  • prediction_utils.py+3 4 modified
    @@ -24,6 +24,7 @@
     from loguru import logger
     from turbojpeg import TurboJPEG
     
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint
     from anomaly_match.data_io.load_images import get_fitsbolt_config, process_single_wrapper
     from anomaly_match.utils.get_default_cfg import get_default_cfg
     
    @@ -189,10 +190,8 @@ def load_model(cfg):
         else:
             logger.info("Using CPU for inference")
     
    -    if torch.cuda.is_available():
    -        checkpoint = torch.load(model_path, weights_only=False)
    -    else:
    -        checkpoint = torch.load(model_path, weights_only=False, map_location=torch.device("cpu"))
    +    device = "cuda" if torch.cuda.is_available() else "cpu"
    +    checkpoint = load_checkpoint(model_path, device=device)
     
         if "eval_model" not in checkpoint:
             raise KeyError(
    
  • pyproject.toml+1 0 modified
    @@ -46,6 +46,7 @@ dependencies = [
         "psutil",
         "pyarrow",
         "pyturbojpeg",
    +    "safetensors",
         "scikit-image",
         "scikit-learn",
         "scipy",
    
  • tests/e2e/test_prediction_process.py+1 1 modified
    @@ -40,7 +40,7 @@ def test_config():
         cfg.net = "efficientnet-lite0"
         cfg.pretrained = True
         cfg.num_channels = 3
    -    cfg.model_path = "tests/test_data/test_model.pth"
    +    cfg.model_path = "tests/test_data/test_model.safetensors"
         cfg.gpu = 0
         cfg.output_dir = tempfile.mkdtemp()
         cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY
    
  • tests/integration/test_fitsbolt_config_persistence.py+90 124 modified
    @@ -7,8 +7,8 @@
     
     """Tests for fitsbolt configuration persistence in model checkpoints.
     
    -The fitsbolt DotMap configuration can be pickled directly via torch.save/load
    -without explicit serialization.
    +The fitsbolt DotMap configuration is serialized via safetensors metadata
    +(JSON-encoded) through save_checkpoint/load_checkpoint.
     """
     
     import shutil
    @@ -22,14 +22,36 @@
     from fitsbolt.cfg.create_config import validate_config
     from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
     
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint
     from anomaly_match.data_io.load_images import get_fitsbolt_config
     
     
    -class TestFitsboltConfigPickling:
    -    """Test cases for fitsbolt config pickling via torch.save/load."""
    -
    -    def test_pickle_roundtrip_basic(self):
    -        """Test basic pickle roundtrip via torch checkpoint."""
    +def _make_checkpoint(fitsbolt_cfg=None, **extra):
    +    """Create a minimal checkpoint dict suitable for save_checkpoint."""
    +    checkpoint = {
    +        "train_model": {"dummy.weight": torch.randn(2, 2)},
    +        "eval_model": {"dummy.weight": torch.randn(2, 2)},
    +        "optimizer": None,
    +        "scheduler": None,
    +        "it": 0,
    +        "total_it": 0,
    +        "best_eval_acc": None,
    +        "best_it": None,
    +        "num_channels": 3,
    +        "net": "efficientnet-lite0",
    +        "normalisation_method": None,
    +        "last_normalisation_method": None,
    +        "fitsbolt_cfg": fitsbolt_cfg,
    +    }
    +    checkpoint.update(extra)
    +    return checkpoint
    +
    +
    +class TestFitsboltConfigSafetensors:
    +    """Test cases for fitsbolt config persistence via safetensors."""
    +
    +    def test_roundtrip_basic(self):
    +        """Test basic roundtrip via safetensors checkpoint."""
             original_cfg = fb_create_cfg(
                 output_dtype=np.uint8,
                 size=[64, 64],
    @@ -38,44 +60,34 @@ def test_pickle_roundtrip_basic(self):
                 num_workers=4,
             )
     
    -        # Save via torch
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 assert loaded_cfg.size == original_cfg.size
                 assert loaded_cfg.n_output_channels == original_cfg.n_output_channels
    -            assert loaded_cfg.num_workers == original_cfg.num_workers
                 assert loaded_cfg.normalisation_method == original_cfg.normalisation_method
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
    -    def test_pickle_numpy_dtype(self):
    -        """Test pickling of numpy dtypes."""
    +    def test_numpy_dtype(self):
    +        """Test persistence of numpy dtypes."""
             original_cfg = fb_create_cfg(
                 output_dtype=np.float32,
                 size=[128, 128],
                 n_output_channels=3,
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 assert loaded_cfg.output_dtype == np.float32
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
    -    def test_pickle_all_normalisation_methods(self):
    -        """Test pickling with all normalisation methods."""
    +    def test_all_normalisation_methods(self):
    +        """Test persistence with all normalisation methods."""
             for method in NormalisationMethod:
                 original_cfg = fb_create_cfg(
                     output_dtype=np.uint8,
    @@ -84,20 +96,16 @@ def test_pickle_all_normalisation_methods(self):
                     normalisation_method=method,
                 )
     
    -            with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -                checkpoint_path = f.name
    -
    -            try:
    -                torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -                loaded = torch.load(checkpoint_path, weights_only=False)
    +            with tempfile.TemporaryDirectory() as tmp:
    +                checkpoint_path = Path(tmp) / "model.safetensors"
    +                save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +                loaded = load_checkpoint(checkpoint_path)
                     loaded_cfg = loaded["fitsbolt_cfg"]
     
                     assert loaded_cfg.normalisation_method == method
    -            finally:
    -                Path(checkpoint_path).unlink(missing_ok=True)
     
    -    def test_pickle_channel_combination(self):
    -        """Test pickling of numpy array channel_combination."""
    +    def test_channel_combination(self):
    +        """Test persistence of numpy array channel_combination."""
             original_cfg = fb_create_cfg(
                 output_dtype=np.uint8,
                 size=[64, 64],
    @@ -106,22 +114,18 @@ def test_pickle_channel_combination(self):
                 channel_combination=np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 np.testing.assert_array_equal(
                     loaded_cfg.channel_combination, original_cfg.channel_combination
                 )
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
    -    def test_pickle_asinh_settings(self):
    -        """Test pickling of asinh normalisation settings."""
    +    def test_asinh_settings(self):
    +        """Test persistence of asinh normalisation settings."""
             original_cfg = fb_create_cfg(
                 output_dtype=np.uint8,
                 size=[64, 64],
    @@ -131,25 +135,21 @@ def test_pickle_asinh_settings(self):
                 norm_asinh_clip=[99.0, 99.5, 99.8],
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 assert loaded_cfg.normalisation.asinh_scale == original_cfg.normalisation.asinh_scale
                 assert loaded_cfg.normalisation.asinh_clip == original_cfg.normalisation.asinh_clip
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
     
     class TestFitsboltConfigValidation:
    -    """Test cases for fitsbolt config validation after pickling."""
    +    """Test cases for fitsbolt config validation after safetensors roundtrip."""
     
    -    def test_validate_pickled_config(self):
    -        """Test that pickled config passes fitsbolt validation."""
    +    def test_validate_roundtripped_config(self):
    +        """Test that roundtripped config passes fitsbolt validation."""
             original_cfg = fb_create_cfg(
                 output_dtype=np.uint8,
                 size=[64, 64],
    @@ -158,25 +158,21 @@ def test_validate_pickled_config(self):
                 num_workers=4,
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 # Validate using fitsbolt's validate_config
                 validate_config(loaded_cfg)
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
     
     class TestFitsboltConfigCompatibility:
         """Test compatibility with fitsbolt's create_config function."""
     
         def test_compatibility_with_fits_extension_settings(self):
    -        """Test pickling with various fits_extension configurations."""
    +        """Test persistence with various fits_extension configurations."""
             # Single integer extension
             cfg1 = fb_create_cfg(
                 output_dtype=np.uint8,
    @@ -185,15 +181,11 @@ def test_compatibility_with_fits_extension_settings(self):
                 fits_extension=0,
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": cfg1}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg1), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 validate_config(loaded["fitsbolt_cfg"])
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
             # List of extensions
             cfg2 = fb_create_cfg(
    @@ -203,22 +195,18 @@ def test_compatibility_with_fits_extension_settings(self):
                 fits_extension=[0, 1, 2],
             )
     
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": cfg2}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg2), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 validate_config(loaded["fitsbolt_cfg"])
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
     
     class TestGetFitsboltConfigIntegration:
    -    """Test get_fitsbolt_config integration with pickling."""
    +    """Test get_fitsbolt_config integration with safetensors persistence."""
     
    -    def test_get_fitsbolt_config_pickling(self):
    -        """Test that config from get_fitsbolt_config can be pickled."""
    +    def test_get_fitsbolt_config_roundtrip(self):
    +        """Test that config from get_fitsbolt_config survives safetensors roundtrip."""
             # Create an AnomalyMatch-style config
             cfg = DotMap()
             cfg.normalisation = DotMap()
    @@ -240,13 +228,10 @@ def test_get_fitsbolt_config_pickling(self):
             # Get fitsbolt config
             cfg = get_fitsbolt_config(cfg)
     
    -        # Save and load via torch
    -        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f:
    -            checkpoint_path = f.name
    -
    -        try:
    -            torch.save({"fitsbolt_cfg": cfg.fitsbolt_cfg}, checkpoint_path)
    -            loaded = torch.load(checkpoint_path, weights_only=False)
    +        with tempfile.TemporaryDirectory() as tmp:
    +            checkpoint_path = Path(tmp) / "model.safetensors"
    +            save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg.fitsbolt_cfg), checkpoint_path)
    +            loaded = load_checkpoint(checkpoint_path)
                 loaded_cfg = loaded["fitsbolt_cfg"]
     
                 # Validate
    @@ -256,8 +241,6 @@ def test_get_fitsbolt_config_pickling(self):
                 assert loaded_cfg.size == [64, 64]
                 assert loaded_cfg.n_output_channels == 3
                 assert loaded_cfg.normalisation_method == NormalisationMethod.CONVERSION_ONLY
    -        finally:
    -            Path(checkpoint_path).unlink(missing_ok=True)
     
     
     class TestFitsboltConfigE2EWithCheckpoint:
    @@ -272,7 +255,7 @@ def teardown_method(self):
             shutil.rmtree(self.temp_dir)
     
         def test_fitsbolt_config_in_checkpoint_dict(self):
    -        """Test that fitsbolt config can be saved and loaded in a checkpoint-like dict."""
    +        """Test that fitsbolt config can be saved and loaded in a checkpoint dict."""
             # Create a fitsbolt config
             fitsbolt_cfg = fb_create_cfg(
                 output_dtype=np.uint8,
    @@ -283,19 +266,12 @@ def test_fitsbolt_config_in_checkpoint_dict(self):
                 norm_asinh_clip=[99.0, 99.5, 99.8],
             )
     
    -        # Create a mock checkpoint
    -        checkpoint = {
    -            "model_state": {"dummy": "data"},
    -            "optimizer_state": None,
    -            "fitsbolt_cfg": fitsbolt_cfg,
    -        }
    -
             # Save checkpoint
    -        checkpoint_path = Path(self.temp_dir) / "test_checkpoint.pth"
    -        torch.save(checkpoint, checkpoint_path)
    +        checkpoint_path = Path(self.temp_dir) / "test_checkpoint.safetensors"
    +        save_checkpoint(_make_checkpoint(fitsbolt_cfg=fitsbolt_cfg), checkpoint_path)
     
             # Load checkpoint
    -        loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)
    +        loaded_checkpoint = load_checkpoint(checkpoint_path)
             loaded_fitsbolt_cfg = loaded_checkpoint["fitsbolt_cfg"]
     
             # Verify
    @@ -310,22 +286,12 @@ def test_fitsbolt_config_in_checkpoint_dict(self):
     
         def test_backward_compatibility_checkpoint_without_fitsbolt(self):
             """Test loading checkpoints that don't have fitsbolt_cfg."""
    -        # Create a mock checkpoint without fitsbolt_cfg (legacy format)
    -        checkpoint = {
    -            "model_state": {"dummy": "data"},
    -            "optimizer_state": None,
    -        }
    -
    -        # Save checkpoint
    -        checkpoint_path = Path(self.temp_dir) / "legacy_checkpoint.pth"
    -        torch.save(checkpoint, checkpoint_path)
    +        # Save checkpoint without fitsbolt_cfg
    +        checkpoint_path = Path(self.temp_dir) / "legacy_checkpoint.safetensors"
    +        save_checkpoint(_make_checkpoint(fitsbolt_cfg=None), checkpoint_path)
     
             # Load checkpoint
    -        loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)
    -
    -        # Check that fitsbolt_cfg is not present
    -        assert "fitsbolt_cfg" not in loaded_checkpoint
    +        loaded_checkpoint = load_checkpoint(checkpoint_path)
     
    -        # Accessing non-existent key should return None via .get()
    -        result = loaded_checkpoint.get("fitsbolt_cfg")
    -        assert result is None
    +        # fitsbolt_cfg should be None (not missing)
    +        assert loaded_checkpoint["fitsbolt_cfg"] is None
    
  • tests/integration/test_model_io_integration.py+8 35 modified
    @@ -15,6 +15,7 @@
     from dotmap import DotMap
     from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
     
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint
     from anomaly_match.data_io.SessionIOHandler import SessionIOHandler
     from anomaly_match.pipeline.SessionTracker import SessionTracker
     from anomaly_match.utils.get_net_builder import get_net_builder
    @@ -59,7 +60,7 @@ def setup_method(self):
             from anomaly_match.utils.get_default_cfg import get_default_cfg
     
             self.cfg = get_default_cfg()
    -        self.cfg.model_path = str(self.temp_dir / "test_model.pth")
    +        self.cfg.model_path = str(self.temp_dir / "test_model.safetensors")
     
         def teardown_method(self):
             """Clean up test fixtures."""
    @@ -135,55 +136,27 @@ def test_load_model_with_normalisation_update(self):
     
         def test_load_model_nonexistent_file(self):
             """Test loading from nonexistent file."""
    -        self.cfg.model_path = str(self.temp_dir / "nonexistent.pth")
    +        self.cfg.model_path = str(self.temp_dir / "nonexistent.safetensors")
     
             success = self.session_io.load_model(self.mock_model, self.cfg)
     
             assert not success
     
    -    def test_load_model_checkpoint(self):
    -        """Test loading model checkpoint."""
    -        # Create and save a checkpoint
    -        model_state = {
    -            "train_model_state_dict": self.mock_model.train_model.state_dict(),
    -            "eval_model_state_dict": self.mock_model.eval_model.state_dict(),
    -            "total_it": self.mock_model.total_it,
    -        }
    -
    -        checkpoint_path = self.session_io.save_model_checkpoint(
    -            model_state, self.session_tracker, "test_checkpoint.pkl"
    -        )
    -
    -        # Load checkpoint
    -        loaded_checkpoint = self.session_io.load_model_checkpoint(checkpoint_path)
    -
    -        # Verify checkpoint was loaded
    -        assert loaded_checkpoint is not None
    -        assert "train_model_state_dict" in loaded_checkpoint
    -        assert "total_it" in loaded_checkpoint
    -        assert loaded_checkpoint["total_it"] == self.mock_model.total_it
    -
    -    def test_load_model_checkpoint_nonexistent(self):
    -        """Test loading nonexistent checkpoint."""
    -        checkpoint = self.session_io.load_model_checkpoint(str(self.temp_dir / "nonexistent.pkl"))
    -
    -        assert checkpoint is None
    -
     
    -TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.pth"
    +TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.safetensors"
     
     
    -@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.pth not available")
    +@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.safetensors not available")
     class TestStoredModelLoading:
    -    """Regression tests for loading the stored test_model.pth checkpoint.
    +    """Regression tests for loading the stored test_model.safetensors checkpoint.
     
         These tests verify that the checked-in test model remains compatible
         with the current model architecture (timm-based EfficientNet).
         """
     
         def test_stored_model_has_expected_keys(self):
             """Verify the stored checkpoint contains expected top-level keys."""
    -        checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu")
    +        checkpoint = load_checkpoint(TEST_MODEL_PATH)
     
             assert "eval_model" in checkpoint, (
                 f"Checkpoint missing 'eval_model' key. Found: {list(checkpoint.keys())}"
    @@ -192,7 +165,7 @@ def test_stored_model_has_expected_keys(self):
     
         def test_stored_model_loads_into_efficientnet_lite0(self):
             """Verify stored model state_dict is compatible with the current architecture."""
    -        checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu")
    +        checkpoint = load_checkpoint(TEST_MODEL_PATH)
     
             net_builder = get_net_builder("efficientnet-lite0", pretrained=False, in_channels=3)
             model = net_builder(num_classes=2, in_channels=3)
    
  • tests/integration/test_run_label_migration.py+9 8 modified
    @@ -13,6 +13,7 @@
     import pytest
     import torch
     
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint
     from anomaly_match.data_io.SessionIOHandler import SessionIOHandler
     from anomaly_match.pipeline.SessionTracker import SessionTracker
     
    @@ -67,25 +68,25 @@ def mock_config(self):
             """Create a mock configuration."""
             config = Mock()
             config.normalisation_method = "min_max"
    -        config.model_path = "test_model.pth"
    +        config.model_path = "test_model.safetensors"
             # Explicitly set fitsbolt_cfg to None to avoid pickling issues with Mock
             config.fitsbolt_cfg = None
             return config
     
         def test_save_run_basic(self, session_io, mock_model, temp_dir):
             """Test basic save_run functionality."""
    -        save_name = "test_model.pth"
    +        save_name = "test_model.safetensors"
             save_path = temp_dir
     
             result_path = session_io.save_run(mock_model, save_name, save_path)
     
    -        # Check that the model was saved
    +        # Check that the model was saved (save_checkpoint forces .safetensors extension)
             expected_path = os.path.join(save_path, save_name)
             assert result_path == expected_path
             assert os.path.exists(expected_path)
     
             # Verify the saved model can be loaded
    -        checkpoint = torch.load(expected_path, weights_only=False)
    +        checkpoint = load_checkpoint(expected_path)
             assert "train_model" in checkpoint
             assert "eval_model" in checkpoint
             assert "optimizer" in checkpoint
    @@ -95,7 +96,7 @@ def test_save_run_basic(self, session_io, mock_model, temp_dir):
     
         def test_save_run_with_session_tracker(self, session_io, mock_model, session_tracker, temp_dir):
             """Test save_run with session tracker integration."""
    -        save_name = "test_model.pth"
    +        save_name = "test_model.safetensors"
             save_path = temp_dir
     
             # Start a session iteration
    @@ -111,7 +112,7 @@ def test_save_run_with_session_tracker(self, session_io, mock_model, session_tra
     
         def test_save_run_with_config(self, session_io, mock_model, mock_config, temp_dir):
             """Test save_run with configuration saving."""
    -        save_name = "test_model.pth"
    +        save_name = "test_model.safetensors"
             save_path = temp_dir
     
             # Mock the config saving function
    @@ -182,7 +183,7 @@ def test_integration_training_run_flow(
             self, session_io, session_tracker, mock_model, mock_config, temp_dir
         ):
             """Test the complete integration flow for training run saving."""
    -        save_name = "final_model.pth"
    +        save_name = "final_model.safetensors"
             save_path = temp_dir
     
             # Simulate a training session
    @@ -199,7 +200,7 @@ def test_integration_training_run_flow(
             assert session_tracker.session_iterations[0].model_state_path == model_path
     
             # Verify model checkpoint structure
    -        checkpoint = torch.load(model_path, weights_only=False)
    +        checkpoint = load_checkpoint(model_path)
             assert all(key in checkpoint for key in ["train_model", "eval_model", "optimizer", "it"])
     
         def test_integration_label_saving_flow(self, session_io, session_tracker, temp_dir):
    
  • tests/test_data/test_model.safetensors+0 0 renamed
  • tests/unit/test_checkpoint_io.py+242 0 added
    @@ -0,0 +1,242 @@
    +#   Copyright (c) European Space Agency, 2025.
    +#
    +#   This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
    +#   is part of this source code package. No part of the package, including
    +#   this file, may be copied, modified, propagated, or distributed except according to
    +#   the terms contained in the file 'LICENCE.txt'.
    +
    +"""Unit tests for checkpoint_io: safetensors-based model checkpoint serialization."""
    +
    +import numpy as np
    +import pytest
    +import torch
    +from dotmap import DotMap
    +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
    +
    +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint
    +
    +
    +def _make_state_dict(seed=0):
    +    """Create a small deterministic state_dict for testing."""
    +    torch.manual_seed(seed)
    +    return {
    +        "layer.weight": torch.randn(4, 3),
    +        "layer.bias": torch.randn(4),
    +        "bn.running_mean": torch.zeros(4),
    +        "bn.running_var": torch.ones(4),
    +        "bn.num_batches_tracked": torch.tensor(0, dtype=torch.long),
    +    }
    +
    +
    +def _make_full_checkpoint(**overrides):
    +    """Create a complete checkpoint dict with sensible defaults."""
    +    checkpoint = {
    +        "train_model": _make_state_dict(seed=0),
    +        "eval_model": _make_state_dict(seed=1),
    +        "optimizer": None,
    +        "scheduler": None,
    +        "it": 42,
    +        "total_it": 100,
    +        "best_eval_acc": 0.95,
    +        "best_it": 80,
    +        "num_channels": 3,
    +        "net": "efficientnet-lite0",
    +        "normalisation_method": NormalisationMethod.CONVERSION_ONLY,
    +        "last_normalisation_method": NormalisationMethod.LOG,
    +        "fitsbolt_cfg": None,
    +    }
    +    checkpoint.update(overrides)
    +    return checkpoint
    +
    +
    +class TestSaveLoadRoundTrip:
    +    """Test that save_checkpoint → load_checkpoint round-trips all data correctly."""
    +
    +    def test_model_weights_roundtrip(self, tmp_path):
    +        """Verify train_model and eval_model state_dicts survive round-trip."""
    +        original = _make_full_checkpoint()
    +        path = save_checkpoint(original, tmp_path / "model")
    +
    +        loaded = load_checkpoint(path)
    +
    +        for key in ("train_model", "eval_model"):
    +            for param_name in original[key]:
    +                assert torch.equal(original[key][param_name], loaded[key][param_name]), (
    +                    f"{key}.{param_name} mismatch after round-trip"
    +                )
    +
    +    def test_scalar_metadata_roundtrip(self, tmp_path):
    +        """Verify scalar metadata (it, total_it, etc.) survives round-trip."""
    +        original = _make_full_checkpoint()
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        assert loaded["it"] == 42
    +        assert loaded["total_it"] == 100
    +        assert loaded["best_eval_acc"] == 0.95
    +        assert loaded["best_it"] == 80
    +        assert loaded["num_channels"] == 3
    +        assert loaded["net"] == "efficientnet-lite0"
    +
    +    def test_normalisation_enum_roundtrip(self, tmp_path):
    +        """Verify NormalisationMethod enum values survive round-trip."""
    +        original = _make_full_checkpoint()
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        assert loaded["normalisation_method"] == NormalisationMethod.CONVERSION_ONLY
    +        assert loaded["last_normalisation_method"] == NormalisationMethod.LOG
    +        assert isinstance(loaded["normalisation_method"], NormalisationMethod)
    +
    +    def test_optimizer_state_roundtrip(self, tmp_path):
    +        """Verify optimizer state (including momentum tensors) survives round-trip."""
    +        # Build a real optimizer state
    +        model = torch.nn.Linear(3, 2)
    +        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    +        # Step once to create momentum buffers
    +        loss = model(torch.randn(1, 3)).sum()
    +        loss.backward()
    +        optimizer.step()
    +
    +        opt_state = optimizer.state_dict()
    +        original = _make_full_checkpoint(optimizer=opt_state)
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        # Check param_groups
    +        assert loaded["optimizer"]["param_groups"][0]["lr"] == 0.01
    +        assert loaded["optimizer"]["param_groups"][0]["momentum"] == 0.9
    +
    +        # Check state tensors
    +        for param_idx in opt_state["state"]:
    +            for key in opt_state["state"][param_idx]:
    +                orig_val = opt_state["state"][param_idx][key]
    +                loaded_val = loaded["optimizer"]["state"][param_idx][key]
    +                if isinstance(orig_val, torch.Tensor):
    +                    assert torch.equal(orig_val, loaded_val)
    +
    +    def test_scheduler_state_roundtrip(self, tmp_path):
    +        """Verify scheduler state survives round-trip."""
    +        sched_state = {
    +            "T_max": 200,
    +            "eta_min": 0,
    +            "last_epoch": 50,
    +            "_step_count": 51,
    +            "base_lrs": [0.01],
    +            "_last_lr": [0.005],
    +        }
    +        original = _make_full_checkpoint(scheduler=sched_state)
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        assert loaded["scheduler"]["T_max"] == 200
    +        assert loaded["scheduler"]["last_epoch"] == 50
    +
    +    def test_fitsbolt_cfg_roundtrip(self, tmp_path):
    +        """Verify fitsbolt DotMap config survives round-trip."""
    +        fb_cfg = DotMap(
    +            {
    +                "output_dtype": np.uint8,
    +                "size": [64, 64],
    +                "normalisation_method": NormalisationMethod.CONVERSION_ONLY,
    +                "n_output_channels": 3,
    +                "channel_combination": np.array([[1, 0], [0, 1], [0.5, 0.5]]),
    +            }
    +        )
    +        original = _make_full_checkpoint(fitsbolt_cfg=fb_cfg)
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        loaded_fb = loaded["fitsbolt_cfg"]
    +        assert isinstance(loaded_fb, DotMap)
    +        assert loaded_fb.normalisation_method == NormalisationMethod.CONVERSION_ONLY
    +        assert loaded_fb.output_dtype == np.uint8
    +        assert np.array_equal(loaded_fb.channel_combination, fb_cfg.channel_combination)
    +
    +    def test_labeled_data_csv_roundtrip(self, tmp_path):
    +        """Verify labeled_data_csv string survives round-trip."""
    +        csv = "filename,label\nimg1.jpg,anomaly\nimg2.jpg,normal\n"
    +        original = _make_full_checkpoint(labeled_data_csv=csv)
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        assert loaded["labeled_data_csv"] == csv
    +
    +    def test_none_values_roundtrip(self, tmp_path):
    +        """Verify None values survive round-trip correctly."""
    +        original = _make_full_checkpoint(
    +            optimizer=None,
    +            scheduler=None,
    +            fitsbolt_cfg=None,
    +            best_eval_acc=None,
    +            normalisation_method=None,
    +        )
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +
    +        assert loaded["optimizer"] is None
    +        assert loaded["scheduler"] is None
    +        assert loaded["fitsbolt_cfg"] is None
    +        assert loaded["best_eval_acc"] is None
    +        assert loaded["normalisation_method"] is None
    +
    +
    +class TestFileFormat:
    +    """Test file format details."""
    +
    +    def test_extension_forced_to_safetensors(self, tmp_path):
    +        """save_checkpoint forces .safetensors extension."""
    +        path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model.pth")
    +        assert path.suffix == ".safetensors"
    +        assert path.exists()
    +
    +    def test_safetensors_extension_preserved(self, tmp_path):
    +        """If .safetensors extension is already correct, it's preserved."""
    +        path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model.safetensors")
    +        assert path.suffix == ".safetensors"
    +
    +    def test_load_nonexistent_raises(self, tmp_path):
    +        """Loading a nonexistent file raises FileNotFoundError."""
    +        with pytest.raises(FileNotFoundError):
    +            load_checkpoint(tmp_path / "nonexistent.safetensors")
    +
    +    def test_shared_memory_tensors(self, tmp_path):
    +        """Tensors that share memory (e.g. EMA copy) are saved without error."""
    +        shared = _make_state_dict(seed=0)
    +        original = _make_full_checkpoint(
    +            train_model=shared,
    +            eval_model=shared,  # same object, shares memory
    +        )
    +        # Should not raise RuntimeError about shared tensors
    +        path = save_checkpoint(original, tmp_path / "model")
    +        loaded = load_checkpoint(path)
    +        assert "train_model" in loaded
    +        assert "eval_model" in loaded
    +
    +
    +class TestSecurity:
    +    """Verify the format is safe against code execution attacks."""
    +
    +    def test_no_pickle_in_file(self, tmp_path):
    +        """The saved file must not contain pickle opcodes."""
    +        path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model")
    +        data = path.read_bytes()
    +        # Pickle protocol markers (0x80 = protocol 2+, 'cos\n' = protocol 0)
    +        # safetensors files start with a little-endian u64 header size
    +        assert not data[8:].startswith(b"\x80\x02")  # not pickle protocol 2
    +        assert not data[8:].startswith(b"cos\n")  # not pickle protocol 0
    +
    +    def test_metadata_is_plain_json(self, tmp_path):
    +        """All metadata in the safetensors header is valid JSON strings."""
    +        import json
    +
    +        from safetensors import safe_open
    +
    +        path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model")
    +        with safe_open(str(path), framework="pt") as f:
    +            metadata = f.metadata()
    +
    +        for key, value in metadata.items():
    +            # Every metadata value must be a valid JSON string
    +            parsed = json.loads(value)
    +            assert parsed is not None or value == "null"
    
  • tests/unit/test_session_io_handler.py+2 46 modified
    @@ -6,7 +6,6 @@
     #   the terms contained in the file 'LICENCE.txt'.
     
     import json
    -import pickle
     import shutil
     import tempfile
     from pathlib import Path
    @@ -101,39 +100,6 @@ def test_save_session_custom_path(self):
             assert save_path.exists()
             assert (save_path / "session_metadata.json").exists()
     
    -    def test_save_model_checkpoint(self):
    -        """Test saving model checkpoint."""
    -        model_state = {"weights": [1, 2, 3], "epoch": 10}
    -
    -        checkpoint_path = self.io_handler.save_model_checkpoint(model_state, self.session_tracker)
    -
    -        # Check checkpoint was saved
    -        assert Path(checkpoint_path).exists()
    -        assert "checkpoints" in checkpoint_path
    -        assert checkpoint_path.endswith(".pkl")
    -
    -        # Verify checkpoint content
    -        with open(checkpoint_path, "rb") as f:
    -            loaded_state = pickle.load(f)
    -        assert loaded_state == model_state
    -
    -        # Verify that session tracker was updated - check the last iteration
    -        assert len(self.session_tracker.session_iterations) > 0
    -        last_iter = self.session_tracker.session_iterations[-1]
    -        assert last_iter.model_state_path == checkpoint_path
    -
    -    def test_save_model_checkpoint_custom_name(self):
    -        """Test saving model checkpoint with custom name."""
    -        model_state = {"test": "data"}
    -        custom_name = "custom_checkpoint.pkl"
    -
    -        checkpoint_path = self.io_handler.save_model_checkpoint(
    -            model_state, self.session_tracker, checkpoint_name=custom_name
    -        )
    -
    -        assert checkpoint_path.endswith(custom_name)
    -        assert Path(checkpoint_path).exists()
    -
         def test_load_session_complete_cycle(self):
             """Test complete save/load cycle."""
             # First save a session
    @@ -221,7 +187,7 @@ def setup_method(self):
             session_tracker.add_labeled_sample("img1.jpg", "anomaly")
             session_tracker.add_labeled_sample("img2.jpg", "normal")
             session_tracker.update_test_performance({"AUROC": 0.92, "AUPRC": 0.88})
    -        session_tracker.update_model_state_path("models/final_model.pth")
    +        session_tracker.update_model_state_path("models/final_model.safetensors")
     
             # Start second iteration
             session_tracker.start_new_session_iteration()
    @@ -347,15 +313,11 @@ def test_full_workflow_integration(self):
             tracker.update_model_iteration(0.5)
             tracker.add_labeled_sample("img4.jpg", "anomaly")
             tracker.update_test_performance({"AUROC": 0.93, "AUPRC": 0.89})
    -        tracker.update_model_state_path("models/best_model.pth")
    +        tracker.update_model_state_path("models/best_model.safetensors")
     
             # Save session
             saved_path = self.io_handler.save_session(tracker)
     
    -        # Save model checkpoint
    -        model_state = {"epoch": 50, "weights": [1, 2, 3, 4]}
    -        checkpoint_path = self.io_handler.save_model_checkpoint(model_state, tracker)
    -
             # Load session back
             loaded_tracker = self.io_handler.load_session(saved_path)
     
    @@ -365,12 +327,6 @@ def test_full_workflow_integration(self):
             assert len(loaded_tracker.get_labeled_data_df()) == 4
             assert len(loaded_tracker.session_iterations) == 2
     
    -        # Check model checkpoint exists
    -        assert Path(checkpoint_path).exists()
    -        with open(checkpoint_path, "rb") as f:
    -            loaded_model = pickle.load(f)
    -        assert loaded_model == model_state
    -
         def test_multiple_sessions_management(self):
             """Test managing multiple sessions."""
             # Create multiple sessions
    
  • .vulture_whitelist.py+0 2 modified
    @@ -12,8 +12,6 @@
     """
     
     # SessionIOHandler methods - public API used in tests
    -save_model_checkpoint  # noqa - Used in test_session_io_handler.py, test_model_io_integration.py
    -load_model_checkpoint  # noqa - Used in test_model_io_integration.py
     list_sessions  # noqa - Used in test_session_io_handler.py
     save_run  # noqa - Used in test_run_label_migration.py
     save_labels_to_output_dir  # noqa - Used in test_run_label_migration.py
    

Vulnerability mechanics

Root cause

"The application uses torch.load() to deserialize untrusted model checkpoint files, which leads to arbitrary code execution via pickle deserialization."

Attack vector

An attacker can trigger this vulnerability by providing a crafted model checkpoint file to the application [ref_id=1]. The application loads these files from session directories using `torch.load()` with unrestricted deserialization, allowing the execution of arbitrary code embedded within the file [ref_id=1].

Affected code

The vulnerability is located in the model loading logic where `torch.load()` was utilized to process checkpoint files [ref_id=1]. The patch introduces `anomaly_match/data_io/checkpoint_io.py` to handle secure checkpoint operations and removes the insecure `save_model_checkpoint` and `load_model_checkpoint` functions from `SessionIOHandler` [ref_id=1].

What the fix does

The patch replaces the use of `torch.load()` with the `safetensors` library [ref_id=1]. By switching to `safetensors`, the application now stores tensors in a binary format and metadata as JSON, eliminating the risk associated with pickle-based deserialization of untrusted files [ref_id=1]. Additionally, all legacy `.pth` and `.pkl` file loading has been removed in favor of the new `.safetensors` format [ref_id=1].

Preconditions

  • inputThe attacker must be able to provide a crafted model checkpoint file to the application.

Generated on Jun 1, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

3

News mentions

0

No linked articles in our index yet.