CVE-2026-5241
Description
Hugging Face Transformers LightGlue model loading allows RCE by overriding trust_remote_code via a malicious config.json.
AI Insight
LLM-synthesized narrative grounded in this CVE's description and references.
Hugging Face Transformers LightGlue model loading allows RCE by overriding trust_remote_code via a malicious config.json.
Vulnerability
A vulnerability exists in the LightGlueConfig class within the huggingface/transformers library, specifically affecting version 5.2.0. The issue arises during the loading of a LightGlue model when the trust_remote_code parameter, intended to prevent arbitrary code execution, is improperly handled. When loading a model with trust_remote_code=False, the configuration can be manipulated via an untrusted config.json file to override this setting in nested AutoConfig.from_pretrained() calls, leading to the execution of attacker-controlled code [1].
Exploitation
An attacker can exploit this vulnerability by providing a malicious model repository containing a crafted config.json file. When a user or application attempts to load a LightGlue model from this repository using AutoModel.from_pretrained() and explicitly sets trust_remote_code=False, the vulnerability is triggered. The malicious config.json will cause the LightGlueConfig to read and propagate an attacker-controlled trust_remote_code value, enabling the execution of arbitrary Python modules embedded within the model repository [1, 2].
Impact
Successful exploitation allows an attacker to achieve arbitrary code execution within the environment where the model is being loaded. This can lead to severe consequences such as the theft of sensitive credentials, lateral movement within the network, or the establishment of persistence mechanisms like backdoors. The impact is particularly high in environments like API inference servers, research notebooks, CI/CD pipelines, and model evaluation workers, where models are frequently loaded from potentially untrusted sources [2].
Mitigation
The vulnerability has been addressed in huggingface/transformers via commit 676559d5022b74aaa0cee1cee0842b7f27c5320e [1]. Users are advised to update to a patched version of the library. No specific workaround is mentioned in the available references, and the end-of-life status or KEV listing for affected versions is not yet disclosed.
AI Insight generated on Jun 3, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.
Affected products
2(expand)+ 1 more
- (no CPE)
- (no CPE)range: =5.2.0
Patches
2676559d5022b:rotating_light: [`LightGlue`] Remove remote code execution (#45122)
5 files changed · +106 −34
src/transformers/models/lightglue/configuration_lightglue.py+3 −14 modified@@ -40,8 +40,6 @@ class LightGlueConfig(PreTrainedConfig): The confidence threshold used to prune points filter_threshold (`float`, *optional*, defaults to 0.1): The confidence threshold used to filter matches - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether to trust remote code when using other models than SuperPoint as keypoint detector. Examples: ```python @@ -73,10 +71,6 @@ class LightGlueConfig(PreTrainedConfig): hidden_act: str = "gelu" attention_dropout: float | int = 0.0 attention_bias: bool = True - # LightGlue can be used with other models than SuperPoint as keypoint detector - # We provide the trust_remote_code argument to allow the use of other models - # that are not registered in the CONFIG_MAPPING dictionary (for example DISK) - trust_remote_code: bool = False def __post_init__(self, **kwargs): if self.num_key_value_heads is None: @@ -86,14 +80,9 @@ def __post_init__(self, **kwargs): # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 if isinstance(self.keypoint_detector_config, dict): self.keypoint_detector_config["model_type"] = self.keypoint_detector_config.get("model_type", "superpoint") - if self.keypoint_detector_config["model_type"] not in CONFIG_MAPPING: - self.keypoint_detector_config = AutoConfig.from_pretrained( - self.keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code - ) - else: - self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]]( - **self.keypoint_detector_config, attn_implementation="eager" - ) + self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]]( + **self.keypoint_detector_config, attn_implementation="eager" + ) elif self.keypoint_detector_config is None: self.keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
src/transformers/models/lightglue/modeling_lightglue.py+1 −3 modified@@ -502,9 +502,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel): def __init__(self, config: LightGlueConfig): super().__init__(config) - self.keypoint_detector = AutoModelForKeypointDetection.from_config( - config.keypoint_detector_config, trust_remote_code=config.trust_remote_code - ) + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim self.descriptor_dim = config.descriptor_dim
src/transformers/models/lightglue/modular_lightglue.py+4 −17 modified@@ -53,8 +53,6 @@ class LightGlueConfig(PreTrainedConfig): The confidence threshold used to prune points filter_threshold (`float`, *optional*, defaults to 0.1): The confidence threshold used to filter matches - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether to trust remote code when using other models than SuperPoint as keypoint detector. Examples: ```python @@ -86,10 +84,6 @@ class LightGlueConfig(PreTrainedConfig): hidden_act: str = "gelu" attention_dropout: float | int = 0.0 attention_bias: bool = True - # LightGlue can be used with other models than SuperPoint as keypoint detector - # We provide the trust_remote_code argument to allow the use of other models - # that are not registered in the CONFIG_MAPPING dictionary (for example DISK) - trust_remote_code: bool = False def __post_init__(self, **kwargs): if self.num_key_value_heads is None: @@ -99,14 +93,9 @@ def __post_init__(self, **kwargs): # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 if isinstance(self.keypoint_detector_config, dict): self.keypoint_detector_config["model_type"] = self.keypoint_detector_config.get("model_type", "superpoint") - if self.keypoint_detector_config["model_type"] not in CONFIG_MAPPING: - self.keypoint_detector_config = AutoConfig.from_pretrained( - self.keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code - ) - else: - self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]]( - **self.keypoint_detector_config, attn_implementation="eager" - ) + self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]]( + **self.keypoint_detector_config, attn_implementation="eager" + ) elif self.keypoint_detector_config is None: self.keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") @@ -520,9 +509,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel): def __init__(self, config: LightGlueConfig): super().__init__(config) - self.keypoint_detector = AutoModelForKeypointDetection.from_config( - config.keypoint_detector_config, trust_remote_code=config.trust_remote_code - ) + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim self.descriptor_dim = config.descriptor_dim
utils/mlinter/rules.toml+21 −0 modified@@ -241,3 +241,24 @@ class FooModel(FooPreTrainedModel): self.layers = nn.ModuleList(...) self.post_init() ''' + +[rules.TRF014] +description = "`trust_remote_code` should never be used in native model integrations." +default_enabled = true +allowlist_models = [] + +[rules.TRF014.explanation] +what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files." +why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers." +bad_example = ''' +class FooModel(FooPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(..., trust_remote_code=True) +''' +good_example = ''' +class FooModel(FooPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(...) +'''
utils/mlinter/trf014.py+77 −0 added@@ -0,0 +1,77 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TRF014: `trust_remote_code` should never be used in native model integrations.""" + +import ast +from pathlib import Path + +from ._helpers import Violation + + +RULE_ID = "" # Set by discovery + + +class TrustRemoteCodeVisitor(ast.NodeVisitor): + def __init__(self, file_path: Path): + self.file_path = file_path + self.violations: list[Violation] = [] + + def _add(self, node: ast.AST, message: str) -> None: + self.violations.append( + Violation( + file_path=self.file_path, + line_number=node.lineno, + message=f"{RULE_ID}: {message}", + ) + ) + + def visit_Call(self, node: ast.Call) -> None: + """ + Three cases covered by this + 1. `foo(..., trust_remote_code=...)` + 2. `foo(**{..., "trust_remote_code": ...})` + 3. `foo(**dict(trust_remote_code=...))` + + Not covered: + `kwargs = {"trust_remote_code": True}; foo(**kwargs)` + """ + for keyword in node.keywords: + if keyword.arg == "trust_remote_code": + self._add(node, "`trust_remote_code` must not be passed as a keyword argument.") + + elif keyword.arg is None: + value = keyword.value + + if isinstance(value, ast.Dict): + for key in value.keys: + if isinstance(key, ast.Constant) and key.value == "trust_remote_code": + self._add(node, "`trust_remote_code` must not be passed through `**kwargs`.") + + elif isinstance(value, ast.Call): + if isinstance(value.func, ast.Name) and value.func.id == "dict": + for kw in value.keywords: + if kw.arg == "trust_remote_code": + self._add( + node, + "`trust_remote_code` must not be passed through `**kwargs` (dict constructor).", + ) + + self.generic_visit(node) + + +def check(tree: ast.Module, file_path: Path, source_lines: list[str]) -> list[Violation]: + visitor = TrustRemoteCodeVisitor(file_path) + visitor.visit(tree) + return visitor.violations
e5a9ce48f711Add LightGlue model (#31718)
20 files changed · +3632 −2
docs/source/en/model_doc/lightglue.md+104 −0 added@@ -0,0 +1,104 @@ +<!--Copyright 2025 The HuggingFace Team. All rights reserved. + +Licensed under the MIT License; you may not use this file except in compliance with +the License. + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. + +⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be +rendered properly in your Markdown viewer. + + +--> + +# LightGlue + +## Overview + +The LightGlue model was proposed in [LightGlue: Local Feature Matching at Light Speed](https://arxiv.org/abs/2306.13643) +by Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. + +Similar to [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor), this model consists of matching +two sets of local features extracted from two images, its goal is to be faster than SuperGlue. Paired with the +[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and +estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc. + +The abstract from the paper is the following: + +*We introduce LightGlue, a deep neural network that learns to match local features across images. We revisit multiple +design decisions of SuperGlue, the state of the art in sparse matching, and derive simple but effective improvements. +Cumulatively, they make LightGlue more efficient - in terms of both memory and computation, more accurate, and much +easier to train. One key property is that LightGlue is adaptive to the difficulty of the problem: the inference is much +faster on image pairs that are intuitively easy to match, for example because of a larger visual overlap or limited +appearance change. This opens up exciting prospects for deploying deep matchers in latency-sensitive applications like +3D reconstruction. The code and trained models are publicly available at this [https URL](https://github.com/cvg/LightGlue)* + +## How to use + +Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched. +The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding +matching scores. +```python +from transformers import AutoImageProcessor, AutoModel +import torch +from PIL import Image +import requests + +url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg" +image1 = Image.open(requests.get(url_image1, stream=True).raw) +url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg" +image2 = Image.open(requests.get(url_image2, stream=True).raw) + +images = [image1, image2] + +processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") +model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint") + +inputs = processor(images, return_tensors="pt") +with torch.no_grad(): + outputs = model(**inputs) +``` + +You can use the `post_process_keypoint_matching` method from the `LightGlueImageProcessor` to get the keypoints and matches in a readable format: +```python +image_sizes = [[(image.height, image.width) for image in images]] +outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) +for i, output in enumerate(outputs): + print("For the image pair", i) + for keypoint0, keypoint1, matching_score in zip( + output["keypoints0"], output["keypoints1"], output["matching_scores"] + ): + print( + f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}." + ) +``` + +You can visualize the matches between the images by providing the original images as well as the outputs to this method: +```python +processor.plot_keypoint_matching(images, outputs) +``` + + + +This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). +The original code can be found [here](https://github.com/cvg/LightGlue). + +## LightGlueConfig + +[[autodoc]] LightGlueConfig + +## LightGlueImageProcessor + +[[autodoc]] LightGlueImageProcessor + +- preprocess +- post_process_keypoint_matching +- plot_keypoint_matching + +## LightGlueForKeypointMatching + +[[autodoc]] LightGlueForKeypointMatching + +- forward
docs/source/en/_toctree.yml+2 −0 modified@@ -743,6 +743,8 @@ title: ImageGPT - local: model_doc/levit title: LeViT + - local: model_doc/lightglue + title: LightGlue - local: model_doc/mask2former title: Mask2Former - local: model_doc/maskformer
src/transformers/__init__.py+2 −0 modified@@ -231,6 +231,7 @@ "is_faiss_available", "is_flax_available", "is_keras_nlp_available", + "is_matplotlib_available", "is_phonemizer_available", "is_psutil_available", "is_py3nvml_available", @@ -728,6 +729,7 @@ is_faiss_available, is_flax_available, is_keras_nlp_available, + is_matplotlib_available, is_phonemizer_available, is_psutil_available, is_py3nvml_available,
src/transformers/models/auto/configuration_auto.py+2 −0 modified@@ -185,6 +185,7 @@ ("layoutlmv3", "LayoutLMv3Config"), ("led", "LEDConfig"), ("levit", "LevitConfig"), + ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), ("llama4", "Llama4Config"), @@ -556,6 +557,7 @@ ("layoutxlm", "LayoutXLM"), ("led", "LED"), ("levit", "LeViT"), + ("lightglue", "LightGlue"), ("lilt", "LiLT"), ("llama", "LLaMA"), ("llama2", "Llama2"),
src/transformers/models/auto/image_processing_auto.py+1 −0 modified@@ -106,6 +106,7 @@ ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")), + ("lightglue", ("LightGlueImageProcessor",)), ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")), ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
src/transformers/models/auto/modeling_auto.py+1 −0 modified@@ -175,6 +175,7 @@ ("layoutlmv3", "LayoutLMv3Model"), ("led", "LEDModel"), ("levit", "LevitModel"), + ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), ("llama4", "Llama4ForConditionalGeneration"),
src/transformers/models/__init__.py+1 −0 modified@@ -162,6 +162,7 @@ from .layoutxlm import * from .led import * from .levit import * + from .lightglue import * from .lilt import * from .llama import * from .llama4 import *
src/transformers/models/lightglue/configuration_lightglue.py+143 −0 added@@ -0,0 +1,143 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig +from ..superpoint import SuperPointConfig + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + **kwargs, + ): + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +__all__ = ["LightGlueConfig"]
src/transformers/models/lightglue/convert_lightglue_to_hf.py+281 −0 added@@ -0,0 +1,281 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import os +import re +from typing import List + +import torch +from datasets import load_dataset + +from transformers import ( + AutoModelForKeypointDetection, + LightGlueForKeypointMatching, + LightGlueImageProcessor, +) +from transformers.models.lightglue.configuration_lightglue import LightGlueConfig + + +DEFAULT_CHECKPOINT_URL = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_lightglue.pth" + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image0 = dataset[0]["image"] + image1 = dataset[1]["image"] + image2 = dataset[2]["image"] + # [image1, image1] on purpose to test the model early stopping + return [[image2, image0], [image1, image1]] + + +def verify_model_outputs(model, device): + images = prepare_imgs() + preprocessor = LightGlueImageProcessor() + inputs = preprocessor(images=images, return_tensors="pt").to(device) + model.to(device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_matches_values = outputs.matches[0, 0, 20:30] + predicted_matching_scores_values = outputs.matching_scores[0, 0, 20:30] + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + + expected_max_number_keypoints = 866 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor([-1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64).to(device) + expected_matching_scores_values = torch.tensor([0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583]).to(device) + + expected_number_of_matches = 140 + + assert outputs.matches.shape == expected_matches_shape + assert outputs.matching_scores.shape == expected_matching_scores_shape + + assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-2) + assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2) + + assert predicted_number_of_matches == expected_number_of_matches + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"posenc.Wr": r"positional_encoder.projector", + r"self_attn.(\d+).Wqkv": r"transformer_layers.\1.self_attention.Wqkv", + r"self_attn.(\d+).out_proj": r"transformer_layers.\1.self_attention.o_proj", + r"self_attn.(\d+).ffn.0": r"transformer_layers.\1.self_mlp.fc1", + r"self_attn.(\d+).ffn.1": r"transformer_layers.\1.self_mlp.layer_norm", + r"self_attn.(\d+).ffn.3": r"transformer_layers.\1.self_mlp.fc2", + r"cross_attn.(\d+).to_qk": r"transformer_layers.\1.cross_attention.to_qk", + r"cross_attn.(\d+).to_v": r"transformer_layers.\1.cross_attention.v_proj", + r"cross_attn.(\d+).to_out": r"transformer_layers.\1.cross_attention.o_proj", + r"cross_attn.(\d+).ffn.0": r"transformer_layers.\1.cross_mlp.fc1", + r"cross_attn.(\d+).ffn.1": r"transformer_layers.\1.cross_mlp.layer_norm", + r"cross_attn.(\d+).ffn.3": r"transformer_layers.\1.cross_mlp.fc2", + r"log_assignment.(\d+).matchability": r"match_assignment_layers.\1.matchability", + r"log_assignment.(\d+).final_proj": r"match_assignment_layers.\1.final_projection", + r"token_confidence.(\d+).token.0": r"token_confidence.\1.token", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: List[str]): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def add_keypoint_detector_state_dict(lightglue_state_dict): + keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + keypoint_detector_state_dict = keypoint_detector.state_dict() + for k, v in keypoint_detector_state_dict.items(): + lightglue_state_dict[f"keypoint_detector.{k}"] = v + return lightglue_state_dict + + +def split_weights(state_dict): + for i in range(9): + # Remove unused r values + log_assignment_r_key = f"log_assignment.{i}.r" + if state_dict.get(log_assignment_r_key, None) is not None: + state_dict.pop(log_assignment_r_key) + + Wqkv_weight = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.weight") + Wqkv_bias = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.bias") + Wqkv_weight = Wqkv_weight.reshape(256, 3, 256) + Wqkv_bias = Wqkv_bias.reshape(256, 3) + query_weight, key_weight, value_weight = Wqkv_weight[:, 0], Wqkv_weight[:, 1], Wqkv_weight[:, 2] + query_bias, key_bias, value_bias = Wqkv_bias[:, 0], Wqkv_bias[:, 1], Wqkv_bias[:, 2] + state_dict[f"transformer_layers.{i}.self_attention.q_proj.weight"] = query_weight + state_dict[f"transformer_layers.{i}.self_attention.k_proj.weight"] = key_weight + state_dict[f"transformer_layers.{i}.self_attention.v_proj.weight"] = value_weight + state_dict[f"transformer_layers.{i}.self_attention.q_proj.bias"] = query_bias + state_dict[f"transformer_layers.{i}.self_attention.k_proj.bias"] = key_bias + state_dict[f"transformer_layers.{i}.self_attention.v_proj.bias"] = value_bias + + to_qk_weight = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.weight") + to_qk_bias = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.bias") + state_dict[f"transformer_layers.{i}.cross_attention.q_proj.weight"] = to_qk_weight + state_dict[f"transformer_layers.{i}.cross_attention.q_proj.bias"] = to_qk_bias + state_dict[f"transformer_layers.{i}.cross_attention.k_proj.weight"] = to_qk_weight + state_dict[f"transformer_layers.{i}.cross_attention.k_proj.bias"] = to_qk_bias + + return state_dict + + +@torch.no_grad() +def write_model( + model_path, + checkpoint_url, + organization, + safe_serialization=True, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + # ------------------------------------------------------------ + # LightGlue config + # ------------------------------------------------------------ + + config = LightGlueConfig( + descriptor_dim=256, + num_hidden_layers=9, + num_attention_heads=4, + ) + config.architectures = ["LightGlueForKeypointMatching"] + config.save_pretrained(model_path) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...") + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url) + + print("Converting model...") + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = original_state_dict.pop(key).contiguous().clone() + + del original_state_dict + gc.collect() + state_dict = split_weights(state_dict) + state_dict = add_keypoint_detector_state_dict(state_dict) + + print("Loading the checkpoint in a LightGlue model...") + device = "cuda" + with torch.device(device): + model = LightGlueForKeypointMatching(config) + model.load_state_dict(state_dict) + print("Checkpoint loaded successfully...") + del model.config._name_or_path + + print("Saving the model...") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = LightGlueForKeypointMatching.from_pretrained(model_path) + print("Model reloaded successfully.") + + model_name = "lightglue" + if "superpoint" in checkpoint_url: + model_name += "_superpoint" + if checkpoint_url == DEFAULT_CHECKPOINT_URL: + print("Checking the model outputs...") + verify_model_outputs(model, device) + print("Model outputs verified successfully.") + + if push_to_hub: + print("Pushing model to the hub...") + model.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add model", + ) + config.push_to_hub(repo_id=f"{organization}/{model_name}", commit_message="Add config") + + write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub) + + +def write_image_processor(save_dir, model_name, organization, push_to_hub=False): + if "superpoint" in model_name: + image_processor = LightGlueImageProcessor(do_grayscale=True) + else: + image_processor = LightGlueImageProcessor() + image_processor.save_pretrained(save_dir) + + if push_to_hub: + print("Pushing image processor to the hub...") + image_processor.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add image processor", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default=DEFAULT_CHECKPOINT_URL, + type=str, + help="URL of the original LightGlue checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push model and image preprocessor to the hub", + ) + parser.add_argument( + "--organization", + default="ETH-CVG", + type=str, + help="Hub organization in which you want the model to be uploaded.", + ) + + args = parser.parse_args() + write_model( + args.pytorch_dump_folder_path, + args.checkpoint_url, + args.organization, + safe_serialization=True, + push_to_hub=args.push_to_hub, + )
src/transformers/models/lightglue/image_processing_lightglue.py+452 −0 added@@ -0,0 +1,452 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_type, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + is_vision_available, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_matplotlib_available, logging, requires_backends +from ...utils.import_utils import requires +from .modeling_lightglue import LightGlueKeypointMatchingOutput + + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def is_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +def validate_and_format_image_pairs(images: ImageInput): + error_message = ( + "Input images must be a one of the following :", + " - A pair of PIL images.", + " - A pair of 3D arrays.", + " - A list of pairs of PIL images.", + " - A list of pairs of 3D arrays.", + ) + + def _is_valid_image(image): + """images is a PIL Image or a 3D array.""" + return is_pil_image(image) or ( + is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3 + ) + + if isinstance(images, list): + if len(images) == 2 and all((_is_valid_image(image)) for image in images): + return images + if all( + isinstance(image_pair, list) + and len(image_pair) == 2 + and all(_is_valid_image(image) for image in image_pair) + for image_pair in images + ): + return [image for image_pair in images for image in image_pair] + raise ValueError(error_message) + + +@requires(backends=("torch",)) +class LightGlueImageProcessor(BaseImageProcessor): + r""" + Constructs a LightGlue image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with + pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set + `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + # Validate and convert the input images into a flattened list of images for all subsequent processing steps. + images = validate_and_format_image_pairs(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_grayscale: + image = convert_to_grayscale(image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + all_images.append(image) + + # Convert back the flattened list of images into a list of pairs of images. + image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)] + + data = {"pixel_values": image_pairs} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + """ + Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + Args: + outputs ([`KeypointMatchingOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*): + Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the + target size `(height, width)` of each image in the batch. This must be the original image size (before + any processing). + threshold (`float`, *optional*, defaults to 0.0): + Threshold to filter out the matches with low scores. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image + of the pair, the matching scores and the matching indices. + """ + if outputs.mask.shape[0] != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + if not all(len(target_size) == 2 for target_size in target_sizes): + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if isinstance(target_sizes, List): + image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_pair_sizes = target_sizes + + keypoints = outputs.keypoints.clone() + keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + keypoints = keypoints.to(torch.int32) + + results = [] + for mask_pair, keypoints_pair, matches, scores in zip( + outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] + ): + mask0 = mask_pair[0] > 0 + mask1 = mask_pair[1] > 0 + keypoints0 = keypoints_pair[0][mask0] + keypoints1 = keypoints_pair[1][mask1] + matches0 = matches[mask0] + scores0 = scores[mask0] + + # Filter out matches with low scores + valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) + + matched_keypoints0 = keypoints0[valid_matches] + matched_keypoints1 = keypoints1[matches0[valid_matches]] + matching_scores = scores0[valid_matches] + + results.append( + { + "keypoints0": matched_keypoints0, + "keypoints1": matched_keypoints1, + "matching_scores": matching_scores, + } + ) + + return results + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + outputs ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +__all__ = ["LightGlueImageProcessor"]
src/transformers/models/lightglue/__init__.py+28 −0 added@@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lightglue import * + from .image_processing_lightglue import * + from .modeling_lightglue import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
src/transformers/models/lightglue/modeling_lightglue.py+926 −0 added@@ -0,0 +1,926 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, auto_docstring +from ...utils.generic import can_return_tuple +from ..auto.modeling_auto import AutoModelForKeypointDetection +from .configuration_lightglue import LightGlueConfig + + +@dataclass +class LightGlueKeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LightGlueAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim: + self.input_projection = nn.Linear( + config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True + ) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: torch.Tensor = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + ( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + ( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"]
src/transformers/models/lightglue/modular_lightglue.py+1000 −0 added@@ -0,0 +1,1000 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import PretrainedConfig +from ...image_utils import ImageInput, to_numpy_array +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging +from ...utils.generic import can_return_tuple +from ..auto import CONFIG_MAPPING, AutoConfig +from ..auto.modeling_auto import AutoModelForKeypointDetection +from ..clip.modeling_clip import CLIPMLP +from ..cohere.modeling_cohere import apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaAttention, eager_attention_forward +from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs +from ..superpoint import SuperPointConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_ = "LightGlueConfig" +_CHECKPOINT_FOR_DOC_ = "ETH-CVG/lightglue_superpoint" + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + **kwargs, + ): + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +@dataclass +class LightGlueKeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class LightGlueImageProcessor(SuperGlueImageProcessor): + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + return super().post_process_keypoint_matching(outputs, target_sizes, threshold) + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + outputs ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +class LightGlueAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(CLIPMLP): + def __init__(self, config: LightGlueConfig): + super().__init__(config) + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim: + self.input_projection = nn.Linear( + config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True + ) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: torch.Tensor = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + ( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + ( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]
src/transformers/models/superglue/image_processing_superglue.py+2 −1 modified@@ -17,7 +17,6 @@ import numpy as np -from ... import is_torch_available, is_vision_available from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import resize, to_channel_dimension_format from ...image_utils import ( @@ -29,7 +28,9 @@ infer_channel_dimension_format, is_pil_image, is_scaled_image, + is_torch_available, is_valid_image, + is_vision_available, to_numpy_array, valid_images, validate_preprocess_arguments,
src/transformers/models/superpoint/modeling_superpoint.py+1 −1 modified@@ -253,7 +253,7 @@ def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch. keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints) # Convert (y, x) to (x, y) - keypoints = torch.flip(keypoints, [1]).float() + keypoints = torch.flip(keypoints, [1]).to(scores.dtype) return keypoints, scores
src/transformers/utils/import_utils.py+5 −0 modified@@ -225,6 +225,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _spqr_available = _is_package_available("spqr_quant") _rich_available = _is_package_available("rich") _kernels_available = _is_package_available("kernels") +_matplotlib_available = _is_package_available("matplotlib") _torch_version = "N/A" _torch_available = False @@ -1443,6 +1444,10 @@ def is_rich_available(): return _rich_available +def is_matplotlib_available(): + return _matplotlib_available + + def check_torch_load_is_safe(): if not is_torch_greater_or_equal("2.6"): raise ValueError(
src/transformers/utils/__init__.py+1 −0 modified@@ -179,6 +179,7 @@ is_librosa_available, is_liger_kernel_available, is_lomo_available, + is_matplotlib_available, is_mlx_available, is_natten_available, is_ninja_available,
tests/models/lightglue/__init__.py+0 −0 addedtests/models/lightglue/test_image_processing_lightglue.py+96 −0 added@@ -0,0 +1,96 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from tests.models.superglue.test_image_processing_superglue import ( + SuperGlueImageProcessingTest, + SuperGlueImageProcessingTester, +) +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + + +if is_torch_available(): + import numpy as np + import torch + + from transformers.models.lightglue.modeling_lightglue import LightGlueKeypointMatchingOutput + +if is_vision_available(): + from transformers import LightGlueImageProcessor + + +def random_array(size): + return np.random.randint(255, size=size) + + +def random_tensor(size): + return torch.rand(size) + + +class LightGlueImageProcessingTester(SuperGlueImageProcessingTester): + """Tester for LightGlueImageProcessor""" + + def __init__( + self, + parent, + batch_size=6, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_grayscale=True, + ): + super().__init__( + parent, batch_size, num_channels, image_size, min_resolution, max_resolution, do_resize, size, do_grayscale + ) + + def prepare_keypoint_matching_output(self, pixel_values): + """Prepare a fake output for the keypoint matching model with random matches between 50 keypoints per image.""" + max_number_keypoints = 50 + batch_size = len(pixel_values) + mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int) + keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2)) + matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int) + scores = torch.zeros((batch_size, 2, max_number_keypoints)) + prune = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int) + for i in range(batch_size): + random_number_keypoints0 = np.random.randint(10, max_number_keypoints) + random_number_keypoints1 = np.random.randint(10, max_number_keypoints) + random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1)) + mask[i, 0, :random_number_keypoints0] = 1 + mask[i, 1, :random_number_keypoints1] = 1 + keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2)) + keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2)) + random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches] + random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches] + matches[i, 0, random_matches_indices1] = random_matches_indices0 + matches[i, 1, random_matches_indices0] = random_matches_indices1 + scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,)) + scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,)) + return LightGlueKeypointMatchingOutput( + mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores, prune=prune + ) + + +@require_torch +@require_vision +class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase): + image_processing_class = LightGlueImageProcessor if is_vision_available() else None + + def setUp(self) -> None: + super().setUp() + self.image_processor_tester = LightGlueImageProcessingTester(self)
tests/models/lightglue/test_modeling_lightglue.py+584 −0 added@@ -0,0 +1,584 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import unittest + +from datasets import load_dataset + +from transformers.models.lightglue.configuration_lightglue import LightGlueConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch + + from transformers import LightGlueForKeypointMatching + +if is_vision_available(): + from transformers import AutoImageProcessor + + +class LightGlueModelTester: + def __init__( + self, + parent, + batch_size=2, + image_width=80, + image_height=60, + keypoint_detector_config={ + "encoder_hidden_sizes": [32, 32, 64], + "decoder_hidden_size": 64, + "keypoint_decoder_dim": 65, + "descriptor_decoder_dim": 64, + "keypoint_threshold": 0.005, + "max_keypoints": 256, + "nms_radius": 4, + "border_removal_distance": 4, + }, + descriptor_dim: int = 64, + num_layers: int = 2, + num_heads: int = 4, + depth_confidence: float = 1.0, + width_confidence: float = 1.0, + filter_threshold: float = 0.1, + matching_threshold: float = 0.0, + ): + self.parent = parent + self.batch_size = batch_size + self.image_width = image_width + self.image_height = image_height + + self.keypoint_detector_config = keypoint_detector_config + self.descriptor_dim = descriptor_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.matching_threshold = matching_threshold + + def prepare_config_and_inputs(self): + # LightGlue expects a grayscale image as input + pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return LightGlueConfig( + keypoint_detector_config=self.keypoint_detector_config, + descriptor_dim=self.descriptor_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + depth_confidence=self.depth_confidence, + width_confidence=self.width_confidence, + filter_threshold=self.filter_threshold, + matching_threshold=self.matching_threshold, + attn_implementation="eager", + ) + + def create_and_check_model(self, config, pixel_values): + model = LightGlueForKeypointMatching(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + maximum_num_matches = result.mask.shape[-1] + self.parent.assertEqual( + result.keypoints.shape, + (self.batch_size, 2, maximum_num_matches, 2), + ) + self.parent.assertEqual( + result.matches.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.matching_scores.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.prune.shape, + (self.batch_size, 2, maximum_num_matches), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class LightGlueModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (LightGlueForKeypointMatching,) if is_torch_available() else () + all_generative_model_classes = () if is_torch_available() else () + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = True + + def setUp(self): + self.model_tester = LightGlueModelTester(self) + self.config_tester = ConfigTester(self, config_class=LightGlueConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + @unittest.skip(reason="LightGlueForKeypointMatching does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching does not use feedforward chunking") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="LightGlue does not output any loss term in the forward pass") + def test_retain_grad_hidden_states_attentions(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + maximum_num_matches = outputs.mask.shape[-1] + + hidden_states_sizes = [ + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim * 2, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim * 2, + self.model_tester.descriptor_dim, + ] * self.model_tester.num_layers + + for i, hidden_states_size in enumerate(hidden_states_sizes): + self.assertListEqual( + list(hidden_states[i].shape[-2:]), + [maximum_num_matches, hidden_states_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + def check_attention_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + maximum_num_matches = outputs.mask.shape[-1] + + expected_attention_shape = [self.model_tester.num_heads, maximum_num_matches, maximum_num_matches] + + for i, attention in enumerate(attentions): + self.assertListEqual( + list(attention.shape[-3:]), + expected_attention_shape, + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + check_attention_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + + check_attention_output(inputs_dict, config, model_class) + + @slow + def test_model_from_pretrained(self): + from_pretrained_ids = ["ETH-CVG/lightglue_superpoint"] + for model_name in from_pretrained_ids: + model = LightGlueForKeypointMatching.from_pretrained(model_name) + self.assertIsNotNone(model) + + # Copied from tests.models.superglue.test_modeling_superglue.SuperGlueModelTest.test_forward_labels_should_be_none + def test_forward_labels_should_be_none(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + model_inputs = self._prepare_for_class(inputs_dict, model_class) + # Provide an arbitrary sized Tensor as labels to model inputs + model_inputs["labels"] = torch.rand((128, 128)) + + with self.assertRaises(ValueError) as cm: + model(**model_inputs) + self.assertEqual(ValueError, cm.exception.__class__) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image0 = dataset[0]["image"] + image1 = dataset[1]["image"] + image2 = dataset[2]["image"] + # [image1, image1] on purpose to test the model early stopping + return [[image2, image0], [image1, image1]] + + +@require_torch +@require_vision +class LightGlueModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") if is_vision_available() else None + + @slow + def test_inference(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", attn_implementation="eager" + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 140 + expected_matches_values0 = torch.tensor( + [14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], + dtype=torch.int64, + device=torch_device, + ) + expected_matching_scores_values0 = torch.tensor( + [0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583], + device=torch_device, + ) + + expected_number_of_matches1 = 866 + expected_matches_values1 = torch.tensor( + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], + dtype=torch.int64, + device=torch_device, + ) + expected_matching_scores_values1 = torch.tensor( + [ + 0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496, + 0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923, + ], + device=torch_device + ) # fmt:skip + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4) + + @slow + def test_inference_without_early_stop(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", attn_implementation="eager", depth_confidence=1.0 + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 134 + expected_matches_values0 = torch.tensor( + [-1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values0 = torch.tensor( + [0.0083, 0, 0.2022, 0.0621, 0, 0.0828, 0, 0, 0.0003, 0, 0, 0, 0.0960, 0, 0, 0.6940, 0, 0.7167, 0, 0.1512] + ).to(torch_device) + + expected_number_of_matches1 = 862 + expected_matches_values1 = torch.tensor( + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values1 = torch.tensor( + [ + 0.4772, + 0.3781, + 0.0631, + 0.9559, + 0.8746, + 0.9271, + 0.4882, + 0.5406, + 0.9439, + 0.1526, + 0.5028, + 0.4107, + 0.5591, + 0.9130, + 0.7572, + 0.0302, + 0.4532, + 0.0893, + 0.9490, + 0.4880, + ] + ).to(torch_device) + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4) + + @slow + def test_inference_without_early_stop_and_keypoint_pruning(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", + attn_implementation="eager", + depth_confidence=1.0, + width_confidence=1.0, + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 144 + expected_matches_values0 = torch.tensor( + [-1, -1, 17, -1, -1, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values0 = torch.tensor( + [ + 0.0699, + 0.0302, + 0.3356, + 0.0820, + 0, + 0.2266, + 0, + 0, + 0.0241, + 0, + 0, + 0, + 0.1674, + 0, + 0, + 0.8114, + 0, + 0.8120, + 0, + 0.2936, + ] + ).to(torch_device) + + expected_number_of_matches1 = 862 + expected_matches_values1 = torch.tensor( + [10, 11, -1, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, -1, 26, -1, 28, 29], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values1 = torch.tensor( + [ + 0.4772, + 0.3781, + 0.0631, + 0.9559, + 0.8746, + 0.9271, + 0.4882, + 0.5406, + 0.9439, + 0.1526, + 0.5028, + 0.4107, + 0.5591, + 0.9130, + 0.7572, + 0.0302, + 0.4532, + 0.0893, + 0.9490, + 0.4880, + ] + ).to(torch_device) + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
Vulnerability mechanics
Root cause
"The `trust_remote_code` parameter is improperly handled in nested model loading, allowing it to be overridden by untrusted configuration data."
Attack vector
An attacker can control a model repository and trick a victim into loading a malicious LightGlue model. When `AutoModel.from_pretrained()` is called with `trust_remote_code=False`, the `LightGlueConfig` reads a `trust_remote_code` value from the untrusted `config.json` file. This value is then propagated into nested `AutoConfig.from_pretrained()` calls, effectively bypassing the initial `trust_remote_code=False` setting and allowing arbitrary code execution [ref_id=1].
Affected code
The vulnerability exists within the `LightGlueConfig` class and its handling of the `trust_remote_code` parameter during model initialization. Specifically, the `__post_init__` method in `LightGlueConfig` and the `__init__` method in `LightGlueForKeypointMatching` are involved in the nested loading process where the `trust_remote_code` flag can be influenced by untrusted data [ref_id=1].
What the fix does
The patch removes the `trust_remote_code` parameter from the `LightGlueConfig` class and its usage in `AutoModelForKeypointDetection.from_config` [ref_id=1]. This prevents the `trust_remote_code` value from being read from the untrusted configuration and passed to nested loading functions, thus eliminating the possibility of remote code execution through this mechanism.
Preconditions
- inputAttacker controls a model repository with a malicious `config.json` file.
- authNo authentication is required for the attacker to host the malicious model repository.
- networkThe victim's environment must be able to reach the attacker-controlled model repository over the network.
Generated on Jun 3, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
2News mentions
0No linked articles in our index yet.