VYPR
Medium severity5.5NVD Advisory· Published Jun 16, 2026· Updated Jun 16, 2026

CVE-2026-47748

CVE-2026-47748

Description

Out-of-bounds read in stable-diffusion.cpp's pickle parser allows denial of service via crafted .ckpt file.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

Out-of-bounds read in stable-diffusion.cpp's pickle parser allows denial of service via crafted .ckpt file.

Vulnerability

The pickle .ckpt parser in src/model.cpp of stable-diffusion.cpp versions prior to master-584-0a7ae07 does not consistently validate that sufficient input remains before reading opcode arguments or advancing the parser buffer. Opcode handlers use expressions such as buffer += N without checking buffer + N <= buffer_end, allowing a crafted or truncated .ckpt file to cause out-of-bounds reads past the end of the metadata buffer [1][2].

Exploitation

An attacker can craft a malformed or truncated .ckpt file that triggers the out-of-bounds read when loaded by an affected application. No authentication is required; the victim or application must load the file from an untrusted source, such as a model sharing site. LibFuzzer demonstrated crashes in under one second using malformed checkpoint inputs [2].

Impact

Successful exploitation results in an out-of-bounds read, which can cause the process to crash, leading to a denial of service. The advisory does not indicate that arbitrary code execution or information disclosure is achievable [2].

Mitigation

The issue is fixed in version master-584-0a7ae07 (commit 0a7ae07) [1][2]. Developers unable to immediately update should avoid loading .ckpt checkpoint files from untrusted sources and prefer safer formats such as .safetensors where possible [2].

AI Insight generated on Jun 16, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected products

2

Patches

1
0a7ae07f948e

feat: add restricted torch legacy checkpoint loading (#1443)

https://github.com/leejet/stable-diffusion.cppleejetApr 19, 2026via body-scan
14 files changed · +1638 471
  • src/denoiser.hpp+8 11 modified
    @@ -1523,12 +1523,10 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
                                                   const std::vector<float>& sigmas,
                                                   std::shared_ptr<RNG> rng,
                                                   float eta) {
    -
         int steps = static_cast<int>(sigmas.size()) - 1;
         for (int i = 0; i < steps; i++) {
    -
    -        float sigma       = sigmas[i];
    -        float sigma_to    = sigmas[i + 1];
    +        float sigma    = sigmas[i];
    +        float sigma_to = sigmas[i + 1];
     
             auto model_output_opt = model(x, sigma, i + 1);
             if (model_output_opt.empty()) {
    @@ -1551,12 +1549,11 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
             float std_dev_t = eta * std::sqrt(variance);
     
             x = pred_original_sample +
    -            std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
    +            std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;
     
             if (eta > 0) {
    -            x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
    +            x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
             }
    -
         }
         return x;
     }
    @@ -1584,8 +1581,10 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
     
         auto get_timestep_from_sigma = [&](float s) -> int {
             auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
    -        if (it == compvis_sigmas.begin()) return 0;
    -        if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
    +        if (it == compvis_sigmas.begin())
    +            return 0;
    +        if (it == compvis_sigmas.end())
    +            return TIMESTEPS - 1;
             int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
             int idx_low  = idx_high - 1;
             if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
    @@ -1596,7 +1595,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
     
         int steps = static_cast<int>(sigmas.size()) - 1;
         for (int i = 0; i < steps; i++) {
    -
             float sigma_to    = sigmas[i + 1];
             int prev_timestep = get_timestep_from_sigma(sigma_to);
             int timestep_s    = (int)floor((1 - eta) * prev_timestep);
    @@ -1626,7 +1624,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
                 x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
                     std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
             }
    -
         }
         return x;
     }
    
  • src/model.cpp+63 26 modified
    @@ -2,6 +2,7 @@
     #include <atomic>
     #include <chrono>
     #include <cstdarg>
    +#include <cstdlib>
     #include <fstream>
     #include <functional>
     #include <mutex>
    @@ -13,9 +14,10 @@
     #include <vector>
     
     #include "model.h"
    -#include "model_io/ckpt_io.h"
     #include "model_io/gguf_io.h"
     #include "model_io/safetensors_io.h"
    +#include "model_io/torch_legacy_io.h"
    +#include "model_io/torch_zip_io.h"
     #include "stable-diffusion.h"
     #include "util.h"
     
    @@ -229,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
         } else if (is_safetensors_file(file_path)) {
             LOG_INFO("load %s using safetensors format", file_path.c_str());
             return init_from_safetensors_file(file_path, prefix);
    -    } else if (is_ckpt_file(file_path)) {
    -        LOG_INFO("load %s using checkpoint format", file_path.c_str());
    -        return init_from_ckpt_file(file_path, prefix);
    +    } else if (is_torch_zip_file(file_path)) {
    +        LOG_INFO("load %s using torch zip format", file_path.c_str());
    +        return init_from_torch_zip_file(file_path, prefix);
    +    } else if (init_from_torch_legacy_file(file_path, prefix)) {
    +        LOG_INFO("load %s using torch legacy format", file_path.c_str());
    +        return true;
         } else {
             if (file_exists(file_path)) {
                 LOG_WARN("unknown format %s", file_path.c_str());
    @@ -329,40 +334,47 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
         return true;
     }
     
    -/*================================================= DiffusersModelLoader ==================================================*/
    +/*================================================= TorchLegacyModelLoader ==================================================*/
     
    -bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
    -    std::string unet_path   = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
    -    std::string vae_path    = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
    -    std::string clip_path   = path_join(file_path, "text_encoder/model.safetensors");
    -    std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
    +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) {
    +    LOG_DEBUG("init from torch legacy '%s'", file_path.c_str());
     
    -    if (!init_from_safetensors_file(unet_path, "unet.")) {
    +    std::vector<TensorStorage> tensor_storages;
    +    std::string error;
    +    if (!read_torch_legacy_file(file_path, tensor_storages, &error)) {
    +        if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) {
    +            LOG_WARN("%s", error.c_str());
    +        }
             return false;
         }
     
    -    if (!init_from_safetensors_file(vae_path, "vae.")) {
    -        LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
    -        // return false;
    -    }
    -    if (!init_from_safetensors_file(clip_path, "te.")) {
    -        LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
    -        // return false;
    -    }
    -    if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
    -        LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
    +    file_paths_.push_back(file_path);
    +    size_t file_index = file_paths_.size() - 1;
    +
    +    for (auto& tensor_storage : tensor_storages) {
    +        if (is_unused_tensor(tensor_storage.name)) {
    +            continue;
    +        }
    +
    +        if (!starts_with(tensor_storage.name, prefix)) {
    +            tensor_storage.name = prefix + tensor_storage.name;
    +        }
    +        tensor_storage.file_index = file_index;
    +
    +        add_tensor_storage(tensor_storage);
         }
    +
         return true;
     }
     
    -/*================================================= CkptModelLoader ==================================================*/
    +/*================================================= TorchZipModelLoader ==================================================*/
     
    -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
    +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) {
         LOG_DEBUG("init from '%s'", file_path.c_str());
     
         std::vector<TensorStorage> tensor_storages;
         std::string error;
    -    if (!read_ckpt_file(file_path, tensor_storages, &error)) {
    +    if (!read_torch_zip_file(file_path, tensor_storages, &error)) {
             LOG_ERROR("%s", error.c_str());
             return false;
         }
    @@ -384,6 +396,32 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
         return true;
     }
     
    +/*================================================= DiffusersModelLoader ==================================================*/
    +
    +bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
    +    std::string unet_path   = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
    +    std::string vae_path    = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
    +    std::string clip_path   = path_join(file_path, "text_encoder/model.safetensors");
    +    std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
    +
    +    if (!init_from_safetensors_file(unet_path, "unet.")) {
    +        return false;
    +    }
    +
    +    if (!init_from_safetensors_file(vae_path, "vae.")) {
    +        LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
    +        // return false;
    +    }
    +    if (!init_from_safetensors_file(clip_path, "te.")) {
    +        LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
    +        // return false;
    +    }
    +    if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
    +        LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
    +    }
    +    return true;
    +}
    +
     SDVersion ModelLoader::get_sd_version() {
         TensorStorage token_embedding_weight, input_block_weight;
     
    @@ -1210,6 +1248,5 @@ bool convert(const char* input_path,
         if (convert_name) {
             model_loader.convert_tensors_name();
         }
    -    bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
    -    return success;
    +    return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
     }
    
  • src/model.h+2 1 modified
    @@ -200,7 +200,8 @@ class ModelLoader {
     
         bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
         bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
    -    bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
    +    bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
    +    bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
         bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
     
     public:
    
  • src/model_io/binary_io.h+57 0 added
    @@ -0,0 +1,57 @@
    +#ifndef __SD_MODEL_IO_BINARY_IO_H__
    +#define __SD_MODEL_IO_BINARY_IO_H__
    +
    +#include <cstdint>
    +#include <ostream>
    +
    +namespace model_io {
    +
    +    inline int32_t read_int(const uint8_t* buffer) {
    +        uint32_t value = 0;
    +        value |= static_cast<uint32_t>(buffer[3]) << 24;
    +        value |= static_cast<uint32_t>(buffer[2]) << 16;
    +        value |= static_cast<uint32_t>(buffer[1]) << 8;
    +        value |= static_cast<uint32_t>(buffer[0]);
    +        return static_cast<int32_t>(value);
    +    }
    +
    +    inline uint16_t read_short(const uint8_t* buffer) {
    +        uint16_t value = 0;
    +        value |= static_cast<uint16_t>(buffer[1]) << 8;
    +        value |= static_cast<uint16_t>(buffer[0]);
    +        return value;
    +    }
    +
    +    inline uint64_t read_u64(const uint8_t* buffer) {
    +        uint64_t value = 0;
    +        value |= static_cast<uint64_t>(buffer[7]) << 56;
    +        value |= static_cast<uint64_t>(buffer[6]) << 48;
    +        value |= static_cast<uint64_t>(buffer[5]) << 40;
    +        value |= static_cast<uint64_t>(buffer[4]) << 32;
    +        value |= static_cast<uint64_t>(buffer[3]) << 24;
    +        value |= static_cast<uint64_t>(buffer[2]) << 16;
    +        value |= static_cast<uint64_t>(buffer[1]) << 8;
    +        value |= static_cast<uint64_t>(buffer[0]);
    +        return value;
    +    }
    +
    +    inline void write_u64(std::ostream& stream, uint64_t value) {
    +        uint8_t buffer[8];
    +        for (int i = 0; i < 8; ++i) {
    +            buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
    +        }
    +        stream.write((const char*)buffer, sizeof(buffer));
    +    }
    +
    +    inline int find_char(const uint8_t* buffer, int len, char c) {
    +        for (int pos = 0; pos < len; pos++) {
    +            if (buffer[pos] == (uint8_t)c) {
    +                return pos;
    +            }
    +        }
    +        return -1;
    +    }
    +
    +}  // namespace model_io
    +
    +#endif  // __SD_MODEL_IO_BINARY_IO_H__
    
  • src/model_io/ckpt_io.cpp+0 403 removed
    @@ -1,403 +0,0 @@
    -#include "ckpt_io.h"
    -
    -#include <cstdint>
    -#include <cstdio>
    -#include <cstdlib>
    -#include <cstring>
    -#include <string>
    -#include <vector>
    -
    -#include "zip.h"
    -
    -static constexpr int MAX_STRING_BUFFER = 512;
    -
    -static void set_error(std::string* error, const std::string& message) {
    -    if (error != nullptr) {
    -        *error = message;
    -    }
    -}
    -
    -static int32_t read_int(const uint8_t* buffer) {
    -    // little endian
    -    uint32_t value = 0;
    -    value |= static_cast<uint32_t>(buffer[3]) << 24;
    -    value |= static_cast<uint32_t>(buffer[2]) << 16;
    -    value |= static_cast<uint32_t>(buffer[1]) << 8;
    -    value |= static_cast<uint32_t>(buffer[0]);
    -    return static_cast<int32_t>(value);
    -}
    -
    -static uint16_t read_short(const uint8_t* buffer) {
    -    // little endian
    -    uint16_t value = 0;
    -    value |= static_cast<uint16_t>(buffer[1]) << 8;
    -    value |= static_cast<uint16_t>(buffer[0]);
    -    return value;
    -}
    -
    -bool is_ckpt_file(const std::string& file_path) {
    -    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    -    if (zip == nullptr) {
    -        return false;
    -    }
    -    zip_close(zip);
    -    return true;
    -}
    -
    -/*================================================= CkptModelLoader ==================================================*/
    -
    -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
    -//     0: \x80 PROTO      2
    -//     2: }    EMPTY_DICT
    -//     3: q    BINPUT     0
    -//     5: (    MARK
    -//     6: X        BINUNICODE 'epoch'
    -//    16: q        BINPUT     1
    -//    18: K        BININT1    6
    -//    20: X        BINUNICODE 'global_step'
    -//    36: q        BINPUT     2
    -//    38: J        BININT     470000
    -//    43: X        BINUNICODE 'pytorch-lightning_version'
    -//    73: q        BINPUT     3
    -//    75: X        BINUNICODE '1.4.2'
    -//    85: q        BINPUT     4
    -//    87: X        BINUNICODE 'state_dict'
    -//   102: q        BINPUT     5
    -//   104: }        EMPTY_DICT
    -//   105: q        BINPUT     6
    -//   107: (        MARK
    -//   108: X            BINUNICODE 'betas'
    -//   118: q            BINPUT     7
    -//   120: c            GLOBAL     'torch._utils _rebuild_tensor_v2'
    -//   153: q            BINPUT     8
    -//   155: (            MARK
    -//   156: (                MARK
    -//   157: X                    BINUNICODE 'storage'
    -//   169: q                    BINPUT     9
    -//   171: c                    GLOBAL     'torch FloatStorage'
    -//   191: q                    BINPUT     10
    -//   193: X                    BINUNICODE '0'
    -//   199: q                    BINPUT     11
    -//   201: X                    BINUNICODE 'cpu'
    -//   209: q                    BINPUT     12
    -//   211: M                    BININT2    1000
    -//   214: t                    TUPLE      (MARK at 156)
    -//   215: q                BINPUT     13
    -//   217: Q                BINPERSID
    -//   218: K                BININT1    0
    -//   220: M                BININT2    1000
    -//  ...............................
    -//  3201: q            BINPUT     250
    -//  3203: R            REDUCE
    -//  3204: q            BINPUT     251
    -//  3206: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
    -//  3264: q            BINPUT     252
    -//  3266: h            BINGET     8
    -//  3268: (            MARK
    -//  3269: (                MARK
    -//  3270: h                    BINGET     9
    -//  3272: h                    BINGET     10
    -//  3274: X                    BINUNICODE '30'
    -//  3281: q                    BINPUT     253
    -//  3283: h                    BINGET     12
    -//  3285: J                    BININT     102400
    -//  3290: t                    TUPLE      (MARK at 3269)
    -//  3291: q                BINPUT     254
    -//  3293: Q                BINPERSID
    -//  3294: K                BININT1    0
    -//  3296: (                MARK
    -//  3297: M                    BININT2    320
    -//  3300: M                    BININT2    320
    -//  3303: K                    BININT1    1
    -//  3305: K                    BININT1    1
    -//  3307: t                    TUPLE      (MARK at 3296)
    -//  3308: q                BINPUT     255
    -//  3310: (                MARK
    -//  3311: M                    BININT2    320
    -//  3314: K                    BININT1    1
    -//  3316: K                    BININT1    1
    -//  3318: K                    BININT1    1
    -//  3320: t                    TUPLE      (MARK at 3310)
    -//  3321: r                LONG_BINPUT 256
    -//  3326: \x89             NEWFALSE
    -//  3327: h                BINGET     16
    -//  3329: )                EMPTY_TUPLE
    -//  3330: R                REDUCE
    -//  3331: r                LONG_BINPUT 257
    -//  3336: t                TUPLE      (MARK at 3268)
    -//  3337: r            LONG_BINPUT 258
    -//  3342: R            REDUCE
    -//  3343: r            LONG_BINPUT 259
    -//  3348: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
    -//  3404: r            LONG_BINPUT 260
    -//  3409: h            BINGET     8
    -//  3411: (            MARK
    -//  3412: (                MARK
    -//  3413: h                    BINGET     9
    -//  3415: h                    BINGET     10
    -//  3417: X                    BINUNICODE '31'
    -
    -struct PickleTensorReader {
    -    enum ReadPhase {
    -        READ_NAME,
    -        READ_DATA,
    -        CHECK_SIZE,
    -        READ_DIMENS
    -    };
    -    ReadPhase phase   = READ_NAME;
    -    size_t entry_size = 0;
    -    int32_t nelements = 0;
    -
    -    TensorStorage tensor_storage;
    -
    -    static ggml_type global_type;  // all pickle_tensors data type
    -    static bool read_global_type;
    -
    -    bool read_int_value(uint32_t value) {
    -        if (phase == CHECK_SIZE) {
    -            if (entry_size == value * ggml_type_size(tensor_storage.type)) {
    -                nelements = value;
    -                phase     = READ_DIMENS;
    -                return true;
    -            } else {
    -                phase = READ_NAME;
    -            }
    -        } else if (phase == READ_DIMENS) {
    -            if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) {  // too many dimens
    -                phase                 = READ_NAME;
    -                tensor_storage.n_dims = 0;
    -            }
    -            if (nelements % value == 0) {
    -                tensor_storage.ne[tensor_storage.n_dims] = value;
    -                tensor_storage.n_dims++;
    -            }
    -        }
    -        return false;
    -    }
    -
    -    void read_global(const std::string& str) {
    -        if (str == "FloatStorage") {
    -            if (read_global_type) {
    -                global_type      = GGML_TYPE_F32;
    -                read_global_type = false;
    -            }
    -            tensor_storage.type = GGML_TYPE_F32;
    -        } else if (str == "HalfStorage") {
    -            if (read_global_type) {
    -                global_type      = GGML_TYPE_F16;
    -                read_global_type = false;
    -            }
    -            tensor_storage.type = GGML_TYPE_F16;
    -        }
    -    }
    -
    -    void read_string(const std::string& str, zip_t* zip, std::string dir) {
    -        if (str == "storage") {
    -            read_global_type = true;
    -        } else if (str != "state_dict") {
    -            if (phase == READ_DATA) {
    -                std::string entry_name = dir + "data/" + std::string(str);
    -
    -                size_t i, n = zip_entries_total(zip);
    -                for (i = 0; i < n; ++i) {
    -                    zip_entry_openbyindex(zip, i);
    -                    {
    -                        std::string name = zip_entry_name(zip);
    -                        if (name == entry_name) {
    -                            tensor_storage.index_in_zip = (int)i;
    -                            entry_size                  = zip_entry_size(zip);
    -                            zip_entry_close(zip);
    -                            break;
    -                        }
    -                    }
    -                    zip_entry_close(zip);
    -                }
    -
    -                phase = entry_size > 0 ? CHECK_SIZE : READ_NAME;
    -            }
    -            if (!read_global_type && phase == READ_NAME) {
    -                tensor_storage.name = str;
    -                phase               = READ_DATA;
    -                tensor_storage.type = global_type;
    -            }
    -        }
    -    }
    -};
    -
    -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32;  // all pickle_tensors data type
    -bool PickleTensorReader::read_global_type = false;
    -
    -static int find_char(uint8_t* buffer, int len, char c) {
    -    for (int pos = 0; pos < len; pos++) {
    -        if (buffer[pos] == c) {
    -            return pos;
    -        }
    -    }
    -    return -1;
    -}
    -
    -static bool parse_data_pkl(uint8_t* buffer,
    -                           size_t buffer_size,
    -                           zip_t* zip,
    -                           std::string dir,
    -                           std::vector<TensorStorage>& tensor_storages,
    -                           std::string* error) {
    -    uint8_t* buffer_end = buffer + buffer_size;
    -    if (buffer[0] == 0x80) {  // proto
    -        if (buffer[1] != 2) {
    -            set_error(error, "unsupported pickle protocol");
    -            return false;
    -        }
    -        buffer += 2;  // 0x80 and version
    -        char string_buffer[MAX_STRING_BUFFER];
    -        bool finish = false;
    -        PickleTensorReader reader;
    -        // read pickle binary file
    -        while (!finish && buffer < buffer_end) {
    -            uint8_t opcode = *buffer;
    -            buffer++;
    -            // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
    -            // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
    -            switch (opcode) {
    -                case '}':  // EMPTY_DICT     = b'}'   # push empty dict
    -                    break;
    -                case ']':  // EMPTY_LIST     = b']'   # push empty list
    -                    break;
    -                // skip unused sections
    -                case 'h':  // BINGET         = b'h'   #   "    "    "    "   "   "  ;   "    " 1-byte arg
    -                case 'q':  // BINPUT         = b'q'   #   "     "    "   "   " ;   "    " 1-byte arg
    -                case 'Q':  // BINPERSID      = b'Q'   #  "       "         "  ;  "  "   "     "  stack
    -                    buffer++;
    -                    break;
    -                case 'r':  // LONG_BINPUT    = b'r'   #   "     "    "   "   " ;   "    " 4-byte arg
    -                    buffer += 4;
    -                    break;
    -                case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    -                    buffer += 8;
    -                    break;
    -                case 0x94:  // MEMOIZE          = b'\x94'  # store top of the stack in memo
    -                    break;
    -                case '(':  // MARK           = b'('   # push special markobject on stack
    -                    break;
    -                case 'K':  // BININT1        = b'K'   # push 1-byte unsigned int
    -                {
    -                    uint8_t value = *buffer;
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;
    -                    }
    -                    buffer++;
    -                } break;
    -                case 'M':  // BININT2        = b'M'   # push 2-byte unsigned int
    -                {
    -                    uint16_t value = read_short(buffer);
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;
    -                    }
    -                    buffer += 2;
    -                } break;
    -                case 'J':  // BININT         = b'J'   # push four-byte signed int
    -                {
    -                    const int32_t value = read_int(buffer);
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;  // skip tuple after read num_elements
    -                    }
    -                    buffer += 4;
    -                } break;
    -                case 'X':  // BINUNICODE     = b'X'   #   "     "       "  ; counted UTF-8 string argument
    -                {
    -                    const int32_t len = read_int(buffer);
    -                    buffer += 4;
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    if (len > MAX_STRING_BUFFER) {
    -                        // keep truncated names null-terminated, matching the old parser behavior
    -                    }
    -                    memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1));
    -                    buffer += len;
    -                    reader.read_string(string_buffer, zip, dir);
    -                } break;
    -                case 0x8C:  // SHORT_BINUNICODE = b'\x8c'  # push short string; UTF-8 length < 256 bytes
    -                {
    -                    const int8_t len = *buffer;
    -                    buffer++;
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    memcpy(string_buffer, buffer, len);
    -                    buffer += len;
    -                    // printf("String: '%s'\n", string_buffer);
    -                } break;
    -                case 'c':  // GLOBAL         = b'c'   # push self.find_class(modname, name); 2 string args
    -                {
    -                    int len = find_char(buffer, MAX_STRING_BUFFER, '\n');
    -
    -                    buffer += len + 1;
    -                    len = find_char(buffer, MAX_STRING_BUFFER, '\n');
    -
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    memcpy(string_buffer, buffer, len);
    -                    buffer += len + 1;
    -                    reader.read_global(string_buffer);
    -                } break;
    -                case 0x86:  // TUPLE2         = b'\x86'  # build 2-tuple from two topmost stack items
    -                case 0x85:  // TUPLE1         = b'\x85'  # build 1-tuple from stack top
    -                case 't':   // TUPLE          = b't'   # build tuple from topmost stack items
    -                    if (reader.phase == PickleTensorReader::READ_DIMENS) {
    -                        reader.tensor_storage.reverse_ne();
    -                        tensor_storages.push_back(reader.tensor_storage);
    -
    -                        // LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
    -                        // reset
    -                        reader = PickleTensorReader();
    -                    }
    -                    break;
    -                case '.':  // STOP           = b'.'   # every pickle ends with STOP
    -                    finish = true;
    -                    break;
    -                default:
    -                    break;
    -            }
    -        }
    -    }
    -    return true;
    -}
    -
    -bool read_ckpt_file(const std::string& file_path,
    -                    std::vector<TensorStorage>& tensor_storages,
    -                    std::string* error) {
    -    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    -    if (zip == nullptr) {
    -        set_error(error, "failed to open '" + file_path + "'");
    -        return false;
    -    }
    -
    -    tensor_storages.clear();
    -    bool success = true;
    -    int n        = (int)zip_entries_total(zip);
    -    for (int i = 0; i < n; ++i) {
    -        zip_entry_openbyindex(zip, i);
    -        {
    -            std::string name = zip_entry_name(zip);
    -            size_t pos       = name.find("data.pkl");
    -            if (pos != std::string::npos) {
    -                std::string dir = name.substr(0, pos);
    -                printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
    -                void* pkl_data = nullptr;
    -                size_t pkl_size;
    -                zip_entry_read(zip, &pkl_data, &pkl_size);
    -
    -                // LOG_DEBUG("%lld", pkl_size);
    -
    -                if (!parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
    -                    success = false;
    -                }
    -
    -                free(pkl_data);
    -            }
    -        }
    -        zip_entry_close(zip);
    -
    -        if (!success) {
    -            break;
    -        }
    -    }
    -    zip_close(zip);
    -    return success;
    -}
    
  • src/model_io/ckpt_io.h+0 14 removed
    @@ -1,14 +0,0 @@
    -#ifndef __SD_MODEL_IO_CKPT_IO_H__
    -#define __SD_MODEL_IO_CKPT_IO_H__
    -
    -#include <string>
    -#include <vector>
    -
    -#include "tensor_storage.h"
    -
    -bool is_ckpt_file(const std::string& file_path);
    -bool read_ckpt_file(const std::string& file_path,
    -                    std::vector<TensorStorage>& tensor_storages,
    -                    std::string* error = nullptr);
    -
    -#endif  // __SD_MODEL_IO_CKPT_IO_H__
    
  • src/model_io/pickle_io.cpp+1064 0 added
    @@ -0,0 +1,1064 @@
    +#include "pickle_io.h"
    +
    +#include <cstdlib>
    +#include <cstring>
    +#include <string>
    +#include <unordered_map>
    +#include <utility>
    +#include <vector>
    +
    +#include "binary_io.h"
    +#include "util.h"
    +
    +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
    +//     0: \x80 PROTO      2
    +//     2: }    EMPTY_DICT
    +//     3: q    BINPUT     0
    +//     5: (    MARK
    +//     6: X        BINUNICODE 'epoch'
    +//    16: q        BINPUT     1
    +//    18: K        BININT1    6
    +//    20: X        BINUNICODE 'global_step'
    +//    36: q        BINPUT     2
    +//    38: J        BININT     470000
    +//    43: X        BINUNICODE 'pytorch-lightning_version'
    +//    73: q        BINPUT     3
    +//    75: X        BINUNICODE '1.4.2'
    +//    85: q        BINPUT     4
    +//    87: X        BINUNICODE 'state_dict'
    +//   102: q        BINPUT     5
    +//   104: }        EMPTY_DICT
    +//   105: q        BINPUT     6
    +//   107: (        MARK
    +//   108: X            BINUNICODE 'betas'
    +//   118: q            BINPUT     7
    +//   120: c            GLOBAL     'torch._utils _rebuild_tensor_v2'
    +//   153: q            BINPUT     8
    +//   155: (            MARK
    +//   156: (                MARK
    +//   157: X                    BINUNICODE 'storage'
    +//   169: q                    BINPUT     9
    +//   171: c                    GLOBAL     'torch FloatStorage'
    +//   191: q                    BINPUT     10
    +//   193: X                    BINUNICODE '0'
    +//   199: q                    BINPUT     11
    +//   201: X                    BINUNICODE 'cpu'
    +//   209: q                    BINPUT     12
    +//   211: M                    BININT2    1000
    +//   214: t                    TUPLE      (MARK at 156)
    +//   215: q                BINPUT     13
    +//   217: Q                BINPERSID
    +//   218: K                BININT1    0
    +//   220: M                BININT2    1000
    +//  ...............................
    +//  3201: q            BINPUT     250
    +//  3203: R            REDUCE
    +//  3204: q            BINPUT     251
    +//  3206: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
    +//  3264: q            BINPUT     252
    +//  3266: h            BINGET     8
    +//  3268: (            MARK
    +//  3269: (                MARK
    +//  3270: h                    BINGET     9
    +//  3272: h                    BINGET     10
    +//  3274: X                    BINUNICODE '30'
    +//  3281: q                    BINPUT     253
    +//  3283: h                    BINGET     12
    +//  3285: J                    BININT     102400
    +//  3290: t                    TUPLE      (MARK at 3269)
    +//  3291: q                BINPUT     254
    +//  3293: Q                BINPERSID
    +//  3294: K                BININT1    0
    +//  3296: (                MARK
    +//  3297: M                    BININT2    320
    +//  3300: M                    BININT2    320
    +//  3303: K                    BININT1    1
    +//  3305: K                    BININT1    1
    +//  3307: t                    TUPLE      (MARK at 3296)
    +//  3308: q                BINPUT     255
    +//  3310: (                MARK
    +//  3311: M                    BININT2    320
    +//  3314: K                    BININT1    1
    +//  3316: K                    BININT1    1
    +//  3318: K                    BININT1    1
    +//  3320: t                    TUPLE      (MARK at 3310)
    +//  3321: r                LONG_BINPUT 256
    +//  3326: \x89             NEWFALSE
    +//  3327: h                BINGET     16
    +//  3329: )                EMPTY_TUPLE
    +//  3330: R                REDUCE
    +//  3331: r                LONG_BINPUT 257
    +//  3336: t                TUPLE      (MARK at 3268)
    +//  3337: r            LONG_BINPUT 258
    +//  3342: R            REDUCE
    +//  3343: r            LONG_BINPUT 259
    +//  3348: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
    +//  3404: r            LONG_BINPUT 260
    +//  3409: h            BINGET     8
    +//  3411: (            MARK
    +//  3412: (                MARK
    +//  3413: h                    BINGET     9
    +//  3415: h                    BINGET     10
    +//  3417: X                    BINUNICODE '31'
    +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
    +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
    +
    +using model_io::find_char;
    +using model_io::read_int;
    +using model_io::read_short;
    +using model_io::read_u64;
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) {
    +    const uint8_t* p   = buffer;
    +    const uint8_t* end = buffer + buffer_size;
    +
    +    while (p < end) {
    +        uint8_t opcode = *p++;
    +        switch (opcode) {
    +            case '.':  // STOP             = b'.'   # every pickle ends with STOP
    +                *object_size = (size_t)(p - buffer);
    +                return true;
    +            case 0x80:  // PROTO            = b'\x80'  # protocol version indicator
    +            case 'K':   // BININT1          = b'K'   # push 1-byte unsigned int
    +            case 'h':   // BINGET           = b'h'   # read memo index, 1-byte arg
    +            case 'q':   // BINPUT           = b'q'   # write memo index, 1-byte arg
    +            case 'C':   // SHORT_BINBYTES   = b'C'   # push bytes; length < 256
    +            case 0x82:  // EXT1             = b'\x82'  # extension code, 1-byte arg
    +                p += 1;
    +                break;
    +            case 'M':   // BININT2          = b'M'   # push 2-byte unsigned int
    +            case 0x83:  // EXT2             = b'\x83'  # extension code, 2-byte arg
    +                p += 2;
    +                break;
    +            case 'J':   // BININT           = b'J'   # push 4-byte signed int
    +            case 'j':   // LONG_BINGET      = b'j'   # read memo index, 4-byte arg
    +            case 'r':   // LONG_BINPUT      = b'r'   # write memo index, 4-byte arg
    +            case 0x84:  // EXT4             = b'\x84'  # extension code, 4-byte arg
    +                p += 4;
    +                break;
    +            case 'I':    // INT              = b'I'   # push decimal integer line
    +            case 'L':    // LONG             = b'L'   # push decimal long integer line
    +            case 'F':    // FLOAT            = b'F'   # push decimal float line
    +            case 'S':    // STRING           = b'S'   # push quoted string line
    +            case 'V': {  // UNICODE          = b'V'   # push raw-unicode string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case 'G':  // BINFLOAT         = b'G'   # push 8-byte binary float
    +                p += 8;
    +                break;
    +            case 0x8A:  // LONG1            = b'\x8a'  # push long integer; 1-byte length
    +                if (p >= end) {
    +                    return false;
    +                }
    +                p += 1 + p[0];
    +                break;
    +            case 0x8B: {  // LONG4            = b'\x8b'  # push long integer; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 'B': {  // BINBYTES         = b'B'   # push bytes; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 'T':    // BINSTRING        = b'T'   # push string; 4-byte length
    +            case 'X': {  // BINUNICODE       = b'X'   # push UTF-8 string; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 0x8D:    // BINUNICODE8      = b'\x8d'  # push UTF-8 string; 8-byte length
    +            case 0x8E:    // BINBYTES8        = b'\x8e'  # push bytes; 8-byte length
    +            case 0x96: {  // BYTEARRAY8       = b'\x96'  # push bytearray; 8-byte length
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                uint64_t n = read_u64(p);
    +                p += 8;
    +                if (n > (uint64_t)(end - p)) {
    +                    return false;
    +                }
    +                p += n;
    +            } break;
    +            case 'U':   // SHORT_BINSTRING  = b'U'   # push string; length < 256
    +            case 0x8C:  // SHORT_BINUNICODE = b'\x8c'  # push UTF-8 string; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                p += 1 + p[0];
    +                break;
    +            case 'P': {  // PERSID           = b'P'   # persistent id, newline-terminated
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    +                p += 8;
    +                break;
    +            case 'c': {  // GLOBAL           = b'c'   # push module/name global reference
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case '}':   // EMPTY_DICT       = b'}'   # push empty dict
    +            case ']':   // EMPTY_LIST       = b']'   # push empty list
    +            case '(':   // MARK             = b'('   # push markobject
    +            case 't':   // TUPLE            = b't'   # build tuple from mark
    +            case 0x85:  // TUPLE1           = b'\x85'  # build 1-tuple from stack
    +            case 0x86:  // TUPLE2           = b'\x86'  # build 2-tuple from stack
    +            case 0x87:  // TUPLE3           = b'\x87'  # build 3-tuple from stack
    +            case ')':   // EMPTY_TUPLE      = b')'   # push empty tuple
    +            case 'l':   // LIST             = b'l'   # build list from mark
    +            case 'Q':   // BINPERSID        = b'Q'   # persistent id from stack
    +            case 0x94:  // MEMOIZE          = b'\x94'  # store top of stack in memo
    +            case 0x88:  // NEWTRUE          = b'\x88'  # push True
    +            case 0x89:  // NEWFALSE         = b'\x89'  # push False
    +            case 'R':   // REDUCE           = b'R'   # apply callable to args
    +            case 'u':   // SETITEMS         = b'u'   # add mark-delimited items to dict
    +            case 's':   // SETITEM          = b's'   # add key/value to dict
    +            case 'e':   // APPENDS          = b'e'   # extend list with mark-delimited items
    +            case 'a':   // APPEND           = b'a'   # append item to list
    +            case 'b':   // BUILD            = b'b'   # build object state
    +            case 0x81:  // NEWOBJ           = b'\x81'  # build object via __new__
    +            case 0x8F:  // EMPTY_SET        = b'\x8f'  # push empty set
    +            case 0x90:  // ADDITEMS         = b'\x90'  # add mark-delimited items to set
    +            case 0x91:  // FROZENSET        = b'\x91'  # build frozenset from mark
    +            case 0x92:  // NEWOBJ_EX        = b'\x92'  # build object with kwargs
    +            case 0x93:  // STACK_GLOBAL     = b'\x93'  # build global from module/name strings
    +            case 0x97:  // NEXT_BUFFER      = b'\x97'  # out-of-band buffer marker
    +            case 0x98:  // READONLY_BUFFER  = b'\x98'  # mark buffer readonly
    +            case 'N':   // NONE             = b'N'   # push None
    +            case '0':   // POP              = b'0'   # discard top stack item
    +            case '1':   // POP_MARK         = b'1'   # discard stack through topmost mark
    +            case '2':   // DUP              = b'2'   # duplicate top stack item
    +            case 'o':   // OBJ              = b'o'   # build class instance from mark
    +                break;
    +            case 'i': {  // INST             = b'i'   # build class instance from module/name
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            default:
    +                return false;
    +        }
    +        if (p > end) {
    +            return false;
    +        }
    +    }
    +
    +    return false;
    +}
    +
    +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) {
    +    static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19};
    +
    +    if (buffer_size < 5 || buffer[0] != 0x80) {
    +        return false;
    +    }
    +
    +    size_t pos = 2;
    +    if (pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t opcode = buffer[pos++];
    +    if (opcode != 0x8A || pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t len = buffer[pos++];
    +    if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) {
    +        return false;
    +    }
    +
    +    if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) {
    +        return false;
    +    }
    +    pos += len;
    +
    +    return pos < buffer_size && buffer[pos] == '.';
    +}
    +
    +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) {
    +    if (buffer_size < 4 || buffer[0] != 0x80) {
    +        return false;
    +    }
    +
    +    size_t pos = 2;
    +    if (pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t opcode = buffer[pos++];
    +    switch (opcode) {
    +        case 'K':  // BININT1          = b'K'   # push 1-byte unsigned int
    +            if (pos + 1 >= buffer_size) {
    +                return false;
    +            }
    +            *value = buffer[pos];
    +            pos += 1;
    +            break;
    +        case 'M':  // BININT2          = b'M'   # push 2-byte unsigned int
    +            if (pos + 2 >= buffer_size) {
    +                return false;
    +            }
    +            *value = read_short(buffer + pos);
    +            pos += 2;
    +            break;
    +        case 'J':  // BININT           = b'J'   # push 4-byte signed int
    +            if (pos + 4 >= buffer_size) {
    +                return false;
    +            }
    +            *value = (uint32_t)read_int(buffer + pos);
    +            pos += 4;
    +            break;
    +        default:
    +            return false;
    +    }
    +
    +    return pos < buffer_size && buffer[pos] == '.';
    +}
    +
    +struct PickleStorageInfo {
    +    std::string key;
    +    ggml_type type              = GGML_TYPE_COUNT;
    +    bool is_f64                 = false;
    +    bool is_i64                 = false;
    +    uint64_t raw_element_nbytes = 0;
    +    uint64_t nbytes             = 0;
    +};
    +
    +struct PickleTensorInfo {
    +    TensorStorage tensor_storage;
    +    int stride_n_dims = 0;
    +    int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1};
    +};
    +
    +struct PickleValue {
    +    enum Kind {
    +        MARK,
    +        NONE,
    +        BOOL,
    +        INT,
    +        STRING,
    +        GLOBAL,
    +        TUPLE,
    +        LIST,
    +        DICT,
    +        ORDERED_DICT,
    +        STORAGE,
    +        TENSOR,
    +    };
    +
    +    Kind kind         = NONE;
    +    int64_t int_value = 0;
    +    bool bool_value   = false;
    +    std::string str_value;
    +    std::vector<PickleValue> items;
    +    std::vector<std::pair<PickleValue, PickleValue>> dict_items;
    +    PickleStorageInfo storage;
    +    PickleTensorInfo tensor;
    +};
    +
    +static PickleValue make_mark_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::MARK;
    +    return value;
    +}
    +
    +static PickleValue make_none_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::NONE;
    +    return value;
    +}
    +
    +static PickleValue make_bool_value(bool b) {
    +    PickleValue value;
    +    value.kind       = PickleValue::BOOL;
    +    value.bool_value = b;
    +    return value;
    +}
    +
    +static PickleValue make_int_value(int64_t x) {
    +    PickleValue value;
    +    value.kind      = PickleValue::INT;
    +    value.int_value = x;
    +    return value;
    +}
    +
    +static PickleValue make_string_value(const std::string& s) {
    +    PickleValue value;
    +    value.kind      = PickleValue::STRING;
    +    value.str_value = s;
    +    return value;
    +}
    +
    +static PickleValue make_global_value(const std::string& s) {
    +    PickleValue value;
    +    value.kind      = PickleValue::GLOBAL;
    +    value.str_value = s;
    +    return value;
    +}
    +
    +static PickleValue make_tuple_value(std::vector<PickleValue> items) {
    +    PickleValue value;
    +    value.kind  = PickleValue::TUPLE;
    +    value.items = std::move(items);
    +    return value;
    +}
    +
    +static PickleValue make_list_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::LIST;
    +    return value;
    +}
    +
    +static PickleValue make_dict_value(bool ordered) {
    +    PickleValue value;
    +    value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT;
    +    return value;
    +}
    +
    +static PickleValue make_storage_value(const PickleStorageInfo& storage) {
    +    PickleValue value;
    +    value.kind    = PickleValue::STORAGE;
    +    value.storage = storage;
    +    return value;
    +}
    +
    +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) {
    +    PickleValue value;
    +    value.kind   = PickleValue::TENSOR;
    +    value.tensor = tensor;
    +    return value;
    +}
    +
    +static std::string pickle_value_to_string(const PickleValue& value) {
    +    if (value.kind == PickleValue::STRING) {
    +        return value.str_value;
    +    }
    +    if (value.kind == PickleValue::INT) {
    +        return std::to_string(value.int_value);
    +    }
    +    return "";
    +}
    +
    +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) {
    +    if (global_name == "torch.FloatStorage") {
    +        storage->type               = GGML_TYPE_F32;
    +        storage->raw_element_nbytes = 4;
    +        return true;
    +    }
    +    if (global_name == "torch.DoubleStorage") {
    +        storage->type               = GGML_TYPE_F32;
    +        storage->is_f64             = true;
    +        storage->raw_element_nbytes = 8;
    +        return true;
    +    }
    +    if (global_name == "torch.HalfStorage") {
    +        storage->type               = GGML_TYPE_F16;
    +        storage->raw_element_nbytes = 2;
    +        return true;
    +    }
    +    if (global_name == "torch.BFloat16Storage") {
    +        storage->type               = GGML_TYPE_BF16;
    +        storage->raw_element_nbytes = 2;
    +        return true;
    +    }
    +    if (global_name == "torch.IntStorage") {
    +        storage->type               = GGML_TYPE_I32;
    +        storage->raw_element_nbytes = 4;
    +        return true;
    +    }
    +    if (global_name == "torch.LongStorage") {
    +        storage->type               = GGML_TYPE_I32;
    +        storage->is_i64             = true;
    +        storage->raw_element_nbytes = 8;
    +        return true;
    +    }
    +    return false;
    +}
    +
    +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) {
    +    if (tensor.tensor_storage.nelements() == 0) {
    +        return true;
    +    }
    +    if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) {
    +        return false;
    +    }
    +
    +    int64_t expected_stride = 1;
    +    for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) {
    +        if (tensor.stride[i] != expected_stride) {
    +            return false;
    +        }
    +        expected_stride *= tensor.tensor_storage.ne[i];
    +    }
    +    return true;
    +}
    +
    +static void collect_tensors_from_pickle_value(const PickleValue& value,
    +                                              std::vector<TensorStorage>& tensor_storages) {
    +    if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) {
    +        return;
    +    }
    +
    +    for (const auto& item : value.dict_items) {
    +        if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) {
    +            TensorStorage tensor_storage = item.second.tensor.tensor_storage;
    +            tensor_storage.name          = item.first.str_value;
    +            tensor_storage.reverse_ne();
    +            tensor_storages.push_back(tensor_storage);
    +        } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) {
    +            collect_tensors_from_pickle_value(item.second, tensor_storages);
    +        }
    +    }
    +}
    +
    +bool parse_torch_state_dict_pickle(const uint8_t* buffer,
    +                                   size_t buffer_size,
    +                                   std::vector<TensorStorage>& tensor_storages,
    +                                   std::unordered_map<std::string, uint64_t>& storage_nbytes,
    +                                   std::string* error) {
    +    if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) {
    +        set_error(error, "unsupported torch pickle protocol");
    +        return false;
    +    }
    +
    +    const uint8_t* p   = buffer + 2;
    +    const uint8_t* end = buffer + buffer_size;
    +    std::vector<PickleValue> stack;
    +    std::unordered_map<int32_t, PickleValue> memo;
    +
    +    while (p < end) {
    +        uint8_t opcode = *p++;
    +        switch (opcode) {
    +            case '.': {  // STOP             = b'.'   # every pickle ends with STOP
    +                if (stack.empty()) {
    +                    set_error(error, "empty torch pickle stack");
    +                    return false;
    +                }
    +                size_t old_tensor_count = tensor_storages.size();
    +                collect_tensors_from_pickle_value(stack.back(), tensor_storages);
    +                if (tensor_storages.size() == old_tensor_count) {
    +                    set_error(error, "torch pickle does not contain a supported state_dict");
    +                    return false;
    +                }
    +                return true;
    +            }
    +            case '}':  // EMPTY_DICT       = b'}'   # push empty dict
    +                stack.push_back(make_dict_value(false));
    +                break;
    +            case ']':  // EMPTY_LIST       = b']'   # push empty list
    +                stack.push_back(make_list_value());
    +                break;
    +            case 'l': {  // LIST             = b'l'   # build list from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    set_error(error, "torch pickle list without mark");
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                PickleValue list_value = make_list_value();
    +                list_value.items       = std::move(items);
    +                stack.push_back(std::move(list_value));
    +            } break;
    +            case '(':  // MARK             = b'('   # push markobject
    +                stack.push_back(make_mark_value());
    +                break;
    +            case ')':  // EMPTY_TUPLE      = b')'   # push empty tuple
    +                stack.push_back(make_tuple_value({}));
    +                break;
    +            case 'N':  // NONE             = b'N'   # push None
    +                stack.push_back(make_none_value());
    +                break;
    +            case 0x88:  // NEWTRUE          = b'\x88'  # push True
    +                stack.push_back(make_bool_value(true));
    +                break;
    +            case 0x89:  // NEWFALSE         = b'\x89'  # push False
    +                stack.push_back(make_bool_value(false));
    +                break;
    +            case 'K':  // BININT1          = b'K'   # push 1-byte unsigned int
    +                if (p >= end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(*p++));
    +                break;
    +            case 'M':  // BININT2          = b'M'   # push 2-byte unsigned int
    +                if (p + 2 > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(read_short(p)));
    +                p += 2;
    +                break;
    +            case 'J':  // BININT           = b'J'   # push 4-byte signed int
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(read_int(p)));
    +                p += 4;
    +                break;
    +            case 'I': {  // INT              = b'I'   # push decimal integer line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (s == "01") {
    +                    stack.push_back(make_bool_value(true));
    +                } else if (s == "00") {
    +                    stack.push_back(make_bool_value(false));
    +                } else {
    +                    stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
    +                }
    +            } break;
    +            case 'L': {  // LONG             = b'L'   # push decimal long integer line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (!s.empty() && s.back() == 'L') {
    +                    s.pop_back();
    +                }
    +                stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
    +            } break;
    +            case 'F': {  // FLOAT            = b'F'   # push decimal float line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                stack.push_back(make_none_value());
    +            } break;
    +            case 'G':  // BINFLOAT         = b'G'   # push 8-byte binary float
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                p += 8;
    +                stack.push_back(make_none_value());
    +                break;
    +            case 0x8A: {  // LONG1            = b'\x8a'  # push long integer; 1-byte length
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t n = *p++;
    +                if (p + n > end || n > 8) {
    +                    return false;
    +                }
    +                int64_t value = 0;
    +                for (uint8_t i = 0; i < n; ++i) {
    +                    value |= (int64_t)p[i] << (i * 8);
    +                }
    +                p += n;
    +                stack.push_back(make_int_value(value));
    +            } break;
    +            case 'C': {  // SHORT_BINBYTES   = b'C'   # push bytes; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t len = *p++;
    +                if (p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'B': {  // BINBYTES         = b'B'   # push bytes; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t len = read_int(p);
    +                p += 4;
    +                if (len < 0 || p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'T':    // BINSTRING        = b'T'   # push string; 4-byte length
    +            case 'X': {  // BINUNICODE       = b'X'   # push UTF-8 string; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t len = read_int(p);
    +                p += 4;
    +                if (len < 0 || p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 0x8D:    // BINUNICODE8      = b'\x8d'  # push UTF-8 string; 8-byte length
    +            case 0x8E:    // BINBYTES8        = b'\x8e'  # push bytes; 8-byte length
    +            case 0x96: {  // BYTEARRAY8       = b'\x96'  # push bytearray; 8-byte length
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                uint64_t len = read_u64(p);
    +                p += 8;
    +                if (len > (uint64_t)(end - p)) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, (size_t)len)));
    +                p += len;
    +            } break;
    +            case 'U':     // SHORT_BINSTRING  = b'U'   # push string; length < 256
    +            case 0x8C: {  // SHORT_BINUNICODE = b'\x8c'  # push UTF-8 string; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t len = *p++;
    +                if (p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'S': {  // STRING           = b'S'   # push quoted string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) {
    +                    s = s.substr(1, s.size() - 2);
    +                }
    +                stack.push_back(make_string_value(s));
    +            } break;
    +            case 'V': {  // UNICODE          = b'V'   # push raw-unicode string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len + 1;
    +            } break;
    +            case 'c': {  // GLOBAL           = b'c'   # push module/name global reference
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string module((const char*)p, len);
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string name((const char*)p, len);
    +                p += len + 1;
    +                stack.push_back(make_global_value(module + "." + name));
    +            } break;
    +            case 0x93: {  // STACK_GLOBAL     = b'\x93'  # build global from module/name strings
    +                if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING ||
    +                    stack.back().kind != PickleValue::STRING) {
    +                    return false;
    +                }
    +                std::string name = stack.back().str_value;
    +                stack.pop_back();
    +                std::string module = stack.back().str_value;
    +                stack.pop_back();
    +                stack.push_back(make_global_value(module + "." + name));
    +            } break;
    +            case 'h':  // BINGET           = b'h'   # read memo index, 1-byte arg
    +                if (p >= end || !memo.count(*p)) {
    +                    return false;
    +                }
    +                stack.push_back(memo[*p++]);
    +                break;
    +            case 'j': {  // LONG_BINGET      = b'j'   # read memo index, 4-byte arg
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t memo_idx = read_int(p);
    +                if (!memo.count(memo_idx)) {
    +                    return false;
    +                }
    +                stack.push_back(memo[memo_idx]);
    +                p += 4;
    +            } break;
    +            case 'q':  // BINPUT           = b'q'   # write memo index, 1-byte arg
    +                if (p >= end || stack.empty()) {
    +                    return false;
    +                }
    +                memo[*p++] = stack.back();
    +                break;
    +            case 'r':  // LONG_BINPUT      = b'r'   # write memo index, 4-byte arg
    +                if (p + 4 > end || stack.empty()) {
    +                    return false;
    +                }
    +                memo[read_int(p)] = stack.back();
    +                p += 4;
    +                break;
    +            case 0x94:  // MEMOIZE          = b'\x94'  # store top of stack in memo
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                memo[(int32_t)memo.size()] = stack.back();
    +                break;
    +            case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                p += 8;
    +                break;
    +            case '0':  // POP              = b'0'   # discard top stack item
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                stack.pop_back();
    +                break;
    +            case '1': {  // POP_MARK         = b'1'   # discard stack through topmost mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case '2':  // DUP              = b'2'   # duplicate top stack item
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                stack.push_back(stack.back());
    +                break;
    +            case 0x8F:  // EMPTY_SET        = b'\x8f'  # push empty set
    +                stack.push_back(make_list_value());
    +                break;
    +            case 0x90: {  // ADDITEMS         = b'\x90'  # add mark-delimited items to set
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                PickleValue& set_value = stack[mark_idx - 1];
    +                set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 0x91: {  // FROZENSET        = b'\x91'  # build frozenset from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                PickleValue set_value = make_list_value();
    +                set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                stack.push_back(std::move(set_value));
    +            } break;
    +            case 0x85:    // TUPLE1           = b'\x85'  # build 1-tuple from stack
    +            case 0x86:    // TUPLE2           = b'\x86'  # build 2-tuple from stack
    +            case 0x87: {  // TUPLE3           = b'\x87'  # build 3-tuple from stack
    +                int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3);
    +                if ((int)stack.size() < tuple_size) {
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.end() - tuple_size, stack.end());
    +                stack.erase(stack.end() - tuple_size, stack.end());
    +                stack.push_back(make_tuple_value(std::move(items)));
    +            } break;
    +            case 't': {  // TUPLE            = b't'   # build tuple from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                stack.push_back(make_tuple_value(std::move(items)));
    +            } break;
    +            case 'Q': {  // BINPERSID        = b'Q'   # persistent id from stack
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                PickleValue pid = stack.back();
    +                stack.pop_back();
    +                if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING ||
    +                    pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT ||
    +                    pid.items[0].str_value != "storage") {
    +                    return false;
    +                }
    +
    +                PickleStorageInfo storage;
    +                storage.key = pickle_value_to_string(pid.items[2]);
    +                if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) {
    +                    return false;
    +                }
    +                storage.nbytes              = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes;
    +                storage_nbytes[storage.key] = storage.nbytes;
    +                stack.push_back(make_storage_value(storage));
    +            } break;
    +            case 'R': {  // REDUCE           = b'R'   # apply callable to args
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                PickleValue args = stack.back();
    +                stack.pop_back();
    +                PickleValue callable = stack.back();
    +                stack.pop_back();
    +                if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) {
    +                    stack.push_back(make_none_value());
    +                    break;
    +                }
    +
    +                if (callable.str_value == "collections.OrderedDict" && args.items.empty()) {
    +                    stack.push_back(make_dict_value(true));
    +                    break;
    +                }
    +
    +                if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") &&
    +                    args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE &&
    +                    args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE &&
    +                    args.items[3].kind == PickleValue::TUPLE) {
    +                    PickleTensorInfo tensor;
    +                    tensor.tensor_storage.type        = args.items[0].storage.type;
    +                    tensor.tensor_storage.is_f64      = args.items[0].storage.is_f64;
    +                    tensor.tensor_storage.is_i64      = args.items[0].storage.is_i64;
    +                    tensor.tensor_storage.storage_key = args.items[0].storage.key;
    +                    tensor.tensor_storage.offset      = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes;
    +
    +                    for (const auto& item : args.items[2].items) {
    +                        if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) {
    +                            return false;
    +                        }
    +                        tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value;
    +                    }
    +
    +                    for (const auto& item : args.items[3].items) {
    +                        if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) {
    +                            return false;
    +                        }
    +                        tensor.stride[tensor.stride_n_dims++] = item.int_value;
    +                    }
    +
    +                    if (!tensor_is_contiguous(tensor)) {
    +                        return false;
    +                    }
    +                    stack.push_back(make_tensor_value(tensor));
    +                    break;
    +                }
    +
    +                // Non-tensor checkpoint metadata can use REDUCE for arbitrary
    +                // Python objects. Do not execute it; keep stack shape only.
    +                stack.push_back(make_none_value());
    +                break;
    +            }
    +            case 'b':  // BUILD            = b'b'   # build object state
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                stack.pop_back();
    +                break;
    +            case 'u': {  // SETITEMS         = b'u'   # add mark-delimited items to dict
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0) {
    +                    return false;
    +                }
    +                PickleValue& dict = stack[mark_idx - 1];
    +                if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
    +                    return false;
    +                }
    +                for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) {
    +                    dict.dict_items.emplace_back(stack[i], stack[i + 1]);
    +                }
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 's': {  // SETITEM          = b's'   # add key/value to dict
    +                if (stack.size() < 3) {
    +                    return false;
    +                }
    +                PickleValue value = stack.back();
    +                stack.pop_back();
    +                PickleValue key = stack.back();
    +                stack.pop_back();
    +                PickleValue& dict = stack.back();
    +                if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
    +                    return false;
    +                }
    +                dict.dict_items.emplace_back(key, value);
    +            } break;
    +            case 'e': {  // APPENDS          = b'e'   # extend list with mark-delimited items
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                PickleValue& list_value = stack[mark_idx - 1];
    +                list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 'a': {  // APPEND           = b'a'   # append item to list
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                PickleValue item = stack.back();
    +                stack.pop_back();
    +                if (stack.back().kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                stack.back().items.push_back(item);
    +            } break;
    +            default:
    +                set_error(error,
    +                          "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) +
    +                              " at offset " + std::to_string((p - buffer) - 1));
    +                return false;
    +        }
    +    }
    +
    +    set_error(error, "unterminated torch state_dict pickle");
    +    return false;
    +}
    
  • src/model_io/pickle_io.h+21 0 added
    @@ -0,0 +1,21 @@
    +#ifndef __SD_MODEL_IO_PICKLE_IO_H__
    +#define __SD_MODEL_IO_PICKLE_IO_H__
    +
    +#include <cstddef>
    +#include <cstdint>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size);
    +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size);
    +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value);
    +bool parse_torch_state_dict_pickle(const uint8_t* buffer,
    +                                   size_t buffer_size,
    +                                   std::vector<TensorStorage>& tensor_storages,
    +                                   std::unordered_map<std::string, uint64_t>& storage_nbytes,
    +                                   std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_PICKLE_IO_H__
    
  • src/model_io/safetensors_io.cpp+3 16 modified
    @@ -6,6 +6,7 @@
     #include <string>
     #include <vector>
     
    +#include "binary_io.h"
     #include "json.hpp"
     
     static constexpr size_t ST_HEADER_SIZE_LEN = 8;
    @@ -16,20 +17,6 @@ static void set_error(std::string* error, const std::string& message) {
         }
     }
     
    -static uint64_t read_u64(const uint8_t* buffer) {
    -    // little endian
    -    uint64_t value = 0;
    -    value |= static_cast<uint64_t>(buffer[7]) << 56;
    -    value |= static_cast<uint64_t>(buffer[6]) << 48;
    -    value |= static_cast<uint64_t>(buffer[5]) << 40;
    -    value |= static_cast<uint64_t>(buffer[4]) << 32;
    -    value |= static_cast<uint64_t>(buffer[3]) << 24;
    -    value |= static_cast<uint64_t>(buffer[2]) << 16;
    -    value |= static_cast<uint64_t>(buffer[1]) << 8;
    -    value |= static_cast<uint64_t>(buffer[0]);
    -    return value;
    -}
    -
     bool is_safetensors_file(const std::string& file_path) {
         std::ifstream file(file_path, std::ios::binary);
         if (!file.is_open()) {
    @@ -52,7 +39,7 @@ bool is_safetensors_file(const std::string& file_path) {
             return false;
         }
     
    -    size_t header_size_ = read_u64(header_size_buf);
    +    size_t header_size_ = model_io::read_u64(header_size_buf);
         if (header_size_ >= file_size_ || header_size_ <= 2) {
             return false;
         }
    @@ -123,7 +110,7 @@ bool read_safetensors_file(const std::string& file_path,
             return false;
         }
     
    -    size_t header_size_ = read_u64(header_size_buf);
    +    size_t header_size_ = model_io::read_u64(header_size_buf);
         if (header_size_ >= file_size_) {
             set_error(error, "invalid safetensor file '" + file_path + "'");
             return false;
    
  • src/model_io/tensor_storage.h+1 0 modified
    @@ -24,6 +24,7 @@ struct TensorStorage {
         int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
         int n_dims              = 0;
     
    +    std::string storage_key;
         size_t file_index = 0;
         int index_in_zip  = -1;  // >= means stored in a zip file
         uint64_t offset   = 0;   // offset in file
    
  • src/model_io/torch_legacy_io.cpp+252 0 added
    @@ -0,0 +1,252 @@
    +#include "torch_legacy_io.h"
    +
    +#include <algorithm>
    +#include <cstdint>
    +#include <fstream>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "pickle_io.h"
    +#include "util.h"
    +
    +// torch.save format background:
    +//
    +//   - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by
    +//     default.
    +//   - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive
    +//     containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder.
    +//   - The old format can still be produced explicitly with:
    +//       torch.save(obj, path, _use_new_zipfile_serialization=False)
    +//
    +// Whether obj is a state_dict or a whole nn.Module does not change the outer
    +// container format selected by torch.save. It changes the pickled object inside:
    +//
    +//   - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a
    +//     restricted subset of this layout because tensor metadata and raw storages
    +//     can be recovered without executing pickle callables.
    +//   - whole module/checkpoint object: arbitrary Python object graph. This may
    +//     require importing user classes and executing pickle GLOBAL/REDUCE rebuild
    +//     logic, so it is intentionally not supported here.
    +//
    +// Legacy non-zip PyTorch files are not a single pickle object:
    +//
    +//   1. pickle object: PyTorch legacy magic number
    +//   2. pickle object: legacy protocol version, expected to be 1001
    +//   3. pickle object: sys_info metadata, ignored by this reader
    +//   4. pickle object: state_dict metadata, parsed by pickle_io.cpp
    +//   5. pickle object: serialized storage key list, skipped here
    +//   6. raw storage data payloads
    +//      - PyTorch writes storages after the pickles, ordered by storage key
    +//      - each storage has an 8-byte legacy storage header followed by raw bytes
    +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8;
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +static std::string bytes_to_hex(const std::vector<uint8_t>& bytes) {
    +    static const char* hex = "0123456789ABCDEF";
    +    std::string result;
    +    result.reserve(bytes.size() * 3);
    +    for (size_t i = 0; i < bytes.size(); ++i) {
    +        if (i > 0) {
    +            result.push_back('-');
    +        }
    +        result.push_back(hex[(bytes[i] >> 4) & 0x0F]);
    +        result.push_back(hex[bytes[i] & 0x0F]);
    +    }
    +    return result;
    +}
    +
    +static bool is_probably_tar_file(const std::vector<uint8_t>& header) {
    +    return header.size() >= 262 &&
    +           header[257] == 'u' &&
    +           header[258] == 's' &&
    +           header[259] == 't' &&
    +           header[260] == 'a' &&
    +           header[261] == 'r';
    +}
    +
    +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector<uint8_t>& buffer) {
    +    if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) {
    +        return "";
    +    }
    +    if (buffer.empty()) {
    +        return "unsupported PyTorch file '" + file_path + "': empty file";
    +    }
    +
    +    size_t short_len = std::min<size_t>(buffer.size(), 32);
    +    std::vector<uint8_t> short_header(buffer.begin(), buffer.begin() + short_len);
    +    const bool raw_pickle = buffer[0] == 0x80;
    +    const bool tar_file   = is_probably_tar_file(buffer);
    +
    +    std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " +
    +                          bytes_to_hex(short_header) +
    +                          ", raw_pickle=" + (raw_pickle ? "true" : "false") +
    +                          ", tar=" + (tar_file ? "true" : "false");
    +    if (raw_pickle) {
    +        message += "; raw pickle did not match the restricted state_dict layouts currently supported";
    +    } else if (tar_file) {
    +        message += "; legacy tar PyTorch checkpoints are not supported yet";
    +    }
    +    return message;
    +}
    +
    +bool read_torch_legacy_file(const std::string& file_path,
    +                            std::vector<TensorStorage>& tensor_storages,
    +                            std::string* error) {
    +    std::ifstream file(file_path, std::ios::binary);
    +    if (!file.is_open()) {
    +        set_error(error, "failed to open '" + file_path + "'");
    +        return false;
    +    }
    +
    +    file.seekg(0, file.end);
    +    size_t file_size = (size_t)file.tellg();
    +    file.seekg(0, file.beg);
    +    if (file_size == 0) {
    +        set_error(error, "empty file '" + file_path + "'");
    +        return false;
    +    }
    +
    +    std::vector<uint8_t> buffer(file_size);
    +    file.read((char*)buffer.data(), file_size);
    +    if (!file) {
    +        set_error(error, "failed to read '" + file_path + "'");
    +        return false;
    +    }
    +
    +    auto finalize_tensor_offsets = [&](size_t storage_data_offset,
    +                                       const std::unordered_map<std::string, uint64_t>& legacy_storage_map) -> bool {
    +        if (storage_data_offset > file_size) {
    +            return false;
    +        }
    +
    +        std::vector<std::string> storage_keys;
    +        storage_keys.reserve(legacy_storage_map.size());
    +        for (const auto& [storage_key, _] : legacy_storage_map) {
    +            storage_keys.push_back(storage_key);
    +        }
    +        std::sort(storage_keys.begin(), storage_keys.end());
    +
    +        std::unordered_map<std::string, uint64_t> storage_offsets;
    +        uint64_t current_offset = storage_data_offset;
    +        for (const auto& storage_key : storage_keys) {
    +            auto it = legacy_storage_map.find(storage_key);
    +            if (it == legacy_storage_map.end()) {
    +                return false;
    +            }
    +            if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) {
    +                return false;
    +            }
    +            storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE;
    +            current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second;
    +        }
    +
    +        for (auto& tensor_storage : tensor_storages) {
    +            if (tensor_storage.storage_key.empty()) {
    +                continue;
    +            }
    +
    +            auto it_offset = storage_offsets.find(tensor_storage.storage_key);
    +            auto it_size   = legacy_storage_map.find(tensor_storage.storage_key);
    +            if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) {
    +                return false;
    +            }
    +
    +            uint64_t base_offset    = it_offset->second;
    +            uint64_t storage_nbytes = it_size->second;
    +            uint64_t tensor_nbytes  = tensor_storage.nbytes_to_read();
    +            if (tensor_storage.offset + tensor_nbytes > storage_nbytes) {
    +                return false;
    +            }
    +
    +            tensor_storage.offset = base_offset + tensor_storage.offset;
    +            tensor_storage.storage_key.clear();
    +        }
    +
    +        return true;
    +    };
    +
    +    auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool {
    +        tensor_storages.clear();
    +        std::unordered_map<std::string, uint64_t> legacy_storage_map;
    +        if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset,
    +                                           state_dict_size,
    +                                           tensor_storages,
    +                                           legacy_storage_map,
    +                                           error)) {
    +            return false;
    +        }
    +
    +        size_t offset_after_state_dict = state_dict_offset + state_dict_size;
    +        size_t storage_keys_size       = 0;
    +        if (!skip_pickle_object(buffer.data() + offset_after_state_dict,
    +                                buffer.size() - offset_after_state_dict,
    +                                &storage_keys_size)) {
    +            return false;
    +        }
    +
    +        *storage_data_offset = offset_after_state_dict + storage_keys_size;
    +        return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map);
    +    };
    +
    +    size_t object_size_1 = 0;
    +    size_t offset        = 0;
    +
    +    if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) &&
    +        pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) {
    +        offset += object_size_1;
    +
    +        size_t object_size_2 = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        uint32_t protocol_version = 0;
    +        if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        offset += object_size_2;
    +
    +        size_t object_size_3 = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        offset += object_size_3;
    +
    +        size_t state_dict_size = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +
    +        size_t storage_data_offset = 0;
    +        if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) {
    +            return true;
    +        }
    +
    +        if (error != nullptr && error->empty()) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +        }
    +        return false;
    +    }
    +
    +    size_t state_dict_size = 0;
    +    if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) {
    +        size_t storage_data_offset = 0;
    +        if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) {
    +            return true;
    +        }
    +    }
    +
    +    if (error != nullptr && error->empty()) {
    +        set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +    }
    +    return false;
    +}
    
  • src/model_io/torch_legacy_io.h+13 0 added
    @@ -0,0 +1,13 @@
    +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    +
    +#include <string>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool read_torch_legacy_file(const std::string& file_path,
    +                            std::vector<TensorStorage>& tensor_storages,
    +                            std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    
  • src/model_io/torch_zip_io.cpp+140 0 added
    @@ -0,0 +1,140 @@
    +#include "torch_zip_io.h"
    +
    +#include <cstdint>
    +#include <cstdlib>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "pickle_io.h"
    +
    +#include "zip.h"
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +bool is_torch_zip_file(const std::string& file_path) {
    +    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    +    if (zip == nullptr) {
    +        return false;
    +    }
    +    zip_close(zip);
    +    return true;
    +}
    +
    +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) {
    +    size_t n = zip_entries_total(zip);
    +    for (size_t i = 0; i < n; ++i) {
    +        zip_entry_openbyindex(zip, i);
    +        std::string name = zip_entry_name(zip);
    +        if (name == entry_name) {
    +            *index = (int)i;
    +            *size  = zip_entry_size(zip);
    +            zip_entry_close(zip);
    +            return true;
    +        }
    +        zip_entry_close(zip);
    +    }
    +    return false;
    +}
    +
    +static bool parse_zip_data_pkl(const uint8_t* buffer,
    +                               size_t buffer_size,
    +                               zip_t* zip,
    +                               const std::string& dir,
    +                               std::vector<TensorStorage>& tensor_storages,
    +                               std::string* error) {
    +    std::vector<TensorStorage> parsed_tensors;
    +    std::unordered_map<std::string, uint64_t> storage_nbytes;
    +    if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) {
    +        if (error != nullptr && error->empty()) {
    +            *error = "failed to parse torch zip pickle metadata";
    +        }
    +        return false;
    +    }
    +
    +    for (auto& tensor_storage : parsed_tensors) {
    +        if (tensor_storage.storage_key.empty()) {
    +            set_error(error, "tensor '" + tensor_storage.name + "' has no storage key");
    +            return false;
    +        }
    +
    +        const std::string entry_name = dir + "data/" + tensor_storage.storage_key;
    +        int zip_index                = -1;
    +        uint64_t entry_size          = 0;
    +        if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) {
    +            set_error(error, "storage entry '" + entry_name + "' was not found");
    +            return false;
    +        }
    +
    +        auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key);
    +        if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) {
    +            set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata");
    +            return false;
    +        }
    +
    +        uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
    +        if (tensor_storage.offset + tensor_nbytes > entry_size) {
    +            set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'");
    +            return false;
    +        }
    +
    +        tensor_storage.index_in_zip = zip_index;
    +        tensor_storage.storage_key.clear();
    +        tensor_storages.push_back(tensor_storage);
    +    }
    +
    +    return true;
    +}
    +
    +bool read_torch_zip_file(const std::string& file_path,
    +                         std::vector<TensorStorage>& tensor_storages,
    +                         std::string* error) {
    +    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    +    if (zip == nullptr) {
    +        set_error(error, "failed to open '" + file_path + "'");
    +        return false;
    +    }
    +
    +    tensor_storages.clear();
    +    bool success        = true;
    +    bool found_data_pkl = false;
    +    int n               = (int)zip_entries_total(zip);
    +    for (int i = 0; i < n; ++i) {
    +        zip_entry_openbyindex(zip, i);
    +        std::string name = zip_entry_name(zip);
    +        size_t pos       = name.find("data.pkl");
    +        if (pos != std::string::npos) {
    +            found_data_pkl  = true;
    +            std::string dir = name.substr(0, pos);
    +            void* pkl_data  = nullptr;
    +            size_t pkl_size = 0;
    +            zip_entry_read(zip, &pkl_data, &pkl_size);
    +
    +            if (pkl_data == nullptr || pkl_size == 0) {
    +                set_error(error, "failed to read '" + name + "' from '" + file_path + "'");
    +                success = false;
    +            } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
    +                success = false;
    +            }
    +
    +            free(pkl_data);
    +        }
    +        zip_entry_close(zip);
    +
    +        if (!success) {
    +            break;
    +        }
    +    }
    +
    +    if (success && !found_data_pkl) {
    +        set_error(error, "data.pkl was not found in '" + file_path + "'");
    +        success = false;
    +    }
    +
    +    zip_close(zip);
    +    return success;
    +}
    
  • src/model_io/torch_zip_io.h+14 0 added
    @@ -0,0 +1,14 @@
    +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__
    +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__
    +
    +#include <string>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool is_torch_zip_file(const std::string& file_path);
    +bool read_torch_zip_file(const std::string& file_path,
    +                         std::vector<TensorStorage>& tensor_storages,
    +                         std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_TORCH_ZIP_IO_H__
    

Vulnerability mechanics

Root cause

"The pickle parser in src/model.cpp advances the buffer pointer by opcode argument sizes without checking that enough input remains, causing out-of-bounds reads on truncated .ckpt files."

Attack vector

An attacker crafts a truncated or malformed `.ckpt` file whose pickle opcode stream omits expected argument bytes. When stable-diffusion.cpp loads this file, the pickle parser in `src/model.cpp` advances the buffer pointer (e.g. `buffer += N`) without first verifying that enough input remains, causing reads past the end of the metadata buffer. The victim triggers the bug by loading the malicious `.ckpt` file — for example, by downloading a model from an untrusted model-sharing site and opening it in an affected application. This is a classic out-of-bounds read vulnerability [CWE-125], exploitable without any special privileges or authentication. [ref_id=2]

What the fix does

The patch replaces the ad‑hoc pickle parser in `src/model.cpp` with a completely rewritten, bounds‑checked `skip_pickle_object` function in the new file `src/model_io/pickle_io.cpp`. Every opcode handler now verifies that sufficient input remains before advancing the buffer — for example, `if (p + 4 > end) { return false; }` guards 4‑byte reads. The commit [patch_id=6192989] also refactors checkpoint loading into restricted, format‑specific paths (`torch_legacy_io`, `torch_zip_io`) so that legacy pickle parsing is isolated and can be rigorously validated. This eliminates the root cause by ensuring that no opcode argument is ever read from beyond the allocated buffer. [ref_id=1]

Preconditions

  • inputThe victim must load a .ckpt checkpoint file from an untrusted source (e.g. downloaded from a model sharing site).
  • inputThe .ckpt file must be truncated or crafted to remove or corrupt pickle opcode argument bytes.
  • authNo authentication or special privileges are required — the bug triggers during normal file parsing.

Generated on Jun 16, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

2

News mentions

0

No linked articles in our index yet.