VYPR
High severityNVD Advisory· Published Nov 5, 2021· Updated Aug 4, 2024

Missing validation during checkpoint loading

CVE-2021-41203

Description

TensorFlow is an open source platform for machine learning. In affected versions an attacker can trigger undefined behavior, integer overflows, segfaults and CHECK-fail crashes if they can change saved checkpoints from outside of TensorFlow. This is because the checkpoints loading infrastructure is missing validation for invalid file formats. The fixes will be included in TensorFlow 2.7.0. We will also cherrypick these commits on TensorFlow 2.6.1, TensorFlow 2.5.2, and TensorFlow 2.4.4, as these are also affected and still in supported range.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
tensorflowPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflowPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflowPyPI
< 2.4.42.4.4
tensorflow-cpuPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflow-cpuPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflow-cpuPyPI
< 2.4.42.4.4
tensorflow-gpuPyPI
>= 2.6.0, < 2.6.12.6.1
tensorflow-gpuPyPI
>= 2.5.0, < 2.5.22.5.2
tensorflow-gpuPyPI
< 2.4.42.4.4

Affected products

1

Patches

4
368af875869a

Avoid buffer overflow when loading tensors with insufficient data from checkpoints.

https://github.com/tensorflow/tensorflowA. Unique TensorFlowerAug 25, 2021via ghsa
3 files changed · +69 0
  • tensorflow/core/util/saved_tensor_slice_util.h+26 0 modified
    @@ -59,6 +59,9 @@ Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
     template <typename T>
     struct SaveTypeTraits;
     
    +template <typename T>
    +int TensorProtoDataSize(const TensorProto& t);
    +
     template <typename T>
     const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
         const TensorProto& t);
    @@ -95,6 +98,10 @@ void Fill(T* data, size_t n, TensorProto* t);
     #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE)             \
       TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE)     \
       template <>                                                     \
    +  inline int TensorProtoDataSize<TYPE>(const TensorProto& t) {    \
    +    return t.FIELD##_val_size();                                  \
    +  }                                                               \
    +  template <>                                                     \
       inline void Fill(const TYPE* data, size_t n, TensorProto* t) {  \
         typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
         t->mutable_##FIELD##_val()->Swap(&copy);                      \
    @@ -104,6 +111,10 @@ void Fill(T* data, size_t n, TensorProto* t);
     #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE)       \
       TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE)        \
       template <>                                                       \
    +  inline int TensorProtoDataSize<TYPE>(const TensorProto& t) {      \
    +    return t.FIELD##_val_size() / 2;                                \
    +  }                                                                 \
    +  template <>                                                       \
       inline void Fill(const TYPE* data, size_t n, TensorProto* t) {    \
         const FTYPE* sub = reinterpret_cast<const FTYPE*>(data);        \
         typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
    @@ -136,6 +147,11 @@ TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
     template <>
     struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
     
    +template <>
    +inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
    +  return t.int_val_size();
    +}
    +
     template <>
     inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
       static_assert(SaveTypeTraits<qint32>::supported,
    @@ -158,6 +174,11 @@ struct SaveTypeTraits<Eigen::half> {
       typedef protobuf::RepeatedField<int32> RepeatedField;
     };
     
    +template <>
    +inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
    +  return t.half_val_size();
    +}
    +
     template <>
     inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
       return t.half_val().data();
    @@ -187,6 +208,11 @@ struct SaveTypeTraits<tstring> {
       typedef protobuf::RepeatedPtrField<string> RepeatedField;
     };
     
    +template <>
    +inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
    +  return t.string_val_size();
    +}
    +
     template <>
     inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
       static_assert(SaveTypeTraits<tstring>::supported,
    
  • tensorflow/core/util/tensor_slice_reader.h+16 0 modified
    @@ -181,6 +181,22 @@ bool TensorSliceReader::CopySliceData(const string& name,
                   << slice_s.DebugString() << ": computed key = " << key;
           return false;
         }
    +    // Ensure the TensorSlice contains the expected amount of data.
    +    TensorShape shp_s;
    +    Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s);
    +    if (!s.ok()) {
    +      VLOG(1) << "Failed to slice tensor " << name << ", slice "
    +              << slice_s.DebugString() << ": " << s;
    +      return false;
    +    }
    +    if (checkpoint::TensorProtoDataSize<T>(sts.data().data()) !=
    +        shp_s.num_elements()) {
    +      VLOG(1) << "Tensor " << name << ", slice " << slice_s.DebugString()
    +              << " had an unexpected amount of data: expected = "
    +              << shp_s.num_elements() << ", got = "
    +              << checkpoint::TensorProtoDataSize<T>(sts.data().data());
    +      return false;
    +    }
         CopyDataFromTensorSliceToTensorSlice(
             tss->shape(), slice_s, slice,
             checkpoint::TensorProtoData<T>(sts.data().data()), data);
    
  • tensorflow/core/util/tensor_slice_reader_test.cc+27 0 modified
    @@ -459,6 +459,33 @@ TEST(TensorSliceReaderTest, InvalidTensorSlice) {
       EXPECT_FALSE(reader.status().ok());
     }
     
    +TEST(TensorSliceReaderTest, MissingTensorData) {
    +  const string fname =
    +      io::JoinPath(testing::TmpDir(), "missing_data_checkpoint");
    +  TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
    +  const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    +  TF_ASSERT_OK(writer.Add("test", TensorShape({4, 5}),
    +                          TensorSlice::ParseOrDie("0,2:-"), data));
    +  TF_ASSERT_OK(writer.Finish());
    +
    +  MutateSavedTensorSlices(fname, [&](SavedTensorSlices sts) {
    +    if (sts.has_data()) {
    +      // Replace the data with only 4 elements.
    +      Fill(data, 4, sts.mutable_data()->mutable_data());
    +    }
    +    return sts.SerializeAsString();
    +  });
    +
    +  TensorSliceReader reader(fname, OpenTableTensorSliceReader);
    +  TF_ASSERT_OK(reader.status());
    +
    +  // The tensor should be present, but loading it should fail due to the missing
    +  // data.
    +  EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
    +  std::unique_ptr<Tensor> tensor;
    +  EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
    +}
    +
     void CachedTensorSliceReaderTesterHelper(
         const TensorSliceWriter::CreateBuilderFunction& create_function,
         const TensorSliceReader::OpenTableFunction& open_function) {
    
e8dc63704c88

Add BuildTensorSlice for building from unvalidated TensorSliceProtos.

https://github.com/tensorflow/tensorflowA. Unique TensorFlowerAug 25, 2021via ghsa
5 files changed · +107 1
  • tensorflow/core/framework/tensor_slice.cc+31 0 modified
    @@ -14,7 +14,10 @@ limitations under the License.
     ==============================================================================*/
     
     #include "tensorflow/core/framework/tensor_slice.h"
    +
    +#include <limits>
     #include <vector>
    +
     #include "tensorflow/core/lib/core/errors.h"
     #include "tensorflow/core/lib/strings/numbers.h"
     #include "tensorflow/core/lib/strings/str_util.h"
    @@ -44,6 +47,34 @@ TensorSlice::TensorSlice(
       }
     }
     
    +Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto,
    +                                     TensorSlice* output) {
    +  output->Clear();
    +  output->starts_.reserve(proto.extent_size());
    +  output->lengths_.reserve(proto.extent_size());
    +  for (const auto& e : proto.extent()) {
    +    int64_t l = GetExtentLength(e);
    +    if (e.start() != 0 || l != kFullExtent) {
    +      if (e.start() < 0 || l <= 0) {
    +        return errors::InvalidArgument(
    +            "Expected non-negative start and positive length but got start = ",
    +            e.start(), ", length = ", l, ": extent = ", e.ShortDebugString());
    +      }
    +      // Calculating the extent end must not cause signed integer overflow.
    +      if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) >
    +          std::numeric_limits<int64_t>::max()) {
    +        return errors::InvalidArgument(
    +            "Extent end exceeds the maximum possible size: extent = ",
    +            e.ShortDebugString());
    +      }
    +    }
    +    output->starts_.push_back(e.start());
    +    output->lengths_.push_back(l);
    +  }
    +
    +  return Status::OK();
    +}
    +
     Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
       std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
       slice->starts_.reserve(items.size());
    
  • tensorflow/core/framework/tensor_slice.h+6 0 modified
    @@ -48,6 +48,12 @@ class TensorSlice {
       explicit TensorSlice(
           std::initializer_list<std::pair<int64_t, int64_t>> extents);
     
    +  // This factory methods should be used instead of the constructor that takes a
    +  // `TensorSliceProto` if calling code cannot validate that the sizes specify a
    +  // valid `TensorSlice`.
    +  static Status BuildTensorSlice(const TensorSliceProto& proto,
    +                                 TensorSlice* output);
    +
       static Status Parse(const string& str, TensorSlice* output);
       static TensorSlice ParseOrDie(const string& str) {
         TensorSlice ret;
    
  • tensorflow/core/framework/tensor_slice_test.cc+44 0 modified
    @@ -15,6 +15,8 @@ limitations under the License.
     
     #include "tensorflow/core/framework/tensor_slice.h"
     
    +#include <limits>
    +
     #include "tensorflow/core/lib/core/status_test_util.h"
     #include "tensorflow/core/platform/logging.h"
     #include "tensorflow/core/platform/protobuf.h"
    @@ -125,6 +127,48 @@ TEST(TensorSliceTest, Serialization) {
       }
     }
     
    +// Testing `BuildTensorSlice` with valid and invalid input protos.
    +TEST(TensorSliceTest, BuildTensorSlice) {
    +  TensorSliceProto proto;
    +  TensorSlice({{0, -1}, {0, 10}, {14, 1}}).AsProto(&proto);
    +  TensorSlice s;
    +
    +  // Successful building.
    +  {
    +    TF_ASSERT_OK(TensorSlice::BuildTensorSlice(proto, &s));
    +    EXPECT_EQ("-:0,10:14,1", s.DebugString());
    +  }
    +
    +  // Failed building due to negative extent start.
    +  {
    +    TensorSliceProto invalid_proto = proto;
    +    invalid_proto.mutable_extent(0)->set_start(-1);
    +    EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
    +  }
    +
    +  // Failed building due to negative extent length.
    +  {
    +    TensorSliceProto invalid_proto = proto;
    +    invalid_proto.mutable_extent(2)->set_length(-1);
    +    EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
    +  }
    +
    +  // Failed building due to missing extent length.
    +  {
    +    TensorSliceProto invalid_proto = proto;
    +    invalid_proto.mutable_extent(2)->clear_length();
    +    EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
    +  }
    +
    +  // Failed building due to extent end overflowing.
    +  {
    +    TensorSliceProto invalid_proto = proto;
    +    invalid_proto.mutable_extent(2)->set_length(
    +        std::numeric_limits<int64_t>::max());
    +    EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
    +  }
    +}
    +
     // Testing the slice intersection
     TEST(TensorSliceTest, Intersection) {
       // "EVERYTHING" intersects with everything
    
  • tensorflow/core/util/tensor_slice_reader.cc+3 1 modified
    @@ -172,7 +172,9 @@ void TensorSliceReader::LoadShard(int shard) const {
         status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
         if (!status_.ok()) return;
         for (const TensorSliceProto& tsp : ssm.slice()) {
    -      TensorSlice ss_slice(tsp);
    +      TensorSlice ss_slice;
    +      status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice);
    +      if (!status_.ok()) return;
           status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
                                         ss_slice, &tensors_);
           if (!status_.ok()) return;
    
  • tensorflow/core/util/tensor_slice_reader_test.cc+23 0 modified
    @@ -436,6 +436,29 @@ TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
       EXPECT_FALSE(reader.status().ok());
     }
     
    +TEST(TensorSliceReaderTest, InvalidTensorSlice) {
    +  const string fname =
    +      io::JoinPath(testing::TmpDir(), "invalid_slice_checkpoint");
    +  TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
    +  const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    +  TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
    +                         TensorSlice::ParseOrDie("0,2:-"), data));
    +  TF_CHECK_OK(writer.Finish());
    +
    +  MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
    +    if (sts.has_meta()) {
    +      for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
    +        tensor.mutable_slice(0)->mutable_extent(0)->set_length(-10);
    +      }
    +    }
    +    return sts.SerializeAsString();
    +  });
    +
    +  TensorSliceReader reader(fname, OpenTableTensorSliceReader);
    +  // The negative exent length should cause loading to fail.
    +  EXPECT_FALSE(reader.status().ok());
    +}
    +
     void CachedTensorSliceReaderTesterHelper(
         const TensorSliceWriter::CreateBuilderFunction& create_function,
         const TensorSliceReader::OpenTableFunction& open_function) {
    
b619c6f86571

Use BuildTensorShapeBase when parsing unverified TensorShapes during checkpoint loading.

https://github.com/tensorflow/tensorflowA. Unique TensorFlowerAug 24, 2021via ghsa
2 files changed · +29 1
  • tensorflow/core/util/tensor_slice_reader.cc+3 1 modified
    @@ -168,7 +168,9 @@ void TensorSliceReader::LoadShard(int shard) const {
                               "checkpoint");
       if (!status_.ok()) return;
       for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
    -    TensorShape ssm_shape(ssm.shape());
    +    TensorShape ssm_shape;
    +    status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
    +    if (!status_.ok()) return;
         for (const TensorSliceProto& tsp : ssm.slice()) {
           TensorSlice ss_slice(tsp);
           status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
    
  • tensorflow/core/util/tensor_slice_reader_test.cc+26 0 modified
    @@ -18,6 +18,7 @@ limitations under the License.
     #include <utility>
     #include <vector>
     
    +#include "tensorflow/core/framework/tensor_shape.pb.h"
     #include "tensorflow/core/framework/types.h"
     #include "tensorflow/core/framework/versions.pb.h"
     #include "tensorflow/core/lib/core/status_test_util.h"
    @@ -410,6 +411,31 @@ TEST(TensorSliceReaderTest, UnsupportedTensorType) {
       EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
     }
     
    +TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
    +  const string fname =
    +      io::JoinPath(testing::TmpDir(), "negative_dim_checkpoint");
    +  TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
    +  const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    +  TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
    +                         TensorSlice::ParseOrDie("0,2:-"), data));
    +  TF_CHECK_OK(writer.Finish());
    +
    +  MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
    +    if (sts.has_meta()) {
    +      for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
    +        for (auto& dim : *tensor.mutable_shape()->mutable_dim()) {
    +          dim.set_size(-dim.size());
    +        }
    +      }
    +    }
    +    return sts.SerializeAsString();
    +  });
    +
    +  TensorSliceReader reader(fname, OpenTableTensorSliceReader);
    +  // The negative dimension should cause loading to fail.
    +  EXPECT_FALSE(reader.status().ok());
    +}
    +
     void CachedTensorSliceReaderTesterHelper(
         const TensorSliceWriter::CreateBuilderFunction& create_function,
         const TensorSliceReader::OpenTableFunction& open_function) {
    
abcced051cb1

Prevent crashes when loading tensor slices with unsupported types.

https://github.com/tensorflow/tensorflowA. Unique TensorFlowerAug 24, 2021via ghsa
5 files changed · +134 11
  • tensorflow/core/framework/BUILD+1 0 modified
    @@ -835,6 +835,7 @@ tf_cuda_library(
             "//tensorflow/core/lib/strings:str_util",
             "//tensorflow/core/lib/strings:strcat",
             "//tensorflow/core/platform:abi",
    +        "//tensorflow/core/platform:errors",
             "//tensorflow/core/platform:logging",
             "//tensorflow/core/platform:macros",
             "//tensorflow/core/platform:platform_port",
    
  • tensorflow/core/framework/tensor.cc+18 8 modified
    @@ -52,6 +52,7 @@ limitations under the License.
     #include "tensorflow/core/lib/gtl/inlined_vector.h"
     #include "tensorflow/core/lib/strings/str_util.h"
     #include "tensorflow/core/lib/strings/strcat.h"
    +#include "tensorflow/core/platform/errors.h"
     #include "tensorflow/core/platform/logging.h"
     #include "tensorflow/core/platform/macros.h"
     #include "tensorflow/core/platform/protobuf.h"
    @@ -723,11 +724,11 @@ bool Tensor::RefCountIsOne() const {
     // The macro CASES() expands to a switch statement conditioned on
     // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
     #define SINGLE_ARG(...) __VA_ARGS__
    -#define CASE(TYPE, STMTS)             \
    -  case DataTypeToEnum<TYPE>::value: { \
    -    typedef TYPE T;                   \
    -    STMTS;                            \
    -    break;                            \
    +#define CASE(TYPE, STMTS)               \
    +  case DataTypeToEnum<TYPE>::value: {   \
    +    typedef TF_ATTRIBUTE_UNUSED TYPE T; \
    +    STMTS;                              \
    +    break;                              \
       }
     #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
       switch (TYPE_ENUM) {                                         \
    @@ -763,9 +764,8 @@ bool Tensor::RefCountIsOne() const {
       }
     
     #define CASES(TYPE_ENUM, STMTS)                                      \
    -  CASES_WITH_DEFAULT(TYPE_ENUM, STMTS,                               \
    -                     LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
    -                     , LOG(FATAL) << "Type not set";)
    +  CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
    +                     , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
     
     Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
         : shape_(shape), buf_(nullptr) {
    @@ -795,6 +795,16 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
       }
     }
     
    +Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
    +                           Tensor* out_tensor) {
    +  // Avoid crashes due to invalid or unsupported types.
    +  CASES_WITH_DEFAULT(
    +      type, {}, return errors::InvalidArgument("Type not set"),
    +      return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
    +  *out_tensor = Tensor(type, shape);
    +  return Status::OK();
    +}
    +
     // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
     // the default CPU allocator for NUMA zone 0. Accessing that currently involves
     // acquiring a lock, which guards initialization of the per-NUMA zone
    
  • tensorflow/core/framework/tensor.h+9 0 modified
    @@ -170,6 +170,15 @@ class Tensor {
       /// for details.
       explicit Tensor(DataType type);
     
    +  /// \brief Initializes a tensor with the input `type` and `shape`, or returns
    +  /// an error and leaves `out_tensor` unmodified. This factory method should be
    +  /// used instead of the corresponding constructor if calling code cannot
    +  /// validate that the `DataType` is valid and supported.
    +  ///
    +  /// The underlying buffer is allocated using a `CPUAllocator`.
    +  static Status BuildTensor(DataType type, const TensorShape& shape,
    +                            Tensor* out_tensor);
    +
      private:
       // A tag type for selecting the `Tensor` constructor overload that creates a
       // scalar tensor in host memory.
    
  • tensorflow/core/util/tensor_slice_reader.cc+3 1 modified
    @@ -248,7 +248,9 @@ Status TensorSliceReader::GetTensor(
         slice = tss->Slices().begin()->second.slice;
       }
     
    -  std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
    +  std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
    +  Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
    +  if (!s.ok()) return s;
       bool success = false;
     
     #define READER_COPY(dt)                                                  \
    
  • tensorflow/core/util/tensor_slice_reader_test.cc+103 2 modified
    @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
     limitations under the License.
     ==============================================================================*/
     
    -#include <utility>
    -
     #include "tensorflow/core/util/tensor_slice_reader.h"
     
    +#include <utility>
    +#include <vector>
    +
     #include "tensorflow/core/framework/types.h"
     #include "tensorflow/core/framework/versions.pb.h"
     #include "tensorflow/core/lib/core/status_test_util.h"
     #include "tensorflow/core/lib/core/stringpiece.h"
    +#include "tensorflow/core/lib/io/iterator.h"
     #include "tensorflow/core/lib/io/path.h"
    +#include "tensorflow/core/lib/io/table.h"
    +#include "tensorflow/core/lib/io/table_builder.h"
     #include "tensorflow/core/lib/strings/str_util.h"
     #include "tensorflow/core/lib/strings/strcat.h"
     #include "tensorflow/core/platform/env.h"
    @@ -30,6 +34,7 @@ limitations under the License.
     #include "tensorflow/core/platform/test.h"
     #include "tensorflow/core/platform/types.h"
     #include "tensorflow/core/public/version.h"
    +#include "tensorflow/core/util/saved_tensor_slice.pb.h"
     #include "tensorflow/core/util/saved_tensor_slice_util.h"
     #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     #include "tensorflow/core/util/tensor_slice_writer.h"
    @@ -309,6 +314,102 @@ TEST_SIMPLE_INT(int16, int32)
     TEST_SIMPLE_INT(int8, int32)
     TEST_SIMPLE_INT(uint8, int32)
     
    +// Modifies the SavedTensorSlices messages in a checkpoint to allow creating
    +// malformed or unsupported checkpoints.
    +void MutateSavedTensorSlices(
    +    const std::string& fname,
    +    const std::function<std::string(SavedTensorSlices)>& mutator) {
    +  table::Options options;
    +  options.compression = table::kNoCompression;
    +
    +  // Read all entres from the table.
    +  std::vector<std::pair<std::string, std::string>> entries;
    +  {
    +    std::unique_ptr<RandomAccessFile> file;
    +    TF_CHECK_OK(Env::Default()->NewRandomAccessFile(fname, &file));
    +    uint64 file_size;
    +    TF_CHECK_OK(Env::Default()->GetFileSize(fname, &file_size));
    +    table::Table* t;
    +    TF_CHECK_OK(table::Table::Open(options, file.get(), file_size, &t));
    +    std::unique_ptr<table::Table> table(t);
    +    std::unique_ptr<table::Iterator> it(table->NewIterator());
    +    for (it->Seek(""); it->Valid(); it->Next()) {
    +      entries.emplace_back(it->key(), it->value());
    +    }
    +    TF_CHECK_OK(it->status());
    +  }
    +
    +  // Rewrite the table, mutating each value.
    +  {
    +    std::unique_ptr<WritableFile> file;
    +    TF_CHECK_OK(Env::Default()->NewWritableFile(fname, &file));
    +    table::TableBuilder builder(options, file.get());
    +    for (const auto& entry : entries) {
    +      SavedTensorSlices sts;
    +      CHECK(sts.ParseFromString(entry.second));
    +      builder.Add(entry.first, mutator(std::move(sts)));
    +    }
    +    TF_CHECK_OK(builder.Finish());
    +    TF_CHECK_OK(file->Close());
    +  }
    +}
    +
    +TEST(TensorSliceReaderTest, MissingTensorType) {
    +  const string fname = io::JoinPath(testing::TmpDir(), "invalid_checkpoint");
    +  TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
    +  const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    +  TensorShape shape({4, 5});
    +  TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
    +  TF_CHECK_OK(writer.Add("test", shape, slice, data));
    +  TF_CHECK_OK(writer.Finish());
    +
    +  MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
    +    if (sts.has_meta()) {
    +      for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
    +        tensor.clear_type();
    +      }
    +    }
    +    return sts.SerializeAsString();
    +  });
    +
    +  TensorSliceReader reader(fname, OpenTableTensorSliceReader);
    +  TF_CHECK_OK(reader.status());
    +
    +  // The tensor should be present, but loading it should fail due to the
    +  // unset (invalid) type.
    +  EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
    +  std::unique_ptr<Tensor> tensor;
    +  EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
    +}
    +
    +TEST(TensorSliceReaderTest, UnsupportedTensorType) {
    +  const string fname = io::JoinPath(testing::TmpDir(), "int32_ref_checkpoint");
    +  TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
    +  const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    +  TensorShape shape({4, 5});
    +  TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
    +  TF_CHECK_OK(writer.Add("test", shape, slice, data));
    +  TF_CHECK_OK(writer.Finish());
    +
    +  MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
    +    if (sts.has_meta()) {
    +      for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
    +        tensor.set_type(DT_INT32_REF);
    +      }
    +    }
    +    return sts.SerializeAsString();
    +  });
    +
    +  TensorSliceReader reader(fname, OpenTableTensorSliceReader);
    +  TF_CHECK_OK(reader.status());
    +
    +  // The tensor should be present, but loading it should fail due to the
    +  // unsupported type.
    +  EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
    +  std::unique_ptr<Tensor> tensor;
    +  EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
    +}
    +
     void CachedTensorSliceReaderTesterHelper(
         const TensorSliceWriter::CreateBuilderFunction& create_function,
         const TensorSliceReader::OpenTableFunction& open_function) {
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

10

News mentions

0

No linked articles in our index yet.