VYPR
Moderate severityNVD Advisory· Published Mar 20, 2025· Updated Mar 20, 2025

CSRF in mlflow/mlflow

CVE-2025-1473

Description

A Cross-Site Request Forgery (CSRF) vulnerability exists in the Signup feature of mlflow/mlflow versions 2.17.0 to 2.20.1. This vulnerability allows an attacker to create a new account, which may be used to perform unauthorized actions on behalf of the malicious user.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

A CSRF vulnerability in MLflow's Signup feature allows an attacker to create unauthorized accounts.

A Cross-Site Request Forgery (CSRF) vulnerability exists in the Signup feature of MLflow versions 2.17.0 to 2.20.1 [1][2]. This flaw allows an attacker to forge a request that, when executed by an authenticated user's browser, creates a new account controlled by the attacker. The root cause is the absence of anti-CSRF tokens or other origin-validation checks on the signup endpoint.

To exploit this vulnerability, an attacker must trick a victim user who is authenticated to an MLflow instance into visiting a malicious page. No additional authentication is needed for the attacker, as the forged request leverages the victim's existing session. The attack can be performed remotely via any standard CSRF vector, such as embedding a form submission or a cross-origin image request.

Successful exploitation enables the attacker to create a new account on the MLflow instance. This rogue account can then be used to perform any actions permitted to a normal user, such as accessing or modifying models, experiments, and other artifacts, depending on the instance's permissions. This could lead to data exposure, model tampering, or further privilege escalation if combined with other vulnerabilities.

The MLflow maintainers have addressed this vulnerability; users should update to a version containing the fix [1]. As of the publication date, no workaround has been provided, and upgrading is the recommended mitigation. The vulnerability is not listed on CISA's Known Exploited Vulnerabilities (KEV) catalog.

AI Insight generated on May 20, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
mlflowPyPI
>= 2.17.0, < 2.20.32.20.3

Affected products

4
  • Mlflow/Mlflowllm-fuzzy
    Range: >=2.17.0, <=2.20.1
  • osv-coords2 versions
    >= 2.17.0, < 2.20.1+ 1 more
    • (no CPE)range: >= 2.17.0, < 2.20.1
    • (no CPE)range: >= 2.17.0, < 2.20.3
  • mlflow/mlflow/mlflowv5
    Range: unspecified

Patches

1
ecfa61cb43d3

Merge branch 'master' into daniellok-db/fix-huntr

https://github.com/mlflow/mlflowDaniel LokFeb 5, 2025via ghsa
36 files changed · +581 270
  • dev/build.py+8 1 modified
    @@ -26,7 +26,14 @@ def restore_changes():
         try:
             yield
         finally:
    -        subprocess.check_call(["git", "restore", ":^dev/build.py"])
    +        subprocess.check_call(
    +            [
    +                "git",
    +                "restore",
    +                "README.md",
    +                "pyproject.toml",
    +            ]
    +        )
     
     
     def main():
    
  • docs/docs/llms/dspy/index.mdx+22 20 modified
    @@ -48,25 +48,41 @@ use a DSPy program.
     
     ## Concepts
     
    -### `Module`
    +#### `Module`
     
     [Modules](https://dspy.ai/learn/programming/modules) are components that handle specific text transformations, like answering questions or summarizing. They replace traditional hand-written prompts and can learn from examples, making them more adaptable.
     
    -### `Signature`
    +#### `Signature`
     
     A [signature](https://dspy.ai/learn/programming/signatures) is a natural language description of a module's input and output behavior. For example, _"question -> answer"_ specifies that the module should take a question as input and return an answer.
     
    -### `Optimizer`
    +#### `Optimizer`
     
     A [optimizer](https://dspy.ai/learn/optimization/optimizers) improves LM pipelines by adjusting modules to meet a performance metric, either by generating better prompts or fine-tuning models.
     
    -### `Program`
    +#### `Program`
     
     A program is a a set of modules connected into a pipeline to perform complex tasks. DSPy programs are flexible, allowing you to optimize and adapt them using the compiler.
     
    -## Usage
     
    -### Saving and Loading DSPy Program in MLflow Experiment
    +## Automatic Tracing
    +
    +![DSPy Tracing via autolog](/images/llms/tracing/dspy-tracing.gif)
    +
    +
    +[MLflow Tracing](/tracing) tracing is a powerful feature that allows you to monitor and debug your DSPy programs. With MLflow, you can enable auto tracing just by calling the <APILink fn="mlflow.dspy.autolog" /> function in your code.
    +
    +```python
    +import mlflow
    +
    +mlflow.dspy.autolog()
    +```
    +
    +Once enabled, MLflow will generate traces whenever your DSPy program is executed and record them in your MLflow Experiment.
    +
    +Learn more about MLflow DSPy tracing capabilities [here](/tracing/integrations/dspy).
    +
    +## Tracking DSPy Program in MLflow Experiment
     
     #### Creating a DSPy Program
     
    @@ -172,20 +188,6 @@ To load the DSPy program itself back instead of the PyFunc-wrapped model, use th
     model = mlflow.dspy.load_model(model_uri)
     ```
     
    -### Enabling and Disabling Auto Tracing for DSPy Programs
    -
    -Auto tracing is a powerful feature that allows you to monitor and debug your DSPy programs. With MLflow, you can enable auto tracing just by calling the <APILink fn="mlflow.dspy.autolog" /> function in your code.
    -
    -```python
    -import mlflow
    -
    -mlflow.dspy.autolog()
    -```
    -
    -Once enabled, MLflow will generate traces whenever your DSPy program is executed and record them in your MLflow Experiment.
    -
    -You can disable auto-tracing for DSPy by calling _mlflow.dspy.autolog(disabled=True)_.
    -
     ## FAQ
     
     ### How can I save a compiled vs. uncompiled model?
    
  • docs/docs/tracing/integrations/dspy.mdx+2 2 modified
    @@ -11,7 +11,7 @@ import TabItem from "@theme/TabItem";
     
     # Tracing DSPy🧩
     
    -![LlamaIndex Tracing via autolog](/images/llms/tracing/llamaindex-tracing.gif)
    +![DSPy Tracing via autolog](/images/llms/tracing/dspy-tracing.gif)
     
     [DSPy](https://dspy.ai/) is an open-source framework for building modular AI systems and offers algorithms for optimizing their prompts and weights.
     
    @@ -21,7 +21,7 @@ for DSPy by calling the <APILink fn="mlflow.dspy.autolog" /> function, and neste
     ```python
     import mlflow
     
    -mlflow.llama_index.autolog()
    +mlflow.dspy.autolog()
     ```
     
     :::tip
    
  • docs/static/images/llms/tracing/dspy-tracing.gif+0 0 added
  • docs/static/images/llms/tracing/dspy-tracing.png+0 0 removed
  • mlflow/entities/trace_data.py+21 0 modified
    @@ -2,6 +2,7 @@
     from typing import Any, Optional
     
     from mlflow.entities import Span
    +from mlflow.tracing.constant import SpanAttributeKey
     
     
     @dataclass
    @@ -36,3 +37,23 @@ def to_dict(self) -> dict[str, Any]:
                 "request": self.request,
                 "response": self.response,
             }
    +
    +    @property
    +    def intermediate_outputs(self) -> Optional[dict[str, Any]]:
    +        """
    +        Returns intermediate outputs produced by the model or agent while handling the request.
    +        There are mainly two flows to return intermediate outputs:
    +        1. When a trace is generate by the `mlflow.log_trace` API,
    +        return `intermediate_outputs` attribute of the span.
    +        2. When a trace is created normally with a tree of spans,
    +        aggregate the outputs of non-root spans.
    +        """
    +        root_span = self._get_root_span()
    +        if root_span and root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS):
    +            return root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS)
    +        # TODO: handle the second case for a normal trace with spans
    +
    +    def _get_root_span(self) -> Optional[Span]:
    +        for span in self.spans:
    +            if span.parent_id is None:
    +                return span
    
  • mlflow/entities/trace.py+0 21 modified
    @@ -12,7 +12,6 @@
     from mlflow.entities.trace_info import TraceInfo
     from mlflow.exceptions import MlflowException
     from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
    -from mlflow.tracing.constant import SpanAttributeKey
     
     _logger = logging.getLogger(__name__)
     
    @@ -232,23 +231,3 @@ def pandas_dataframe_columns() -> list[str]:
                 "spans",
                 "tags",
             ]
    -
    -    @property
    -    def intermediate_outputs(self) -> Optional[dict[str, Any]]:
    -        """
    -        Returns intermediate outputs within the trace.
    -        There are mainly two flows to return intermediate outputs:
    -        1. When a trace only has one root span,
    -        return `intermediate_outputs` attribute of the span.
    -        2. When a trace is created normally with a tree of spans,
    -        aggregate the outputs of non-root spans.
    -        """
    -        root_span = self._get_root_span()
    -        if root_span and root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS):
    -            return root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS)
    -        # TODO: handle the second case for a normal trace with spans
    -
    -    def _get_root_span(self) -> Optional[Span]:
    -        for span in self.data.spans:
    -            if span.parent_id is None:
    -                return span
    
  • mlflow/environment_variables.py+11 3 modified
    @@ -551,7 +551,9 @@ def get(self):
     #: (default: ``True``)
     MLFLOW_ALLOW_HTTP_REDIRECTS = _BooleanEnvironmentVariable("MLFLOW_ALLOW_HTTP_REDIRECTS", True)
     
    -# Specifies the timeout for deployment client APIs to declare a request has timed out
    +#: Specifies the client-based timeout (in seconds) when making an HTTP request to a deployment
    +#: target. Used within the `predict` and `predict_stream` APIs.
    +#: (default: ``120``)
     MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT = _EnvironmentVariable(
         "MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT", int, 120
     )
    @@ -711,6 +713,12 @@ def get(self):
         "_MLFLOW_IS_IN_SERVING_ENVIRONMENT", None
     )
     
    -# Secret key for the Flask app. This is necessary for enabling CSRF protection
    -# in the UI signup page when running the app with basic authentication enabled
    +#: Secret key for the Flask app. This is necessary for enabling CSRF protection
    +#: in the UI signup page when running the app with basic authentication enabled
     MLFLOW_FLASK_SERVER_SECRET_KEY = _EnvironmentVariable("MLFLOW_FLASK_SERVER_SECRET_KEY", str, None)
    +
    +#: Specifies the max length (in chars) of an experiment's artifact location.
    +#: The default is 2048.
    +MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH = _EnvironmentVariable(
    +    "MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH", int, 2048
    +)
    
  • mlflow/__init__.py+2 0 modified
    @@ -130,6 +130,7 @@
         get_current_active_span,
         get_last_active_trace,
         get_trace,
    +    log_trace,
         search_traces,
         start_span,
         trace,
    @@ -226,6 +227,7 @@
         "log_params",
         "log_table",
         "log_text",
    +    "log_trace",
         "login",
         "pyfunc",
         "register_model",
    
  • mlflow/ml-package-versions.yml+2 0 modified
    @@ -740,6 +740,8 @@ langchain:
               # Required to run tests/openai/mock_openai.py
               "fastapi",
               "uvicorn",
    +          # Some model logging/loading requires langchain community
    +          "langchain-community",
             ]
           "< 0.3.0": ["langchain_openai<0.2.0"]
           ">= 0.3.0": [
    
  • mlflow/openai/__init__.py+12 0 modified
    @@ -899,6 +899,18 @@ def autolog(
                 patched_call,
             )
     
    +    try:
    +        from openai.resources.beta.chat.completions import Completions as BetaChatCompletions
    +    except ImportError:
    +        pass
    +    else:
    +        safe_patch(
    +            FLAVOR_NAME,
    +            BetaChatCompletions,
    +            "parse",
    +            patched_call,
    +        )
    +
         # Patch Swarm agent to generate traces
         try:
             from swarm import Swarm
    
  • mlflow/openai/_openai_autolog.py+10 1 modified
    @@ -73,7 +73,7 @@ def _set_api_key_env_var(client):
             os.environ.pop("OPENAI_API_KEY")
     
     
    -def _get_span_type(task) -> str:
    +def _get_span_type(task: type) -> str:
         from openai.resources.chat.completions import Completions as ChatCompletions
         from openai.resources.completions import Completions
         from openai.resources.embeddings import Embeddings
    @@ -83,6 +83,15 @@ def _get_span_type(task) -> str:
             Completions: SpanType.LLM,
             Embeddings: SpanType.EMBEDDING,
         }
    +
    +    try:
    +        # Only available in openai>=1.40.0
    +        from openai.resources.beta.chat.completions import Completions as BetaChatCompletions
    +
    +        span_type_mapping[BetaChatCompletions] = SpanType.CHAT_MODEL
    +    except ImportError:
    +        pass
    +
         return span_type_mapping.get(task, SpanType.UNKNOWN)
     
     
    
  • mlflow/pyfunc/scoring_server/__init__.py+1 1 modified
    @@ -512,7 +512,7 @@ async def transformation(request: Request):
     
             data = await request.body()
             content_type = request.headers.get("content-type")
    -        # TODO: convert "invocations" to an async method to make iternal logic fully non-blocking.
    +        # TODO: convert "invocations" to an async method to make internal logic fully non-blocking.
             result = await asyncio.to_thread(invocations, data, content_type, model, input_schema)
     
             return Response(
    
  • mlflow/pytorch/__init__.py+10 1 modified
    @@ -695,7 +695,7 @@ def load_model(model_uri, dst_path=None, **kwargs):
         return _load_model(path=torch_model_artifacts_path, **kwargs)
     
     
    -def _load_pyfunc(path, model_config=None):  # noqa: D417
    +def _load_pyfunc(path, model_config=None, weights_only=False):  # noqa: D417
         """
         Load PyFunc implementation. Called by ``pyfunc.load_model``.
     
    @@ -717,6 +717,15 @@ def _load_pyfunc(path, model_config=None):  # noqa: D417
             else:
                 device = _TORCH_CPU_DEVICE_NAME
     
    +    # in pytorch >= 2.6.0, the `weights_only` kwarg default has been changed from
    +    # `False` to `True`. this can cause pickle deserialization errors when loading
    +    # models, unless the model classes have been explicitly marked as safe using
    +    # `torch.serialization.add_safe_globals()`
    +    if Version(torch.__version__) >= Version("2.6.0"):
    +        return _PyTorchWrapper(
    +            _load_model(path, device=device, weights_only=weights_only), device=device
    +        )
    +
         return _PyTorchWrapper(_load_model(path, device=device), device=device)
     
     
    
  • mlflow/store/tracking/file_store.py+5 0 modified
    @@ -93,6 +93,7 @@
     from mlflow.utils.validation import (
         _validate_batch_log_data,
         _validate_batch_log_limits,
    +    _validate_experiment_artifact_location_length,
         _validate_experiment_id,
         _validate_experiment_name,
         _validate_metric,
    @@ -403,6 +404,10 @@ def _validate_experiment_does_not_exist(self, name):
         def create_experiment(self, name, artifact_location=None, tags=None):
             self._check_root_dir()
             _validate_experiment_name(name)
    +
    +        if artifact_location:
    +            _validate_experiment_artifact_location_length(artifact_location)
    +
             self._validate_experiment_does_not_exist(name)
             experiment_id = _generate_unique_integer_id()
             return self._create_experiment_with_id(name, str(experiment_id), artifact_location, tags)
    
  • mlflow/store/tracking/sqlalchemy_store.py+2 0 modified
    @@ -87,6 +87,7 @@
         _validate_batch_log_data,
         _validate_batch_log_limits,
         _validate_dataset_inputs,
    +    _validate_experiment_artifact_location_length,
         _validate_experiment_name,
         _validate_experiment_tag,
         _validate_metric,
    @@ -267,6 +268,7 @@ def create_experiment(self, name, artifact_location=None, tags=None):
             _validate_experiment_name(name)
             if artifact_location:
                 artifact_location = resolve_uri_if_local(artifact_location)
    +            _validate_experiment_artifact_location_length(artifact_location)
             with self.ManagedSessionMaker() as session:
                 try:
                     creation_time = get_current_time_millis()
    
  • mlflow/tracing/constant.py+4 2 modified
    @@ -26,8 +26,10 @@ class SpanAttributeKey:
         # such as evaluation
         CHAT_MESSAGES = "mlflow.chat.messages"
         CHAT_TOOLS = "mlflow.chat.tools"
    -    # This attribute is not empty only on the root span.
    -    # Used to populate `intermediate_output` field of a trace.
    +    # This attribute is used to populate `intermediate_outputs` property of a trace data
    +    # representing intermediate outputs of the trace. This attribute is not empty only on
    +    # the root span of a trace created by the `mlflow.log_trace` API. The `intermediate_outputs`
    +    # property of the normal trace is generated by the outputs of non-root spans.
         INTERMEDIATE_OUTPUTS = "mlflow.trace.intermediate_outputs"
     
     
    
  • mlflow/tracing/fluent.py+79 0 modified
    @@ -727,6 +727,85 @@ def predict(input):
             )
     
     
    +@experimental
    +def log_trace(
    +    name: str = "Task",
    +    request: Optional[Any] = None,
    +    response: Optional[Any] = None,
    +    intermediate_outputs: Optional[dict[str, Any]] = None,
    +    attributes: Optional[dict[str, Any]] = None,
    +    tags: Optional[dict[str, str]] = None,
    +    start_time_ms: Optional[int] = None,
    +    execution_time_ms: Optional[int] = None,
    +) -> str:
    +    """
    +    Create a trace with a single root span.
    +    This API is useful when you want to log an arbitrary (request, response) pair
    +    without structured OpenTelemetry spans. The trace is linked to the active experiment.
    +
    +    Args:
    +        name: The name of the trace (and the root span). Default to "Task".
    +        request: Input data for the entire trace. This is also set on the root span of the trace.
    +        response: Output data for the entire trace. This is also set on the root span of the trace.
    +        intermediate_outputs: A dictionary of intermediate outputs produced by the model or agent
    +            while handling the request. Keys are the names of the outputs,
    +            and values are the outputs themselves. Values must be JSON-serializable.
    +        attributes: A dictionary of attributes to set on the root span of the trace.
    +        tags: A dictionary of tags to set on the trace.
    +        start_time_ms: The start time of the trace in milliseconds since the UNIX epoch.
    +            When not specified, current time is used for start and end time of the trace.
    +        execution_time_ms: The execution time of the trace in milliseconds since the UNIX epoch.
    +
    +    Returns:
    +        The request ID of the logged trace.
    +
    +    Example:
    +
    +    .. code-block:: python
    +        :test:
    +
    +        import time
    +        import mlflow
    +
    +        request_id = mlflow.log_trace(
    +            request="Does mlflow support tracing?",
    +            response="Yes",
    +            intermediate_outputs={
    +                "retrieved_documents": ["mlflow documentation"],
    +                "system_prompt": ["answer the question with yes or no"],
    +            },
    +            start_time_ms=int(time.time() * 1000),
    +            execution_time_ms=5129,
    +        )
    +        trace = mlflow.get_trace(request_id)
    +
    +        print(trace.data.intermediate_outputs)
    +    """
    +    client = MlflowClient()
    +    if intermediate_outputs:
    +        if attributes:
    +            attributes.update(SpanAttributeKey.INTERMEDIATE_OUTPUTS, intermediate_outputs)
    +        else:
    +            attributes = {SpanAttributeKey.INTERMEDIATE_OUTPUTS: intermediate_outputs}
    +
    +    span = client.start_trace(
    +        name=name,
    +        inputs=request,
    +        attributes=attributes,
    +        tags=tags,
    +        start_time_ns=start_time_ms * 1000000 if start_time_ms else None,
    +    )
    +    client.end_trace(
    +        request_id=span.request_id,
    +        outputs=response,
    +        end_time_ns=(start_time_ms + execution_time_ms) * 1000000
    +        if start_time_ms and execution_time_ms
    +        else None,
    +    )
    +
    +    return span.request_id
    +
    +
     def _merge_trace(
         trace: Trace,
         target_request_id: str,
    
  • mlflow/tracing/processor/inference_table.py+5 26 modified
    @@ -21,17 +21,6 @@
     _logger = logging.getLogger(__name__)
     
     
    -_HEADER_REQUEST_ID_KEY = "X-Request-Id"
    -
    -
    -# Extracting for testing purposes
    -def _get_flask_request():
    -    import flask
    -
    -    if flask.has_request_context():
    -        return flask.request
    -
    -
     class InferenceTableSpanProcessor(SimpleSpanProcessor):
         """
         Defines custom hooks to be executed when a span is started or ended (before exporting).
    @@ -56,21 +45,11 @@ def on_start(self, span: OTelSpan, parent_context: Optional[Context] = None):
             """
             request_id = maybe_get_request_id()
             if request_id is None:
    -            # If this is invoked outside of a flask request, it raises error about
    -            # outside of request context. We should avoid this by skipping the trace processing
    -            if flask_request := _get_flask_request():
    -                request_id = flask_request.headers.get(_HEADER_REQUEST_ID_KEY)
    -                if not request_id:
    -                    _logger.warning(
    -                        "Request ID not found in the request headers. Skipping trace processing."
    -                    )
    -                    return
    -            else:
    -                _logger.warning(
    -                    "Failed to get request ID from the request headers because "
    -                    "request context is not available. Skipping trace processing."
    -                )
    -                return
    +            _logger.warning(
    +                "Failed to get request ID from the request headers because "
    +                "request context is not available. Skipping trace processing."
    +            )
    +
             span.set_attribute(SpanAttributeKey.REQUEST_ID, json.dumps(request_id))
             tags = {}
             if dependencies_schema := maybe_get_dependencies_schemas():
    
  • mlflow/utils/exception_utils.py+1 1 modified
    @@ -9,6 +9,6 @@ def get_stacktrace(error):
                 tb = traceback.format_exception(error.__class__, error, error.__traceback__)
             else:
                 tb = traceback.format_exception(error)
    -        return (msg + "\n\n".join(tb)).strip()
    +        return (msg + "".join(tb)).strip()
         except Exception:
             return msg
    
  • mlflow/utils/validation.py+15 1 modified
    @@ -9,7 +9,10 @@
     import re
     
     from mlflow.entities import Dataset, DatasetInput, InputTag, Param, RunTag
    -from mlflow.environment_variables import MLFLOW_TRUNCATE_LONG_VALUES
    +from mlflow.environment_variables import (
    +    MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH,
    +    MLFLOW_TRUNCATE_LONG_VALUES,
    +)
     from mlflow.exceptions import MlflowException
     from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
     from mlflow.store.db.db_types import DATABASE_ENGINES
    @@ -617,3 +620,14 @@ def _validate_trace_tag(key, value):
         key = _validate_length_limit("key", MAX_TRACE_TAG_KEY_LENGTH, key)
         value = _validate_length_limit("value", MAX_TRACE_TAG_VAL_LENGTH, value, truncate=True)
         return key, value
    +
    +
    +def _validate_experiment_artifact_location_length(artifact_location: str):
    +    max_length = MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.get()
    +    if len(artifact_location) > max_length:
    +        raise MlflowException(
    +            "Invalid artifact path length. The length of the artifact path cannot be "
    +            f"greater than {max_length} characters. To configure this limit, please set the "
    +            "MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH environment variable.",
    +            INVALID_PARAMETER_VALUE,
    +        )
    
  • pyproject.skinny.toml+1 1 modified
    @@ -35,7 +35,7 @@ dependencies = [
       "opentelemetry-sdk<3,>=1.9.0",
       "packaging<25",
       "protobuf<6,>=3.12.0",
    -  "pydantic<3,>=1.0",
    +  "pydantic<3,>=1.10.8",
       "pyyaml<7,>=5.1",
       "requests<3,>=2.17.3",
       "sqlparse<1,>=0.4.0",
    
  • pyproject.toml+1 1 modified
    @@ -50,7 +50,7 @@ dependencies = [
       "pandas<3",
       "protobuf<6,>=3.12.0",
       "pyarrow<19,>=4.0.0",
    -  "pydantic<3,>=1.0",
    +  "pydantic<3,>=1.10.8",
       "pyyaml<7,>=5.1",
       "requests<3,>=2.17.3",
       "scikit-learn<2",
    
  • requirements/skinny-requirements.txt+1 1 modified
    @@ -15,5 +15,5 @@ cachetools<6,>=5.0.0
     opentelemetry-api<3,>=1.9.0
     opentelemetry-sdk<3,>=1.9.0
     databricks-sdk<1,>=0.20.0
    -pydantic<3,>=1.0
    +pydantic<3,>=1.10.8
     typing-extensions<5,>=4.0.0
    
  • requirements/skinny-requirements.yaml+1 1 modified
    @@ -77,7 +77,7 @@ databricks-sdk:
     
     pydantic:
       pip_release: pydantic
    -  minimum: "1.0"
    +  minimum: "1.10.8"
       max_major_version: 2
     
     typing-extensions:
    
  • tests/entities/test_trace_data.py+27 0 modified
    @@ -123,3 +123,30 @@ def always_fail(self):
         # Convert back from dict to TraceData and compare
         trace_data_from_dict = TraceData.from_dict(trace_data.to_dict())
         assert trace_data.to_dict() == trace_data_from_dict.to_dict()
    +
    +
    +def test_intermediate_outputs_from_attribute():
    +    intermediate_outputs = {
    +        "retrieved_documents": ["document 1", "document 2"],
    +        "generative_prompt": "prompt",
    +    }
    +
    +    def run():
    +        with mlflow.start_span(name="run") as span:
    +            span.set_attribute("mlflow.trace.intermediate_outputs", intermediate_outputs)
    +
    +    run()
    +    trace = mlflow.get_last_active_trace()
    +
    +    assert trace.data.intermediate_outputs == intermediate_outputs
    +
    +
    +def test_intermediate_outputs_no_value():
    +    def run():
    +        with mlflow.start_span(name="run") as span:
    +            span.set_outputs(1)
    +
    +    run()
    +    trace = mlflow.get_last_active_trace()
    +
    +    assert trace.data.intermediate_outputs is None
    
  • tests/entities/test_trace.py+0 27 modified
    @@ -296,30 +296,3 @@ def run(x: int) -> int:
     
         with pytest.raises(MlflowException, match="Invalid type for 'name'"):
             trace.search_spans(name=123)
    -
    -
    -def test_intermediate_outputs_from_attribute():
    -    intermediate_outputs = {
    -        "retrieved_documents": ["document 1", "document 2"],
    -        "generative_prompt": "prompt",
    -    }
    -
    -    def run():
    -        with mlflow.start_span(name="run") as span:
    -            span.set_attribute("mlflow.trace.intermediate_outputs", intermediate_outputs)
    -
    -    run()
    -    trace = mlflow.get_last_active_trace()
    -
    -    assert trace.intermediate_outputs == intermediate_outputs
    -
    -
    -def test_intermediate_outputs_no_value():
    -    def run():
    -        with mlflow.start_span(name="run") as span:
    -            span.set_outputs(1)
    -
    -    run()
    -    trace = mlflow.get_last_active_trace()
    -
    -    assert trace.intermediate_outputs is None
    
  • tests/langchain/test_langchain_autolog.py+90 2 modified
    @@ -29,6 +29,8 @@
     from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
     from langchain_core.runnables.config import RunnableConfig
     
    +from mlflow.entities.trace import Trace
    +
     # NB: We run this test suite twice - once with langchain_community installed and once without.
     try:
         from langchain_community.chat_models import ChatOpenAI
    @@ -61,11 +63,11 @@
     from mlflow.models.dependencies_schemas import DependenciesSchemasType, set_retriever_schema
     from mlflow.models.signature import infer_signature
     from mlflow.models.utils import _read_example
    -from mlflow.tracing.constant import SpanAttributeKey, TraceMetadataKey
    +from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey, TraceMetadataKey
     
     from tests.langchain.conftest import DeterministicDummyEmbeddings
     from tests.tracing.conftest import async_logging_enabled
    -from tests.tracing.helper import get_traces
    +from tests.tracing.helper import get_traces, score_in_model_serving
     
     MODEL_DIR = "model"
     # The mock OpenAI endpoint simply echos the prompt back as the completion.
    @@ -1171,3 +1173,89 @@ def _mock_import(name, *args):
             traces = get_traces()
             assert len(traces) == 2
             assert all(len(trace.data.spans) == 11 for trace in traces)
    +
    +
    +def test_langchain_auto_tracing_in_serving_runnable():
    +    mlflow.langchain.autolog()
    +
    +    chain = create_openai_runnable()
    +
    +    with mlflow.start_run():
    +        model_info = mlflow.langchain.log_model(
    +            chain,
    +            "model",
    +            input_example={"product": "MLflow"},
    +        )
    +
    +    expected_output = '[{"role": "user", "content": "What is MLflow?"}]'
    +
    +    request_id, predictions, trace = score_in_model_serving(
    +        model_info.model_uri,
    +        [{"product": "MLflow"}],
    +    )
    +
    +    assert predictions == [expected_output]
    +    trace = Trace.from_dict(trace)
    +    assert trace.info.request_id == request_id
    +    assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "2"
    +    spans = trace.data.spans
    +    assert len(spans) == 4
    +
    +    root_span = spans[0]
    +    assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
    +    # there might be slight difference when we truncate nano seconds to milliseconds
    +    assert (
    +        root_span.end_time_ns // 1_000_000
    +        - (trace.info.timestamp_ms + trace.info.execution_time_ms)
    +    ) <= 1
    +    assert root_span.inputs == {"product": "MLflow"}
    +    assert root_span.outputs == expected_output
    +    assert root_span.span_type == "CHAIN"
    +
    +    root_span_id = root_span.span_id
    +    child_span = spans[2]
    +    assert child_span.parent_id == root_span_id
    +    assert child_span.inputs[0][0]["content"] == "What is MLflow?"
    +    assert child_span.outputs["generations"][0][0]["text"] == expected_output
    +    assert child_span.span_type == "CHAT_MODEL"
    +
    +
    +@pytest.mark.skipif(
    +    Version(langchain.__version__) < Version("0.2.0"),
    +    reason="ToolCall message is not available in older versions",
    +)
    +def test_langchain_auto_tracing_in_serving_agent():
    +    mlflow.langchain.autolog()
    +
    +    input_example = {"input": "What is 2 * 3?"}
    +    expected_output = {"output": "The result of 2 * 3 is 6."}
    +
    +    with mlflow.start_run():
    +        model_info = mlflow.langchain.log_model(
    +            "tests/langchain/sample_code/openai_agent.py",
    +            "langchain_model",
    +            input_example=input_example,
    +        )
    +
    +    request_id, response, trace_dict = score_in_model_serving(
    +        model_info.model_uri,
    +        input_example,
    +    )
    +
    +    trace = Trace.from_dict(trace_dict)
    +    assert trace.info.request_id == request_id
    +    assert trace.info.status == "OK"
    +
    +    spans = trace.data.spans
    +    assert len(spans) == 16
    +
    +    root_span = spans[0]
    +    assert root_span.name == "AgentExecutor"
    +    assert root_span.span_type == "CHAIN"
    +    assert root_span.inputs == input_example
    +    assert root_span.outputs == expected_output
    +    assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
    +    assert (
    +        root_span.end_time_ns // 1_000_000
    +        - (trace.info.timestamp_ms + trace.info.execution_time_ms)
    +    ) <= 1
    
  • tests/langchain/test_langchain_model_export.py+0 37 modified
    @@ -96,14 +96,11 @@
     from mlflow.models.signature import ModelSignature, Schema, infer_signature
     from mlflow.models.utils import load_serving_example
     from mlflow.pyfunc.context import Context
    -from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY
    -from mlflow.tracing.processor.inference_table import _HEADER_REQUEST_ID_KEY
     from mlflow.tracking.artifact_utils import _download_artifact_from_uri
     from mlflow.types.schema import AnyType, Array, ColSpec, DataType, Object, Property
     
     from tests.helper_functions import _compare_logged_code_paths, pyfunc_serve_and_score_model
     from tests.langchain.conftest import DeterministicDummyEmbeddings
    -from tests.tracing.export.test_inference_table_exporter import _REQUEST_ID
     
     # this kwarg was added in langchain_community 0.0.27, and
     # prevents the use of pickled objects if not provided.
    @@ -3074,40 +3071,6 @@ def retrieve_history(input):
         }
     
     
    -@pytest.mark.parametrize("enable_mlflow_tracing", [True, False])
    -def test_langchain_model_inject_callback_in_model_serving(
    -    monkeypatch, model_path, enable_mlflow_tracing
    -):
    -    # Emulate the model serving environment
    -    monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
    -    monkeypatch.setenv("MLFLOW_ENABLE_TRACE_IN_SERVING", "true")
    -    monkeypatch.setenv("ENABLE_MLFLOW_TRACING", str(enable_mlflow_tracing).lower())
    -
    -    model = create_openai_runnable()
    -    mlflow.langchain.save_model(model, model_path)
    -
    -    loaded_model = mlflow.pyfunc.load_model(model_path)
    -
    -    # Mock Flask context
    -    with mock.patch("mlflow.tracing.processor.inference_table._get_flask_request") as mock_request:
    -        mock_request.return_value.headers = {_HEADER_REQUEST_ID_KEY: _REQUEST_ID}
    -
    -        loaded_model.predict({"product": "shoe"})
    -
    -    # Trace should be logged to the inference table
    -    from mlflow.tracing.export.inference_table import _TRACE_BUFFER
    -
    -    if enable_mlflow_tracing:
    -        assert len(_TRACE_BUFFER) == 1
    -        assert _REQUEST_ID in _TRACE_BUFFER
    -        trace = _TRACE_BUFFER[_REQUEST_ID]
    -        assert trace["info"]["request_metadata"][TRACE_SCHEMA_VERSION_KEY] == str(
    -            TRACE_SCHEMA_VERSION
    -        )
    -    else:
    -        assert len(_TRACE_BUFFER) == 0
    -
    -
     @pytest.mark.parametrize("env_var", ["MLFLOW_ENABLE_TRACE_IN_SERVING", "ENABLE_MLFLOW_TRACING"])
     def test_langchain_model_not_inject_callback_when_disabled(monkeypatch, model_path, env_var):
         # Emulate the model serving environment
    
  • tests/langchain/test_langchain_tracer.py+6 92 modified
    @@ -31,9 +31,7 @@
     from mlflow.exceptions import MlflowException
     from mlflow.langchain import _LangChainModelWrapper
     from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
    -from mlflow.pyfunc.context import Context
    -from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey
    -from mlflow.tracing.export.inference_table import pop_trace
    +from mlflow.tracing.constant import SpanAttributeKey
     from mlflow.tracing.provider import trace_disabled
     
     from tests.tracing.helper import get_traces
    @@ -508,103 +506,19 @@ def test_multiple_components():
         _validate_trace_json_serialization(trace)
     
     
    -def _predict_with_callbacks(lc_model, request_id, data):
    -    model = _LangChainModelWrapper(lc_model)
    -    tracer = MlflowLangchainTracer(prediction_context=Context(request_id=request_id))
    -    response = model._predict_with_callbacks(
    -        data, callback_handlers=[tracer], convert_chat_responses=True
    -    )
    -    trace_dict = pop_trace(request_id)
    -    return response, trace_dict
    -
    -
    -def test_e2e_rag_model_tracing_in_serving(mock_databricks_serving_with_tracing_env, monkeypatch):
    -    monkeypatch.setenv("RAG_TRACE_V2_ENABLED", "true")
    -
    -    llm_chain = create_openai_llmchain()
    -
    -    request_id = "test_request_id"
    -    response, trace_dict = _predict_with_callbacks(llm_chain, request_id, ["MLflow"])
    -
    -    assert response == [{"text": TEST_CONTENT}]
    -    trace = Trace.from_dict(trace_dict)
    -    assert trace.info.request_id == request_id
    -    assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "2"
    -    spans = trace.data.spans
    -    assert len(spans) == 2
    -
    -    root_span = spans[0]
    -    assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
    -    # there might be slight difference when we truncate nano seconds to milliseconds
    -    assert (
    -        root_span.end_time_ns // 1_000_000
    -        - (trace.info.timestamp_ms + trace.info.execution_time_ms)
    -    ) <= 1
    -    assert root_span.inputs == {"product": "MLflow"}
    -    assert root_span.outputs == {"text": TEST_CONTENT}
    -    assert root_span.span_type == "CHAIN"
    -
    -    root_span_id = root_span.span_id
    -    child_span = spans[1]
    -    assert child_span.parent_id == root_span_id
    -    assert child_span.inputs == ["What is MLflow?"]
    -    assert child_span.outputs["generations"][0][0]["text"] == TEST_CONTENT
    -    assert child_span.span_type == "LLM"
    -
    -    _validate_trace_json_serialization(trace)
    -
    -
    -@pytest.mark.skipif(
    -    Version(langchain.__version__) < Version("0.2.0"),
    -    reason="ToolCall message is not available in older versions",
    -)
    -def test_agent_success(mock_databricks_serving_with_tracing_env):
    -    # Load the agent definition (with OpenAI mock) from the sample script
    -    from tests.langchain.sample_code.openai_agent import create_openai_agent
    -
    -    agent = create_openai_agent()
    -
    -    langchain_input = {"input": "what is the value of magic_function(3)?"}
    -    expected_output = {"output": "The result of 2 * 3 is 6."}
    -    request_id = "test_request_id"
    -    response, trace_dict = _predict_with_callbacks(agent, request_id, langchain_input)
    -
    -    assert response == expected_output
    -
    -    trace = Trace.from_dict(trace_dict)
    -    assert trace.info.status == "OK"
    -
    -    spans = trace.data.spans
    -    assert len(spans) == 16
    -
    -    root_span = spans[0]
    -    assert root_span.name == "AgentExecutor"
    -    assert root_span.span_type == "CHAIN"
    -    assert root_span.inputs == langchain_input
    -    assert root_span.outputs == expected_output
    -    assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
    -    assert (
    -        root_span.end_time_ns // 1_000_000
    -        - (trace.info.timestamp_ms + trace.info.execution_time_ms)
    -    ) <= 1
    -
    -    _validate_trace_json_serialization(trace)
    -
    -
    -def test_tool_success(mock_databricks_serving_with_tracing_env):
    +def test_tool_success():
    +    callback = MlflowLangchainTracer()
         prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}"
         llm = OpenAI(temperature=0.9)
     
         chain = prompt | llm | StrOutputParser()
         chain_tool = tool("chain_tool", chain)
     
         tool_input = {"question": "What up"}
    -    request_id = "test_request_id"
    -    response, trace_dict = _predict_with_callbacks(chain_tool, request_id, tool_input)
    +    response = chain_tool.invoke(tool_input, config={"callbacks": [callback]})
     
         # str output is converted to _ChatResponse
    -    output = response["choices"][0]["message"]["content"]
    -    trace = Trace.from_dict(trace_dict)
    +    trace = mlflow.get_last_active_trace()
         spans = trace.data.spans
         assert len(spans) == 5
     
    @@ -631,7 +545,7 @@ def test_tool_success(mock_databricks_serving_with_tracing_env):
         # StrOutputParser
         output_parser_span = spans[4]
         assert output_parser_span.span_type == "CHAIN"
    -    assert output_parser_span.outputs == output
    +    assert output_parser_span.outputs == response
     
         _validate_trace_json_serialization(trace)
     
    
  • tests/openai/test_openai_autolog.py+73 0 modified
    @@ -1,10 +1,15 @@
     import json
    +from unittest import mock
     
    +import httpx
     import openai
     import pytest
    +from packaging.version import Version
    +from pydantic import BaseModel
     
     import mlflow
     from mlflow import MlflowClient
    +from mlflow.entities.span import SpanType
     from mlflow.exceptions import MlflowException
     from mlflow.tracing.constant import SpanAttributeKey, TraceMetadataKey
     
    @@ -461,3 +466,71 @@ def test_autolog_raw_response_stream(client):
             messages + [{"role": "assistant", "content": "Hello world"}]
         )
         assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == MOCK_TOOLS
    +
    +
    +@pytest.mark.skipif(
    +    Version(openai.__version__) < Version("1.40"), reason="Requires OpenAI SDK >= 1.40"
    +)
    +def test_response_format(client):
    +    mlflow.openai.autolog()
    +
    +    class Person(BaseModel):
    +        name: str
    +        age: int
    +
    +    def send_patch(self, request, *args, **kwargs):
    +        return httpx.Response(
    +            status_code=200,
    +            request=request,
    +            json={
    +                "id": "chatcmpl-Ax4UAd5xf32KjgLkS1SEEY9oorI9m",
    +                "object": "chat.completion",
    +                "created": 1738641958,
    +                "model": "gpt-4o-2024-08-06",
    +                "choices": [
    +                    {
    +                        "index": 0,
    +                        "message": {
    +                            "role": "assistant",
    +                            "content": '{"name":"Angelo","age":42}',
    +                            "refusal": None,
    +                        },
    +                        "logprobs": None,
    +                        "finish_reason": "stop",
    +                    }
    +                ],
    +                "usage": {
    +                    "prompt_tokens": 68,
    +                    "completion_tokens": 11,
    +                    "total_tokens": 79,
    +                    "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
    +                    "completion_tokens_details": {
    +                        "reasoning_tokens": 0,
    +                        "audio_tokens": 0,
    +                        "accepted_prediction_tokens": 0,
    +                        "rejected_prediction_tokens": 0,
    +                    },
    +                },
    +                "service_tier": "default",
    +                "system_fingerprint": "fp_50cad350e4",
    +            },
    +        )
    +
    +    with mock.patch("httpx.Client.send", send_patch):
    +        client = openai.OpenAI()
    +        response = client.beta.chat.completions.parse(
    +            messages=[
    +                {"role": "system", "content": "Extract info from text"},
    +                {"role": "user", "content": "I am Angelo and I am 42."},
    +            ],
    +            model="gpt-4o",
    +            temperature=0,
    +            response_format=Person,
    +        )
    +
    +    assert response.choices[0].message.parsed == Person(name="Angelo", age=42)
    +    trace = mlflow.get_last_active_trace()
    +    assert len(trace.data.spans) == 1
    +    span = trace.data.spans[0]
    +    assert span.outputs["choices"][0]["message"]["content"] == '{"name":"Angelo","age":42}'
    +    assert span.span_type == SpanType.CHAT_MODEL
    
  • tests/pytorch/test_pytorch_model_export.py+26 8 modified
    @@ -12,6 +12,7 @@
     import pytest
     import torch
     import yaml
    +from packaging.version import Version
     from sklearn import datasets
     from torch import nn
     from torch.utils.data import DataLoader
    @@ -58,6 +59,12 @@
         [] if _is_available_on_pypi("torch") else ["--env-manager", "local"]
     )
     
    +# in pytorch >= 2.6.0, the `weights_only` kwarg default has been changed from
    +# `False` to `True`. this can cause pickle deserialization errors when loading
    +# models, unless the model classes have been explicitly marked as safe using
    +# `torch.serialization.add_safe_globals()`
    +ENABLE_LEGACY_DESERIALIZATION = Version(torch.__version__) >= Version("2.6.0")
    +
     
     @pytest.fixture(scope="module")
     def data():
    @@ -697,7 +704,11 @@ def track_module_imports(module_name):
             import_mock.side_effect = track_module_imports
             pyfunc.load_model(model_path)
     
    -    torch_load_mock.assert_called_with(mock.ANY, pickle_module=custom_pickle_module)
    +    expected_kwargs = {"pickle_module": custom_pickle_module}
    +    if ENABLE_LEGACY_DESERIALIZATION:
    +        expected_kwargs["weights_only"] = False
    +
    +    torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs)
         assert custom_pickle_module.__name__ in imported_modules
     
     
    @@ -729,7 +740,11 @@ def track_module_imports(module_name):
             import_mock.side_effect = track_module_imports
             pyfunc.load_model(model_uri=model_uri)
     
    -    torch_load_mock.assert_called_with(mock.ANY, pickle_module=custom_pickle_module)
    +    expected_kwargs = {"pickle_module": custom_pickle_module}
    +    if ENABLE_LEGACY_DESERIALIZATION:
    +        expected_kwargs["weights_only"] = False
    +
    +    torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs)
         assert custom_pickle_module.__name__ in imported_modules
     
     
    @@ -1261,9 +1276,12 @@ def test_load_model_to_device(sequential_model):
         with mock.patch("mlflow.pytorch._load_model") as load_model_mock:
             with mlflow.start_run():
                 model_info = mlflow.pytorch.log_model(sequential_model, "pytorch")
    -            mlflow.pyfunc.load_model(
    -                model_uri=model_info.model_uri, model_config={"device": "cuda"}
    -            )
    -            load_model_mock.assert_called_with(mock.ANY, device="cuda")
    -            mlflow.pytorch.load_model(model_uri=model_info.model_uri, device="cuda")
    -            load_model_mock.assert_called_with(path=mock.ANY, device="cuda")
    +            model_config = {"device": "cuda"}
    +            if ENABLE_LEGACY_DESERIALIZATION:
    +                model_config["weights_only"] = False
    +
    +            mlflow.pyfunc.load_model(model_uri=model_info.model_uri, model_config=model_config)
    +
    +            load_model_mock.assert_called_with(mock.ANY, **model_config)
    +            mlflow.pytorch.load_model(model_uri=model_info.model_uri, **model_config)
    +            load_model_mock.assert_called_with(path=mock.ANY, **model_config)
    
  • tests/tracing/helper.py+38 0 modified
    @@ -1,6 +1,10 @@
    +import os
     import time
    +import uuid
    +from concurrent.futures import ThreadPoolExecutor
     from dataclasses import dataclass
     from typing import Optional
    +from unittest import mock
     
     import opentelemetry.trace as trace_api
     import pytest
    @@ -10,6 +14,7 @@
     from mlflow.entities import Trace, TraceData, TraceInfo
     from mlflow.entities.trace_status import TraceStatus
     from mlflow.ml_package_versions import FLAVOR_TO_MODULE_NAME
    +from mlflow.tracing.export.inference_table import pop_trace
     from mlflow.tracing.processor.mlflow import MlflowSpanProcessor
     from mlflow.tracing.provider import _get_tracer
     from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID
    @@ -165,3 +170,36 @@ def reset_autolog_state():
             revert_patches(flavor)
     
         AUTOLOGGING_INTEGRATIONS.clear()
    +
    +
    +def score_in_model_serving(model_uri: str, model_input: dict):
    +    """
    +    A helper function to emulate model prediction inside a Databricks model serving environment.
    +
    +    This is highly simplified version, but captures important aspects for testing tracing:
    +      1. Setting env vars that users set for enable tracing in model serving
    +      2. Load the model in a background thread
    +    """
    +    from mlflow.pyfunc.context import Context, set_prediction_context
    +
    +    with mock.patch.dict(
    +        "os.environ",
    +        os.environ | {"IS_IN_DB_MODEL_SERVING_ENV": "true", "ENABLE_MLFLOW_TRACING": "true"},
    +        clear=True,
    +    ):
    +        # Reset tracing setup to start fresh w/ model serving environment
    +        mlflow.tracing.reset()
    +
    +        def _load_model():
    +            return mlflow.pyfunc.load_model(model_uri)
    +
    +        with ThreadPoolExecutor(max_workers=1) as executor:
    +            model = executor.submit(_load_model).result()
    +
    +        # Score the model
    +        request_id = uuid.uuid4().hex
    +        with set_prediction_context(Context(request_id=request_id)):
    +            predictions = model.predict(model_input)
    +
    +        trace = pop_trace(request_id)
    +        return (request_id, predictions, trace)
    
  • tests/tracing/processor/test_inference_table_processor.py+8 19 modified
    @@ -1,15 +1,11 @@
     import json
     from unittest import mock
     
    -import pytest
    -
     from mlflow.entities.span import LiveSpan
     from mlflow.entities.trace_status import TraceStatus
    +from mlflow.pyfunc.context import Context, set_prediction_context
     from mlflow.tracing.constant import SpanAttributeKey
    -from mlflow.tracing.processor.inference_table import (
    -    _HEADER_REQUEST_ID_KEY,
    -    InferenceTableSpanProcessor,
    -)
    +from mlflow.tracing.processor.inference_table import InferenceTableSpanProcessor
     from mlflow.tracing.trace_manager import InMemoryTraceManager
     
     from tests.tracing.helper import create_mock_otel_span, create_test_trace_info
    @@ -18,25 +14,16 @@
     _REQUEST_ID = f"tr-{_TRACE_ID}"
     
     
    -@pytest.fixture
    -def flask_request():
    -    with mock.patch(
    -        "mlflow.tracing.processor.inference_table._get_flask_request"
    -    ) as mock_get_flask_request:
    -        request = mock_get_flask_request.return_value
    -        request.headers = {_HEADER_REQUEST_ID_KEY: _REQUEST_ID}
    -        yield request
    -
    -
    -def test_on_start(flask_request):
    +def test_on_start():
         # Root span should create a new trace on start
         span = create_mock_otel_span(
             trace_id=_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000
         )
         trace_manager = InMemoryTraceManager.get_instance()
         processor = InferenceTableSpanProcessor(span_exporter=mock.MagicMock())
     
    -    processor.on_start(span)
    +    with set_prediction_context(Context(request_id=_REQUEST_ID)):
    +        processor.on_start(span)
     
         assert span.attributes.get(SpanAttributeKey.REQUEST_ID) == json.dumps(_REQUEST_ID)
         assert _REQUEST_ID in InMemoryTraceManager.get_instance()._traces
    @@ -52,7 +39,9 @@ def test_on_start(flask_request):
         child_span = create_mock_otel_span(
             trace_id=_TRACE_ID, span_id=2, parent_id=1, start_time=8_000_000
         )
    -    processor.on_start(child_span)
    +
    +    with set_prediction_context(Context(request_id=_REQUEST_ID)):
    +        processor.on_start(child_span)
     
         assert child_span.attributes.get(SpanAttributeKey.REQUEST_ID) == json.dumps(_REQUEST_ID)
     
    
  • tests/tracing/test_fluent.py+57 1 modified
    @@ -235,7 +235,8 @@ def predict():
             data = json.loads(flask.request.data.decode("utf-8"))
             request_id = flask.request.headers.get("X-Request-ID")
     
    -        prediction = TestModel().predict(**data)
    +        with set_prediction_context(Context(request_id=request_id)):
    +            prediction = TestModel().predict(**data)
     
             trace = pop_trace(request_id=request_id)
     
    @@ -1475,3 +1476,58 @@ def test_add_trace_logging_model_from_code():
         trace = mlflow.get_last_active_trace()
         assert trace is not None
         assert len(trace.data.spans) == 2
    +
    +
    +@pytest.mark.parametrize(
    +    "inputs", [{"question": "Does mlflow support tracing?"}, "Does mlflow support tracing?", None]
    +)
    +@pytest.mark.parametrize("outputs", [{"answer": "Yes"}, "Yes", None])
    +@pytest.mark.parametrize(
    +    "intermediate_outputs",
    +    [
    +        {
    +            "retrieved_documents": ["mlflow documentation"],
    +            "system_prompt": ["answer the question with yes or no"],
    +        },
    +        None,
    +    ],
    +)
    +def test_log_trace_success(inputs, outputs, intermediate_outputs):
    +    start_time_ms = 1736144700
    +    execution_time_ms = 5129
    +
    +    mlflow.log_trace(
    +        name="test",
    +        request=inputs,
    +        response=outputs,
    +        intermediate_outputs=intermediate_outputs,
    +        start_time_ms=start_time_ms,
    +        execution_time_ms=execution_time_ms,
    +    )
    +
    +    trace = mlflow.get_last_active_trace()
    +    if inputs is not None:
    +        assert trace.data.request == json.dumps(inputs)
    +    else:
    +        assert trace.data.request is None
    +    if outputs is not None:
    +        assert trace.data.response == json.dumps(outputs)
    +    else:
    +        assert trace.data.response is None
    +    if intermediate_outputs is not None:
    +        assert trace.data.intermediate_outputs == intermediate_outputs
    +    spans = trace.data.spans
    +    assert len(spans) == 1
    +    root_span = spans[0]
    +    assert root_span.name == "test"
    +    assert root_span.start_time_ns == start_time_ms * 1000000
    +    assert root_span.end_time_ns == (start_time_ms + execution_time_ms) * 1000000
    +
    +
    +def test_log_trace_fail_within_span_context():
    +    with pytest.raises(MlflowException, match="Another trace is already set in the global context"):
    +        with mlflow.start_span("span"):
    +            mlflow.log_trace(
    +                request="Does mlflow support tracing?",
    +                response="Yes",
    +            )
    
  • tests/utils/test_validation.py+40 0 modified
    @@ -3,6 +3,7 @@
     import pytest
     
     from mlflow.entities import Metric, Param, RunTag
    +from mlflow.environment_variables import MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH
     from mlflow.exceptions import MlflowException
     from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode
     from mlflow.utils.os import is_windows
    @@ -13,6 +14,7 @@
         _validate_batch_log_limits,
         _validate_db_type_string,
         _validate_experiment_artifact_location,
    +    _validate_experiment_artifact_location_length,
         _validate_experiment_name,
         _validate_metric_name,
         _validate_model_alias_name,
    @@ -333,3 +335,41 @@ def test_validate_db_type_string_bad(db_type):
         with pytest.raises(MlflowException, match="Invalid database engine") as e:
             _validate_db_type_string(db_type)
         assert "Invalid database engine" in e.value.message
    +
    +
    +@pytest.mark.parametrize(
    +    "artifact_location",
    +    [
    +        "s3://test-bucket/",
    +        "file:///path/to/artifacts",
    +        "mlflow-artifacts:/path/to/artifacts",
    +        "dbfs:/databricks/mlflow-tracking/some-id",
    +    ],
    +)
    +def test_validate_experiment_artifact_location_length_good(artifact_location):
    +    _validate_experiment_artifact_location_length(artifact_location)
    +
    +
    +@pytest.mark.parametrize(
    +    "artifact_location",
    +    ["s3://test-bucket/" + "a" * 10000, "file:///path/to/" + "directory" * 1111],
    +)
    +def test_validate_experiment_artifact_location_length_bad(artifact_location):
    +    with pytest.raises(MlflowException, match="Invalid artifact path length"):
    +        _validate_experiment_artifact_location_length(artifact_location)
    +
    +
    +def test_setting_experiment_artifact_location_env_var_works(monkeypatch):
    +    artifact_location = "file://aaaa"  # length 11
    +
    +    # should not throw
    +    _validate_experiment_artifact_location_length(artifact_location)
    +
    +    # reduce limit to 10
    +    monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "10")
    +    with pytest.raises(MlflowException, match="Invalid artifact path length"):
    +        _validate_experiment_artifact_location_length(artifact_location)
    +
    +    # increase limit to 11
    +    monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "11")
    +    _validate_experiment_artifact_location_length(artifact_location)
    

Vulnerability mechanics

Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

4

News mentions

0

No linked articles in our index yet.