Source code for stable_pretraining.utils.timm_to_hf_hub

import os
import torch
import timm
from typing import Optional, Tuple
from huggingface_hub import HfFolder, create_repo, upload_folder
from transformers import (
    ViTConfig,
    ViTModel,
    DeiTConfig,
    DeiTModel,
    SwinConfig,
    SwinModel,
    ConvNextConfig,
    ConvNextModel,
)
import safetensors.torch


[docs] def push_timm_to_hf( model_name: str, model: torch.nn.Module, repo_id: str, hf_token: Optional[str] = None, private: bool = False, validate: bool = True, batch_size: int = 2, atol: float = 1e-4, rtol: float = 1e-4, device: Optional[str] = None, strict: bool = False, ) -> str: family_map = { "vit": (ViTConfig, ViTModel), "deit": (DeiTConfig, DeiTModel), "swin": (SwinConfig, SwinModel), "convnext": (ConvNextConfig, ConvNextModel), } hf_token = ( hf_token or os.environ.get("HUGGINGFACE_HUB_TOKEN") or HfFolder.get_token() ) if not hf_token: raise RuntimeError( "Hugging Face token not found. Pass hf_token or set HUGGINGFACE_HUB_TOKEN." ) family = next((fam for fam in family_map if fam in model_name.lower()), None) repo_url = f"https://huggingface.co/{repo_id}" local_dir = f"./{repo_id.replace('/', '__')}" os.makedirs(local_dir, exist_ok=True) try: create_repo(repo_id, token=hf_token, private=private, exist_ok=True) except Exception as e: print(f"Repo creation warning: {e}") # Model card (README) readme = f"""--- tags: - timm - vision - {family or "custom"} license: apache-2.0 --- # Model: {repo_id} - TIMM model: `{model_name}` - TIMM version: {timm.__version__} - Architecture: {family or "custom"} - Converted: {"Transformers" if family else "PyTorch"} - Example usage: ```python from transformers import AutoModel, AutoImageProcessor model = AutoModel.from_pretrained("{repo_id}", trust_remote_code=True) processor = AutoImageProcessor.from_pretrained("{repo_id}")""" with open(f"{local_dir}/README.md", "w") as f: f.write(readme) if family: config_cls, model_cls = family_map[family] image_size = _normalize_img_size(getattr(model, "img_size", 224)) in_chans = getattr(model, "in_chans", 3) num_labels = getattr(model, "num_classes", 1000) config = config_cls( image_size=image_size[0] if isinstance(image_size, tuple) else image_size, num_channels=in_chans, num_labels=num_labels, ) hf_model = model_cls(config) try: hf_model.load_state_dict(model.state_dict(), strict=False) except Exception as e: print(f"State dict mapping failed: {e}") family = None if family and validate: _validate_timm_vs_hf( timm_model=model, hf_model=hf_model, family=family, batch_size=batch_size, atol=atol, rtol=rtol, device=device, strict=strict, ) if family: safetensors.torch.save_file( hf_model.state_dict(), f"{local_dir}/model.safetensors" ) hf_model.config.save_pretrained(local_dir) _save_image_processor(local_dir, image_size, family) upload_folder( repo_id=repo_id, folder_path=local_dir, token=hf_token, commit_message="Push TIMM-converted model", ) return repo_url torch.save(model.state_dict(), f"{local_dir}/pytorch_model.bin") with open(f"{local_dir}/model_type.txt", "w") as f: f.write( f"TIMM model: {model_name}\nNo direct Transformers conversion available.\n" ) upload_folder( repo_id=repo_id, folder_path=local_dir, token=hf_token, commit_message="Push plain TIMM PyTorch model", ) print( f"WARNING: {model_name} not natively supported for Transformers conversion. Uploaded as PyTorch weights." ) return repo_url
def _normalize_img_size(img_size) -> Tuple[int, int]: if isinstance(img_size, (list, tuple)): if len(img_size) == 2: return int(img_size[0]), int(img_size[1]) return int(img_size[0]), int(img_size[0]) return int(img_size), int(img_size) def _validate_timm_vs_hf( timm_model: torch.nn.Module, hf_model: torch.nn.Module, family: str, batch_size: int, atol: float, rtol: float, device: Optional[str], strict: bool, ): device = device or ("cuda" if torch.cuda.is_available() else "cpu") timm_model = timm_model.to(device).eval() hf_model = hf_model.to(device).eval() torch.manual_seed(42) img_size = _normalize_img_size(getattr(timm_model, "img_size", 224)) h, w = img_size x = torch.rand(batch_size, 3, h, w, device=device, dtype=torch.float32) mean = torch.tensor([0.485, 0.456, 0.406], device=device, dtype=torch.float32).view( 1, 3, 1, 1 ) std = torch.tensor([0.229, 0.224, 0.225], device=device, dtype=torch.float32).view( 1, 3, 1, 1 ) x_norm = (x - mean) / std with torch.no_grad(): feats_timm = _extract_features_timm(timm_model, x_norm, family) feats_hf = _extract_features_hf(hf_model, x_norm, family) if feats_timm.shape != feats_hf.shape: min_last = min(feats_timm.shape[-1], feats_hf.shape[-1]) feats_timm = feats_timm[..., :min_last] feats_hf = feats_hf[..., :min_last] cos_sim, max_abs_diff = _compare_tensors(feats_timm, feats_hf) print( f"[Sanity Check] Cosine similarity: {cos_sim:.6f}, Max abs diff: {max_abs_diff:.6g}" ) thresh = atol + rtol * feats_timm.abs().max().item() if not (cos_sim >= 0.999 or max_abs_diff <= thresh): msg = f"Sanity check failed: cosine={cos_sim:.6f}, max_abs_diff={max_abs_diff:.6g}, thresh={thresh:.6g}" if strict: raise ValueError(msg) print("WARNING:", msg) def _extract_features_timm( model: torch.nn.Module, x: torch.Tensor, family: str ) -> torch.Tensor: if family in ("vit", "deit"): out = model.forward_features(x) if isinstance(out, dict): if "cls_token" in out: return out["cls_token"] if "x" in out and out["x"].ndim == 3: return out["x"][:, 0] if "x_norm_cls" in out: return out["x_norm_cls"] raise RuntimeError("Unexpected forward_features dict for ViT/DeiT.") return out[:, 0] if out.ndim == 3 else out elif family == "swin": out = model.forward_features(x) if isinstance(out, dict) and "x" in out: tokens = out["x"] return tokens.mean(dim=1) return out.mean(dim=1) if out.ndim == 3 else out.mean(dim=[2, 3]) elif family == "convnext": features = model.forward_features(x) if isinstance(features, dict) and "x" in features: features = features["x"] return features.mean(dim=[2, 3]) raise NotImplementedError(f"Sanity check not implemented for {family}") def _extract_features_hf( model: torch.nn.Module, x: torch.Tensor, family: str ) -> torch.Tensor: if family in ("vit", "deit"): out = model(x) return out.last_hidden_state[:, 0] elif family == "swin": out = model(x) if hasattr(out, "pooler_output") and out.pooler_output is not None: return out.pooler_output return out.last_hidden_state.mean(dim=1) elif family == "convnext": out = model(x) if hasattr(out, "pooler_output") and out.pooler_output is not None: return out.pooler_output return out.last_hidden_state.mean(dim=[2, 3]) raise NotImplementedError(f"Sanity check not implemented for {family}") def _compare_tensors(a: torch.Tensor, b: torch.Tensor) -> Tuple[float, float]: a = a.view(a.size(0), -1).float().cpu() b = b.view(b.size(0), -1).float().cpu() cos = torch.nn.functional.cosine_similarity(a, b, dim=1) return float(cos.mean()), float((a - b).abs().max()) def _save_image_processor( local_dir: str, img_size: Tuple[int, int], family: str ) -> None: h, w = img_size image_mean = [0.485, 0.456, 0.406] image_std = [0.229, 0.224, 0.225] try: import json proc_cfg = { "_class_name": "AutoImageProcessor", "do_resize": True, "size": {"height": h, "width": w}, "do_center_crop": False, "do_normalize": True, "image_mean": image_mean, "image_std": image_std, } with open(os.path.join(local_dir, "preprocessor_config.json"), "w") as f: json.dump(proc_cfg, f, indent=2) except Exception as e: print(f"Image processor fallback save warning: {e}") if __name__ == "__main__": timm_model_name = "vit_base_patch16_224" model = timm.create_model(timm_model_name, pretrained=True) repo_id = "your-username/my-vit-base-patch16-224" url = push_timm_to_hf( timm_model_name, model, repo_id, private=True, validate=True, strict=False ) print(f"Model pushed to: {url}")