Skip to content

Commit ee3af60

Browse files
authored
Add support for fine-tuning CLIP-like models using contrastive-image-text example (#29070)
* add support for siglip and chinese-clip model training with contrastive-image-text example * codebase fixups
1 parent 0996a10 commit ee3af60

File tree

6 files changed

+20
-7
lines changed

6 files changed

+20
-7
lines changed

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
("camembert", "CamembertConfig"),
5555
("canine", "CanineConfig"),
5656
("chinese_clip", "ChineseCLIPConfig"),
57+
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
5758
("clap", "ClapConfig"),
5859
("clip", "CLIPConfig"),
5960
("clip_vision_model", "CLIPVisionConfig"),
@@ -512,6 +513,7 @@
512513
("camembert", "CamemBERT"),
513514
("canine", "CANINE"),
514515
("chinese_clip", "Chinese-CLIP"),
516+
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
515517
("clap", "CLAP"),
516518
("clip", "CLIP"),
517519
("clip_vision_model", "CLIPVisionModel"),
@@ -773,6 +775,7 @@
773775
("xclip", "x_clip"),
774776
("clip_vision_model", "clip"),
775777
("siglip_vision_model", "siglip"),
778+
("chinese_clip_vision_model", "chinese_clip"),
776779
]
777780
)
778781

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
("camembert", "CamembertModel"),
5858
("canine", "CanineModel"),
5959
("chinese_clip", "ChineseCLIPModel"),
60+
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
6061
("clap", "ClapModel"),
6162
("clip", "CLIPModel"),
6263
("clip_vision_model", "CLIPVisionModel"),

src/transformers/models/chinese_clip/configuration_chinese_clip.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ class ChineseCLIPVisionConfig(PretrainedConfig):
171171
This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
172172
ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
173173
configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
174-
[OFA-Sys/chinese-clip-vit-base-patch16](https:
175-
//huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
174+
[OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
176175
177176
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
178177
documentation from [`PretrainedConfig`] for more information.

src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@
1818
from ...configuration_utils import PretrainedConfig
1919
from ...utils import logging
2020
from ..auto.configuration_auto import AutoConfig
21+
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
2122
from ..clip.configuration_clip import CLIPVisionConfig
23+
from ..siglip.configuration_siglip import SiglipVisionConfig
2224

2325

2426
logger = logging.get_logger(__name__)
2527

28+
VISION_MODEL_CONFIGS = {
29+
"clip_vision_model": CLIPVisionConfig,
30+
"chinese_clip_vision_model": ChineseCLIPVisionConfig,
31+
"siglip_vision_model": SiglipVisionConfig,
32+
}
33+
2634

2735
class VisionTextDualEncoderConfig(PretrainedConfig):
2836
r"""
@@ -85,12 +93,13 @@ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
8593
vision_model_type = vision_config.pop("model_type")
8694
text_model_type = text_config.pop("model_type")
8795

88-
if vision_model_type == "clip":
89-
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
90-
elif vision_model_type == "clip_vision_model":
91-
self.vision_config = CLIPVisionConfig(**vision_config)
96+
vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
97+
if vision_config_class is not None:
98+
self.vision_config = vision_config_class(**vision_config)
9299
else:
93100
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
101+
if hasattr(self.vision_config, "vision_config"):
102+
self.vision_config = self.vision_config.vision_config
94103

95104
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
96105

utils/check_copies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def check_model_list_copy(overwrite: bool = False):
10701070
"VisionTextDualEncoder",
10711071
"CLIPVisionModel",
10721072
"SiglipVisionModel",
1073+
"ChineseCLIPVisionModel",
10731074
]
10741075

10751076
# Template for new entries to add in the main README when we have missing models.

utils/check_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _center_text(text: str, width: int) -> str:
171171
"XLS-R": "Wav2Vec2",
172172
"XLSR-Wav2Vec2": "Wav2Vec2",
173173
}
174-
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel"]
174+
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel"]
175175

176176

177177
def get_model_table_from_auto_modules() -> str:

0 commit comments

Comments
 (0)