High severityNVD Advisory· Published Mar 24, 2023· Updated Feb 19, 2025
TensorFlow vulnerable to Out-of-Bounds Read in GRUBlockCellGrad
CVE-2023-25658
Description
TensorFlow is an open source platform for machine learning. Prior to versions 2.12.0 and 2.11.1, an out of bounds read is in GRUBlockCellGrad. A fix is included in TensorFlow 2.12.0 and 2.11.1.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
tensorflowPyPI | < 2.11.1 | 2.11.1 |
tensorflow-cpuPyPI | < 2.11.1 | 2.11.1 |
tensorflow-gpuPyPI | < 2.11.1 | 2.11.1 |
Affected products
1- Range: < 2.11.1
Patches
1ff459137c271Merged commit includes the following changes:
81 files changed · +472 −484
.bazelrc+0 −1 modified@@ -242,7 +242,6 @@ build:mkl_aarch64 -c opt # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). # with Eigen threadpool support build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true -build:mkl_aarch64_threadpool --define=build_with_acl=true build:mkl_aarch64_threadpool -c opt # This config refers to building CUDA op kernels with nvcc.
tensorflow/compiler/jit/BUILD+1 −0 modified@@ -411,6 +411,7 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", + "//learning/brain/tfrt/tpu_plugin:__pkg__", "//learning/brain/tfrt/tpu_common:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ],
tensorflow/compiler/xla/backends/interpreter/BUILD+1 −1 modified@@ -138,7 +138,7 @@ cc_library( ":executor", ":platform_id", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/compiler/xla/stream_executor/platform", "//tensorflow/tsl/platform:status", "@com_google_absl//absl/strings:str_format", ],
tensorflow/compiler/xla/backends/interpreter/platform.cc+1 −1 modified@@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/backends/interpreter/executor.h" #include "tensorflow/compiler/xla/stream_executor/device_options.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/tsl/platform/status.h" namespace stream_executor {
tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc+1 −1 modified@@ -222,7 +222,7 @@ Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) { absl::little_endian::Store32( const_cast<char*>(error_payload.data()), static_cast<uint32_t>(EvalErrorDetail::kDynamicValueDependence)); - error.SetPayload(kEvalErrorDetailUrl, error_payload); + error.SetPayload(kEvalErrorDetailUrl, absl::Cord(error_payload)); return error; }
tensorflow/compiler/xla/mlir/tools/mlir_replay/BUILD+1 −0 modified@@ -9,6 +9,7 @@ xla_cc_binary( "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc_impl", "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc", "//tensorflow/compiler/xla/mlir_hlo:gml_st", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration",
tensorflow/compiler/xla/service/hlo_graph_dumper.cc+4 −0 modified@@ -1251,6 +1251,10 @@ ExtractGemmBackendConfigProps(const gpu::GemmBackendConfig& config, if (config.algorithm_case() == gpu::GemmBackendConfig::kSelectedAlgorithm) { props.emplace_back("algorithm", StrCat(config.selected_algorithm())); } + if (config.epilogue() != gpu::GemmBackendConfig::DEFAULT) { + props.emplace_back( + "epilogue", gpu::GemmBackendConfig::Epilogue_Name(config.epilogue())); + } return props; }
tensorflow/compiler/xla/stream_executor/BUILD+3 −1 modified@@ -239,6 +239,7 @@ cc_library( ":stream_executor_headers", "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", @@ -450,6 +451,7 @@ tsl_gpu_library( "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/stream_executor/platform", "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:stacktrace", "//tensorflow/tsl/platform:status", @@ -661,7 +663,7 @@ cc_library( ":platform", ":plugin", ":stream_executor_headers", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers",
tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc+1 −1 modified@@ -63,7 +63,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc+58 −59 modified@@ -36,8 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h" #include "tensorflow/compiler/xla/stream_executor/dnn.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" @@ -85,7 +84,7 @@ static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); std::ostringstream oss; \ oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \ << __LINE__ << "): '" << #expr << "'"; \ - return tsl::Status(port::error::UNKNOWN, oss.str()); \ + return tsl::Status(tsl::error::UNKNOWN, oss.str()); \ } \ } while (false) @@ -96,7 +95,7 @@ static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); std::ostringstream oss; \ oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \ << __LINE__ << "): '" << #expr << "' " << (expr).get_error(); \ - return tsl::Status(port::error::UNKNOWN, oss.str()); \ + return tsl::Status(tsl::error::UNKNOWN, oss.str()); \ } \ } while (false) @@ -417,7 +416,7 @@ tsl::Status CudnnSupport::Init() { "configuration."); LOG(ERROR) << error; cudnnDestroy(cudnn_handle); - return tsl::Status(port::error::INTERNAL, error); + return tsl::Status(tsl::error::INTERNAL, error); } cudnn_.reset(new CudnnAccess(cudnn_handle)); @@ -441,7 +440,7 @@ tsl::Status CudnnSupport::Init() { } } - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, absl::StrCat("cudnn library could not create a handle: ", CudnnStatusToString(status))); } @@ -1299,7 +1298,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { ? algorithm_config.algorithm()->tensor_ops_enabled() : allow_tensor_ops; if (use_tensor_ops && !allow_tensor_ops) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Algo requests disallowed tensor op evaluation."); } @@ -1658,7 +1657,7 @@ class CudnnRnnSequenceTensorDescriptor GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { if (max_seq_length <= 0) { - return tsl::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "max_seq_length <= 0"); } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; @@ -1677,7 +1676,7 @@ class CudnnRnnSequenceTensorDescriptor const absl::Span<const int>& seq_lengths, bool time_major, cudnnDataType_t data_type) { if (max_seq_length <= 0) { - return tsl::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "max_seq_length <= 0"); } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; @@ -1804,30 +1803,30 @@ tsl::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( model_dims.num_layers * model_dims.dir_count && input_h_desc.batch_size() == model_dims.batch_size && input_h_desc.data_size() == model_dims.hidden_size)) { - return tsl::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Invalid input_h shape"); } // The LSTM projection will be used if input_h_desc.data_size() < // input_c_desc.data_size() if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && input_h_desc.data_size() <= input_c_desc.data_size())) { - return tsl::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Invalid input_c shape"); } if (!(output_desc.max_seq_length() == model_dims.max_seq_length && output_desc.batch_size() == model_dims.batch_size && output_desc.data_size() == model_dims.hidden_size * model_dims.dir_count)) { - return tsl::Status(port::error::INVALID_ARGUMENT, "Invalid output shape"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Invalid output shape"); } if (!(input_h_desc.num_layers() == output_h_desc.num_layers() && input_h_desc.batch_size() == output_h_desc.batch_size() && input_h_desc.data_size() == output_h_desc.data_size())) { - return tsl::Status(port::error::INVALID_ARGUMENT, "Invalid output_h shape"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Invalid output_h shape"); } if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && input_h_desc.batch_size() == output_c_desc.batch_size() && input_h_desc.data_size() <= output_c_desc.data_size())) { - return tsl::Status(port::error::INVALID_ARGUMENT, "Invalid output_c shape"); + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Invalid output_c shape"); } return model_dims; @@ -1849,7 +1848,7 @@ tsl::Status CheckRNNParameterSize( #endif if (static_cast<int64_t>(params_size_in_bytes) != rnn_desc.ParamsSizeInBytes()) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Mismatching RNN parameter size"); } return ::tsl::OkStatus(); @@ -1997,7 +1996,7 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -2020,7 +2019,7 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -2058,7 +2057,7 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -2130,7 +2129,7 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -2204,7 +2203,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -2253,7 +2252,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -2275,7 +2274,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -2362,7 +2361,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -2404,7 +2403,7 @@ tsl::Status CudnnSupport::DoCtcLossImpl( /*workspace=*/scratch_memory.opaque(), /*workSpaceSizeInBytes=*/scratch_memory.size())); #else - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "No supported cudnnCTCLoss when " "CUDNN_VERSION < 7.6.3"); #endif @@ -2786,7 +2785,7 @@ tsl::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo( return perf_results[r].algo; } } - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, "cudnnGetConvolutionForwardAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); #else @@ -2828,7 +2827,7 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, return perf_results[r].algo; } } - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); #else @@ -2870,7 +2869,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, return perf_results[r].algo; } } - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned " "no suitable algorithms. This could be a cudnn bug."); #else @@ -2895,7 +2894,7 @@ tsl::StatusOr<DeviceMemory<uint8_t>> AllocateCudnnConvolutionForwardWorkspace( ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "Mismatch between cudnn conv and algorithm descriptors."); } @@ -2917,7 +2916,7 @@ tsl::StatusOr<DeviceMemory<uint8_t>> AllocateCudnnConvolutionForwardWorkspace( if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "cudnnGetConvolutionForwardWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -2927,7 +2926,7 @@ tsl::StatusOr<DeviceMemory<uint8_t>> AllocateCudnnConvolutionForwardWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "No scratch allocator provided"); } @@ -2944,7 +2943,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "Mismatch between cudnn conv and algorithm descriptors."); } @@ -2967,7 +2966,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -2977,7 +2976,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "No scratch allocator provided"); } @@ -2994,7 +2993,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "Mismatch between cudnn conv and algorithm descriptors."); } @@ -3017,7 +3016,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -3027,7 +3026,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "No scratch allocator provided"); } @@ -3040,7 +3039,7 @@ tsl::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type, if (desc.has_value()) { use_tensor_ops = desc->tensor_ops_enabled(); if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "Algo requests disabled tensor op evaluation."); } } else { @@ -3162,7 +3161,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( // no_scratch algorithm. if (!algo_desc.has_value()) { return tsl::Status( - port::error::INVALID_ARGUMENT, + tsl::error::INVALID_ARGUMENT, "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided."); } @@ -3224,7 +3223,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( // no_scratch algorithm. if (!algo_desc.has_value()) { return tsl::Status( - port::error::INVALID_ARGUMENT, + tsl::error::INVALID_ARGUMENT, absl::StrCat( "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided. Actual error: ", @@ -4254,7 +4253,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -4264,7 +4263,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { ToCudnnDataType(input_type_) == CUDNN_DATA_INT8 && ToCudnnDataType(output_type_) == CUDNN_DATA_FLOAT) { return tsl::Status( - port::error::FAILED_PRECONDITION, + tsl::error::FAILED_PRECONDITION, "This configuration potentially produces incorrect results."); } #else @@ -4336,7 +4335,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } profile_result->set_algorithm(algo); profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); @@ -4631,7 +4630,7 @@ class CudnnExecutionPlanRunner<void(Args...)> // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } @@ -4641,7 +4640,7 @@ class CudnnExecutionPlanRunner<void(Args...)> if (is_profiling) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); profile_result->set_algorithm(desc); @@ -4868,7 +4867,7 @@ tsl::Status CudnnSupport::GetConvolveRunners( } if (!got_algos) { return tsl::Status( - port::error::UNKNOWN, + tsl::error::UNKNOWN, absl::StrFormat("Listing algorithms failed for kind %d", kind)); } @@ -5037,7 +5036,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to start timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to start timer"); } } auto side_input_data_ptr = (side_input_scale_ == 0) @@ -5065,7 +5064,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { << "\noutput_data.opaque() = " << output_data.opaque(); if (IsTensorMathOpSet(conv_) != tensor_ops_enabled_) { - return tsl::Status(port::error::FAILED_PRECONDITION, + return tsl::Status(tsl::error::FAILED_PRECONDITION, "Tensor op math type in dnn::AlgorithmDesc does not " "match that of the CudnnConvolutionDescriptor"); } @@ -5095,7 +5094,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { if (profile_result) { if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(port::error::INTERNAL, "Failed to stop timer"); + return tsl::Status(tsl::error::INTERNAL, "Failed to stop timer"); } profile_result->set_algorithm(algo); profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); @@ -5308,7 +5307,7 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( activation_mode != dnn::ActivationMode::kElu && activation_mode != dnn::ActivationMode::kLeakyRelu && activation_mode != dnn::ActivationMode::kNone) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "CuDNN fusion only supports activations of " "{Relu, Relu6, Elu, <None>}."); } @@ -5319,7 +5318,7 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( auto cuda_compute_capability = stream->GetCudaComputeCapability(); if (!GetConvolveAlgorithms(cuda_compute_capability, input_type, &algorithms)) { - return tsl::Status(port::error::UNKNOWN, + return tsl::Status(tsl::error::UNKNOWN, "Listing fused convolve algorithms failed."); } @@ -5354,7 +5353,7 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( leakyrelu_alpha, input_descriptor, filter_descriptor, bias_descriptor, output_descriptor, convolution_descriptor, activation_mode, cudnn); if (!op_graph_status.status().ok()) { - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, absl::StrCat("Cudnn graph failed to build: ", op_graph_status.status().ToString())); } @@ -5391,7 +5390,7 @@ tsl::Status CudnnSupport::GetFusedMatmulRunners( input_type, bias_type, output_type, trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, cudnn); if (!op_graph_status.status().ok()) { - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, absl::StrCat("Cudnn graph failed to build: ", op_graph_status.status().ToString())); } @@ -5685,7 +5684,7 @@ tsl::Status CudnnSupport::DoBatchNormalizationForwardImpl( if (activation_mode != dnn::ActivationMode::kNone || !side_input.is_null()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat( "Side input and activation are not supported by cuDNN version: ", CUDNN_VERSION)); @@ -5968,7 +5967,7 @@ tsl::Status CudnnSupport::DoFusedConvolve( if (activation_mode != dnn::ActivationMode::kRelu && activation_mode != dnn::ActivationMode::kNone) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "cudnnConvolutionBiasActivationForward() only supports " "Relu or None activation."); } @@ -6070,7 +6069,7 @@ tsl::Status CudnnSupport::DoPrepareForCtcLoss( } *ctc_loss_algo_id = algo; #else - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "No supported cudnnGetCTCLossWorkspaceSize when " "CUDNN_VERSION < 7.6.3"); #endif @@ -6100,7 +6099,7 @@ tsl::Status CudnnSupport::DoCtcLoss( int ctc_loss_algo_id) { // Current cuDNN CTC Loss only supports the float datatype if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "CudnnCtcLossDescriptor is supported only when the " "CUDNN_VERSION >= 7.6.3 and DataType is float"); } @@ -6382,7 +6381,7 @@ tsl::StatusOr<std::vector<PoolingSplitsSpec>> GetTensorSplits( if (max_batches_per_split == 0) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat( "Tensor has too many elements for int32 indexing: batches=", num_batches, " elements_per_batch=", elements_per_batch_input, @@ -6442,7 +6441,7 @@ tsl::Status CudnnSupport::DoPoolForward( auto splits_or = GetTensorSplits(input_dimensions, output_dimensions, element_type); if (!splits_or.ok()) { - return tsl::Status(port::error::INTERNAL, "Cudnn pooling failed to split"); + return tsl::Status(tsl::error::INTERNAL, "Cudnn pooling failed to split"); } auto splits = std::move(splits_or.value()); @@ -6511,7 +6510,7 @@ tsl::Status CudnnSupport::DoPoolBackward( auto splits_or = GetTensorSplits(input_dimensions, output_dimensions, element_type); if (!splits_or.ok()) { - return tsl::Status(port::error::INTERNAL, "Cudnn pooling failed to split"); + return tsl::Status(tsl::error::INTERNAL, "Cudnn pooling failed to split"); } auto splits = std::move(splits_or.value());
tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc+20 −20 modified@@ -35,10 +35,10 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/stacktrace.h" #include "tensorflow/tsl/platform/static_threadlocal.h" #include "tensorflow/tsl/platform/threadpool.h" @@ -267,7 +267,7 @@ static tsl::Status InternalInit() { } Diagnostician::LogDiagnosticInformation(); - return tsl::Status(port::error::ABORTED, + return tsl::Status(tsl::error::ABORTED, absl::StrCat("failed call to cuInit: ", ToString(res))); } @@ -400,7 +400,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, } } - return tsl::Status(port::error::INTERNAL, message); + return tsl::Status(tsl::error::INTERNAL, message); } /* static */ void GpuDriver::DestroyContext(GpuContext* context) { @@ -673,7 +673,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, } return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat("failed to get device for context: ", ToString(result))); } @@ -972,7 +972,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { /* static */ tsl::Status GpuDriver::DestroyEvent(GpuContext* context, CUevent* event) { if (*event == nullptr) { - return tsl::Status(port::error::INVALID_ARGUMENT, + return tsl::Status(tsl::error::INVALID_ARGUMENT, "input event cannot be null"); } @@ -997,7 +997,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CUresult res = cuEventQuery(event); if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to query event: %s", ToString(res))); } @@ -1263,11 +1263,11 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { if (res == CUDA_SUCCESS) { return ::tsl::OkStatus(); } else if (res == CUDA_ERROR_OUT_OF_MEMORY) { - return tsl::Status(port::error::RESOURCE_EXHAUSTED, + return tsl::Status(tsl::error::RESOURCE_EXHAUSTED, "could not create CUDA event: out of device memory"); } else { return tsl::Status( - port::error::FAILED_PRECONDITION, + tsl::error::FAILED_PRECONDITION, absl::StrCat("could not create CUDA event: ", ToString(res))); } } @@ -1299,14 +1299,14 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { // error then the original one. if (context == nullptr) { return tsl::Status( - port::error::UNAVAILABLE, + tsl::error::UNAVAILABLE, "Empty context returned while querying context for device pointer"); } return context; } return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat("failed to query context for device pointer: ", ToString(result))); } @@ -1324,13 +1324,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return MemorySpace::kHost; default: return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat("unknown memory space provided by CUDA API: ", value)); } } return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrCat("failed to query device pointer for memory space: ", ToString(result))); } @@ -1346,13 +1346,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { // "there was an internal error while performing this operation" (return // below). return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrFormat("not a device pointer %p; %s", reinterpret_cast<void*>(dptr), ToString(result))); } return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to get pointer into for device pointer %p; %s", reinterpret_cast<void*>(dptr), ToString(result))); } @@ -1377,7 +1377,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); if (res != CUDA_SUCCESS) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat( "failed to get compute capability major for device: %s; %d", ToString(res), device)); @@ -1387,7 +1387,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); if (res != CUDA_SUCCESS) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat( "failed to get compute capability minor for device: %s; %d", ToString(res), device)); @@ -1399,13 +1399,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { /* static */ tsl::Status GpuDriver::GetGpuISAVersion(int* version, CUdevice device) { return tsl::Status{ - port::error::INTERNAL, + tsl::error::INTERNAL, "Feature not supported on CUDA platform (GetGpuISAVersion)"}; } /* static */ tsl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { return tsl::Status{ - port::error::INTERNAL, + tsl::error::INTERNAL, "Feature not supported on CUDA platform (GetGpuGCNArchName)"}; } @@ -1519,7 +1519,7 @@ static tsl::StatusOr<T> GetSimpleAttribute(CUdevice device, CUresult res = cuDeviceGetAttribute(&val, attribute, device); if (res != CUDA_SUCCESS) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to get device attribute %d for device %d: %s", attribute, device, ToString(res))); } @@ -1628,7 +1628,7 @@ static tsl::StatusOr<T> GetSimpleAttribute(CUdevice device, if (result != CUDA_SUCCESS && result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to enable peer access from %p to %p: %s", from, to, ToString(result))); }
tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc+1 −1 modified@@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc+3 −3 modified@@ -42,9 +42,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h" #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" @@ -53,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/stream_executor/timer.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/numbers.h" #include "tensorflow/tsl/platform/statusor.h" @@ -745,7 +745,7 @@ tsl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { return ::tsl::OkStatus(); } else { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("error recording waiting for CUDA event on stream %p", stream)); }
tensorflow/compiler/xla/stream_executor/cuda/cuda_platform.cc+4 −4 modified@@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status.h" namespace stream_executor { @@ -117,7 +117,7 @@ tsl::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus( } return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); } @@ -177,7 +177,7 @@ CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto init_status = executor->Init(config.device_options); if (!init_status.ok()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat( "failed initializing StreamExecutor for CUDA device ordinal %d: %s", config.ordinal, init_status.ToString()));
tensorflow/compiler/xla/stream_executor/cuda/cuda_rng.cc+1 −1 modified@@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/rng.h" #include "tensorflow/tsl/platform/status.h"
tensorflow/compiler/xla/stream_executor/host/BUILD+3 −4 modified@@ -46,7 +46,7 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/stream_executor/platform", - "@com_google_absl//absl/base:core_headers", + "//tensorflow/tsl/platform:errors", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", ], @@ -104,9 +104,8 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:kernel", "//tensorflow/compiler/xla/stream_executor:rng", "//tensorflow/compiler/xla/stream_executor:stream_executor_internal", - "//tensorflow/compiler/xla/stream_executor:stream_executor_pimpl", - "//tensorflow/compiler/xla/stream_executor:timer", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/compiler/xla/stream_executor:stream_executor_pimpl", # fixdeps: keep + "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:platform_port", "//tensorflow/tsl/platform/profile_utils:profile_utils_cpu_utils", "@com_google_absl//absl/functional:any_invocable",
tensorflow/compiler/xla/stream_executor/host/host_gpu_executor.h+1 −1 modified@@ -25,10 +25,10 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/host/host_stream.h" #include "tensorflow/compiler/xla/stream_executor/host/host_timer.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/rng.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" +#include "tensorflow/tsl/platform/errors.h" namespace stream_executor { namespace host {
tensorflow/compiler/xla/stream_executor/host/host_platform.cc+3 −3 modified@@ -21,8 +21,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/stream_executor/host/host_gpu_executor.h" #include "tensorflow/compiler/xla/stream_executor/host/host_platform_id.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" +#include "tensorflow/tsl/platform/errors.h" namespace stream_executor { namespace host { @@ -75,7 +75,7 @@ HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto init_status = executor->Init(config.device_options); if (!init_status.ok()) { return tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", config.ordinal, init_status.ToString().c_str()));
tensorflow/compiler/xla/stream_executor/lib/BUILD+0 −1 modified@@ -31,7 +31,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/xla/stream_executor/platform", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:stacktrace", "//tensorflow/tsl/platform:status",
tensorflow/compiler/xla/stream_executor/lib/error.h+0 −30 removed@@ -1,30 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - - -#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_ERROR_H_ -#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_ERROR_H_ - -#include "tensorflow/tsl/protobuf/error_codes.pb.h" // IWYU pragma: export - -namespace stream_executor { -namespace port { - -namespace error = tensorflow::error; - -} // namespace port -} // namespace stream_executor - -#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_ERROR_H_
tensorflow/compiler/xla/stream_executor/lib/initialize.h+0 −21 removed@@ -1,21 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_INITIALIZE_H_ -#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_INITIALIZE_H_ - -#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" - -#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LIB_INITIALIZE_H_
tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc+6 −7 modified@@ -24,8 +24,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/tsl/platform/errors.h" namespace stream_executor { @@ -96,7 +95,7 @@ tsl::Status MultiPlatformManagerImpl::RegisterPlatform( std::string key = absl::AsciiStrToLower(platform->Name()); absl::MutexLock lock(&mu_); if (name_map_.find(key) != name_map_.end()) { - return tsl::Status(port::error::INTERNAL, + return tsl::Status(tsl::error::INTERNAL, "platform is already registered with name: \"" + platform->Name() + "\""); } @@ -156,7 +155,7 @@ tsl::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName( TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); if (platform->Initialized()) { return tsl::Status( - port::error::FAILED_PRECONDITION, + tsl::error::FAILED_PRECONDITION, absl::StrCat("platform \"", target, "\" is already initialized")); } @@ -172,7 +171,7 @@ tsl::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId( TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); if (platform->Initialized()) { return tsl::Status( - port::error::FAILED_PRECONDITION, + tsl::error::FAILED_PRECONDITION, absl::StrFormat("platform with id %p is already initialized", id)); } @@ -232,7 +231,7 @@ tsl::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked( auto it = name_map_.find(absl::AsciiStrToLower(target)); if (it == name_map_.end()) { return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrCat("Could not find registered platform with name: \"", target, "\". Available platform names are: ", absl::StrJoin(InitializedPlatformNamesWithFilter(), " "))); @@ -245,7 +244,7 @@ tsl::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked( auto it = id_map_.find(id); if (it == id_map_.end()) { return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrFormat("could not find registered platform with id: %p", id)); } return it->second;
tensorflow/compiler/xla/stream_executor/multi_platform_manager.h+1 −1 modified@@ -70,8 +70,8 @@ limitations under the License. #include <vector> #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/tsl/platform/status.h" #include "tensorflow/tsl/platform/statusor.h"
tensorflow/compiler/xla/stream_executor/platform.cc+3 −3 modified@@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/platform.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/tsl/platform/errors.h" namespace stream_executor { @@ -92,14 +92,14 @@ bool Platform::Initialized() const { return true; } tsl::Status Platform::Initialize( const std::map<std::string, std::string> &platform_options) { if (!platform_options.empty()) { - return tsl::Status(port::error::UNIMPLEMENTED, + return tsl::Status(tsl::error::UNIMPLEMENTED, "this platform does not support custom initialization"); } return ::tsl::OkStatus(); } tsl::Status Platform::ForceExecutorShutdown() { - return tsl::Status(port::error::UNIMPLEMENTED, + return tsl::Status(tsl::error::UNIMPLEMENTED, "executor shutdown is not supported on this platform"); }
tensorflow/compiler/xla/stream_executor/plugin_registry.cc+5 −5 modified@@ -19,8 +19,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h" +#include "tensorflow/tsl/platform/errors.h" namespace stream_executor { @@ -76,7 +76,7 @@ tsl::Status PluginRegistry::RegisterFactoryInternal( if (factories->find(plugin_id) != factories->end()) { return tsl::Status( - port::error::ALREADY_EXISTS, + tsl::error::ALREADY_EXISTS, absl::StrFormat("Attempting to register factory for plugin %s when " "one has already been registered", plugin_name)); @@ -96,7 +96,7 @@ tsl::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal( iter = generic_factories.find(plugin_id); if (iter == generic_factories.end()) { return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrFormat("Plugin ID %p not registered.", plugin_id)); } } @@ -217,7 +217,7 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, \ if (plugin_id == kNullPlugin) { \ return tsl::Status( \ - port::error::FAILED_PRECONDITION, \ + tsl::error::FAILED_PRECONDITION, \ "No suitable " PLUGIN_STRING \ " plugin registered. Have you linked in a " PLUGIN_STRING \ "-providing plugin?"); \ @@ -236,7 +236,7 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, PlatformKind platform_kind, PluginId plugin_id) { \ auto iter = platform_id_by_kind_.find(platform_kind); \ if (iter == platform_id_by_kind_.end()) { \ - return tsl::Status(port::error::FAILED_PRECONDITION, \ + return tsl::Status(tsl::error::FAILED_PRECONDITION, \ absl::StrFormat("Platform kind %d not registered.", \ static_cast<int>(platform_kind))); \ } \
tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc+1 −1 modified@@ -32,8 +32,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
tensorflow/compiler/xla/stream_executor/rocm/rocm_diagnostics.cc+1 −1 modified@@ -35,8 +35,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/host_info.h" namespace stream_executor {
tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc+2 −2 modified@@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" #include "tensorflow/compiler/xla/stream_executor/rocm/rocm_diagnostics.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/hash.h" #include "tensorflow/tsl/util/determinism.h" #include "tensorflow/tsl/util/env_var.h"
tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc+1 −1 modified@@ -28,11 +28,11 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_diagnostics.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/numbers.h" #include "tensorflow/tsl/platform/stacktrace.h" #include "tensorflow/tsl/platform/static_threadlocal.h"
tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.cc+1 −1 modified@@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc+2 −2 modified@@ -28,10 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/stream_executor/timer.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" #ifdef PLATFORMS_GPUS_ROCM_DYNAMIC_LIBROCM_DYNAMIC_LIBROCM_H_ #error \
tensorflow/compiler/xla/stream_executor/rocm/rocm_platform.cc+2 −2 modified@@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/rocm/rocm_platform_id.h" +#include "tensorflow/tsl/platform/errors.h" namespace stream_executor { namespace gpu {
tensorflow/compiler/xla/stream_executor/rocm/rocm_rng.cc+1 −1 modified@@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_rng.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/rng.h" #include "tensorflow/compiler/xla/stream_executor/rocm/rocm_platform_id.h"
tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc+8 −8 modified@@ -32,11 +32,11 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/fft.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/compiler/xla/stream_executor/rng.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/stacktrace.h" #include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/platform/threadpool.h" @@ -405,7 +405,7 @@ StreamExecutor::createRnnDescriptor( bool use_padded_io) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { - return tsl::Status(port::error::UNKNOWN, + return tsl::Status(tsl::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnDescriptor( @@ -420,7 +420,7 @@ StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, dnn::DataType data_type) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { - return tsl::Status(port::error::UNKNOWN, + return tsl::Status(tsl::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnSequenceTensorDescriptor( @@ -434,7 +434,7 @@ StreamExecutor::createRnnSequenceTensorDescriptor( dnn::DataType data_type) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { - return tsl::Status(port::error::UNKNOWN, + return tsl::Status(tsl::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnSequenceTensorDescriptor( @@ -448,7 +448,7 @@ StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, dnn::DataType data_type) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { - return tsl::Status(port::error::UNKNOWN, + return tsl::Status(tsl::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, @@ -546,7 +546,7 @@ tsl::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol( } return tsl::Status( - port::error::NOT_FOUND, + tsl::error::NOT_FOUND, absl::StrCat("Check if module containing symbol ", symbol_name, " is loaded (module_handle = ", reinterpret_cast<uintptr_t>(module_handle.id()), ")")); @@ -691,7 +691,7 @@ tsl::Status StreamExecutor::SynchronousMemcpyD2H( result = implementation_->SynchronousMemcpy(host_dst, device_src, size); if (!result.ok()) { result = tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to synchronously memcpy device-to-host: device " "%p to host %p size %d: %s", device_src.opaque(), host_dst, size, @@ -715,7 +715,7 @@ tsl::Status StreamExecutor::SynchronousMemcpyH2D(const void* host_src, result = implementation_->SynchronousMemcpy(device_dst, host_src, size); if (!result.ok()) { result = tsl::Status( - port::error::INTERNAL, + tsl::error::INTERNAL, absl::StrFormat("failed to synchronously memcpy host-to-device: host " "%p to device %p size %d: %s", host_src, device_dst->opaque(), size,
tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.cc+0 −1 modified@@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/stream_executor/lib/error.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/tsl/platform/errors.h"
tensorflow/core/common_runtime/function.cc+1 −1 modified@@ -535,7 +535,7 @@ class CallOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library is provided."), done); - FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts(ctx->step_id()); opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.step_container = ctx->step_container();
tensorflow/core/data/BUILD+1 −0 modified@@ -257,6 +257,7 @@ cc_library( # copybara:uncomment copts = ["-Wthread-safety-analysis"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform:env", "//tensorflow/core/platform:mutex",
tensorflow/core/data/service/server_lib.cc+1 −1 modified@@ -37,7 +37,7 @@ constexpr char kPortPlaceholder[] = "%port%"; } GrpcDataServerBase::GrpcDataServerBase( - int port, const std::string& protocol, const std::string server_type, + int port, const std::string& protocol, const std::string& server_type, std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options) : requested_port_(port), protocol_(protocol),
tensorflow/core/data/service/server_lib.h+1 −1 modified@@ -44,7 +44,7 @@ class GrpcDataServerBase { // found by calling `BoundPort()`. GrpcDataServerBase( int requested_port, const std::string& protocol, - const std::string server_type, + const std::string& server_type, std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {}); virtual ~GrpcDataServerBase() = default;
tensorflow/core/data/service/snapshot/BUILD+4 −0 modified@@ -84,7 +84,9 @@ cc_library( hdrs = ["path_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", ], ) @@ -96,6 +98,8 @@ tf_cc_test( ":path_utils", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/tsl/platform:status_matchers", + "//tensorflow/tsl/protobuf:protos_all_cc", ], )
tensorflow/core/data/service/snapshot/path_utils.cc+27 −0 modified@@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/path_utils.h" #include <string> +#include <utility> +#include <vector> #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace data { @@ -65,6 +70,28 @@ std::string SplitPath(absl::string_view snapshot_path, int64_t stream_index, absl::StrCat("split_", local_index, "_", global_index)); } +tsl::StatusOr<std::pair<int64_t, int64_t>> SplitIndex( + absl::string_view split_path) { + std::vector<std::string> tokens = absl::StrSplit(split_path, '_'); + int64_t local_split_index = 0, global_split_index = 0; + if (tokens.size() != 3 || tokens[0] != "split" || + !absl::SimpleAtoi(tokens[1], &local_split_index) || + local_split_index < 0 || + !absl::SimpleAtoi(tokens[2], &global_split_index) || + global_split_index < 0) { + return tsl::errors::InvalidArgument( + "Invalid split file name: ", split_path, + ". Expected split_<local_split_index>_<global_split_index>."); + } + if (local_split_index > global_split_index) { + return tsl::errors::InvalidArgument( + "Invalid split file name: ", split_path, ". The local split index ", + local_split_index, " exceeds the global split index ", + global_split_index, "."); + } + return std::make_pair(local_split_index, global_split_index); +} + std::string SnapshotMetadataFilePath(absl::string_view snapshot_path_) { return tsl::io::JoinPath(snapshot_path_, kSnapshotMetadataFileName); }
tensorflow/core/data/service/snapshot/path_utils.h+8 −0 modified@@ -16,8 +16,10 @@ limitations under the License. #define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PATH_UTILS_H_ #include <string> +#include <utility> #include "absl/strings/string_view.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace data { @@ -45,6 +47,12 @@ std::string SplitPath(absl::string_view snapshot_path, int64_t stream_index, int64_t source_id, int64_t local_index, int64_t global_index); +// Returns a pair of {local_split_index, global_split_index} of the split. The +// expected format of `split_path` is: +// split_<local_split_index>_<global_split_index> +tsl::StatusOr<std::pair<int64_t, int64_t>> SplitIndex( + absl::string_view split_path); + // Returns the path of the DONE file of a snapshot stream. std::string StreamDoneFilePath(absl::string_view snapshot_path, int64_t stream_index);
tensorflow/core/data/service/snapshot/path_utils_test.cc+34 −0 modified@@ -14,13 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tensorflow/tsl/platform/status_matchers.h" #include "tensorflow/tsl/platform/test.h" +#include "tensorflow/tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { namespace { +using ::testing::HasSubstr; using ::testing::MatchesRegex; +using ::testing::Pair; +using tsl::testing::IsOkAndHolds; +using tsl::testing::StatusIs; TEST(PathUtilsTest, StreamsDirectory) { EXPECT_THAT(StreamsDirectory("/path/to/snapshot"), @@ -51,6 +57,34 @@ TEST(PathUtilsTest, SplitPath) { "/path/to/snapshot.streams.stream_0.splits.source_1.split_2_3")); } +TEST(PathUtilsTest, SplitIndex) { + EXPECT_THAT(SplitIndex("split_0_1"), IsOkAndHolds(Pair(0, 1))); +} + +TEST(PathUtilsTest, InvalidSplitFile) { + EXPECT_THAT( + SplitIndex(""), + StatusIs(error::INVALID_ARGUMENT, + HasSubstr( + "Expected split_<local_split_index>_<global_split_index>"))); + EXPECT_THAT( + SplitIndex("split_123"), + StatusIs(error::INVALID_ARGUMENT, + HasSubstr( + "Expected split_<local_split_index>_<global_split_index>"))); + EXPECT_THAT( + SplitIndex("split_-1_(-1)"), + StatusIs(error::INVALID_ARGUMENT, + HasSubstr( + "Expected split_<local_split_index>_<global_split_index>"))); + EXPECT_THAT( + SplitIndex("split_5_0"), + StatusIs( + error::INVALID_ARGUMENT, + HasSubstr( + "The local split index 5 exceeds the global split index 0"))); +} + TEST(PathUtilsTest, StreamDoneFilePath) { EXPECT_THAT(StreamDoneFilePath("/path/to/snapshot", /*stream_index=*/0), MatchesRegex("/path/to/snapshot.streams.stream_0.DONE"));
tensorflow/core/data/service/snapshot/snapshot_manager.cc+2 −15 modified@@ -187,21 +187,8 @@ Status SnapshotManager::ReadOnDiskSource( // `split_filename` must have this format: // "split_<local_split_index>_<global_split_index>". - std::vector<std::string> tokens = absl::StrSplit(split_filename, '_'); - int64_t local_split_index; - int64_t global_split_index; - if (tokens.size() != 3 || - !absl::SimpleAtoi(tokens[1], &local_split_index) || - local_split_index < 0 || - !absl::SimpleAtoi(tokens[2], &global_split_index) || - global_split_index < 0) { - return InvalidArgument("can't parse the name of ", split_path); - } - if (local_split_index > global_split_index) { - return InvalidArgument( - "found conflict between local split index and global split index in ", - "name of ", split_path); - } + TF_ASSIGN_OR_RETURN(auto split_index, SplitIndex(split_filename)); + auto [local_split_index, global_split_index] = split_index; if (local_split_index > split_filenames.size() - 1) { return InvalidArgument( "found conflict between the number of splits and name of ",
tensorflow/core/data/tfdataz_metrics.cc+7 −2 modified@@ -88,8 +88,9 @@ absl::Duration ApproximateLatencyEstimator::GetAverageLatency(Duration duration) return absl::Duration(absl::Microseconds(interval_latency)) / interval_count; } -TfDatazMetricsCollector::TfDatazMetricsCollector(const Env& env) - : latency_estimator_(env) {} +TfDatazMetricsCollector::TfDatazMetricsCollector(const Env& env, + IteratorBase* iterator) + : iterator_(iterator), latency_estimator_(env) {} void TfDatazMetricsCollector::RecordGetNextLatency( int64_t get_next_latency_usec) { @@ -113,6 +114,10 @@ absl::Duration TfDatazMetricsCollector::GetAverageLatencyForLastSixtyMinutes() { ApproximateLatencyEstimator::Duration::kSixtyMinutes); } +int64_t TfDatazMetricsCollector::GetIteratorTotalMemoryUsage() { + return iterator_->TotalBufferedBytes(); +} + namespace { static mutex* get_tfdataz_metrics_registry_lock() { static mutex tfdataz_metrics_registry_lock(LINKER_INITIALIZED);
tensorflow/core/data/tfdataz_metrics.h+8 −1 modified@@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/time/time.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -95,7 +96,7 @@ class TfDatazMetricsCollector { // We only collect metrics for CPU devices. This is a heuristic to avoid // collecting metrics for device-side iterators created by the multi-device // iterator mechanism. - explicit TfDatazMetricsCollector(const Env& env); + TfDatazMetricsCollector(const Env& env, IteratorBase* iterator); // Records `GetNext` call latency. void RecordGetNextLatency(int64_t get_next_latency_usec); @@ -109,7 +110,13 @@ class TfDatazMetricsCollector { // Returns the average `GetNext` latency for past 60 minutes. absl::Duration GetAverageLatencyForLastSixtyMinutes(); + // Returns the total memory (in bytes) used by the iterator. + // Total memory used by the iterator includes the total number of bytes + // buffered in all nodes in the subtree. + int64_t GetIteratorTotalMemoryUsage(); + private: + IteratorBase* iterator_; // not owned ApproximateLatencyEstimator latency_estimator_; };
tensorflow/core/data/tfdataz_metrics_test.cc+16 −12 modified@@ -18,7 +18,7 @@ limitations under the License. #include <utility> #include "absl/time/time.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/fake_clock_env.h" @@ -41,14 +41,16 @@ class TfDatazMetricsTest : public ::testing::Test { protected: void SetUp() override { env_ = std::make_unique<FakeClockEnv>(Env::Default()); - tfdataz_metrics_ = std::make_unique<TfDatazMetricsCollector>(*env_); + tfdataz_metrics_ = + std::make_unique<TfDatazMetricsCollector>(*env_, iterator_.get()); } void TearDown() override { env_.reset(); tfdataz_metrics_.reset(); } + std::unique_ptr<IteratorBase> iterator_; std::unique_ptr<FakeClockEnv> env_; std::unique_ptr<TfDatazMetricsCollector> tfdataz_metrics_; }; @@ -184,10 +186,11 @@ class ScopedTfDataMetricsRegistration { }; TEST(TfDatazMetricsRegistryTest, Register) { - auto collector_one = - std::make_shared<TfDatazMetricsCollector>(*Env::Default()); - auto collector_two = - std::make_shared<TfDatazMetricsCollector>(*Env::Default()); + std::unique_ptr<IteratorBase> iterator; + auto collector_one = std::make_shared<TfDatazMetricsCollector>( + *Env::Default(), iterator.get()); + auto collector_two = std::make_shared<TfDatazMetricsCollector>( + *Env::Default(), iterator.get()); ScopedTfDataMetricsRegistration scoped_registration_one(collector_one); ScopedTfDataMetricsRegistration scoped_registration_two(collector_two); @@ -196,12 +199,13 @@ TEST(TfDatazMetricsRegistryTest, Register) { } TEST(TfDatazMetricsRegistryTest, Deregister) { - auto collector_one = - std::make_shared<TfDatazMetricsCollector>(*Env::Default()); - auto collector_two = - std::make_shared<TfDatazMetricsCollector>(*Env::Default()); - auto collector_three = - std::make_shared<TfDatazMetricsCollector>(*Env::Default()); + std::unique_ptr<IteratorBase> iterator; + auto collector_one = std::make_shared<TfDatazMetricsCollector>( + *Env::Default(), iterator.get()); + auto collector_two = std::make_shared<TfDatazMetricsCollector>( + *Env::Default(), iterator.get()); + auto collector_three = std::make_shared<TfDatazMetricsCollector>( + *Env::Default(), iterator.get()); ScopedTfDataMetricsRegistration scoped_registration_one(collector_one); ScopedTfDataMetricsRegistration scoped_registration_two(collector_two); ScopedTfDataMetricsRegistration scoped_registration_three(collector_three);
tensorflow/core/distributed_runtime/eager/remote_mgr.cc+1 −1 modified@@ -33,7 +33,7 @@ Status WithErrorSourcePayload(Status error) { error_source_proto.set_error_source( core::platform::ErrorSourceProto::EAGER_REMOTE_MGR); error.SetPayload(tensorflow::kErrorSource, - error_source_proto.SerializeAsString()); + absl::Cord(error_source_proto.SerializeAsString())); return error; } } // namespace
tensorflow/core/distributed_runtime/integration_test/coordination_test_opkernel_registration.cc+2 −1 modified@@ -149,7 +149,8 @@ class TestReportErrorToClusterOp : public OpKernel { } tensorflow::Status s(static_cast<tensorflow::error::Code>(error_code), error_message); - s.SetPayload(tsl::CoordinationErrorPayloadKey(), "testing error payload"); + s.SetPayload(tsl::CoordinationErrorPayloadKey(), + absl::Cord("testing error payload")); OP_REQUIRES_OK(ctx, coord_agent->ReportError(s)); } };
tensorflow/core/framework/dataset.h+7 −0 modified@@ -959,6 +959,13 @@ class IteratorBase : public Checkpointable { return OkStatus(); } + // Returns the total number of bytes buffered by the iterator across all nodes + // in the subtree for which autotuning is enabled. + int64_t TotalBufferedBytes() const { + if (node_) return node_->TotalBufferedBytes(); + return 0; + } + protected: // Returns a node that models this iterator. virtual std::shared_ptr<model::Node> CreateNode(
tensorflow/core/framework/op_requires.h+1 −1 modified@@ -62,7 +62,7 @@ namespace tensorflow { if (!TF_PREDICT_TRUE(STATUS.ok())) { \ CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \ if (!PAYLOAD_VALUE.empty()) { \ - STATUS.SetPayload(PAYLOAD_KEY, PAYLOAD_VALUE); \ + STATUS.SetPayload(PAYLOAD_KEY, absl::Cord(PAYLOAD_VALUE)); \ } \ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, STATUS); \ return; \
tensorflow/core/kernels/data/iterator_ops.cc+3 −2 modified@@ -93,8 +93,6 @@ IteratorResource::IteratorResource( /*iterator=*/nullptr)), output_dtypes_(output_dtypes), output_shapes_(output_shapes) { - tf_dataz_metrics_collector_ = std::make_shared<TfDatazMetricsCollector>(*env); - TfDatazMetricsRegistry::Register(tf_dataz_metrics_collector_); VLOG(2) << "creating iterator resource"; } @@ -274,6 +272,9 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, new_state->MergeCheckpoint(iter_ctx.checkpoint()); mutex_lock l(mu_); std::swap(iterator_state_, new_state); + tf_dataz_metrics_collector_ = + std::make_shared<TfDatazMetricsCollector>(env_, iterator.get()); + TfDatazMetricsRegistry::Register(tf_dataz_metrics_collector_); return OkStatus(); }
tensorflow/core/kernels/functional_ops.cc+7 −3 modified@@ -160,7 +160,8 @@ class IfOp : public AsyncOpKernel { then_handle_(then_handle), else_handle_(else_handle), done_(std::move(done)), - lib_(CHECK_NOTNULL(ctx_->function_library())) { + lib_(CHECK_NOTNULL(ctx_->function_library())), + opts_(ctx->step_id()) { SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); for (int i = 1; i < ctx_->num_inputs(); ++i) { args_.push_back(ctx_->input(i)); @@ -286,7 +287,8 @@ class CaseOp : public AsyncOpKernel { branch_(branch), branch_handles_(branch_handles), done_(std::move(done)), - lib_(CHECK_NOTNULL(ctx_->function_library())) { + lib_(CHECK_NOTNULL(ctx_->function_library())), + opts_(ctx->step_id()) { SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); for (int i = 1; i < ctx_->num_inputs(); ++i) { args_.push_back(ctx_->input(i)); @@ -507,7 +509,8 @@ class WhileOp : public AsyncOpKernel { cond_handle_(cond_handle), body_handle_(body_handle), done_(std::move(done)), - lib_(CHECK_NOTNULL(ctx_->function_library())) { + lib_(CHECK_NOTNULL(ctx_->function_library())), + opts_(ctx->step_id()) { SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); GetArgsFromContext(ctx, &args_, &loop_var_types_); body_frame_ = @@ -751,6 +754,7 @@ class ForOp : public AsyncOpKernel { ctx_(ctx), done_(std::move(done)), lib_(CHECK_NOTNULL(ctx_->function_library())), + opts_(ctx->step_id()), args_(1 + ctx_->num_inputs() - 3) { args_[0] = Tensor(DT_INT32, {}); iter_ = &args_[0].scalar<int32>()();
tensorflow/core/kernels/function_ops.cc+1 −1 modified@@ -238,7 +238,7 @@ class SymbolicGradientOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done); - FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts(ctx->step_id()); opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.collective_executor = ctx->collective_executor();
tensorflow/core/kernels/rnn/gru_ops.cc+55 −33 modified@@ -49,61 +49,68 @@ class GRUCellBlockOp : public OpKernel { const Tensor* b_c_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor)); + // Sanity checks for input shapes. + + // Shape of 'x' must be [batch_size, input_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(x_tensor->shape()), + errors::InvalidArgument("Rank of x must be 2", x_tensor->dims(), + " vs. 2")); const int64_t batch_size = x_tensor->dim_size(0); const int64_t input_size = x_tensor->dim_size(1); - const int64_t cell_size = h_prev_tensor->dim_size(1); - - // Sanity checks for input shapes. // Shape of 'h' must be [batch_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(h_prev_tensor->shape()), + errors::InvalidArgument("Rank of h_prev must be 2, got ", + h_prev_tensor->dims())); OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, errors::InvalidArgument("h_prev.dims(0) != batch_size: ", h_prev_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument( - "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), - " vs. ", cell_size)); + const int64_t cell_size = h_prev_tensor->dim_size(1); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_ru_tensor->shape()), + errors::InvalidArgument("Rank of w_ru_ must be 2, got ", + w_ru_tensor->dims())); OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w_ru.dim_size(0) != input_size + cell_size: ", w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2, errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ", w_ru_tensor->dim_size(1), " vs. ", cell_size * 2)); // Shape of 'w_c' must be [input_size+cell_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_c_tensor->shape()), + errors::InvalidArgument("Rank of w_c must be 2, got ", + w_c_tensor->dims())); OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, errors::InvalidArgument( "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(b_ru_tensor->shape()), + errors::InvalidArgument("Rank of b_ru must be 1, got ", + b_ru_tensor->dims())); OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ", b_ru_tensor->dim_size(0), " vs. ", cell_size * 2)); - OP_REQUIRES(ctx, b_ru_tensor->dims() == 1, - errors::InvalidArgument("Rank of b_ru must be 1", - b_ru_tensor->dims(), " vs. 1", 1)); // Shape of 'b_c' must be [cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(b_c_tensor->shape()), + errors::InvalidArgument("Rank of b_c must be 1, got ", + b_c_tensor->dims())); OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, errors::InvalidArgument( "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), " vs. ", cell_size)); - OP_REQUIRES(ctx, b_c_tensor->dims() == 1, - errors::InvalidArgument("Rank of b_c must be 1", - b_c_tensor->dims(), " vs. 1")); // Create output tensors. Tensor* r_tensor = nullptr; @@ -204,65 +211,71 @@ class GRUBlockCellGradOp : public OpKernel { const Tensor* d_h_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("d_h", &d_h_tensor)); + // Shape of 'x' must be [batch_size, input_size] + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(x_tensor->shape()), + errors::InvalidArgument("Rank of x must be 2, got ", x_tensor->dims())); const int64_t batch_size = x_tensor->dim_size(0); const int64_t input_size = x_tensor->dim_size(1); - const int64_t cell_size = h_prev_tensor->dim_size(1); - // Sanity checks for input shapes. - - // Shape of 'h_prev' must be [batch_size, cell_size] + // Shape of 'h' must be [batch_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(h_prev_tensor->shape()), + errors::InvalidArgument("Rank of h_prev must be 2, got ", + h_prev_tensor->dims())); OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, errors::InvalidArgument("h_prev.dims(0) != batch_size: ", h_prev_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument( - "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), - " vs. ", cell_size)); + const int64_t cell_size = h_prev_tensor->dim_size(1); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_ru_tensor->shape()), + errors::InvalidArgument("Rank of w_ru_ must be 2, got ", + w_ru_tensor->dims())); OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w_ru.dim_size(0) != input_size + cell_size: ", w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2, errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ", w_ru_tensor->dim_size(1), " vs. ", cell_size * 2)); // Shape of 'w_c' must be [input_size+cell_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_c_tensor->shape()), + errors::InvalidArgument("Rank of w_c must be 2, got ", + w_c_tensor->dims())); OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, errors::InvalidArgument( "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(b_ru_tensor->shape()), + errors::InvalidArgument("Rank of b_ru must be 1, got ", + b_ru_tensor->dims())); OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ", b_ru_tensor->dim_size(0), " vs. ", cell_size * 2)); - OP_REQUIRES(ctx, b_ru_tensor->dims() == 1, - errors::InvalidArgument("Rank of b_ru must be 1", - b_ru_tensor->dims(), " vs. 1")); - // Shape of 'b_c' must be [cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(b_c_tensor->shape()), + errors::InvalidArgument("Rank of b_c must be 1, got ", + b_c_tensor->dims())); OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, errors::InvalidArgument( "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), " vs. ", cell_size)); - OP_REQUIRES(ctx, b_c_tensor->dims() == 1, - errors::InvalidArgument("Rank of b_c must be 1 ", - b_c_tensor->dims(), " vs. 1")); - // Shape of 'r' must be [batch_size, cell_size] + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(r_tensor->shape()), + errors::InvalidArgument("Rank of r must be 2, got ", r_tensor->dims())); OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size, errors::InvalidArgument( "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ", @@ -273,6 +286,9 @@ class GRUBlockCellGradOp : public OpKernel { cell_size)); // Shape of 'u' must be [batch_size, cell_size] + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(u_tensor->shape()), + errors::InvalidArgument("Rank of u must be 2, got ", u_tensor->dims())); OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size, errors::InvalidArgument( "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ", @@ -283,6 +299,9 @@ class GRUBlockCellGradOp : public OpKernel { cell_size)); // Shape of 'c' must be [batch_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(c_tensor->shape()), + errors::InvalidArgument("Rank of w_c must be 2, got ", + c_tensor->dims())); OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size, errors::InvalidArgument( "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ", @@ -293,6 +312,9 @@ class GRUBlockCellGradOp : public OpKernel { cell_size)); // Shape of 'd_h' must be [batch_size, cell_size] + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(d_h_tensor->shape()), + errors::InvalidArgument("Rank of d_h must be 2, got ", + d_h_tensor->dims())); OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size, errors::InvalidArgument( "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0),
tensorflow/core/lib/core/status_test.cc+4 −4 modified@@ -177,23 +177,23 @@ TEST(StatusGroup, AggregateWithMultipleErrorStatus) { TEST(Status, InvalidPayloadGetsIgnored) { Status s = Status(); - s.SetPayload("Invalid", "Invalid Val"); + s.SetPayload("Invalid", absl::Cord("Invalid Val")); ASSERT_FALSE(s.GetPayload("Invalid").has_value()); bool is_err_erased = s.ErasePayload("Invalid"); ASSERT_EQ(is_err_erased, false); } TEST(Status, SetPayloadSetsOrUpdatesIt) { Status s(error::INTERNAL, "Error message"); - s.SetPayload("Error key", "Original"); + s.SetPayload("Error key", absl::Cord("Original")); ASSERT_EQ(s.GetPayload("Error key"), absl::Cord("Original")); - s.SetPayload("Error key", "Updated"); + s.SetPayload("Error key", absl::Cord("Updated")); ASSERT_EQ(s.GetPayload("Error key"), absl::Cord("Updated")); } TEST(Status, ErasePayloadRemovesIt) { Status s(error::INTERNAL, "Error message"); - s.SetPayload("Error key", "Original"); + s.SetPayload("Error key", absl::Cord("Original")); bool is_err_erased = s.ErasePayload("Error key"); ASSERT_EQ(is_err_erased, true);
tensorflow/core/platform/error_payloads.cc+1 −1 modified@@ -27,7 +27,7 @@ void OkOrSetErrorCounterPayload( ErrorSourceProto error_source_proto; error_source_proto.set_error_source(error_source); status.SetPayload(tensorflow::kErrorSource, - error_source_proto.SerializeAsString()); + absl::Cord(error_source_proto.SerializeAsString())); } }
tensorflow/core/tpu/kernels/tpu_compile_op_common.cc+1 −1 modified@@ -408,7 +408,7 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) { SerializeToTString(proto, &output.scalar<tstring>()()); ctx->set_output(0, output); status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey, - output.scalar<tstring>()()); + absl::Cord(output.scalar<tstring>()())); } if (status.ok()) {
tensorflow/core/tpu/kernels/tpu_functional_ops.cc+3 −3 modified@@ -1410,7 +1410,7 @@ Status TPUPartitionedCallOp::InitializeVarOnTPU( TF_RETURN_IF_ERROR( InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr)); - FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts(ctx->step_id()); opts.step_container = ctx->step_container(); opts.cancellation_manager = ctx->cancellation_manager(); opts.stats_collector = ctx->stats_collector(); @@ -1569,7 +1569,7 @@ Status TPUPartitionedCallOp::InitializeShardedVarOnTPU( functions.push_back(DeviceAndFHandle{.device = target, .handle = handle}); } - FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts(ctx->step_id()); // Blocking on threads in the same thread pool is disallowed because // concurrent warm-up requests can exhaust the default thread pool. @@ -2702,7 +2702,7 @@ void TPUPartitionedCallOp::ExecuteFunctions( const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx, int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) { profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteFunctions"); - FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts(ctx->step_id()); opts.step_container = ctx->step_container(); opts.stats_collector = ctx->stats_collector(); // TODO(akshayka): Consider selecting a runner on a per-device basis,
tensorflow/core/tpu/tpu_embedding_errors.cc+2 −1 modified@@ -29,7 +29,8 @@ Status AppendTpuEmbeddingErrorPayload(Status obj) { absl::StrCat(kTpuEmbeddingErrorMessage, ". ", obj.error_message()); Status status(obj.code(), error_message); TPUEmbeddingError error_payload; - status.SetPayload(kTpuEmbeddingErrorUrl, error_payload.SerializeAsString()); + status.SetPayload(kTpuEmbeddingErrorUrl, + absl::Cord(error_payload.SerializeAsString())); return status; } }
tensorflow/core/tpu/tpu_embedding_errors.h+2 −1 modified@@ -50,7 +50,8 @@ StatusOr<T> AppendTpuEmbeddingErrorPayload(StatusOr<T> obj) { kTpuEmbeddingErrorMessage, ". ", obj.status().error_message()); Status status(obj.status().code(), error_message); TPUEmbeddingError error_payload; - status.SetPayload(kTpuEmbeddingErrorUrl, error_payload.SerializeAsString()); + status.SetPayload(kTpuEmbeddingErrorUrl, + absl::Cord(error_payload.SerializeAsString())); return status; } }
tensorflow/core/util/zen_util.h+2 −2 modified@@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { -int64_t GetMempool() { +inline int64_t GetMempool() { static absl::once_flag once; static int64_t mempool = 1; absl::call_once(once, [&] { @@ -34,7 +34,7 @@ int64_t GetMempool() { return mempool; } -bool IsBlockedFormatEnabled() { +inline bool IsBlockedFormatEnabled() { static absl::once_flag once; static bool blocked_format = false; absl::call_once(once, [&] {
tensorflow/lite/delegates/xnnpack/README.md+24 −173 modified@@ -454,23 +454,34 @@ Below is the list of currently supported floating-point operators: * Output size, filter and bias (if present) must be static (use `kTfLiteMmapRo` allocation type). -### Floating-Point (IEEE FP16) Operators (experimental) +### Floating-Point (IEEE FP16) Operators XNNPACK supports half-precision (using IEEE FP16 format) inference for a subset of floating-point operators. XNNPACK automatically enables half-precision inference when the following conditions are met: * XNNPACK runs on hardware that natively supports computations in IEEE FP16 -format. Currently, this hardware is limited to ARM64 devices with ARMv8.2 FP16 -arithmetics extension, and includes Android phones starting with Pixel 3, -Galaxy S9 (Snapdragon SoC), Galaxy S10 (Exynos SoC), iOS devices with A11 or -newer SoCs, and all Apple Silicon Macs. +format. Currently, this hardware is limited to ARM & ARM64 devices with +ARMv8.2 FP16 arithmetics extension, and includes Android phones starting with +Pixel 3, Galaxy S9 (Snapdragon SoC), Galaxy S10 (Exynos SoC), iOS devices with +A11 or newer SoCs, all Apple Silicon Macs, and Windows ARM64 laptops based with +Snapdragon 850 SoC or newer. * IEEE FP16 inference is supported for every floating-point operator in the model. * The model's "reduced_precision_support" metadata indicates that the model -is compatible with FP16 inference. +is compatible with FP16 inference. The metadata can be added during model +conversion using the `_experimental_supported_accumulation_type` attribute +of the [tf.lite.TargetSpec](https://www.tensorflow.org/api_docs/python/tf/lite/TargetSpec) +object: + +```python +converter.optimizations = [tf.lite.Optimize.DEFAULT] +... +converter.target_spec.supported_types = [tf.float16] +converter.target_spec._experimental_supported_accumulation_type = tf.dtypes.float16 +``` When the above conditions are met, XNNPACK replace FP32 operators with their FP16 equivalents, and insert additional operators to convert model inputs @@ -486,7 +497,7 @@ is used. Forcing FP16 inference has several effects: * Besides ARM64 devices with ARMv8.2 FP16 arithmetics extension, forced FP16 inference is supported on x86/x86-64 devices with AVX2 extension in emulation mode: all elementary floating-point operations are computed in FP32, then -converted to FP16 and back to FP32. Note that such simulation is not exactly +converted to FP16 and back to FP32. Note that such simulation is not bit-exact equivalent to native FP16 inference, but simulates the effects of restricted mantissa precision and exponent range in the native FP16 arithmetics. @@ -512,171 +523,10 @@ TfLiteDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&xnnpack_options); ``` -Below is the list of operators supported in IEEE FP16 inference: - -#### `ABS` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `ADD` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `AVERAGE_POOL_2D` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `CEIL` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `CONV_2D` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `CONCATENATION` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `DEPTH_TO_SPACE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `DEPTHWISE_CONV_2D` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `DIV` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `FLOOR` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `FULLY_CONNECTED` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `HARD_SWISH` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `LEAKY_RELU` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `LOGISTIC` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `MAX_POOL_2D` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `MAXIMUM` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `MEAN` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `MINIMUM` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `MUL` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `NEG` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `PAD` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `PRELU` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `RELU` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `RELU6` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `RELU_N1_TO_1` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `RESHAPE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `RESIZE_BILINEAR` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `ROUND` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SLICE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SOFTMAX` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SPACE_TO_DEPTH` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SPLIT` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SQRT` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SQUARE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SQUARED_DIFFERENCE` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `STRIDED_SLICE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `SUB` - -* Must satisfy constraints on the floating-point (FP32) operator. -* Neither of the inputs can be static (use `kTfLiteMmapRo` allocation type). - -#### `TRANSPOSE` - -* Must satisfy constraints on the floating-point (FP32) operator. - -#### `TRANSPOSE_CONV` - -* Must satisfy constraints on the floating-point (FP32) operator. +XNNPACK has full feature parity between FP32 and FP16 operators: all operators +that are supported for FP32 inference are also supported for FP16 inference, +and vice versa. In particular, sparse inference operators are supported for FP16 +inference on ARM processors. ### Quantized Operators @@ -855,7 +705,8 @@ Below is the list of currently supported quantized operators: XNNPACK backend supports sparse inference for CNN models described in the [Fast Sparse ConvNets](https://arxiv.org/abs/1911.09723) paper. Sparse -inference is restricted to subgraphs with the following operators: +inference is restricted to subgraphs with the following floating-point +operators: * Sparse subgraph must store its weights in sparse representation (using `DENSIFY` operators in the TensorFlow Lite schema).
tensorflow/python/data/experimental/kernel_tests/service/snapshot_ft_test.py+4 −2 modified@@ -131,14 +131,16 @@ def testSnapshotRecoveryFailsWithOutOfBoundsSourceName(self): def testSnapshotRecoveryFailsWithBadSplitNames(self, bad_split_filename): cluster, _ = self.setup() write_file(os.path.join(self.source_dir(), bad_split_filename)) - with self.assertRaisesRegex(ValueError, "can't parse"): + with self.assertRaisesRegex( + ValueError, "Expected split_<local_split_index>_<global_split_index>"): cluster.restart_dispatcher() @combinations.generate(test_base.eager_only_combinations()) def testSnapshotRecoveryFailsWithOutOfOrderSplitName(self): cluster, _ = self.setup() write_file(os.path.join(self.source_dir(), "split_1_0")) - with self.assertRaisesRegex(ValueError, "found conflict"): + with self.assertRaisesRegex( + ValueError, "The local split index 1 exceeds the global split index 0"): cluster.restart_dispatcher() @combinations.generate(test_base.eager_only_combinations())
tensorflow/python/framework/BUILD+2 −0 modified@@ -1094,6 +1094,8 @@ py_library( name = "tensor_conversion_registry", srcs = ["tensor_conversion_registry.py"], srcs_version = "PY3", + # TODO(b/266747022): remove extra visibility + visibility = visibility + ["//learning/brain/experimental:__subpackages__"], deps = [ "//tensorflow/python/eager:context", ],
tensorflow/python/framework/errors_test_helper.cc+2 −2 modified@@ -21,8 +21,8 @@ PYBIND11_MODULE(_errors_test_helper, m) { m.def("TestRaiseFromStatus", [](int code) { tensorflow::Status status(static_cast<tensorflow::error::Code>(code), "test message"); - status.SetPayload("key1", "value1"); - status.SetPayload("key2", "value2"); + status.SetPayload("key1", absl::Cord("value1")); + status.SetPayload("key2", absl::Cord("value2")); MaybeRaiseRegisteredFromStatus(status); return 0; });
tensorflow/python/tpu/tpu_strategy_util.py+29 −3 modified@@ -22,6 +22,8 @@ from tensorflow.python.eager import context from tensorflow.python.eager import monitoring from tensorflow.python.eager.def_function import function +from tensorflow.python.eager.def_function import functions_run_eagerly +from tensorflow.python.eager.def_function import run_functions_eagerly from tensorflow.python.framework import device from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -111,6 +113,15 @@ def _tpu_init_fn(): # The TPU_SYSTEM device must match the device used in tpu.initialize_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. + run_eagerly = functions_run_eagerly() + if run_eagerly: + logging.warning( + "It looks like tf.function behavior was disabled, perhaps using" + " tf.config.run_functions_eagerly." + " tf.tpu.experimental.initialize_tpu_system requires tf.function to" + " work. This primitive will override the disable." + ) + run_functions_eagerly(False) try: with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access output = _tpu_init_fn() @@ -120,7 +131,9 @@ def _tpu_init_fn(): None, None, "TPUs not found in the cluster. Failed in initialization: " + str(e)) - + finally: + if run_eagerly is not None: + run_functions_eagerly(run_eagerly) # Clear out the eager context caches since the memory is invalid now. context.context()._initialize_logical_devices() # pylint: disable=protected-access @@ -221,8 +234,21 @@ def _tpu_shutdown_fn(): # The TPU_SYSTEM device must match the device used in tpu.shutdown_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. - with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access - _tpu_shutdown_fn() + run_eagerly = functions_run_eagerly() + if run_eagerly: + logging.warning( + "It looks like tf.function behavior was disabled, perhaps using" + " tf.config.run_functions_eagerly." + " tf.tpu.experimental.shutdown_tpu_system requires tf.function to" + " work. This primitive will override the disable." + ) + run_functions_eagerly(False) + try: + with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access + _tpu_shutdown_fn() + finally: + if run_eagerly is not None: + run_functions_eagerly(run_eagerly) # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches")
tensorflow/tools/ci_build/release/requirements_common.txt+3 −3 modified@@ -8,14 +8,14 @@ google_pasta ~= 0.2 h5py ~= 3.8.0 # Earliest version for Python 3.11 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower # to support TFX. Remove when Apache Beam upgrades to newer NumPy. -numpy ~= 1.21.4; python_version < '3.11' +numpy ~= 1.22.0; python_version < '3.11' numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11 opt_einsum ~= 3.3.0 protobuf ~= 3.19.3 # NOTE: Earliest version for Python 3.10 six ~= 1.16.0 termcolor ~= 2.1.1 typing_extensions ~= 3.10.0.0 -wheel ~= 0.36.2 +wheel ~= 0.38.1 wrapt ~= 1.14.1 # We need to pin the gast dependency exactly @@ -37,4 +37,4 @@ scipy ~= 1.9.2; python_version >= '3.11' # Earliest version for Python 3.11 # This is usually vendored in setuptools but ensure it gets installed in CI anyway # No bound here, we prefer the one in setuptools -packaging \ No newline at end of file +packaging
tensorflow/tools/ci_build/release/requirements_mac.txt+1 −1 modified@@ -1,7 +1,7 @@ -r requirements_common.txt # Dependencies only required for Mac -certifi ~= 2020.12.5 +certifi ~= 2022.12.07 # Install build related dependencies twine ~= 3.6.0
tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt+3 −3 modified@@ -11,15 +11,15 @@ google_pasta ~= 0.2 h5py ~= 3.8.0 # Earliest version for Python 3.11 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower # to support TFX. Remove when Apache Beam upgrades to newer NumPy. -numpy ~= 1.21.4; python_version < '3.11' +numpy ~= 1.22.0; python_version < '3.11' numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11 opt_einsum ~= 3.3.0 packaging ~= 21.3 protobuf ~= 3.20.1 six ~= 1.16.0 termcolor ~= 2.1.1 typing_extensions ~= 3.10.0.0 -wheel ~= 0.36.2 +wheel ~= 0.38.1 wrapt ~= 1.14.1 # We need to pin the gast dependency exactly gast == 0.4.0 @@ -48,4 +48,4 @@ twine ~= 3.6.0 # For user tool scripts junitparser ~= 2.2.0 lxml ~= 4.9.1 -pylint ~= 2.13.9 \ No newline at end of file +pylint ~= 2.13.9
tensorflow/tsl/c/tsl_status.cc+1 −1 modified@@ -36,7 +36,7 @@ void TSL_SetStatus(TSL_Status* s, TSL_Code code, const char* msg) { } void TSL_SetPayload(TSL_Status* s, const char* key, const char* value) { - s->status.SetPayload(key, value); + s->status.SetPayload(key, absl::Cord(absl::string_view(value))); } void TSL_SetStatusFromIOError(TSL_Status* s, int error_code,
tensorflow/tsl/c/tsl_status_helper_test.cc+2 −2 modified@@ -24,8 +24,8 @@ namespace { TEST(StatusHelper, TestStatusHelper) { TSL_Status* s = TSL_NewStatus(); Status cc_status(errors::InvalidArgument("some error")); - cc_status.SetPayload("key1", "value1"); - cc_status.SetPayload("key2", "value2"); + cc_status.SetPayload("key1", absl::Cord("value1")); + cc_status.SetPayload("key2", absl::Cord("value2")); Set_TSL_Status_from_Status(s, cc_status); ASSERT_EQ(TSL_INVALID_ARGUMENT, TSL_GetCode(s)); ASSERT_EQ(std::string("some error"), TSL_Message(s));
tensorflow/tsl/distributed_runtime/coordination/coordination_service_error_util.h+5 −3 modified@@ -29,7 +29,7 @@ constexpr absl::string_view CoordinationErrorPayloadKey() { // Mark error as a coordination service error (as opposed to RPC // errors). inline Status MakeCoordinationError(Status s) { - s.SetPayload(CoordinationErrorPayloadKey(), ""); + s.SetPayload(CoordinationErrorPayloadKey(), absl::Cord("")); return s; } @@ -43,14 +43,16 @@ inline Status MakeCoordinationError(Status s, tensorflow::CoordinationServiceError error; *error.mutable_source_task() = origin; error.set_is_reported_error(is_reported_error); - s.SetPayload(CoordinationErrorPayloadKey(), error.SerializeAsString()); + s.SetPayload(CoordinationErrorPayloadKey(), + absl::Cord(error.SerializeAsString())); return s; } // Mark error as a coordination service error with payload. inline Status MakeCoordinationError( Status s, const tensorflow::CoordinationServiceError& payload) { - s.SetPayload(CoordinationErrorPayloadKey(), payload.SerializeAsString()); + s.SetPayload(CoordinationErrorPayloadKey(), + absl::Cord(payload.SerializeAsString())); return s; } } // namespace tsl
tensorflow/tsl/distributed_runtime/rpc/grpc_util.h+3 −3 modified@@ -71,12 +71,12 @@ inline void InsertSerializedPayloads(Status& s, std::string payloads) { tensorflow::distributed_runtime::GrpcPayloadContainer container; if (container.ParseFromString(payloads)) { for (const auto& key_val : container.payloads()) { - s.SetPayload(key_val.first, key_val.second); + s.SetPayload(key_val.first, absl::Cord(key_val.second)); } } else { s.SetPayload(kGrpcPayloadsLost, - tensorflow::distributed_runtime::GrpcPayloadsLost() - .SerializeAsString()); + absl::Cord(tensorflow::distributed_runtime::GrpcPayloadsLost() + .SerializeAsString())); } }
tensorflow/tsl/distributed_runtime/rpc/grpc_util_test.cc+1 −1 modified@@ -71,7 +71,7 @@ TestRequest MakeProto(int size) { TEST(PayloadSerialization, PayloadsAreTransmitted) { Status status = errors::InvalidArgument("invalid arg message"); - status.SetPayload("a", "\\xFF\\x02\\x03"); + status.SetPayload("a", absl::Cord("\\xFF\\x02\\x03")); Status status_recovered = FromGrpcStatus(ToGrpcStatus(status)); ASSERT_TRUE(status_recovered.GetPayload("a").has_value());
tensorflow/tsl/platform/errors.h+3 −2 modified@@ -21,6 +21,7 @@ limitations under the License. #include <utility> #include "absl/base/attributes.h" +#include "absl/strings/cord.h" #include "absl/strings/str_join.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/macros.h" @@ -102,15 +103,15 @@ inline void InsertPayloads( ::tsl::Status& status, const std::unordered_map<std::string, std::string>& payloads) { for (const auto& payload : payloads) { - status.SetPayload(payload.first, payload.second); + status.SetPayload(payload.first, absl::Cord(payload.second)); } } // Copies all payloads from one Status to another. Will overwrite existing // payloads in the destination if they exist with the same key. inline void CopyPayloads(const ::tsl::Status& from, ::tsl::Status& to) { from.ForEachPayload([&to](tsl::StringPiece key, tsl::StringPiece value) { - to.SetPayload(key, value); + to.SetPayload(key, absl::Cord(value)); }); }
tensorflow/tsl/platform/test.h+7 −0 modified@@ -85,6 +85,13 @@ int RandomSeed(); // NOTE: This function is not thread-safe. int PickUnusedPortOrDie(); +// Constant which is false internally and true in open source. +#ifdef PLATFORM_GOOGLE +inline constexpr bool kIsOpenSource = false; +#else +inline constexpr bool kIsOpenSource = true; +#endif // PLATFORM_GOOGLE + } // namespace testing } // namespace tsl
tensorflow/tsl/profiler/lib/BUILD+2 −0 modified@@ -213,6 +213,7 @@ cc_library( deps = [ "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) @@ -226,6 +227,7 @@ tsl_cc_test( "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], )
tensorflow/tsl/profiler/lib/traceme_encode.h+13 −5 modified@@ -20,6 +20,7 @@ limitations under the License. #include <initializer_list> #include <string> +#include "absl/base/attributes.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -31,14 +32,21 @@ namespace profiler { // An argument passed to TraceMeEncode. struct TraceMeArg { - // This constructor is required because absl::AlphaNum is non-copyable. - template <typename Value> - TraceMeArg(absl::string_view k, Value v) : key(k), value(v) {} + // String conversions of value types are supported via AlphaNum. We keep a + // reference to the AlphaNum's internal buffer here, so it must remain valid + // for the lifetime of this object. We cannot store it by value because it is + // not safe to construct an AlphaNum as a member of a class, particularly when + // AbslStringify is being used (it may reference default arguments that are on + // the caller's stack, if we constructed it here those default arguments would + // be destroyed before they are used). + TraceMeArg(absl::string_view k, + const absl::AlphaNum& v ABSL_ATTRIBUTE_LIFETIME_BOUND) + : key(k), value(v.Piece()) {} TF_DISALLOW_COPY_AND_ASSIGN(TraceMeArg); absl::string_view key; - absl::AlphaNum value; + absl::string_view value; }; namespace traceme_internal { @@ -74,7 +82,7 @@ TF_ATTRIBUTE_ALWAYS_INLINE inline std::string AppendArgs( for (const auto& arg : args) { out = Append(out, arg.key); *out++ = '='; - out = Append(out, arg.value.Piece()); + out = Append(out, arg.value); *out++ = ','; } *(out - 1) = '#';
tensorflow/tsl/profiler/lib/traceme_encode_test.cc+22 −0 modified@@ -17,6 +17,7 @@ limitations under the License. #include <string> #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/tsl/platform/platform.h" #include "tensorflow/tsl/platform/test.h" @@ -53,6 +54,27 @@ TEST(TraceMeEncodeTest, TemporaryStringTest) { } #endif +// This can be removed when the absl version has been updated to include +// AbslStringify for open source builds. +#if defined(PLATFORM_GOOGLE) + +struct Point { + template <typename Sink> + friend void AbslStringify(Sink& sink, const Point& p) { + absl::Format(&sink, "(%d, %d)", p.x, p.y); + } + + int x; + int y; +}; + +TEST(TraceMeEncodeTest, AbslStringifyTest) { + EXPECT_EQ(TraceMeEncode("Plot", {{"point", Point{10, 20}}}), + "Plot#point=(10, 20)#"); +} + +#endif + TEST(TraceMeEncodeTest, NoNameTest) { EXPECT_EQ(TraceMeEncode({{"context", "World"}, {"request_id", 42}}), "#context=World,request_id=42#");
third_party/tf_runtime/workspace.bzl+2 −2 modified@@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "c1248a015d23949afa2471bb21f6f52850aead7d" - TFRT_SHA256 = "8cdd8ea905478ac4ffd36ffb39cebe288d3b840d71a02d418bc6a8a760f92af8" + TFRT_COMMIT = "c653281a1a23c0c3d41536a983c7d10fcc5b1fbf" + TFRT_SHA256 = "3d1edd27c4e36d9cfc9493aef7088489babb370d2a7955bab3545acfbb024ccf" tf_http_archive( name = "tf_runtime",
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
4- github.com/advisories/GHSA-68v3-g9cm-rmm6ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2023-25658ghsaADVISORY
- github.com/tensorflow/tensorflow/commit/ff459137c2716a2a60f7d441b855fcb466d778cbghsax_refsource_MISCWEB
- github.com/tensorflow/tensorflow/security/advisories/GHSA-68v3-g9cm-rmm6ghsax_refsource_CONFIRMWEB
News mentions
0No linked articles in our index yet.