VYPR
High severity7.8NVD Advisory· Published Jun 16, 2026· Updated Jun 16, 2026

CVE-2026-47749

CVE-2026-47749

Description

Heap buffer overflow in SHORT_BINUNICODE parsing in stable-diffusion.cpp before master-584-0a7ae07 allows code execution via crafted .ckpt file.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

Heap buffer overflow in SHORT_BINUNICODE parsing in stable-diffusion.cpp before master-584-0a7ae07 allows code execution via crafted .ckpt file.

Vulnerability

The vulnerability resides in the SHORT_BINUNICODE opcode handler of the pickle ".ckpt" parser in src/model.cpp of stable-diffusion.cpp versions prior to master-584-0a7ae07. The flaw is a sign confusion bug: the length field is read as a signed int8_t, so a byte value of 0xFF becomes -1. This negative value passes the len > 512 check (-1 < 512 as signed) and is then passed to memcpy where it is implicitly cast to size_t, resulting in a size of SIZE_MAX and causing an immediate heap buffer overflow [2].

Exploitation

An attacker must craft a malicious ".ckpt" checkpoint file containing a SHORT_BINUNICODE opcode with a length byte of 0xFF. The victim or an application using an affected version of stable-diffusion.cpp must load this file from an untrusted source (e.g., a model-sharing website) [2]. No authentication, special network position, or user interaction beyond loading the file is required. The exploit triggers memcpy with an attacker-controlled length of SIZE_MAX, corrupting adjacent heap memory [2].

Impact

Successful exploitation causes heap corruption via an oversized memcpy operation. The immediate consequence is a process crash due to memory access violation. Depending on heap layout and the attacker's ability to control data adjacent to the destination buffer, this vulnerability could potentially be leveraged for arbitrary code execution [2].

Mitigation

The issue is fixed in version master-584-0a7ae07, corresponding to commit 0a7ae07f948eff4611968a65a22bd7c7031ad74f [1][2][3]. This commit replaces the vulnerable legacy checkpoint parser with a restricted loader that properly validates lengths. If upgrading immediately is not feasible, the workaround is to never load ".ckpt" files from untrusted sources and to prefer the safer ".safetensors" format where possible [2].

AI Insight generated on Jun 16, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected products

2

Patches

1
0a7ae07f948e

feat: add restricted torch legacy checkpoint loading (#1443)

https://github.com/leejet/stable-diffusion.cppleejetApr 19, 2026via body-scan
14 files changed · +1638 471
  • src/denoiser.hpp+8 11 modified
    @@ -1523,12 +1523,10 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
                                                   const std::vector<float>& sigmas,
                                                   std::shared_ptr<RNG> rng,
                                                   float eta) {
    -
         int steps = static_cast<int>(sigmas.size()) - 1;
         for (int i = 0; i < steps; i++) {
    -
    -        float sigma       = sigmas[i];
    -        float sigma_to    = sigmas[i + 1];
    +        float sigma    = sigmas[i];
    +        float sigma_to = sigmas[i + 1];
     
             auto model_output_opt = model(x, sigma, i + 1);
             if (model_output_opt.empty()) {
    @@ -1551,12 +1549,11 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
             float std_dev_t = eta * std::sqrt(variance);
     
             x = pred_original_sample +
    -            std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
    +            std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;
     
             if (eta > 0) {
    -            x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
    +            x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
             }
    -
         }
         return x;
     }
    @@ -1584,8 +1581,10 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
     
         auto get_timestep_from_sigma = [&](float s) -> int {
             auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
    -        if (it == compvis_sigmas.begin()) return 0;
    -        if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
    +        if (it == compvis_sigmas.begin())
    +            return 0;
    +        if (it == compvis_sigmas.end())
    +            return TIMESTEPS - 1;
             int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
             int idx_low  = idx_high - 1;
             if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
    @@ -1596,7 +1595,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
     
         int steps = static_cast<int>(sigmas.size()) - 1;
         for (int i = 0; i < steps; i++) {
    -
             float sigma_to    = sigmas[i + 1];
             int prev_timestep = get_timestep_from_sigma(sigma_to);
             int timestep_s    = (int)floor((1 - eta) * prev_timestep);
    @@ -1626,7 +1624,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
                 x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
                     std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
             }
    -
         }
         return x;
     }
    
  • src/model.cpp+63 26 modified
    @@ -2,6 +2,7 @@
     #include <atomic>
     #include <chrono>
     #include <cstdarg>
    +#include <cstdlib>
     #include <fstream>
     #include <functional>
     #include <mutex>
    @@ -13,9 +14,10 @@
     #include <vector>
     
     #include "model.h"
    -#include "model_io/ckpt_io.h"
     #include "model_io/gguf_io.h"
     #include "model_io/safetensors_io.h"
    +#include "model_io/torch_legacy_io.h"
    +#include "model_io/torch_zip_io.h"
     #include "stable-diffusion.h"
     #include "util.h"
     
    @@ -229,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
         } else if (is_safetensors_file(file_path)) {
             LOG_INFO("load %s using safetensors format", file_path.c_str());
             return init_from_safetensors_file(file_path, prefix);
    -    } else if (is_ckpt_file(file_path)) {
    -        LOG_INFO("load %s using checkpoint format", file_path.c_str());
    -        return init_from_ckpt_file(file_path, prefix);
    +    } else if (is_torch_zip_file(file_path)) {
    +        LOG_INFO("load %s using torch zip format", file_path.c_str());
    +        return init_from_torch_zip_file(file_path, prefix);
    +    } else if (init_from_torch_legacy_file(file_path, prefix)) {
    +        LOG_INFO("load %s using torch legacy format", file_path.c_str());
    +        return true;
         } else {
             if (file_exists(file_path)) {
                 LOG_WARN("unknown format %s", file_path.c_str());
    @@ -329,40 +334,47 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
         return true;
     }
     
    -/*================================================= DiffusersModelLoader ==================================================*/
    +/*================================================= TorchLegacyModelLoader ==================================================*/
     
    -bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
    -    std::string unet_path   = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
    -    std::string vae_path    = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
    -    std::string clip_path   = path_join(file_path, "text_encoder/model.safetensors");
    -    std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
    +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) {
    +    LOG_DEBUG("init from torch legacy '%s'", file_path.c_str());
     
    -    if (!init_from_safetensors_file(unet_path, "unet.")) {
    +    std::vector<TensorStorage> tensor_storages;
    +    std::string error;
    +    if (!read_torch_legacy_file(file_path, tensor_storages, &error)) {
    +        if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) {
    +            LOG_WARN("%s", error.c_str());
    +        }
             return false;
         }
     
    -    if (!init_from_safetensors_file(vae_path, "vae.")) {
    -        LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
    -        // return false;
    -    }
    -    if (!init_from_safetensors_file(clip_path, "te.")) {
    -        LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
    -        // return false;
    -    }
    -    if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
    -        LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
    +    file_paths_.push_back(file_path);
    +    size_t file_index = file_paths_.size() - 1;
    +
    +    for (auto& tensor_storage : tensor_storages) {
    +        if (is_unused_tensor(tensor_storage.name)) {
    +            continue;
    +        }
    +
    +        if (!starts_with(tensor_storage.name, prefix)) {
    +            tensor_storage.name = prefix + tensor_storage.name;
    +        }
    +        tensor_storage.file_index = file_index;
    +
    +        add_tensor_storage(tensor_storage);
         }
    +
         return true;
     }
     
    -/*================================================= CkptModelLoader ==================================================*/
    +/*================================================= TorchZipModelLoader ==================================================*/
     
    -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
    +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) {
         LOG_DEBUG("init from '%s'", file_path.c_str());
     
         std::vector<TensorStorage> tensor_storages;
         std::string error;
    -    if (!read_ckpt_file(file_path, tensor_storages, &error)) {
    +    if (!read_torch_zip_file(file_path, tensor_storages, &error)) {
             LOG_ERROR("%s", error.c_str());
             return false;
         }
    @@ -384,6 +396,32 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
         return true;
     }
     
    +/*================================================= DiffusersModelLoader ==================================================*/
    +
    +bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
    +    std::string unet_path   = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
    +    std::string vae_path    = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
    +    std::string clip_path   = path_join(file_path, "text_encoder/model.safetensors");
    +    std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
    +
    +    if (!init_from_safetensors_file(unet_path, "unet.")) {
    +        return false;
    +    }
    +
    +    if (!init_from_safetensors_file(vae_path, "vae.")) {
    +        LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
    +        // return false;
    +    }
    +    if (!init_from_safetensors_file(clip_path, "te.")) {
    +        LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
    +        // return false;
    +    }
    +    if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
    +        LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
    +    }
    +    return true;
    +}
    +
     SDVersion ModelLoader::get_sd_version() {
         TensorStorage token_embedding_weight, input_block_weight;
     
    @@ -1210,6 +1248,5 @@ bool convert(const char* input_path,
         if (convert_name) {
             model_loader.convert_tensors_name();
         }
    -    bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
    -    return success;
    +    return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
     }
    
  • src/model.h+2 1 modified
    @@ -200,7 +200,8 @@ class ModelLoader {
     
         bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
         bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
    -    bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
    +    bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
    +    bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
         bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
     
     public:
    
  • src/model_io/binary_io.h+57 0 added
    @@ -0,0 +1,57 @@
    +#ifndef __SD_MODEL_IO_BINARY_IO_H__
    +#define __SD_MODEL_IO_BINARY_IO_H__
    +
    +#include <cstdint>
    +#include <ostream>
    +
    +namespace model_io {
    +
    +    inline int32_t read_int(const uint8_t* buffer) {
    +        uint32_t value = 0;
    +        value |= static_cast<uint32_t>(buffer[3]) << 24;
    +        value |= static_cast<uint32_t>(buffer[2]) << 16;
    +        value |= static_cast<uint32_t>(buffer[1]) << 8;
    +        value |= static_cast<uint32_t>(buffer[0]);
    +        return static_cast<int32_t>(value);
    +    }
    +
    +    inline uint16_t read_short(const uint8_t* buffer) {
    +        uint16_t value = 0;
    +        value |= static_cast<uint16_t>(buffer[1]) << 8;
    +        value |= static_cast<uint16_t>(buffer[0]);
    +        return value;
    +    }
    +
    +    inline uint64_t read_u64(const uint8_t* buffer) {
    +        uint64_t value = 0;
    +        value |= static_cast<uint64_t>(buffer[7]) << 56;
    +        value |= static_cast<uint64_t>(buffer[6]) << 48;
    +        value |= static_cast<uint64_t>(buffer[5]) << 40;
    +        value |= static_cast<uint64_t>(buffer[4]) << 32;
    +        value |= static_cast<uint64_t>(buffer[3]) << 24;
    +        value |= static_cast<uint64_t>(buffer[2]) << 16;
    +        value |= static_cast<uint64_t>(buffer[1]) << 8;
    +        value |= static_cast<uint64_t>(buffer[0]);
    +        return value;
    +    }
    +
    +    inline void write_u64(std::ostream& stream, uint64_t value) {
    +        uint8_t buffer[8];
    +        for (int i = 0; i < 8; ++i) {
    +            buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
    +        }
    +        stream.write((const char*)buffer, sizeof(buffer));
    +    }
    +
    +    inline int find_char(const uint8_t* buffer, int len, char c) {
    +        for (int pos = 0; pos < len; pos++) {
    +            if (buffer[pos] == (uint8_t)c) {
    +                return pos;
    +            }
    +        }
    +        return -1;
    +    }
    +
    +}  // namespace model_io
    +
    +#endif  // __SD_MODEL_IO_BINARY_IO_H__
    
  • src/model_io/ckpt_io.cpp+0 403 removed
    @@ -1,403 +0,0 @@
    -#include "ckpt_io.h"
    -
    -#include <cstdint>
    -#include <cstdio>
    -#include <cstdlib>
    -#include <cstring>
    -#include <string>
    -#include <vector>
    -
    -#include "zip.h"
    -
    -static constexpr int MAX_STRING_BUFFER = 512;
    -
    -static void set_error(std::string* error, const std::string& message) {
    -    if (error != nullptr) {
    -        *error = message;
    -    }
    -}
    -
    -static int32_t read_int(const uint8_t* buffer) {
    -    // little endian
    -    uint32_t value = 0;
    -    value |= static_cast<uint32_t>(buffer[3]) << 24;
    -    value |= static_cast<uint32_t>(buffer[2]) << 16;
    -    value |= static_cast<uint32_t>(buffer[1]) << 8;
    -    value |= static_cast<uint32_t>(buffer[0]);
    -    return static_cast<int32_t>(value);
    -}
    -
    -static uint16_t read_short(const uint8_t* buffer) {
    -    // little endian
    -    uint16_t value = 0;
    -    value |= static_cast<uint16_t>(buffer[1]) << 8;
    -    value |= static_cast<uint16_t>(buffer[0]);
    -    return value;
    -}
    -
    -bool is_ckpt_file(const std::string& file_path) {
    -    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    -    if (zip == nullptr) {
    -        return false;
    -    }
    -    zip_close(zip);
    -    return true;
    -}
    -
    -/*================================================= CkptModelLoader ==================================================*/
    -
    -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
    -//     0: \x80 PROTO      2
    -//     2: }    EMPTY_DICT
    -//     3: q    BINPUT     0
    -//     5: (    MARK
    -//     6: X        BINUNICODE 'epoch'
    -//    16: q        BINPUT     1
    -//    18: K        BININT1    6
    -//    20: X        BINUNICODE 'global_step'
    -//    36: q        BINPUT     2
    -//    38: J        BININT     470000
    -//    43: X        BINUNICODE 'pytorch-lightning_version'
    -//    73: q        BINPUT     3
    -//    75: X        BINUNICODE '1.4.2'
    -//    85: q        BINPUT     4
    -//    87: X        BINUNICODE 'state_dict'
    -//   102: q        BINPUT     5
    -//   104: }        EMPTY_DICT
    -//   105: q        BINPUT     6
    -//   107: (        MARK
    -//   108: X            BINUNICODE 'betas'
    -//   118: q            BINPUT     7
    -//   120: c            GLOBAL     'torch._utils _rebuild_tensor_v2'
    -//   153: q            BINPUT     8
    -//   155: (            MARK
    -//   156: (                MARK
    -//   157: X                    BINUNICODE 'storage'
    -//   169: q                    BINPUT     9
    -//   171: c                    GLOBAL     'torch FloatStorage'
    -//   191: q                    BINPUT     10
    -//   193: X                    BINUNICODE '0'
    -//   199: q                    BINPUT     11
    -//   201: X                    BINUNICODE 'cpu'
    -//   209: q                    BINPUT     12
    -//   211: M                    BININT2    1000
    -//   214: t                    TUPLE      (MARK at 156)
    -//   215: q                BINPUT     13
    -//   217: Q                BINPERSID
    -//   218: K                BININT1    0
    -//   220: M                BININT2    1000
    -//  ...............................
    -//  3201: q            BINPUT     250
    -//  3203: R            REDUCE
    -//  3204: q            BINPUT     251
    -//  3206: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
    -//  3264: q            BINPUT     252
    -//  3266: h            BINGET     8
    -//  3268: (            MARK
    -//  3269: (                MARK
    -//  3270: h                    BINGET     9
    -//  3272: h                    BINGET     10
    -//  3274: X                    BINUNICODE '30'
    -//  3281: q                    BINPUT     253
    -//  3283: h                    BINGET     12
    -//  3285: J                    BININT     102400
    -//  3290: t                    TUPLE      (MARK at 3269)
    -//  3291: q                BINPUT     254
    -//  3293: Q                BINPERSID
    -//  3294: K                BININT1    0
    -//  3296: (                MARK
    -//  3297: M                    BININT2    320
    -//  3300: M                    BININT2    320
    -//  3303: K                    BININT1    1
    -//  3305: K                    BININT1    1
    -//  3307: t                    TUPLE      (MARK at 3296)
    -//  3308: q                BINPUT     255
    -//  3310: (                MARK
    -//  3311: M                    BININT2    320
    -//  3314: K                    BININT1    1
    -//  3316: K                    BININT1    1
    -//  3318: K                    BININT1    1
    -//  3320: t                    TUPLE      (MARK at 3310)
    -//  3321: r                LONG_BINPUT 256
    -//  3326: \x89             NEWFALSE
    -//  3327: h                BINGET     16
    -//  3329: )                EMPTY_TUPLE
    -//  3330: R                REDUCE
    -//  3331: r                LONG_BINPUT 257
    -//  3336: t                TUPLE      (MARK at 3268)
    -//  3337: r            LONG_BINPUT 258
    -//  3342: R            REDUCE
    -//  3343: r            LONG_BINPUT 259
    -//  3348: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
    -//  3404: r            LONG_BINPUT 260
    -//  3409: h            BINGET     8
    -//  3411: (            MARK
    -//  3412: (                MARK
    -//  3413: h                    BINGET     9
    -//  3415: h                    BINGET     10
    -//  3417: X                    BINUNICODE '31'
    -
    -struct PickleTensorReader {
    -    enum ReadPhase {
    -        READ_NAME,
    -        READ_DATA,
    -        CHECK_SIZE,
    -        READ_DIMENS
    -    };
    -    ReadPhase phase   = READ_NAME;
    -    size_t entry_size = 0;
    -    int32_t nelements = 0;
    -
    -    TensorStorage tensor_storage;
    -
    -    static ggml_type global_type;  // all pickle_tensors data type
    -    static bool read_global_type;
    -
    -    bool read_int_value(uint32_t value) {
    -        if (phase == CHECK_SIZE) {
    -            if (entry_size == value * ggml_type_size(tensor_storage.type)) {
    -                nelements = value;
    -                phase     = READ_DIMENS;
    -                return true;
    -            } else {
    -                phase = READ_NAME;
    -            }
    -        } else if (phase == READ_DIMENS) {
    -            if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) {  // too many dimens
    -                phase                 = READ_NAME;
    -                tensor_storage.n_dims = 0;
    -            }
    -            if (nelements % value == 0) {
    -                tensor_storage.ne[tensor_storage.n_dims] = value;
    -                tensor_storage.n_dims++;
    -            }
    -        }
    -        return false;
    -    }
    -
    -    void read_global(const std::string& str) {
    -        if (str == "FloatStorage") {
    -            if (read_global_type) {
    -                global_type      = GGML_TYPE_F32;
    -                read_global_type = false;
    -            }
    -            tensor_storage.type = GGML_TYPE_F32;
    -        } else if (str == "HalfStorage") {
    -            if (read_global_type) {
    -                global_type      = GGML_TYPE_F16;
    -                read_global_type = false;
    -            }
    -            tensor_storage.type = GGML_TYPE_F16;
    -        }
    -    }
    -
    -    void read_string(const std::string& str, zip_t* zip, std::string dir) {
    -        if (str == "storage") {
    -            read_global_type = true;
    -        } else if (str != "state_dict") {
    -            if (phase == READ_DATA) {
    -                std::string entry_name = dir + "data/" + std::string(str);
    -
    -                size_t i, n = zip_entries_total(zip);
    -                for (i = 0; i < n; ++i) {
    -                    zip_entry_openbyindex(zip, i);
    -                    {
    -                        std::string name = zip_entry_name(zip);
    -                        if (name == entry_name) {
    -                            tensor_storage.index_in_zip = (int)i;
    -                            entry_size                  = zip_entry_size(zip);
    -                            zip_entry_close(zip);
    -                            break;
    -                        }
    -                    }
    -                    zip_entry_close(zip);
    -                }
    -
    -                phase = entry_size > 0 ? CHECK_SIZE : READ_NAME;
    -            }
    -            if (!read_global_type && phase == READ_NAME) {
    -                tensor_storage.name = str;
    -                phase               = READ_DATA;
    -                tensor_storage.type = global_type;
    -            }
    -        }
    -    }
    -};
    -
    -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32;  // all pickle_tensors data type
    -bool PickleTensorReader::read_global_type = false;
    -
    -static int find_char(uint8_t* buffer, int len, char c) {
    -    for (int pos = 0; pos < len; pos++) {
    -        if (buffer[pos] == c) {
    -            return pos;
    -        }
    -    }
    -    return -1;
    -}
    -
    -static bool parse_data_pkl(uint8_t* buffer,
    -                           size_t buffer_size,
    -                           zip_t* zip,
    -                           std::string dir,
    -                           std::vector<TensorStorage>& tensor_storages,
    -                           std::string* error) {
    -    uint8_t* buffer_end = buffer + buffer_size;
    -    if (buffer[0] == 0x80) {  // proto
    -        if (buffer[1] != 2) {
    -            set_error(error, "unsupported pickle protocol");
    -            return false;
    -        }
    -        buffer += 2;  // 0x80 and version
    -        char string_buffer[MAX_STRING_BUFFER];
    -        bool finish = false;
    -        PickleTensorReader reader;
    -        // read pickle binary file
    -        while (!finish && buffer < buffer_end) {
    -            uint8_t opcode = *buffer;
    -            buffer++;
    -            // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
    -            // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
    -            switch (opcode) {
    -                case '}':  // EMPTY_DICT     = b'}'   # push empty dict
    -                    break;
    -                case ']':  // EMPTY_LIST     = b']'   # push empty list
    -                    break;
    -                // skip unused sections
    -                case 'h':  // BINGET         = b'h'   #   "    "    "    "   "   "  ;   "    " 1-byte arg
    -                case 'q':  // BINPUT         = b'q'   #   "     "    "   "   " ;   "    " 1-byte arg
    -                case 'Q':  // BINPERSID      = b'Q'   #  "       "         "  ;  "  "   "     "  stack
    -                    buffer++;
    -                    break;
    -                case 'r':  // LONG_BINPUT    = b'r'   #   "     "    "   "   " ;   "    " 4-byte arg
    -                    buffer += 4;
    -                    break;
    -                case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    -                    buffer += 8;
    -                    break;
    -                case 0x94:  // MEMOIZE          = b'\x94'  # store top of the stack in memo
    -                    break;
    -                case '(':  // MARK           = b'('   # push special markobject on stack
    -                    break;
    -                case 'K':  // BININT1        = b'K'   # push 1-byte unsigned int
    -                {
    -                    uint8_t value = *buffer;
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;
    -                    }
    -                    buffer++;
    -                } break;
    -                case 'M':  // BININT2        = b'M'   # push 2-byte unsigned int
    -                {
    -                    uint16_t value = read_short(buffer);
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;
    -                    }
    -                    buffer += 2;
    -                } break;
    -                case 'J':  // BININT         = b'J'   # push four-byte signed int
    -                {
    -                    const int32_t value = read_int(buffer);
    -                    if (reader.read_int_value(value)) {
    -                        buffer++;  // skip tuple after read num_elements
    -                    }
    -                    buffer += 4;
    -                } break;
    -                case 'X':  // BINUNICODE     = b'X'   #   "     "       "  ; counted UTF-8 string argument
    -                {
    -                    const int32_t len = read_int(buffer);
    -                    buffer += 4;
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    if (len > MAX_STRING_BUFFER) {
    -                        // keep truncated names null-terminated, matching the old parser behavior
    -                    }
    -                    memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1));
    -                    buffer += len;
    -                    reader.read_string(string_buffer, zip, dir);
    -                } break;
    -                case 0x8C:  // SHORT_BINUNICODE = b'\x8c'  # push short string; UTF-8 length < 256 bytes
    -                {
    -                    const int8_t len = *buffer;
    -                    buffer++;
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    memcpy(string_buffer, buffer, len);
    -                    buffer += len;
    -                    // printf("String: '%s'\n", string_buffer);
    -                } break;
    -                case 'c':  // GLOBAL         = b'c'   # push self.find_class(modname, name); 2 string args
    -                {
    -                    int len = find_char(buffer, MAX_STRING_BUFFER, '\n');
    -
    -                    buffer += len + 1;
    -                    len = find_char(buffer, MAX_STRING_BUFFER, '\n');
    -
    -                    memset(string_buffer, 0, MAX_STRING_BUFFER);
    -                    memcpy(string_buffer, buffer, len);
    -                    buffer += len + 1;
    -                    reader.read_global(string_buffer);
    -                } break;
    -                case 0x86:  // TUPLE2         = b'\x86'  # build 2-tuple from two topmost stack items
    -                case 0x85:  // TUPLE1         = b'\x85'  # build 1-tuple from stack top
    -                case 't':   // TUPLE          = b't'   # build tuple from topmost stack items
    -                    if (reader.phase == PickleTensorReader::READ_DIMENS) {
    -                        reader.tensor_storage.reverse_ne();
    -                        tensor_storages.push_back(reader.tensor_storage);
    -
    -                        // LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
    -                        // reset
    -                        reader = PickleTensorReader();
    -                    }
    -                    break;
    -                case '.':  // STOP           = b'.'   # every pickle ends with STOP
    -                    finish = true;
    -                    break;
    -                default:
    -                    break;
    -            }
    -        }
    -    }
    -    return true;
    -}
    -
    -bool read_ckpt_file(const std::string& file_path,
    -                    std::vector<TensorStorage>& tensor_storages,
    -                    std::string* error) {
    -    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    -    if (zip == nullptr) {
    -        set_error(error, "failed to open '" + file_path + "'");
    -        return false;
    -    }
    -
    -    tensor_storages.clear();
    -    bool success = true;
    -    int n        = (int)zip_entries_total(zip);
    -    for (int i = 0; i < n; ++i) {
    -        zip_entry_openbyindex(zip, i);
    -        {
    -            std::string name = zip_entry_name(zip);
    -            size_t pos       = name.find("data.pkl");
    -            if (pos != std::string::npos) {
    -                std::string dir = name.substr(0, pos);
    -                printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
    -                void* pkl_data = nullptr;
    -                size_t pkl_size;
    -                zip_entry_read(zip, &pkl_data, &pkl_size);
    -
    -                // LOG_DEBUG("%lld", pkl_size);
    -
    -                if (!parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
    -                    success = false;
    -                }
    -
    -                free(pkl_data);
    -            }
    -        }
    -        zip_entry_close(zip);
    -
    -        if (!success) {
    -            break;
    -        }
    -    }
    -    zip_close(zip);
    -    return success;
    -}
    
  • src/model_io/ckpt_io.h+0 14 removed
    @@ -1,14 +0,0 @@
    -#ifndef __SD_MODEL_IO_CKPT_IO_H__
    -#define __SD_MODEL_IO_CKPT_IO_H__
    -
    -#include <string>
    -#include <vector>
    -
    -#include "tensor_storage.h"
    -
    -bool is_ckpt_file(const std::string& file_path);
    -bool read_ckpt_file(const std::string& file_path,
    -                    std::vector<TensorStorage>& tensor_storages,
    -                    std::string* error = nullptr);
    -
    -#endif  // __SD_MODEL_IO_CKPT_IO_H__
    
  • src/model_io/pickle_io.cpp+1064 0 added
    @@ -0,0 +1,1064 @@
    +#include "pickle_io.h"
    +
    +#include <cstdlib>
    +#include <cstring>
    +#include <string>
    +#include <unordered_map>
    +#include <utility>
    +#include <vector>
    +
    +#include "binary_io.h"
    +#include "util.h"
    +
    +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
    +//     0: \x80 PROTO      2
    +//     2: }    EMPTY_DICT
    +//     3: q    BINPUT     0
    +//     5: (    MARK
    +//     6: X        BINUNICODE 'epoch'
    +//    16: q        BINPUT     1
    +//    18: K        BININT1    6
    +//    20: X        BINUNICODE 'global_step'
    +//    36: q        BINPUT     2
    +//    38: J        BININT     470000
    +//    43: X        BINUNICODE 'pytorch-lightning_version'
    +//    73: q        BINPUT     3
    +//    75: X        BINUNICODE '1.4.2'
    +//    85: q        BINPUT     4
    +//    87: X        BINUNICODE 'state_dict'
    +//   102: q        BINPUT     5
    +//   104: }        EMPTY_DICT
    +//   105: q        BINPUT     6
    +//   107: (        MARK
    +//   108: X            BINUNICODE 'betas'
    +//   118: q            BINPUT     7
    +//   120: c            GLOBAL     'torch._utils _rebuild_tensor_v2'
    +//   153: q            BINPUT     8
    +//   155: (            MARK
    +//   156: (                MARK
    +//   157: X                    BINUNICODE 'storage'
    +//   169: q                    BINPUT     9
    +//   171: c                    GLOBAL     'torch FloatStorage'
    +//   191: q                    BINPUT     10
    +//   193: X                    BINUNICODE '0'
    +//   199: q                    BINPUT     11
    +//   201: X                    BINUNICODE 'cpu'
    +//   209: q                    BINPUT     12
    +//   211: M                    BININT2    1000
    +//   214: t                    TUPLE      (MARK at 156)
    +//   215: q                BINPUT     13
    +//   217: Q                BINPERSID
    +//   218: K                BININT1    0
    +//   220: M                BININT2    1000
    +//  ...............................
    +//  3201: q            BINPUT     250
    +//  3203: R            REDUCE
    +//  3204: q            BINPUT     251
    +//  3206: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
    +//  3264: q            BINPUT     252
    +//  3266: h            BINGET     8
    +//  3268: (            MARK
    +//  3269: (                MARK
    +//  3270: h                    BINGET     9
    +//  3272: h                    BINGET     10
    +//  3274: X                    BINUNICODE '30'
    +//  3281: q                    BINPUT     253
    +//  3283: h                    BINGET     12
    +//  3285: J                    BININT     102400
    +//  3290: t                    TUPLE      (MARK at 3269)
    +//  3291: q                BINPUT     254
    +//  3293: Q                BINPERSID
    +//  3294: K                BININT1    0
    +//  3296: (                MARK
    +//  3297: M                    BININT2    320
    +//  3300: M                    BININT2    320
    +//  3303: K                    BININT1    1
    +//  3305: K                    BININT1    1
    +//  3307: t                    TUPLE      (MARK at 3296)
    +//  3308: q                BINPUT     255
    +//  3310: (                MARK
    +//  3311: M                    BININT2    320
    +//  3314: K                    BININT1    1
    +//  3316: K                    BININT1    1
    +//  3318: K                    BININT1    1
    +//  3320: t                    TUPLE      (MARK at 3310)
    +//  3321: r                LONG_BINPUT 256
    +//  3326: \x89             NEWFALSE
    +//  3327: h                BINGET     16
    +//  3329: )                EMPTY_TUPLE
    +//  3330: R                REDUCE
    +//  3331: r                LONG_BINPUT 257
    +//  3336: t                TUPLE      (MARK at 3268)
    +//  3337: r            LONG_BINPUT 258
    +//  3342: R            REDUCE
    +//  3343: r            LONG_BINPUT 259
    +//  3348: X            BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
    +//  3404: r            LONG_BINPUT 260
    +//  3409: h            BINGET     8
    +//  3411: (            MARK
    +//  3412: (                MARK
    +//  3413: h                    BINGET     9
    +//  3415: h                    BINGET     10
    +//  3417: X                    BINUNICODE '31'
    +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
    +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
    +
    +using model_io::find_char;
    +using model_io::read_int;
    +using model_io::read_short;
    +using model_io::read_u64;
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) {
    +    const uint8_t* p   = buffer;
    +    const uint8_t* end = buffer + buffer_size;
    +
    +    while (p < end) {
    +        uint8_t opcode = *p++;
    +        switch (opcode) {
    +            case '.':  // STOP             = b'.'   # every pickle ends with STOP
    +                *object_size = (size_t)(p - buffer);
    +                return true;
    +            case 0x80:  // PROTO            = b'\x80'  # protocol version indicator
    +            case 'K':   // BININT1          = b'K'   # push 1-byte unsigned int
    +            case 'h':   // BINGET           = b'h'   # read memo index, 1-byte arg
    +            case 'q':   // BINPUT           = b'q'   # write memo index, 1-byte arg
    +            case 'C':   // SHORT_BINBYTES   = b'C'   # push bytes; length < 256
    +            case 0x82:  // EXT1             = b'\x82'  # extension code, 1-byte arg
    +                p += 1;
    +                break;
    +            case 'M':   // BININT2          = b'M'   # push 2-byte unsigned int
    +            case 0x83:  // EXT2             = b'\x83'  # extension code, 2-byte arg
    +                p += 2;
    +                break;
    +            case 'J':   // BININT           = b'J'   # push 4-byte signed int
    +            case 'j':   // LONG_BINGET      = b'j'   # read memo index, 4-byte arg
    +            case 'r':   // LONG_BINPUT      = b'r'   # write memo index, 4-byte arg
    +            case 0x84:  // EXT4             = b'\x84'  # extension code, 4-byte arg
    +                p += 4;
    +                break;
    +            case 'I':    // INT              = b'I'   # push decimal integer line
    +            case 'L':    // LONG             = b'L'   # push decimal long integer line
    +            case 'F':    // FLOAT            = b'F'   # push decimal float line
    +            case 'S':    // STRING           = b'S'   # push quoted string line
    +            case 'V': {  // UNICODE          = b'V'   # push raw-unicode string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case 'G':  // BINFLOAT         = b'G'   # push 8-byte binary float
    +                p += 8;
    +                break;
    +            case 0x8A:  // LONG1            = b'\x8a'  # push long integer; 1-byte length
    +                if (p >= end) {
    +                    return false;
    +                }
    +                p += 1 + p[0];
    +                break;
    +            case 0x8B: {  // LONG4            = b'\x8b'  # push long integer; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 'B': {  // BINBYTES         = b'B'   # push bytes; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 'T':    // BINSTRING        = b'T'   # push string; 4-byte length
    +            case 'X': {  // BINUNICODE       = b'X'   # push UTF-8 string; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                uint32_t n = (uint32_t)read_int(p);
    +                p += 4 + n;
    +            } break;
    +            case 0x8D:    // BINUNICODE8      = b'\x8d'  # push UTF-8 string; 8-byte length
    +            case 0x8E:    // BINBYTES8        = b'\x8e'  # push bytes; 8-byte length
    +            case 0x96: {  // BYTEARRAY8       = b'\x96'  # push bytearray; 8-byte length
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                uint64_t n = read_u64(p);
    +                p += 8;
    +                if (n > (uint64_t)(end - p)) {
    +                    return false;
    +                }
    +                p += n;
    +            } break;
    +            case 'U':   // SHORT_BINSTRING  = b'U'   # push string; length < 256
    +            case 0x8C:  // SHORT_BINUNICODE = b'\x8c'  # push UTF-8 string; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                p += 1 + p[0];
    +                break;
    +            case 'P': {  // PERSID           = b'P'   # persistent id, newline-terminated
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    +                p += 8;
    +                break;
    +            case 'c': {  // GLOBAL           = b'c'   # push module/name global reference
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            case '}':   // EMPTY_DICT       = b'}'   # push empty dict
    +            case ']':   // EMPTY_LIST       = b']'   # push empty list
    +            case '(':   // MARK             = b'('   # push markobject
    +            case 't':   // TUPLE            = b't'   # build tuple from mark
    +            case 0x85:  // TUPLE1           = b'\x85'  # build 1-tuple from stack
    +            case 0x86:  // TUPLE2           = b'\x86'  # build 2-tuple from stack
    +            case 0x87:  // TUPLE3           = b'\x87'  # build 3-tuple from stack
    +            case ')':   // EMPTY_TUPLE      = b')'   # push empty tuple
    +            case 'l':   // LIST             = b'l'   # build list from mark
    +            case 'Q':   // BINPERSID        = b'Q'   # persistent id from stack
    +            case 0x94:  // MEMOIZE          = b'\x94'  # store top of stack in memo
    +            case 0x88:  // NEWTRUE          = b'\x88'  # push True
    +            case 0x89:  // NEWFALSE         = b'\x89'  # push False
    +            case 'R':   // REDUCE           = b'R'   # apply callable to args
    +            case 'u':   // SETITEMS         = b'u'   # add mark-delimited items to dict
    +            case 's':   // SETITEM          = b's'   # add key/value to dict
    +            case 'e':   // APPENDS          = b'e'   # extend list with mark-delimited items
    +            case 'a':   // APPEND           = b'a'   # append item to list
    +            case 'b':   // BUILD            = b'b'   # build object state
    +            case 0x81:  // NEWOBJ           = b'\x81'  # build object via __new__
    +            case 0x8F:  // EMPTY_SET        = b'\x8f'  # push empty set
    +            case 0x90:  // ADDITEMS         = b'\x90'  # add mark-delimited items to set
    +            case 0x91:  // FROZENSET        = b'\x91'  # build frozenset from mark
    +            case 0x92:  // NEWOBJ_EX        = b'\x92'  # build object with kwargs
    +            case 0x93:  // STACK_GLOBAL     = b'\x93'  # build global from module/name strings
    +            case 0x97:  // NEXT_BUFFER      = b'\x97'  # out-of-band buffer marker
    +            case 0x98:  // READONLY_BUFFER  = b'\x98'  # mark buffer readonly
    +            case 'N':   // NONE             = b'N'   # push None
    +            case '0':   // POP              = b'0'   # discard top stack item
    +            case '1':   // POP_MARK         = b'1'   # discard stack through topmost mark
    +            case '2':   // DUP              = b'2'   # duplicate top stack item
    +            case 'o':   // OBJ              = b'o'   # build class instance from mark
    +                break;
    +            case 'i': {  // INST             = b'i'   # build class instance from module/name
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +            } break;
    +            default:
    +                return false;
    +        }
    +        if (p > end) {
    +            return false;
    +        }
    +    }
    +
    +    return false;
    +}
    +
    +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) {
    +    static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19};
    +
    +    if (buffer_size < 5 || buffer[0] != 0x80) {
    +        return false;
    +    }
    +
    +    size_t pos = 2;
    +    if (pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t opcode = buffer[pos++];
    +    if (opcode != 0x8A || pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t len = buffer[pos++];
    +    if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) {
    +        return false;
    +    }
    +
    +    if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) {
    +        return false;
    +    }
    +    pos += len;
    +
    +    return pos < buffer_size && buffer[pos] == '.';
    +}
    +
    +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) {
    +    if (buffer_size < 4 || buffer[0] != 0x80) {
    +        return false;
    +    }
    +
    +    size_t pos = 2;
    +    if (pos >= buffer_size) {
    +        return false;
    +    }
    +
    +    uint8_t opcode = buffer[pos++];
    +    switch (opcode) {
    +        case 'K':  // BININT1          = b'K'   # push 1-byte unsigned int
    +            if (pos + 1 >= buffer_size) {
    +                return false;
    +            }
    +            *value = buffer[pos];
    +            pos += 1;
    +            break;
    +        case 'M':  // BININT2          = b'M'   # push 2-byte unsigned int
    +            if (pos + 2 >= buffer_size) {
    +                return false;
    +            }
    +            *value = read_short(buffer + pos);
    +            pos += 2;
    +            break;
    +        case 'J':  // BININT           = b'J'   # push 4-byte signed int
    +            if (pos + 4 >= buffer_size) {
    +                return false;
    +            }
    +            *value = (uint32_t)read_int(buffer + pos);
    +            pos += 4;
    +            break;
    +        default:
    +            return false;
    +    }
    +
    +    return pos < buffer_size && buffer[pos] == '.';
    +}
    +
    +struct PickleStorageInfo {
    +    std::string key;
    +    ggml_type type              = GGML_TYPE_COUNT;
    +    bool is_f64                 = false;
    +    bool is_i64                 = false;
    +    uint64_t raw_element_nbytes = 0;
    +    uint64_t nbytes             = 0;
    +};
    +
    +struct PickleTensorInfo {
    +    TensorStorage tensor_storage;
    +    int stride_n_dims = 0;
    +    int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1};
    +};
    +
    +struct PickleValue {
    +    enum Kind {
    +        MARK,
    +        NONE,
    +        BOOL,
    +        INT,
    +        STRING,
    +        GLOBAL,
    +        TUPLE,
    +        LIST,
    +        DICT,
    +        ORDERED_DICT,
    +        STORAGE,
    +        TENSOR,
    +    };
    +
    +    Kind kind         = NONE;
    +    int64_t int_value = 0;
    +    bool bool_value   = false;
    +    std::string str_value;
    +    std::vector<PickleValue> items;
    +    std::vector<std::pair<PickleValue, PickleValue>> dict_items;
    +    PickleStorageInfo storage;
    +    PickleTensorInfo tensor;
    +};
    +
    +static PickleValue make_mark_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::MARK;
    +    return value;
    +}
    +
    +static PickleValue make_none_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::NONE;
    +    return value;
    +}
    +
    +static PickleValue make_bool_value(bool b) {
    +    PickleValue value;
    +    value.kind       = PickleValue::BOOL;
    +    value.bool_value = b;
    +    return value;
    +}
    +
    +static PickleValue make_int_value(int64_t x) {
    +    PickleValue value;
    +    value.kind      = PickleValue::INT;
    +    value.int_value = x;
    +    return value;
    +}
    +
    +static PickleValue make_string_value(const std::string& s) {
    +    PickleValue value;
    +    value.kind      = PickleValue::STRING;
    +    value.str_value = s;
    +    return value;
    +}
    +
    +static PickleValue make_global_value(const std::string& s) {
    +    PickleValue value;
    +    value.kind      = PickleValue::GLOBAL;
    +    value.str_value = s;
    +    return value;
    +}
    +
    +static PickleValue make_tuple_value(std::vector<PickleValue> items) {
    +    PickleValue value;
    +    value.kind  = PickleValue::TUPLE;
    +    value.items = std::move(items);
    +    return value;
    +}
    +
    +static PickleValue make_list_value() {
    +    PickleValue value;
    +    value.kind = PickleValue::LIST;
    +    return value;
    +}
    +
    +static PickleValue make_dict_value(bool ordered) {
    +    PickleValue value;
    +    value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT;
    +    return value;
    +}
    +
    +static PickleValue make_storage_value(const PickleStorageInfo& storage) {
    +    PickleValue value;
    +    value.kind    = PickleValue::STORAGE;
    +    value.storage = storage;
    +    return value;
    +}
    +
    +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) {
    +    PickleValue value;
    +    value.kind   = PickleValue::TENSOR;
    +    value.tensor = tensor;
    +    return value;
    +}
    +
    +static std::string pickle_value_to_string(const PickleValue& value) {
    +    if (value.kind == PickleValue::STRING) {
    +        return value.str_value;
    +    }
    +    if (value.kind == PickleValue::INT) {
    +        return std::to_string(value.int_value);
    +    }
    +    return "";
    +}
    +
    +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) {
    +    if (global_name == "torch.FloatStorage") {
    +        storage->type               = GGML_TYPE_F32;
    +        storage->raw_element_nbytes = 4;
    +        return true;
    +    }
    +    if (global_name == "torch.DoubleStorage") {
    +        storage->type               = GGML_TYPE_F32;
    +        storage->is_f64             = true;
    +        storage->raw_element_nbytes = 8;
    +        return true;
    +    }
    +    if (global_name == "torch.HalfStorage") {
    +        storage->type               = GGML_TYPE_F16;
    +        storage->raw_element_nbytes = 2;
    +        return true;
    +    }
    +    if (global_name == "torch.BFloat16Storage") {
    +        storage->type               = GGML_TYPE_BF16;
    +        storage->raw_element_nbytes = 2;
    +        return true;
    +    }
    +    if (global_name == "torch.IntStorage") {
    +        storage->type               = GGML_TYPE_I32;
    +        storage->raw_element_nbytes = 4;
    +        return true;
    +    }
    +    if (global_name == "torch.LongStorage") {
    +        storage->type               = GGML_TYPE_I32;
    +        storage->is_i64             = true;
    +        storage->raw_element_nbytes = 8;
    +        return true;
    +    }
    +    return false;
    +}
    +
    +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) {
    +    if (tensor.tensor_storage.nelements() == 0) {
    +        return true;
    +    }
    +    if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) {
    +        return false;
    +    }
    +
    +    int64_t expected_stride = 1;
    +    for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) {
    +        if (tensor.stride[i] != expected_stride) {
    +            return false;
    +        }
    +        expected_stride *= tensor.tensor_storage.ne[i];
    +    }
    +    return true;
    +}
    +
    +static void collect_tensors_from_pickle_value(const PickleValue& value,
    +                                              std::vector<TensorStorage>& tensor_storages) {
    +    if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) {
    +        return;
    +    }
    +
    +    for (const auto& item : value.dict_items) {
    +        if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) {
    +            TensorStorage tensor_storage = item.second.tensor.tensor_storage;
    +            tensor_storage.name          = item.first.str_value;
    +            tensor_storage.reverse_ne();
    +            tensor_storages.push_back(tensor_storage);
    +        } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) {
    +            collect_tensors_from_pickle_value(item.second, tensor_storages);
    +        }
    +    }
    +}
    +
    +bool parse_torch_state_dict_pickle(const uint8_t* buffer,
    +                                   size_t buffer_size,
    +                                   std::vector<TensorStorage>& tensor_storages,
    +                                   std::unordered_map<std::string, uint64_t>& storage_nbytes,
    +                                   std::string* error) {
    +    if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) {
    +        set_error(error, "unsupported torch pickle protocol");
    +        return false;
    +    }
    +
    +    const uint8_t* p   = buffer + 2;
    +    const uint8_t* end = buffer + buffer_size;
    +    std::vector<PickleValue> stack;
    +    std::unordered_map<int32_t, PickleValue> memo;
    +
    +    while (p < end) {
    +        uint8_t opcode = *p++;
    +        switch (opcode) {
    +            case '.': {  // STOP             = b'.'   # every pickle ends with STOP
    +                if (stack.empty()) {
    +                    set_error(error, "empty torch pickle stack");
    +                    return false;
    +                }
    +                size_t old_tensor_count = tensor_storages.size();
    +                collect_tensors_from_pickle_value(stack.back(), tensor_storages);
    +                if (tensor_storages.size() == old_tensor_count) {
    +                    set_error(error, "torch pickle does not contain a supported state_dict");
    +                    return false;
    +                }
    +                return true;
    +            }
    +            case '}':  // EMPTY_DICT       = b'}'   # push empty dict
    +                stack.push_back(make_dict_value(false));
    +                break;
    +            case ']':  // EMPTY_LIST       = b']'   # push empty list
    +                stack.push_back(make_list_value());
    +                break;
    +            case 'l': {  // LIST             = b'l'   # build list from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    set_error(error, "torch pickle list without mark");
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                PickleValue list_value = make_list_value();
    +                list_value.items       = std::move(items);
    +                stack.push_back(std::move(list_value));
    +            } break;
    +            case '(':  // MARK             = b'('   # push markobject
    +                stack.push_back(make_mark_value());
    +                break;
    +            case ')':  // EMPTY_TUPLE      = b')'   # push empty tuple
    +                stack.push_back(make_tuple_value({}));
    +                break;
    +            case 'N':  // NONE             = b'N'   # push None
    +                stack.push_back(make_none_value());
    +                break;
    +            case 0x88:  // NEWTRUE          = b'\x88'  # push True
    +                stack.push_back(make_bool_value(true));
    +                break;
    +            case 0x89:  // NEWFALSE         = b'\x89'  # push False
    +                stack.push_back(make_bool_value(false));
    +                break;
    +            case 'K':  // BININT1          = b'K'   # push 1-byte unsigned int
    +                if (p >= end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(*p++));
    +                break;
    +            case 'M':  // BININT2          = b'M'   # push 2-byte unsigned int
    +                if (p + 2 > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(read_short(p)));
    +                p += 2;
    +                break;
    +            case 'J':  // BININT           = b'J'   # push 4-byte signed int
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_int_value(read_int(p)));
    +                p += 4;
    +                break;
    +            case 'I': {  // INT              = b'I'   # push decimal integer line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (s == "01") {
    +                    stack.push_back(make_bool_value(true));
    +                } else if (s == "00") {
    +                    stack.push_back(make_bool_value(false));
    +                } else {
    +                    stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
    +                }
    +            } break;
    +            case 'L': {  // LONG             = b'L'   # push decimal long integer line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (!s.empty() && s.back() == 'L') {
    +                    s.pop_back();
    +                }
    +                stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10)));
    +            } break;
    +            case 'F': {  // FLOAT            = b'F'   # push decimal float line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                p += len + 1;
    +                stack.push_back(make_none_value());
    +            } break;
    +            case 'G':  // BINFLOAT         = b'G'   # push 8-byte binary float
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                p += 8;
    +                stack.push_back(make_none_value());
    +                break;
    +            case 0x8A: {  // LONG1            = b'\x8a'  # push long integer; 1-byte length
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t n = *p++;
    +                if (p + n > end || n > 8) {
    +                    return false;
    +                }
    +                int64_t value = 0;
    +                for (uint8_t i = 0; i < n; ++i) {
    +                    value |= (int64_t)p[i] << (i * 8);
    +                }
    +                p += n;
    +                stack.push_back(make_int_value(value));
    +            } break;
    +            case 'C': {  // SHORT_BINBYTES   = b'C'   # push bytes; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t len = *p++;
    +                if (p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'B': {  // BINBYTES         = b'B'   # push bytes; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t len = read_int(p);
    +                p += 4;
    +                if (len < 0 || p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'T':    // BINSTRING        = b'T'   # push string; 4-byte length
    +            case 'X': {  // BINUNICODE       = b'X'   # push UTF-8 string; 4-byte length
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t len = read_int(p);
    +                p += 4;
    +                if (len < 0 || p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 0x8D:    // BINUNICODE8      = b'\x8d'  # push UTF-8 string; 8-byte length
    +            case 0x8E:    // BINBYTES8        = b'\x8e'  # push bytes; 8-byte length
    +            case 0x96: {  // BYTEARRAY8       = b'\x96'  # push bytearray; 8-byte length
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                uint64_t len = read_u64(p);
    +                p += 8;
    +                if (len > (uint64_t)(end - p)) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, (size_t)len)));
    +                p += len;
    +            } break;
    +            case 'U':     // SHORT_BINSTRING  = b'U'   # push string; length < 256
    +            case 0x8C: {  // SHORT_BINUNICODE = b'\x8c'  # push UTF-8 string; length < 256
    +                if (p >= end) {
    +                    return false;
    +                }
    +                uint8_t len = *p++;
    +                if (p + len > end) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len;
    +            } break;
    +            case 'S': {  // STRING           = b'S'   # push quoted string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string s((const char*)p, len);
    +                p += len + 1;
    +                if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) {
    +                    s = s.substr(1, s.size() - 2);
    +                }
    +                stack.push_back(make_string_value(s));
    +            } break;
    +            case 'V': {  // UNICODE          = b'V'   # push raw-unicode string line
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                stack.push_back(make_string_value(std::string((const char*)p, len)));
    +                p += len + 1;
    +            } break;
    +            case 'c': {  // GLOBAL           = b'c'   # push module/name global reference
    +                int len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string module((const char*)p, len);
    +                p += len + 1;
    +                len = find_char(p, (int)(end - p), '\n');
    +                if (len < 0) {
    +                    return false;
    +                }
    +                std::string name((const char*)p, len);
    +                p += len + 1;
    +                stack.push_back(make_global_value(module + "." + name));
    +            } break;
    +            case 0x93: {  // STACK_GLOBAL     = b'\x93'  # build global from module/name strings
    +                if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING ||
    +                    stack.back().kind != PickleValue::STRING) {
    +                    return false;
    +                }
    +                std::string name = stack.back().str_value;
    +                stack.pop_back();
    +                std::string module = stack.back().str_value;
    +                stack.pop_back();
    +                stack.push_back(make_global_value(module + "." + name));
    +            } break;
    +            case 'h':  // BINGET           = b'h'   # read memo index, 1-byte arg
    +                if (p >= end || !memo.count(*p)) {
    +                    return false;
    +                }
    +                stack.push_back(memo[*p++]);
    +                break;
    +            case 'j': {  // LONG_BINGET      = b'j'   # read memo index, 4-byte arg
    +                if (p + 4 > end) {
    +                    return false;
    +                }
    +                int32_t memo_idx = read_int(p);
    +                if (!memo.count(memo_idx)) {
    +                    return false;
    +                }
    +                stack.push_back(memo[memo_idx]);
    +                p += 4;
    +            } break;
    +            case 'q':  // BINPUT           = b'q'   # write memo index, 1-byte arg
    +                if (p >= end || stack.empty()) {
    +                    return false;
    +                }
    +                memo[*p++] = stack.back();
    +                break;
    +            case 'r':  // LONG_BINPUT      = b'r'   # write memo index, 4-byte arg
    +                if (p + 4 > end || stack.empty()) {
    +                    return false;
    +                }
    +                memo[read_int(p)] = stack.back();
    +                p += 4;
    +                break;
    +            case 0x94:  // MEMOIZE          = b'\x94'  # store top of stack in memo
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                memo[(int32_t)memo.size()] = stack.back();
    +                break;
    +            case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
    +                if (p + 8 > end) {
    +                    return false;
    +                }
    +                p += 8;
    +                break;
    +            case '0':  // POP              = b'0'   # discard top stack item
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                stack.pop_back();
    +                break;
    +            case '1': {  // POP_MARK         = b'1'   # discard stack through topmost mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case '2':  // DUP              = b'2'   # duplicate top stack item
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                stack.push_back(stack.back());
    +                break;
    +            case 0x8F:  // EMPTY_SET        = b'\x8f'  # push empty set
    +                stack.push_back(make_list_value());
    +                break;
    +            case 0x90: {  // ADDITEMS         = b'\x90'  # add mark-delimited items to set
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                PickleValue& set_value = stack[mark_idx - 1];
    +                set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 0x91: {  // FROZENSET        = b'\x91'  # build frozenset from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                PickleValue set_value = make_list_value();
    +                set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                stack.push_back(std::move(set_value));
    +            } break;
    +            case 0x85:    // TUPLE1           = b'\x85'  # build 1-tuple from stack
    +            case 0x86:    // TUPLE2           = b'\x86'  # build 2-tuple from stack
    +            case 0x87: {  // TUPLE3           = b'\x87'  # build 3-tuple from stack
    +                int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3);
    +                if ((int)stack.size() < tuple_size) {
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.end() - tuple_size, stack.end());
    +                stack.erase(stack.end() - tuple_size, stack.end());
    +                stack.push_back(make_tuple_value(std::move(items)));
    +            } break;
    +            case 't': {  // TUPLE            = b't'   # build tuple from mark
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx < 0) {
    +                    return false;
    +                }
    +                std::vector<PickleValue> items(stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +                stack.push_back(make_tuple_value(std::move(items)));
    +            } break;
    +            case 'Q': {  // BINPERSID        = b'Q'   # persistent id from stack
    +                if (stack.empty()) {
    +                    return false;
    +                }
    +                PickleValue pid = stack.back();
    +                stack.pop_back();
    +                if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING ||
    +                    pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT ||
    +                    pid.items[0].str_value != "storage") {
    +                    return false;
    +                }
    +
    +                PickleStorageInfo storage;
    +                storage.key = pickle_value_to_string(pid.items[2]);
    +                if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) {
    +                    return false;
    +                }
    +                storage.nbytes              = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes;
    +                storage_nbytes[storage.key] = storage.nbytes;
    +                stack.push_back(make_storage_value(storage));
    +            } break;
    +            case 'R': {  // REDUCE           = b'R'   # apply callable to args
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                PickleValue args = stack.back();
    +                stack.pop_back();
    +                PickleValue callable = stack.back();
    +                stack.pop_back();
    +                if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) {
    +                    stack.push_back(make_none_value());
    +                    break;
    +                }
    +
    +                if (callable.str_value == "collections.OrderedDict" && args.items.empty()) {
    +                    stack.push_back(make_dict_value(true));
    +                    break;
    +                }
    +
    +                if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") &&
    +                    args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE &&
    +                    args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE &&
    +                    args.items[3].kind == PickleValue::TUPLE) {
    +                    PickleTensorInfo tensor;
    +                    tensor.tensor_storage.type        = args.items[0].storage.type;
    +                    tensor.tensor_storage.is_f64      = args.items[0].storage.is_f64;
    +                    tensor.tensor_storage.is_i64      = args.items[0].storage.is_i64;
    +                    tensor.tensor_storage.storage_key = args.items[0].storage.key;
    +                    tensor.tensor_storage.offset      = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes;
    +
    +                    for (const auto& item : args.items[2].items) {
    +                        if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) {
    +                            return false;
    +                        }
    +                        tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value;
    +                    }
    +
    +                    for (const auto& item : args.items[3].items) {
    +                        if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) {
    +                            return false;
    +                        }
    +                        tensor.stride[tensor.stride_n_dims++] = item.int_value;
    +                    }
    +
    +                    if (!tensor_is_contiguous(tensor)) {
    +                        return false;
    +                    }
    +                    stack.push_back(make_tensor_value(tensor));
    +                    break;
    +                }
    +
    +                // Non-tensor checkpoint metadata can use REDUCE for arbitrary
    +                // Python objects. Do not execute it; keep stack shape only.
    +                stack.push_back(make_none_value());
    +                break;
    +            }
    +            case 'b':  // BUILD            = b'b'   # build object state
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                stack.pop_back();
    +                break;
    +            case 'u': {  // SETITEMS         = b'u'   # add mark-delimited items to dict
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0) {
    +                    return false;
    +                }
    +                PickleValue& dict = stack[mark_idx - 1];
    +                if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
    +                    return false;
    +                }
    +                for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) {
    +                    dict.dict_items.emplace_back(stack[i], stack[i + 1]);
    +                }
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 's': {  // SETITEM          = b's'   # add key/value to dict
    +                if (stack.size() < 3) {
    +                    return false;
    +                }
    +                PickleValue value = stack.back();
    +                stack.pop_back();
    +                PickleValue key = stack.back();
    +                stack.pop_back();
    +                PickleValue& dict = stack.back();
    +                if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) {
    +                    return false;
    +                }
    +                dict.dict_items.emplace_back(key, value);
    +            } break;
    +            case 'e': {  // APPENDS          = b'e'   # extend list with mark-delimited items
    +                int mark_idx = (int)stack.size() - 1;
    +                while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) {
    +                    --mark_idx;
    +                }
    +                if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                PickleValue& list_value = stack[mark_idx - 1];
    +                list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end());
    +                stack.erase(stack.begin() + mark_idx, stack.end());
    +            } break;
    +            case 'a': {  // APPEND           = b'a'   # append item to list
    +                if (stack.size() < 2) {
    +                    return false;
    +                }
    +                PickleValue item = stack.back();
    +                stack.pop_back();
    +                if (stack.back().kind != PickleValue::LIST) {
    +                    return false;
    +                }
    +                stack.back().items.push_back(item);
    +            } break;
    +            default:
    +                set_error(error,
    +                          "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) +
    +                              " at offset " + std::to_string((p - buffer) - 1));
    +                return false;
    +        }
    +    }
    +
    +    set_error(error, "unterminated torch state_dict pickle");
    +    return false;
    +}
    
  • src/model_io/pickle_io.h+21 0 added
    @@ -0,0 +1,21 @@
    +#ifndef __SD_MODEL_IO_PICKLE_IO_H__
    +#define __SD_MODEL_IO_PICKLE_IO_H__
    +
    +#include <cstddef>
    +#include <cstdint>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size);
    +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size);
    +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value);
    +bool parse_torch_state_dict_pickle(const uint8_t* buffer,
    +                                   size_t buffer_size,
    +                                   std::vector<TensorStorage>& tensor_storages,
    +                                   std::unordered_map<std::string, uint64_t>& storage_nbytes,
    +                                   std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_PICKLE_IO_H__
    
  • src/model_io/safetensors_io.cpp+3 16 modified
    @@ -6,6 +6,7 @@
     #include <string>
     #include <vector>
     
    +#include "binary_io.h"
     #include "json.hpp"
     
     static constexpr size_t ST_HEADER_SIZE_LEN = 8;
    @@ -16,20 +17,6 @@ static void set_error(std::string* error, const std::string& message) {
         }
     }
     
    -static uint64_t read_u64(const uint8_t* buffer) {
    -    // little endian
    -    uint64_t value = 0;
    -    value |= static_cast<uint64_t>(buffer[7]) << 56;
    -    value |= static_cast<uint64_t>(buffer[6]) << 48;
    -    value |= static_cast<uint64_t>(buffer[5]) << 40;
    -    value |= static_cast<uint64_t>(buffer[4]) << 32;
    -    value |= static_cast<uint64_t>(buffer[3]) << 24;
    -    value |= static_cast<uint64_t>(buffer[2]) << 16;
    -    value |= static_cast<uint64_t>(buffer[1]) << 8;
    -    value |= static_cast<uint64_t>(buffer[0]);
    -    return value;
    -}
    -
     bool is_safetensors_file(const std::string& file_path) {
         std::ifstream file(file_path, std::ios::binary);
         if (!file.is_open()) {
    @@ -52,7 +39,7 @@ bool is_safetensors_file(const std::string& file_path) {
             return false;
         }
     
    -    size_t header_size_ = read_u64(header_size_buf);
    +    size_t header_size_ = model_io::read_u64(header_size_buf);
         if (header_size_ >= file_size_ || header_size_ <= 2) {
             return false;
         }
    @@ -123,7 +110,7 @@ bool read_safetensors_file(const std::string& file_path,
             return false;
         }
     
    -    size_t header_size_ = read_u64(header_size_buf);
    +    size_t header_size_ = model_io::read_u64(header_size_buf);
         if (header_size_ >= file_size_) {
             set_error(error, "invalid safetensor file '" + file_path + "'");
             return false;
    
  • src/model_io/tensor_storage.h+1 0 modified
    @@ -24,6 +24,7 @@ struct TensorStorage {
         int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
         int n_dims              = 0;
     
    +    std::string storage_key;
         size_t file_index = 0;
         int index_in_zip  = -1;  // >= means stored in a zip file
         uint64_t offset   = 0;   // offset in file
    
  • src/model_io/torch_legacy_io.cpp+252 0 added
    @@ -0,0 +1,252 @@
    +#include "torch_legacy_io.h"
    +
    +#include <algorithm>
    +#include <cstdint>
    +#include <fstream>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "pickle_io.h"
    +#include "util.h"
    +
    +// torch.save format background:
    +//
    +//   - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by
    +//     default.
    +//   - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive
    +//     containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder.
    +//   - The old format can still be produced explicitly with:
    +//       torch.save(obj, path, _use_new_zipfile_serialization=False)
    +//
    +// Whether obj is a state_dict or a whole nn.Module does not change the outer
    +// container format selected by torch.save. It changes the pickled object inside:
    +//
    +//   - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a
    +//     restricted subset of this layout because tensor metadata and raw storages
    +//     can be recovered without executing pickle callables.
    +//   - whole module/checkpoint object: arbitrary Python object graph. This may
    +//     require importing user classes and executing pickle GLOBAL/REDUCE rebuild
    +//     logic, so it is intentionally not supported here.
    +//
    +// Legacy non-zip PyTorch files are not a single pickle object:
    +//
    +//   1. pickle object: PyTorch legacy magic number
    +//   2. pickle object: legacy protocol version, expected to be 1001
    +//   3. pickle object: sys_info metadata, ignored by this reader
    +//   4. pickle object: state_dict metadata, parsed by pickle_io.cpp
    +//   5. pickle object: serialized storage key list, skipped here
    +//   6. raw storage data payloads
    +//      - PyTorch writes storages after the pickles, ordered by storage key
    +//      - each storage has an 8-byte legacy storage header followed by raw bytes
    +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8;
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +static std::string bytes_to_hex(const std::vector<uint8_t>& bytes) {
    +    static const char* hex = "0123456789ABCDEF";
    +    std::string result;
    +    result.reserve(bytes.size() * 3);
    +    for (size_t i = 0; i < bytes.size(); ++i) {
    +        if (i > 0) {
    +            result.push_back('-');
    +        }
    +        result.push_back(hex[(bytes[i] >> 4) & 0x0F]);
    +        result.push_back(hex[bytes[i] & 0x0F]);
    +    }
    +    return result;
    +}
    +
    +static bool is_probably_tar_file(const std::vector<uint8_t>& header) {
    +    return header.size() >= 262 &&
    +           header[257] == 'u' &&
    +           header[258] == 's' &&
    +           header[259] == 't' &&
    +           header[260] == 'a' &&
    +           header[261] == 'r';
    +}
    +
    +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector<uint8_t>& buffer) {
    +    if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) {
    +        return "";
    +    }
    +    if (buffer.empty()) {
    +        return "unsupported PyTorch file '" + file_path + "': empty file";
    +    }
    +
    +    size_t short_len = std::min<size_t>(buffer.size(), 32);
    +    std::vector<uint8_t> short_header(buffer.begin(), buffer.begin() + short_len);
    +    const bool raw_pickle = buffer[0] == 0x80;
    +    const bool tar_file   = is_probably_tar_file(buffer);
    +
    +    std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " +
    +                          bytes_to_hex(short_header) +
    +                          ", raw_pickle=" + (raw_pickle ? "true" : "false") +
    +                          ", tar=" + (tar_file ? "true" : "false");
    +    if (raw_pickle) {
    +        message += "; raw pickle did not match the restricted state_dict layouts currently supported";
    +    } else if (tar_file) {
    +        message += "; legacy tar PyTorch checkpoints are not supported yet";
    +    }
    +    return message;
    +}
    +
    +bool read_torch_legacy_file(const std::string& file_path,
    +                            std::vector<TensorStorage>& tensor_storages,
    +                            std::string* error) {
    +    std::ifstream file(file_path, std::ios::binary);
    +    if (!file.is_open()) {
    +        set_error(error, "failed to open '" + file_path + "'");
    +        return false;
    +    }
    +
    +    file.seekg(0, file.end);
    +    size_t file_size = (size_t)file.tellg();
    +    file.seekg(0, file.beg);
    +    if (file_size == 0) {
    +        set_error(error, "empty file '" + file_path + "'");
    +        return false;
    +    }
    +
    +    std::vector<uint8_t> buffer(file_size);
    +    file.read((char*)buffer.data(), file_size);
    +    if (!file) {
    +        set_error(error, "failed to read '" + file_path + "'");
    +        return false;
    +    }
    +
    +    auto finalize_tensor_offsets = [&](size_t storage_data_offset,
    +                                       const std::unordered_map<std::string, uint64_t>& legacy_storage_map) -> bool {
    +        if (storage_data_offset > file_size) {
    +            return false;
    +        }
    +
    +        std::vector<std::string> storage_keys;
    +        storage_keys.reserve(legacy_storage_map.size());
    +        for (const auto& [storage_key, _] : legacy_storage_map) {
    +            storage_keys.push_back(storage_key);
    +        }
    +        std::sort(storage_keys.begin(), storage_keys.end());
    +
    +        std::unordered_map<std::string, uint64_t> storage_offsets;
    +        uint64_t current_offset = storage_data_offset;
    +        for (const auto& storage_key : storage_keys) {
    +            auto it = legacy_storage_map.find(storage_key);
    +            if (it == legacy_storage_map.end()) {
    +                return false;
    +            }
    +            if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) {
    +                return false;
    +            }
    +            storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE;
    +            current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second;
    +        }
    +
    +        for (auto& tensor_storage : tensor_storages) {
    +            if (tensor_storage.storage_key.empty()) {
    +                continue;
    +            }
    +
    +            auto it_offset = storage_offsets.find(tensor_storage.storage_key);
    +            auto it_size   = legacy_storage_map.find(tensor_storage.storage_key);
    +            if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) {
    +                return false;
    +            }
    +
    +            uint64_t base_offset    = it_offset->second;
    +            uint64_t storage_nbytes = it_size->second;
    +            uint64_t tensor_nbytes  = tensor_storage.nbytes_to_read();
    +            if (tensor_storage.offset + tensor_nbytes > storage_nbytes) {
    +                return false;
    +            }
    +
    +            tensor_storage.offset = base_offset + tensor_storage.offset;
    +            tensor_storage.storage_key.clear();
    +        }
    +
    +        return true;
    +    };
    +
    +    auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool {
    +        tensor_storages.clear();
    +        std::unordered_map<std::string, uint64_t> legacy_storage_map;
    +        if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset,
    +                                           state_dict_size,
    +                                           tensor_storages,
    +                                           legacy_storage_map,
    +                                           error)) {
    +            return false;
    +        }
    +
    +        size_t offset_after_state_dict = state_dict_offset + state_dict_size;
    +        size_t storage_keys_size       = 0;
    +        if (!skip_pickle_object(buffer.data() + offset_after_state_dict,
    +                                buffer.size() - offset_after_state_dict,
    +                                &storage_keys_size)) {
    +            return false;
    +        }
    +
    +        *storage_data_offset = offset_after_state_dict + storage_keys_size;
    +        return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map);
    +    };
    +
    +    size_t object_size_1 = 0;
    +    size_t offset        = 0;
    +
    +    if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) &&
    +        pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) {
    +        offset += object_size_1;
    +
    +        size_t object_size_2 = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        uint32_t protocol_version = 0;
    +        if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        offset += object_size_2;
    +
    +        size_t object_size_3 = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +        offset += object_size_3;
    +
    +        size_t state_dict_size = 0;
    +        if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +            return false;
    +        }
    +
    +        size_t storage_data_offset = 0;
    +        if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) {
    +            return true;
    +        }
    +
    +        if (error != nullptr && error->empty()) {
    +            set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +        }
    +        return false;
    +    }
    +
    +    size_t state_dict_size = 0;
    +    if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) {
    +        size_t storage_data_offset = 0;
    +        if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) {
    +            return true;
    +        }
    +    }
    +
    +    if (error != nullptr && error->empty()) {
    +        set_error(error, torch_legacy_diagnostics(file_path, buffer));
    +    }
    +    return false;
    +}
    
  • src/model_io/torch_legacy_io.h+13 0 added
    @@ -0,0 +1,13 @@
    +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    +
    +#include <string>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool read_torch_legacy_file(const std::string& file_path,
    +                            std::vector<TensorStorage>& tensor_storages,
    +                            std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_TORCH_LEGACY_IO_H__
    
  • src/model_io/torch_zip_io.cpp+140 0 added
    @@ -0,0 +1,140 @@
    +#include "torch_zip_io.h"
    +
    +#include <cstdint>
    +#include <cstdlib>
    +#include <string>
    +#include <unordered_map>
    +#include <vector>
    +
    +#include "pickle_io.h"
    +
    +#include "zip.h"
    +
    +static void set_error(std::string* error, const std::string& message) {
    +    if (error != nullptr) {
    +        *error = message;
    +    }
    +}
    +
    +bool is_torch_zip_file(const std::string& file_path) {
    +    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    +    if (zip == nullptr) {
    +        return false;
    +    }
    +    zip_close(zip);
    +    return true;
    +}
    +
    +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) {
    +    size_t n = zip_entries_total(zip);
    +    for (size_t i = 0; i < n; ++i) {
    +        zip_entry_openbyindex(zip, i);
    +        std::string name = zip_entry_name(zip);
    +        if (name == entry_name) {
    +            *index = (int)i;
    +            *size  = zip_entry_size(zip);
    +            zip_entry_close(zip);
    +            return true;
    +        }
    +        zip_entry_close(zip);
    +    }
    +    return false;
    +}
    +
    +static bool parse_zip_data_pkl(const uint8_t* buffer,
    +                               size_t buffer_size,
    +                               zip_t* zip,
    +                               const std::string& dir,
    +                               std::vector<TensorStorage>& tensor_storages,
    +                               std::string* error) {
    +    std::vector<TensorStorage> parsed_tensors;
    +    std::unordered_map<std::string, uint64_t> storage_nbytes;
    +    if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) {
    +        if (error != nullptr && error->empty()) {
    +            *error = "failed to parse torch zip pickle metadata";
    +        }
    +        return false;
    +    }
    +
    +    for (auto& tensor_storage : parsed_tensors) {
    +        if (tensor_storage.storage_key.empty()) {
    +            set_error(error, "tensor '" + tensor_storage.name + "' has no storage key");
    +            return false;
    +        }
    +
    +        const std::string entry_name = dir + "data/" + tensor_storage.storage_key;
    +        int zip_index                = -1;
    +        uint64_t entry_size          = 0;
    +        if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) {
    +            set_error(error, "storage entry '" + entry_name + "' was not found");
    +            return false;
    +        }
    +
    +        auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key);
    +        if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) {
    +            set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata");
    +            return false;
    +        }
    +
    +        uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
    +        if (tensor_storage.offset + tensor_nbytes > entry_size) {
    +            set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'");
    +            return false;
    +        }
    +
    +        tensor_storage.index_in_zip = zip_index;
    +        tensor_storage.storage_key.clear();
    +        tensor_storages.push_back(tensor_storage);
    +    }
    +
    +    return true;
    +}
    +
    +bool read_torch_zip_file(const std::string& file_path,
    +                         std::vector<TensorStorage>& tensor_storages,
    +                         std::string* error) {
    +    zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
    +    if (zip == nullptr) {
    +        set_error(error, "failed to open '" + file_path + "'");
    +        return false;
    +    }
    +
    +    tensor_storages.clear();
    +    bool success        = true;
    +    bool found_data_pkl = false;
    +    int n               = (int)zip_entries_total(zip);
    +    for (int i = 0; i < n; ++i) {
    +        zip_entry_openbyindex(zip, i);
    +        std::string name = zip_entry_name(zip);
    +        size_t pos       = name.find("data.pkl");
    +        if (pos != std::string::npos) {
    +            found_data_pkl  = true;
    +            std::string dir = name.substr(0, pos);
    +            void* pkl_data  = nullptr;
    +            size_t pkl_size = 0;
    +            zip_entry_read(zip, &pkl_data, &pkl_size);
    +
    +            if (pkl_data == nullptr || pkl_size == 0) {
    +                set_error(error, "failed to read '" + name + "' from '" + file_path + "'");
    +                success = false;
    +            } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
    +                success = false;
    +            }
    +
    +            free(pkl_data);
    +        }
    +        zip_entry_close(zip);
    +
    +        if (!success) {
    +            break;
    +        }
    +    }
    +
    +    if (success && !found_data_pkl) {
    +        set_error(error, "data.pkl was not found in '" + file_path + "'");
    +        success = false;
    +    }
    +
    +    zip_close(zip);
    +    return success;
    +}
    
  • src/model_io/torch_zip_io.h+14 0 added
    @@ -0,0 +1,14 @@
    +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__
    +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__
    +
    +#include <string>
    +#include <vector>
    +
    +#include "tensor_storage.h"
    +
    +bool is_torch_zip_file(const std::string& file_path);
    +bool read_torch_zip_file(const std::string& file_path,
    +                         std::vector<TensorStorage>& tensor_storages,
    +                         std::string* error = nullptr);
    +
    +#endif  // __SD_MODEL_IO_TORCH_ZIP_IO_H__
    

Vulnerability mechanics

Root cause

"Sign confusion in SHORT_BINUNICODE length parsing: a one-byte length read as signed int8_t (0xFF → -1) passes the length check but is cast to size_t (SIZE_MAX) for memcpy, causing heap buffer overflow."

Attack vector

An attacker crafts a malicious .ckpt checkpoint file containing a SHORT_BINUNICODE opcode whose one-byte length field is set to `0xFF`. When the affected pickle parser in `src/model.cpp` reads this byte as a signed `int8_t`, it becomes `-1`. The sanity check `if (len > 512)` passes because `-1 < 512` as a signed integer. The length is then passed to `memcpy` where it is implicitly cast to `size_t`, producing `SIZE_MAX` and causing a heap buffer overflow (CWE-190/ CWE-122). The attack requires the victim to load the crafted .ckpt file, e.g. by downloading a model from a model-sharing site and running it with stable-diffusion.cpp. [ref_id=2]

What the fix does

The commit [patch_id=6192985] replaces the old inline pickle parser in `src/model.cpp` with a new restricted Torch legacy checkpoint loader in `src/model_io/pickle_io.cpp`. The new handler for `SHORT_BINUNICODE` (opcode `0x8C`) reads the one-byte length into a `uint8_t` variable (`p[0]`) and advances by `1 + p[0]`, never performing a signed-to-unsigned conversion that could produce `SIZE_MAX`. This eliminates the sign-confusion bug entirely. The old vulnerable `init_from_ckpt_file` function is removed and the model loader now refuses untrusted .ckpt files by default unless they match the Torch zip format. [ref_id=1][ref_id=2]

Preconditions

  • inputThe victim or application must load a .ckpt checkpoint file from an untrusted source
  • inputAttacker must control the file contents to set the SHORT_BINUNICODE length byte to 0xFF

Reproduction

Build with AddressSanitizer: `cmake -B build -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fsanitize=address -fno-omit-frame-pointer" -DCMAKE_C_FLAGS="-fsanitize=address -fno-omit-frame-pointer"`. Run: `./build/bin/sd-cli -m critical2_short_binunicode.ckpt -p "test"` using the PoC file `critical2_short_binunicode.ckpt`. ASAN will report `negative-size-param: (size=-1)`. [ref_id=2]

Generated on Jun 16, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

3

News mentions

0

No linked articles in our index yet.