diff --git a/backend/README.md b/backend/README.md
deleted file mode 100644
index 4e6bf71c7455538db98de6e1adc059efef73b083..0000000000000000000000000000000000000000
--- a/backend/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# Backend API
-
-FastAPI backend for serving model data to the React frontend.
-
-## Structure
-
-- `api/` - API routes and main application
-- `services/` - External service integrations (arXiv, model tracking, scheduling)
-- `utils/` - Utility modules (data loading, embeddings, dimensionality reduction, clustering, network analysis)
-- `config/` - Configuration files (requirements.txt, etc.)
-- `cache/` - Cached data (embeddings, reduced dimensions)
-
-## Running
-
-```bash
-cd backend
-uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
-```
-
-## Environment Variables
-
-- `SAMPLE_SIZE` - Limit number of models to load (for development). Set to 0 or leave unset to load all models.
-
-
diff --git a/backend/api/dependencies.py b/backend/api/dependencies.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b755ecace186f059d133d1fb23ef8a8625c0bcb
--- /dev/null
+++ b/backend/api/dependencies.py
@@ -0,0 +1,23 @@
+"""Shared dependencies for API routes."""
+import pandas as pd
+import numpy as np
+from typing import Optional, Dict
+from utils.data_loader import ModelDataLoader
+from utils.embeddings import ModelEmbedder
+from utils.dimensionality_reduction import DimensionReducer
+from utils.graph_embeddings import GraphEmbedder
+
+# Global state (initialized in startup) - these are module-level variables
+# that will be updated by main.py during startup
+data_loader = ModelDataLoader()
+embedder: Optional[ModelEmbedder] = None
+graph_embedder: Optional[GraphEmbedder] = None
+reducer: Optional[DimensionReducer] = None
+df: Optional[pd.DataFrame] = None
+embeddings: Optional[np.ndarray] = None
+graph_embeddings_dict: Optional[Dict[str, np.ndarray]] = None
+combined_embeddings: Optional[np.ndarray] = None
+reduced_embeddings: Optional[np.ndarray] = None
+reduced_embeddings_graph: Optional[np.ndarray] = None
+cluster_labels: Optional[np.ndarray] = None
+
diff --git a/backend/api/main.py b/backend/api/main.py
index 106b009ab38b4976fd0e9d36b885c899c6573f85..884e92852f48cd0a1465db49243eed30e4a7f463 100644
--- a/backend/api/main.py
+++ b/backend/api/main.py
@@ -1,202 +1,216 @@
-"""
-FastAPI backend for serving model data to React/Visx frontend.
-"""
import sys
import os
-backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-if backend_dir not in sys.path:
- sys.path.insert(0, backend_dir)
+import pickle
+import tempfile
+import logging
+from typing import Optional, List, Dict
+from datetime import datetime, timedelta
+import pandas as pd
+import numpy as np
+import httpx
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
-from typing import Optional, List, Dict
-import pandas as pd
-import numpy as np
from pydantic import BaseModel
from umap import UMAP
-import tempfile
-import traceback
-import httpx
from utils.data_loader import ModelDataLoader
from utils.embeddings import ModelEmbedder
from utils.dimensionality_reduction import DimensionReducer
from utils.network_analysis import ModelNetworkBuilder
+from utils.graph_embeddings import GraphEmbedder
from services.model_tracker import get_tracker
-from services.model_tracker_improved import get_improved_tracker
from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers
+from core.config import settings
+from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
+from models.schemas import ModelPoint
+from utils.family_tree import calculate_family_depths
+import api.dependencies as deps
+from api.routes import models, stats, clusters
+
+# Create aliases for backward compatibility with existing routes
+# Note: These are set at module load time and may be None initially
+# Functions should access via deps.* to get current values
+data_loader = deps.data_loader
+
+backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+if backend_dir not in sys.path:
+ sys.path.insert(0, backend_dir)
-app = FastAPI(title="HF Model Ecosystem API")
+logger = logging.getLogger(__name__)
+
+app = FastAPI(title="HF Model Ecosystem API", version="2.0.0")
app.add_middleware(GZipMiddleware, minimum_size=1000)
+CORS_HEADERS = {
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Methods": "*",
+ "Access-Control-Allow-Headers": "*",
+}
+
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
- """Global exception handler that ensures CORS headers are included even on errors."""
- import traceback
- error_detail = str(exc)
- traceback_str = traceback.format_exc()
- import sys
- sys.stderr.write(f"Unhandled exception: {error_detail}\n{traceback_str}\n")
+ logger.exception("Unhandled exception", exc_info=exc)
return JSONResponse(
status_code=500,
- content={"detail": error_detail, "error": "Internal server error"},
- headers={
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "*",
- "Access-Control-Allow-Headers": "*",
- }
+ content={"detail": "Internal server error"},
+ headers=CORS_HEADERS,
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
- """HTTP exception handler with CORS headers."""
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail},
- headers={
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "*",
- "Access-Control-Allow-Headers": "*",
- }
+ headers=CORS_HEADERS,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
- """Validation exception handler with CORS headers."""
return JSONResponse(
status_code=422,
content={"detail": exc.errors()},
- headers={
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "*",
- "Access-Control-Allow-Headers": "*",
- }
+ headers=CORS_HEADERS,
)
-# CORS middleware for React frontend
-# Update allow_origins with your Netlify URL in production
-# Note: Add your specific Netlify URL after deployment
-FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000")
-# Allow all origins for development (restrict in production)
-ALLOW_ALL_ORIGINS = os.getenv("ALLOW_ALL_ORIGINS", "true").lower() == "true"
-if ALLOW_ALL_ORIGINS:
+if settings.ALLOW_ALL_ORIGINS:
app.add_middleware(
CORSMiddleware,
- allow_origins=["*"], # Allow all origins in development
- allow_credentials=False, # Must be False when allow_origins is ["*"]
+ allow_origins=["*"],
+ allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
else:
app.add_middleware(
CORSMiddleware,
- allow_origins=[
- "http://localhost:3000", # Local development
- FRONTEND_URL, # Production frontend URL
- # Add your Netlify URL here after deployment, e.g.:
- # "https://your-app-name.netlify.app",
- ],
+ allow_origins=["http://localhost:3000", settings.FRONTEND_URL],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
-data_loader = ModelDataLoader()
-embedder: Optional[ModelEmbedder] = None
-reducer: Optional[DimensionReducer] = None
-df: Optional[pd.DataFrame] = None
-embeddings: Optional[np.ndarray] = None
-reduced_embeddings: Optional[np.ndarray] = None
-cluster_labels: Optional[np.ndarray] = None # Cached cluster assignments
-
-
-class FilterParams(BaseModel):
- min_downloads: int = 0
- min_likes: int = 0
- search_query: Optional[str] = None
- libraries: Optional[List[str]] = None
- pipeline_tags: Optional[List[str]] = None
-
-
-class ModelPoint(BaseModel):
- model_id: str
- x: float
- y: float
- z: float # 3D coordinate
- library_name: Optional[str]
- pipeline_tag: Optional[str]
- downloads: int
- likes: int
- trending_score: Optional[float]
- tags: Optional[str]
- parent_model: Optional[str] = None
- licenses: Optional[str] = None
- family_depth: Optional[int] = None # Generation depth in family tree (0 = root)
- cluster_id: Optional[int] = None # Cluster assignment for visualization
+# Include routers
+app.include_router(models.router)
+app.include_router(stats.router)
+app.include_router(clusters.router)
@app.on_event("startup")
async def startup_event():
- """Initialize data and models on startup with caching."""
- global df, embedder, reducer, embeddings, reduced_embeddings
+ # All variables are accessed via deps module, no need for global declarations
- import os
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.dirname(backend_dir)
cache_dir = os.path.join(root_dir, "cache")
os.makedirs(cache_dir, exist_ok=True)
embeddings_cache = os.path.join(cache_dir, "embeddings.pkl")
+ graph_embeddings_cache = os.path.join(cache_dir, "graph_embeddings.pkl")
+ combined_embeddings_cache = os.path.join(cache_dir, "combined_embeddings.pkl")
reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl")
+ reduced_cache_umap_graph = os.path.join(cache_dir, "reduced_umap_3d_graph.pkl")
reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl")
+ reducer_cache_umap_graph = os.path.join(cache_dir, "reducer_umap_3d_graph.pkl")
- sample_size_env = os.getenv("SAMPLE_SIZE")
- if sample_size_env is None:
- sample_size = None
+ sample_size = settings.get_sample_size()
+ if sample_size:
+ logger.info(f"Loading limited dataset: {sample_size} models (SAMPLE_SIZE={sample_size})")
else:
- sample_size = int(sample_size_env)
- if sample_size == 0:
- sample_size = None
- df = data_loader.load_data(sample_size=sample_size)
- df = data_loader.preprocess_for_embedding(df)
-
- if 'model_id' in df.columns:
- df.set_index('model_id', drop=False, inplace=True)
+ logger.info("No SAMPLE_SIZE set, loading full dataset")
+
+ deps.df = deps.data_loader.load_data(sample_size=sample_size)
+ deps.df = deps.data_loader.preprocess_for_embedding(deps.df)
+
+ if 'model_id' in deps.df.columns:
+ deps.df.set_index('model_id', drop=False, inplace=True)
for col in ['downloads', 'likes']:
- if col in df.columns:
- df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
+ if col in deps.df.columns:
+ deps.df[col] = pd.to_numeric(deps.df[col], errors='coerce').fillna(0).astype(int)
- embedder = ModelEmbedder()
+ deps.embedder = ModelEmbedder()
+ # Load or generate text embeddings
if os.path.exists(embeddings_cache):
try:
- embeddings = embedder.load_embeddings(embeddings_cache)
+ deps.embeddings = deps.embedder.load_embeddings(embeddings_cache)
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached embeddings: {e}")
+ deps.embeddings = None
+
+ if deps.embeddings is None:
+ texts = deps.df['combined_text'].tolist()
+ deps.embeddings = deps.embedder.generate_embeddings(texts, batch_size=128)
+ deps.embedder.save_embeddings(deps.embeddings, embeddings_cache)
+
+ # Initialize graph embedder and generate graph embeddings (optional, lazy-loaded)
+ if settings.USE_GRAPH_EMBEDDINGS:
+ try:
+ deps.graph_embedder = GraphEmbedder()
+ logger.info("Building family graph for graph embeddings...")
+ graph = deps.graph_embedder.build_family_graph(deps.df)
+
+ if os.path.exists(graph_embeddings_cache):
+ try:
+ deps.graph_embeddings_dict = deps.graph_embedder.load_embeddings(graph_embeddings_cache)
+ logger.info(f"Loaded cached graph embeddings for {len(deps.graph_embeddings_dict)} models")
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached graph embeddings: {e}")
+ deps.graph_embeddings_dict = None
+
+ if deps.graph_embeddings_dict is None or len(deps.graph_embeddings_dict) == 0:
+ logger.info("Generating graph embeddings (this may take a while)...")
+ deps.graph_embeddings_dict = deps.graph_embedder.generate_graph_embeddings(graph, workers=4)
+ if deps.graph_embeddings_dict:
+ deps.graph_embedder.save_embeddings(deps.graph_embeddings_dict, graph_embeddings_cache)
+ logger.info(f"Generated graph embeddings for {len(deps.graph_embeddings_dict)} models")
+
+ # Combine text and graph embeddings
+ if deps.graph_embeddings_dict and len(deps.graph_embeddings_dict) > 0:
+ model_ids = deps.df['model_id'].astype(str).tolist()
+ if os.path.exists(combined_embeddings_cache):
+ try:
+ with open(combined_embeddings_cache, 'rb') as f:
+ deps.combined_embeddings = pickle.load(f)
+ logger.info("Loaded cached combined embeddings")
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached combined embeddings: {e}")
+ deps.combined_embeddings = None
+
+ if deps.combined_embeddings is None:
+ logger.info("Combining text and graph embeddings...")
+ deps.combined_embeddings = deps.graph_embedder.combine_embeddings(
+ deps.embeddings, deps.graph_embeddings_dict, model_ids,
+ text_weight=0.7, graph_weight=0.3
+ )
+ with open(combined_embeddings_cache, 'wb') as f:
+ pickle.dump(deps.combined_embeddings, f)
+ logger.info("Combined embeddings saved")
except Exception as e:
- embeddings = None
-
- if embeddings is None:
- texts = df['combined_text'].tolist()
- embeddings = embedder.generate_embeddings(texts, batch_size=128)
- embedder.save_embeddings(embeddings, embeddings_cache)
+ logger.warning(f"Graph embeddings not available: {e}. Continuing with text-only embeddings.")
+ deps.graph_embedder = None
+ deps.graph_embeddings_dict = None
+ deps.combined_embeddings = None
- reducer = DimensionReducer(method="umap", n_components=3)
+ # Initialize reducer for text embeddings
+ deps.reducer = DimensionReducer(method="umap", n_components=3)
if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap):
try:
- import pickle
with open(reduced_cache_umap, 'rb') as f:
- reduced_embeddings = pickle.load(f)
- reducer.load_reducer(reducer_cache_umap)
- except Exception as e:
- reduced_embeddings = None
-
- if reduced_embeddings is None:
- reducer.reducer = UMAP(
+ deps.reduced_embeddings = pickle.load(f)
+ deps.reducer.load_reducer(reducer_cache_umap)
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
+ deps.reduced_embeddings = None
+
+ if deps.reduced_embeddings is None:
+ deps.reducer.reducer = UMAP(
n_components=3,
n_neighbors=30,
min_dist=0.3,
@@ -206,61 +220,57 @@ async def startup_event():
low_memory=True,
spread=1.5
)
- reduced_embeddings = reducer.fit_transform(embeddings)
- import pickle
+ deps.reduced_embeddings = deps.reducer.fit_transform(deps.embeddings)
with open(reduced_cache_umap, 'wb') as f:
- pickle.dump(reduced_embeddings, f)
- reducer.save_reducer(reducer_cache_umap)
-
-
-def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
- """
- Calculate family tree depth for each model.
- Returns a dictionary mapping model_id to depth (0 = root, 1 = first generation, etc.)
- """
- depths = {}
- visited = set()
+ pickle.dump(deps.reduced_embeddings, f)
+ deps.reducer.save_reducer(reducer_cache_umap)
- def get_depth(model_id: str) -> int:
- if model_id in depths:
- return depths[model_id]
- if model_id in visited:
- # Circular reference, treat as root
- depths[model_id] = 0
- return 0
+ # Initialize reducer for graph-aware embeddings if available
+ if deps.combined_embeddings is not None:
+ reducer_graph = DimensionReducer(method="umap", n_components=3)
- visited.add(model_id)
-
- if model_id not in df.index:
- depths[model_id] = 0
- return 0
-
- parent_id = df.loc[model_id].get('parent_model')
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan' and str(parent_id) != '':
- parent_id_str = str(parent_id)
- if parent_id_str in df.index:
- depth = get_depth(parent_id_str) + 1
- else:
- depth = 0 # Parent not in dataset, treat as root
- else:
- depth = 0 # No parent, this is a root
+ if os.path.exists(reduced_cache_umap_graph) and os.path.exists(reducer_cache_umap_graph):
+ try:
+ with open(reduced_cache_umap_graph, 'rb') as f:
+ deps.reduced_embeddings_graph = pickle.load(f)
+ reducer_graph.load_reducer(reducer_cache_umap_graph)
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached graph-aware reduced embeddings: {e}")
+ deps.reduced_embeddings_graph = None
- depths[model_id] = depth
- return depth
-
- for model_id in df.index:
- if model_id not in depths:
- visited = set() # Reset for each tree
- get_depth(model_id)
+ if deps.reduced_embeddings_graph is None:
+ reducer_graph.reducer = UMAP(
+ n_components=3,
+ n_neighbors=30,
+ min_dist=0.3,
+ metric='cosine',
+ random_state=42,
+ n_jobs=-1,
+ low_memory=True,
+ spread=1.5
+ )
+ deps.reduced_embeddings_graph = reducer_graph.fit_transform(deps.combined_embeddings)
+ with open(reduced_cache_umap_graph, 'wb') as f:
+ pickle.dump(deps.reduced_embeddings_graph, f)
+ reducer_graph.save_reducer(reducer_cache_umap_graph)
+ logger.info("Graph-aware embeddings reduced and cached")
- return depths
+ # Update module-level aliases
+ df = deps.df
+ embedder = deps.embedder
+ graph_embedder = deps.graph_embedder
+ reducer = deps.reducer
+ embeddings = deps.embeddings
+ graph_embeddings_dict = deps.graph_embeddings_dict
+ combined_embeddings = deps.combined_embeddings
+ reduced_embeddings = deps.reduced_embeddings
+ reduced_embeddings_graph = deps.reduced_embeddings_graph
+
+
+from utils.family_tree import calculate_family_depths
def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
- """
- Compute clusters using KMeans on reduced embeddings.
- Returns cluster labels for each point.
- """
from sklearn.cluster import KMeans
n_samples = len(reduced_embeddings)
@@ -268,8 +278,7 @@ def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np
n_clusters = max(1, n_samples // 10)
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
- cluster_labels = kmeans.fit_predict(reduced_embeddings)
- return cluster_labels
+ return kmeans.fit_predict(reduced_embeddings)
@app.get("/")
@@ -284,24 +293,16 @@ async def get_models(
search_query: Optional[str] = Query(None),
color_by: str = Query("library_name"),
size_by: str = Query("downloads"),
- max_points: Optional[int] = Query(None), # Optional limit (None = all points)
- projection_method: str = Query("umap"), # umap or tsne
- base_models_only: bool = Query(False) # Only show root models (no parent)
+ max_points: Optional[int] = Query(None),
+ projection_method: str = Query("umap"),
+ base_models_only: bool = Query(False),
+ max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
+ use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
):
- """
- Get filtered models with 3D coordinates for visualization.
- Supports multiple projection methods: UMAP or t-SNE.
- If base_models_only=True, only returns root models (models without a parent_model).
-
- Returns a JSON object with:
- - models: List of ModelPoint objects
- - filtered_count: Number of models matching filters (before max_points sampling)
- - returned_count: Number of models actually returned (after max_points sampling)
- """
- global df, embedder, reducer, embeddings, reduced_embeddings
+ if deps.df is None:
+ raise DataNotLoadedError()
- if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ df = deps.df
# Filter data
filtered_df = data_loader.filter_data(
@@ -321,7 +322,12 @@ async def get_models(
(filtered_df['parent_model'].astype(str) == 'nan')
]
- # Store the filtered count BEFORE sampling
+ if max_hierarchy_depth is not None:
+ family_depths = calculate_family_depths(df)
+ filtered_df = filtered_df[
+ filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
+ ]
+
filtered_count = len(filtered_df)
if len(filtered_df) == 0:
@@ -332,42 +338,53 @@ async def get_models(
}
if max_points is not None and len(filtered_df) > max_points:
- # Use stratified sampling to preserve distribution of important attributes
- # Sample proportionally from different libraries/pipelines for better representation
if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
- # Stratified sampling by library
- filtered_df = filtered_df.groupby('library_name', group_keys=False).apply(
- lambda x: x.sample(min(len(x), max(1, int(max_points * len(x) / len(filtered_df)))), random_state=42)
- ).reset_index(drop=True)
- # If still too many, random sample the rest
+ # Sample proportionally by library, preserving all columns
+ sampled_dfs = []
+ for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
+ n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
+ sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
+ filtered_df = pd.concat(sampled_dfs, ignore_index=True)
if len(filtered_df) > max_points:
- filtered_df = filtered_df.sample(n=max_points, random_state=42)
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
+ else:
+ filtered_df = filtered_df.reset_index(drop=True)
else:
- filtered_df = filtered_df.sample(n=max_points, random_state=42)
-
- if embeddings is None:
- raise HTTPException(status_code=503, detail="Embeddings not loaded")
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
- if reduced_embeddings is None or (reducer and reducer.method != projection_method.lower()):
- import os
+ # Determine which embeddings to use
+ if use_graph_embeddings and combined_embeddings is not None:
+ current_embeddings = combined_embeddings
+ current_reduced = reduced_embeddings_graph
+ embedding_type = "graph-aware"
+ else:
+ if embeddings is None:
+ raise EmbeddingsNotReadyError()
+ current_embeddings = embeddings
+ current_reduced = reduced_embeddings
+ embedding_type = "text-only"
+
+ # Handle reduced embeddings loading/generation
+ if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.dirname(backend_dir)
cache_dir = os.path.join(root_dir, "cache")
- reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d.pkl")
- reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d.pkl")
+ cache_suffix = "_graph" if use_graph_embeddings and combined_embeddings is not None else ""
+ reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
+ reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
try:
- import pickle
with open(reduced_cache, 'rb') as f:
- reduced_embeddings = pickle.load(f)
+ current_reduced = pickle.load(f)
if reducer is None or reducer.method != projection_method.lower():
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
reducer.load_reducer(reducer_cache)
- except Exception as e:
- reduced_embeddings = None
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
+ current_reduced = None
- if reduced_embeddings is None:
+ if current_reduced is None:
if reducer is None or reducer.method != projection_method.lower():
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
if projection_method.lower() == "umap":
@@ -381,52 +398,91 @@ async def get_models(
low_memory=True,
spread=1.5
)
- reduced_embeddings = reducer.fit_transform(embeddings)
- import pickle
+ current_reduced = reducer.fit_transform(current_embeddings)
with open(reduced_cache, 'wb') as f:
- pickle.dump(reduced_embeddings, f)
+ pickle.dump(current_reduced, f)
reducer.save_reducer(reducer_cache)
+
+ # Update global variable
+ if use_graph_embeddings and deps.combined_embeddings is not None:
+ deps.reduced_embeddings_graph = current_reduced
+ else:
+ deps.reduced_embeddings = current_reduced
+
+ # Get indices for filtered data
+ # Use model_id column to map between filtered_df and original df
+ # This is safer than using index positions which can change after filtering
+ filtered_model_ids = filtered_df['model_id'].astype(str).values
- # Get coordinates for filtered data - optimized vectorized approach
- # Map filtered dataframe indices to original dataframe integer positions
- # Since df is indexed by model_id, we need to get the integer positions
+ # Map model_ids to positions in original df
if df.index.name == 'model_id' or 'model_id' in df.index.names:
- # Get integer positions of filtered rows in original dataframe
- # Use vectorized lookup for better performance
- filtered_indices = np.array([df.index.get_loc(idx) for idx in filtered_df.index], dtype=np.int32)
+ # When df is indexed by model_id, use get_loc directly
+ filtered_indices = []
+ for model_id in filtered_model_ids:
+ try:
+ pos = df.index.get_loc(model_id)
+ # Handle both single position and array of positions
+ if isinstance(pos, (int, np.integer)):
+ filtered_indices.append(int(pos))
+ elif isinstance(pos, (slice, np.ndarray)):
+ # If multiple matches, take first
+ if isinstance(pos, slice):
+ filtered_indices.append(int(pos.start))
+ else:
+ filtered_indices.append(int(pos[0]))
+ except (KeyError, TypeError):
+ continue
+ filtered_indices = np.array(filtered_indices, dtype=np.int32)
else:
- # If using integer index, use directly
- filtered_indices = filtered_df.index.values.astype(np.int32)
-
- # Use advanced indexing for faster access
- filtered_reduced = reduced_embeddings[filtered_indices]
+ # When df is not indexed by model_id, find positions by matching model_id column
+ df_model_ids = df['model_id'].astype(str).values
+ model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
+ filtered_indices = np.array([
+ model_id_to_pos[mid] for mid in filtered_model_ids
+ if mid in model_id_to_pos
+ ], dtype=np.int32)
+
+ if len(filtered_indices) == 0:
+ return {
+ "models": [],
+ "embedding_type": embedding_type,
+ "filtered_count": filtered_count,
+ "returned_count": 0
+ }
+ filtered_reduced = current_reduced[filtered_indices]
family_depths = calculate_family_depths(df)
- global cluster_labels
- if cluster_labels is None or len(cluster_labels) != len(reduced_embeddings):
- cluster_labels = compute_clusters(reduced_embeddings, n_clusters=min(50, len(reduced_embeddings) // 100))
+ # Use appropriate embeddings for clustering
+ clustering_embeddings = current_reduced
+ # Compute clusters if not already computed or if size changed
+ if models.cluster_labels is None or len(models.cluster_labels) != len(clustering_embeddings):
+ models.cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
- filtered_clusters = cluster_labels[filtered_indices]
+ # Handle case where cluster_labels might not match filtered data yet
+ if models.cluster_labels is not None and len(models.cluster_labels) > 0:
+ if len(filtered_indices) <= len(models.cluster_labels):
+ filtered_clusters = models.cluster_labels[filtered_indices]
+ else:
+ # Fallback: use first cluster for all if indices don't match
+ filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
+ else:
+ filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
- # Build response with optimized vectorized operations
- # Pre-extract arrays for faster access
model_ids = filtered_df['model_id'].astype(str).values
- library_names = filtered_df['library_name'].values
- pipeline_tags = filtered_df['pipeline_tag'].values
- downloads_arr = filtered_df['downloads'].fillna(0).astype(int).values
- likes_arr = filtered_df['likes'].fillna(0).astype(int).values
- trending_scores = filtered_df.get('trendingScore', pd.Series()).values
- tags_arr = filtered_df.get('tags', pd.Series()).values
- parent_models = filtered_df.get('parent_model', pd.Series()).values
- licenses_arr = filtered_df.get('licenses', pd.Series()).values
-
- # Vectorized coordinate extraction
+ library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
+ pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
+ downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
+ likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
+ trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
+ tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
+ parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
+ licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
+ created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
+
x_coords = filtered_reduced[:, 0].astype(float)
y_coords = filtered_reduced[:, 1].astype(float)
z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
-
- # Build models list with optimized operations
models = [
ModelPoint(
model_id=model_ids[idx],
@@ -442,28 +498,42 @@ async def get_models(
parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
family_depth=family_depths.get(model_ids[idx], None),
- cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None
+ cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
+ created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
)
for idx in range(len(filtered_df))
]
- return models
+ # Return models with metadata about embedding type
+ return {
+ "models": models,
+ "embedding_type": embedding_type,
+ "filtered_count": filtered_count,
+ "returned_count": len(models)
+ }
@app.get("/api/stats")
async def get_stats():
"""Get dataset statistics."""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
- # Use len(df.index) to handle both regular and indexed DataFrames correctly
total_models = len(df.index) if hasattr(df, 'index') else len(df)
+ # Get unique licenses with counts
+ licenses = {}
+ if 'license' in df.columns:
+ license_counts = df['license'].value_counts().to_dict()
+ licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
+
return {
"total_models": total_models,
"unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
"unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
"unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity
+ "unique_licenses": len(licenses),
+ "licenses": licenses, # License name -> count mapping
"avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
"avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
}
@@ -473,7 +543,7 @@ async def get_stats():
async def get_model_details(model_id: str):
"""Get detailed information about a specific model."""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
model = df[df.get('model_id', '') == model_id]
if len(model) == 0:
@@ -481,11 +551,9 @@ async def get_model_details(model_id: str):
model = model.iloc[0]
- # Extract arXiv IDs from tags
tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else ''
arxiv_ids = extract_arxiv_ids(tags_str)
- # Fetch arXiv papers if any IDs found
papers = []
if arxiv_ids:
papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers
@@ -505,6 +573,8 @@ async def get_model_details(model_id: str):
}
+# Clusters endpoint is handled by routes/clusters.py router
+
@app.get("/api/family/stats")
async def get_family_stats():
"""
@@ -512,9 +582,8 @@ async def get_family_stats():
Returns family size distribution, depth statistics, model card length by depth, etc.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
- # Calculate family sizes
family_sizes = {}
root_models = set()
@@ -528,14 +597,13 @@ async def get_family_stats():
family_sizes[model_id] = 0
else:
parent_id_str = str(parent_id)
- # Find root of this family
root = parent_id_str
visited = set()
while root in df.index and pd.notna(df.loc[root].get('parent_model')):
parent = df.loc[root].get('parent_model')
if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '':
break
- if str(parent) in visited: # Circular reference
+ if str(parent) in visited:
break
visited.add(root)
root = str(parent)
@@ -544,18 +612,15 @@ async def get_family_stats():
family_sizes[root] = 0
family_sizes[root] += 1
- # Count family sizes
size_distribution = {}
for root, size in family_sizes.items():
size_distribution[size] = size_distribution.get(size, 0) + 1
- # Calculate depth statistics
depths = calculate_family_depths(df)
depth_counts = {}
for depth in depths.values():
depth_counts[depth] = depth_counts.get(depth, 0) + 1
- # Calculate model card length by depth
model_card_lengths_by_depth = {}
if 'modelCard' in df.columns:
for idx, row in df.iterrows():
@@ -568,7 +633,6 @@ async def get_family_stats():
model_card_lengths_by_depth[depth] = []
model_card_lengths_by_depth[depth].append(card_length)
- # Calculate statistics for each depth
model_card_stats = {}
for depth, lengths in model_card_lengths_by_depth.items():
if lengths:
@@ -593,99 +657,218 @@ async def get_family_stats():
}
+@app.get("/api/family/path/{model_id}")
+async def get_family_path(
+ model_id: str,
+ target_id: Optional[str] = Query(None, description="Target model ID. If None, returns path to root.")
+):
+ """
+ Get path from model to root or to target model.
+ Returns list of model IDs representing the path.
+ """
+ if df is None:
+ raise DataNotLoadedError()
+
+ model_id_str = str(model_id)
+
+ if df.index.name == 'model_id':
+ if model_id_str not in df.index:
+ raise HTTPException(status_code=404, detail="Model not found")
+ else:
+ model_rows = df[df.get('model_id', '') == model_id_str]
+ if len(model_rows) == 0:
+ raise HTTPException(status_code=404, detail="Model not found")
+
+ path = [model_id_str]
+ visited = set([model_id_str])
+ current = model_id_str
+
+ if target_id:
+ target_str = str(target_id)
+ if df.index.name == 'model_id':
+ if target_str not in df.index:
+ raise HTTPException(status_code=404, detail="Target model not found")
+
+ while current != target_str and current not in visited:
+ try:
+ if df.index.name == 'model_id':
+ row = df.loc[current]
+ else:
+ rows = df[df.get('model_id', '') == current]
+ if len(rows) == 0:
+ break
+ row = rows.iloc[0]
+
+ parent_id = row.get('parent_model')
+ if parent_id and pd.notna(parent_id):
+ parent_str = str(parent_id)
+ if parent_str == target_str:
+ path.append(parent_str)
+ break
+ if parent_str not in visited:
+ path.append(parent_str)
+ visited.add(parent_str)
+ current = parent_str
+ else:
+ break
+ else:
+ break
+ except (KeyError, IndexError):
+ break
+ else:
+ while True:
+ try:
+ if df.index.name == 'model_id':
+ row = df.loc[current]
+ else:
+ rows = df[df.get('model_id', '') == current]
+ if len(rows) == 0:
+ break
+ row = rows.iloc[0]
+
+ parent_id = row.get('parent_model')
+ if parent_id and pd.notna(parent_id):
+ parent_str = str(parent_id)
+ if parent_str not in visited:
+ path.append(parent_str)
+ visited.add(parent_str)
+ current = parent_str
+ else:
+ break
+ else:
+ break
+ except (KeyError, IndexError):
+ break
+
+ return {
+ "path": path,
+ "source": model_id_str,
+ "target": target_id if target_id else "root",
+ "path_length": len(path) - 1
+ }
+
+
@app.get("/api/family/{model_id}")
-async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10)):
+async def get_family_tree(
+ model_id: str,
+ max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
+ max_depth_filter: Optional[int] = Query(None, ge=0, description="Filter results to models at or below this hierarchy depth.")
+):
"""
Get family tree for a model (ancestors and descendants).
Returns the model, its parent chain, and all children.
+
+ If max_depth is None, traverses the entire family tree without depth limits.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
-
- # Find the model
- model_row = df[df.get('model_id', '') == model_id]
- if len(model_row) == 0:
- raise HTTPException(status_code=404, detail="Model not found")
-
- family_models = []
- visited = set()
+ raise DataNotLoadedError()
- # Get coordinates for family members
if reduced_embeddings is None:
raise HTTPException(status_code=503, detail="Embeddings not ready")
- # Optimize: create parent_model index for faster lookups
- if 'parent_model' not in df.index.names and 'parent_model' in df.columns:
- # Create a reverse index for faster parent lookups
- parent_index = df[df['parent_model'].notna()].set_index('parent_model', drop=False, append=True)
+ model_id_str = str(model_id)
- def get_ancestors(current_id: str, depth: int):
- """Recursively get parent chain - optimized with index lookup."""
- if depth <= 0 or current_id in visited:
+ if df.index.name == 'model_id':
+ if model_id_str not in df.index:
+ raise HTTPException(status_code=404, detail="Model not found")
+ model_lookup = df.loc
+ else:
+ model_rows = df[df.get('model_id', '') == model_id_str]
+ if len(model_rows) == 0:
+ raise HTTPException(status_code=404, detail="Model not found")
+ model_lookup = lambda x: df[df.get('model_id', '') == x]
+
+ from utils.network_analysis import _get_all_parents, _parse_parent_list
+
+ children_index: Dict[str, List[str]] = {}
+ parent_columns = ['parent_model', 'finetune_parent', 'quantized_parent', 'adapter_parent', 'merge_parent']
+
+ for idx, row in df.iterrows():
+ model_id_from_row = str(row.get('model_id', idx))
+ all_parents = _get_all_parents(row)
+
+ for rel_type, parent_list in all_parents.items():
+ for parent_str in parent_list:
+ if parent_str not in children_index:
+ children_index[parent_str] = []
+ children_index[parent_str].append(model_id_from_row)
+
+ visited = set()
+
+ def get_ancestors(current_id: str, depth: Optional[int]):
+ if current_id in visited:
+ return
+ if depth is not None and depth <= 0:
return
visited.add(current_id)
- # Use index lookup if available, otherwise fallback to query
- if 'model_id' in df.index.names or df.index.name == 'model_id':
- try:
- model = df.loc[[current_id]]
- except KeyError:
- return
- else:
- model = df[df.get('model_id', '') == current_id]
- if len(model) == 0:
- return
- model = model.iloc[[0]]
-
- parent_id = model.iloc[0].get('parent_model')
-
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan':
- get_ancestors(str(parent_id), depth - 1)
+ try:
+ if df.index.name == 'model_id':
+ row = df.loc[current_id]
+ else:
+ rows = model_lookup(current_id)
+ if len(rows) == 0:
+ return
+ row = rows.iloc[0]
+
+ all_parents = _get_all_parents(row)
+ for rel_type, parent_list in all_parents.items():
+ for parent_str in parent_list:
+ if parent_str != 'nan' and parent_str != '':
+ next_depth = depth - 1 if depth is not None else None
+ get_ancestors(parent_str, next_depth)
+ except (KeyError, IndexError):
+ return
- def get_descendants(current_id: str, depth: int):
- """Recursively get all children - optimized with index lookup."""
- if depth <= 0 or current_id in visited:
+ def get_descendants(current_id: str, depth: Optional[int]):
+ if current_id in visited:
+ return
+ if depth is not None and depth <= 0:
return
visited.add(current_id)
- # Use optimized parent lookup
- if 'parent_model' in df.columns:
- children = df[df['parent_model'] == current_id]
- # Use vectorized iteration
- child_ids = children['model_id'].dropna().astype(str).unique()
- for child_id in child_ids:
- if child_id not in visited:
- get_descendants(child_id, depth - 1)
-
- # Get ancestors (parents)
- get_ancestors(model_id, max_depth)
-
- # Get descendants (children)
- visited = set() # Reset for descendants
- get_descendants(model_id, max_depth)
-
- # Add the root model
- visited.add(model_id)
-
- # Get all family members with coordinates - optimized
- if 'model_id' in df.index.names or df.index.name == 'model_id':
- # Use index lookup if available
+ children = children_index.get(current_id, [])
+ for child_id in children:
+ if child_id not in visited:
+ next_depth = depth - 1 if depth is not None else None
+ get_descendants(child_id, next_depth)
+
+ get_ancestors(model_id_str, max_depth)
+ visited = set()
+ get_descendants(model_id_str, max_depth)
+ visited.add(model_id_str)
+
+ if df.index.name == 'model_id':
try:
family_df = df.loc[list(visited)]
except KeyError:
- # Fallback to isin if some IDs not in index
- family_df = df[df.get('model_id', '').isin(visited)]
+ missing = [v for v in visited if v not in df.index]
+ if missing:
+ logger.warning(f"Some family members not found in index: {missing}")
+ family_df = df.loc[[v for v in visited if v in df.index]]
else:
family_df = df[df.get('model_id', '').isin(visited)]
- family_indices = family_df.index.values # Use values instead of tolist() for speed
+ if len(family_df) == 0:
+ raise HTTPException(status_code=404, detail="Family tree data not available")
+
+ family_indices = family_df.index.values
+ if len(family_indices) > len(reduced_embeddings):
+ raise HTTPException(status_code=503, detail="Embedding indices mismatch")
+
family_reduced = reduced_embeddings[family_indices]
- # Build family tree structure - optimized with vectorized operations
family_map = {}
for idx, (i, row) in enumerate(family_df.iterrows()):
- model_id_val = str(row.get('model_id', 'Unknown'))
- parent_id = row.get('parent_model') if pd.notna(row.get('parent_model')) else None
+ model_id_val = str(row.get('model_id', i))
+ parent_id = row.get('parent_model')
+ parent_id_str = str(parent_id) if parent_id and pd.notna(parent_id) else None
+
+ depths = calculate_family_depths(df)
+ model_depth = depths.get(model_id_val, 0)
+
+ if max_depth_filter is not None and model_depth > max_depth_filter:
+ continue
family_map[model_id_val] = {
"model_id": model_id_val,
@@ -696,12 +879,12 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
"pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None,
"downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0,
"likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0,
- "parent_model": str(parent_id) if parent_id else None,
+ "parent_model": parent_id_str,
"licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None,
+ "family_depth": model_depth,
"children": []
}
- # Build tree structure
root_models = []
for model_id_val, model_data in family_map.items():
parent_id = model_data["parent_model"]
@@ -711,7 +894,7 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
root_models.append(model_id_val)
return {
- "root_model": model_id,
+ "root_model": model_id_str,
"family": list(family_map.values()),
"family_map": family_map,
"root_models": root_models
@@ -720,7 +903,9 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
@app.get("/api/search")
async def search_models(
- query: str = Query(..., min_length=1),
+ q: str = Query(..., min_length=1, alias="query"),
+ query: str = Query(None, min_length=1),
+ limit: int = Query(20, ge=1, le=100),
graph_aware: bool = Query(False),
include_neighbors: bool = Query(True)
):
@@ -729,47 +914,79 @@ async def search_models(
Enhanced with graph-aware search option that includes network relationships.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
+
+ # Support both 'q' and 'query' parameters
+ search_query = query or q
if graph_aware:
- # Use graph-aware search
try:
network_builder = ModelNetworkBuilder(df)
- # Build network for top models (for performance)
top_models = network_builder.get_top_models_by_field(n=1000)
model_ids = [mid for mid, _ in top_models]
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
results = network_builder.search_graph_aware(
- query=query,
+ query=search_query,
graph=graph,
- max_results=20,
+ max_results=limit,
include_neighbors=include_neighbors
)
- return {"results": results, "search_type": "graph_aware"}
- except Exception as e:
- pass
+ return {"results": results, "search_type": "graph_aware", "query": search_query}
+ except (ValueError, KeyError, AttributeError) as e:
+ logger.warning(f"Graph-aware search failed, falling back to basic search: {e}")
+
+ query_lower = search_query.lower()
+
+ # Enhanced search: search model_id, org, tags, library, pipeline
+ model_id_col = df.get('model_id', '').astype(str).str.lower()
+ library_col = df.get('library_name', '').astype(str).str.lower()
+ pipeline_col = df.get('pipeline_tag', '').astype(str).str.lower()
+ tags_col = df.get('tags', '').astype(str).str.lower()
+ license_col = df.get('license', '').astype(str).str.lower()
+
+ # Extract org from model_id
+ org_col = model_id_col.str.split('/').str[0]
+
+ # Multi-field search
+ mask = (
+ model_id_col.str.contains(query_lower, na=False) |
+ org_col.str.contains(query_lower, na=False) |
+ library_col.str.contains(query_lower, na=False) |
+ pipeline_col.str.contains(query_lower, na=False) |
+ tags_col.str.contains(query_lower, na=False) |
+ license_col.str.contains(query_lower, na=False)
+ )
- query_lower = query.lower()
- matches = df[
- df.get('model_id', '').astype(str).str.lower().str.contains(query_lower, na=False)
- ].head(20) # Limit to 20 results
+ matches = df[mask].head(limit)
results = []
for _, row in matches.iterrows():
+ model_id = str(row.get('model_id', ''))
+ org = model_id.split('/')[0] if '/' in model_id else ''
+
+ # Get coordinates if available
+ x = float(row.get('x', 0.0)) if 'x' in row else None
+ y = float(row.get('y', 0.0)) if 'y' in row else None
+ z = float(row.get('z', 0.0)) if 'z' in row else None
+
results.append({
- "model_id": row.get('model_id'),
- "title": row.get('model_id', '').split('/')[-1] if '/' in str(row.get('model_id', '')) else str(row.get('model_id', '')),
- "library_name": row.get('library_name'),
- "pipeline_tag": row.get('pipeline_tag'),
+ "model_id": model_id,
+ "x": x,
+ "y": y,
+ "z": z,
+ "org": org,
+ "library": row.get('library_name'),
+ "pipeline": row.get('pipeline_tag'),
+ "license": row.get('license') if pd.notna(row.get('license')) else None,
"downloads": int(row.get('downloads', 0)),
"likes": int(row.get('likes', 0)),
"parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None,
"match_type": "direct"
})
- return {"results": results, "search_type": "basic"}
+ return {"results": results, "search_type": "basic", "query": search_query}
@app.get("/api/similar/{model_id}")
@@ -778,12 +995,12 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
Get k-nearest neighbors of a model based on embedding similarity.
Returns similar models with distance scores.
"""
- global df, embedder, embeddings, reduced_embeddings
-
- if df is None or embeddings is None:
+ if deps.df is None or deps.embeddings is None:
raise HTTPException(status_code=503, detail="Data not loaded")
- # Find the model - optimized with index lookup
+ df = deps.df
+ embeddings = deps.embeddings
+
if 'model_id' in df.index.names or df.index.name == 'model_id':
try:
model_row = df.loc[[model_id]]
@@ -797,16 +1014,11 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
model_idx = model_row.index[0]
model_embedding = embeddings[model_idx]
- # Calculate cosine similarity to all other models - optimized
from sklearn.metrics.pairwise import cosine_similarity
- # Use vectorized operations for better performance
model_embedding_2d = model_embedding.reshape(1, -1)
similarities = cosine_similarity(model_embedding_2d, embeddings)[0]
- # Get top k similar models (excluding itself) - use argpartition for speed
- # argpartition is faster than full sort for top-k
top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1]
- # Sort only the top k (much faster than sorting all)
top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
similar_models = []
@@ -817,7 +1029,7 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
similar_models.append({
"model_id": row.get('model_id', 'Unknown'),
"similarity": float(similarities[idx]),
- "distance": float(1 - similarities[idx]), # Convert similarity to distance
+ "distance": float(1 - similarities[idx]),
"library_name": row.get('library_name'),
"pipeline_tag": row.get('pipeline_tag'),
"downloads": int(row.get('downloads', 0)),
@@ -843,11 +1055,12 @@ async def get_models_by_semantic_similarity(
Returns models with their similarity scores and coordinates.
Useful for exploring the embedding space around a specific model.
"""
- global df, embedder, embeddings, reduced_embeddings
-
- if df is None or embeddings is None:
+ if deps.df is None or deps.embeddings is None:
raise HTTPException(status_code=503, detail="Data not loaded")
+ df = deps.df
+ embeddings = deps.embeddings
+
# Find the query model
if 'model_id' in df.index.names or df.index.name == 'model_id':
try:
@@ -863,7 +1076,6 @@ async def get_models_by_semantic_similarity(
query_embedding = embeddings[model_idx]
- # Filter by downloads/likes first for performance
filtered_df = data_loader.filter_data(
df=df,
min_downloads=min_downloads,
@@ -873,32 +1085,26 @@ async def get_models_by_semantic_similarity(
pipeline_tags=None
)
- # Get indices of filtered models
if df.index.name == 'model_id' or 'model_id' in df.index.names:
filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index]
filtered_indices = np.array(filtered_indices, dtype=int)
else:
filtered_indices = filtered_df.index.values.astype(int)
- # Calculate similarities only for filtered models
filtered_embeddings = embeddings[filtered_indices]
from sklearn.metrics.pairwise import cosine_similarity
query_embedding_2d = query_embedding.reshape(1, -1)
similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0]
- # Get top k similar models
top_k_local_indices = np.argpartition(similarities, -k)[-k:]
top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1]
- # Get reduced embeddings for visualization
if reduced_embeddings is None:
raise HTTPException(status_code=503, detail="Reduced embeddings not ready")
- # Map back to original indices
top_k_original_indices = filtered_indices[top_k_local_indices]
top_k_reduced = reduced_embeddings[top_k_original_indices]
- # Build response
similar_models = []
for i, orig_idx in enumerate(top_k_original_indices):
row = df.iloc[orig_idx]
@@ -935,11 +1141,12 @@ async def get_distance(
"""
Calculate distance/similarity between two models.
"""
- global df, embedder, embeddings
-
- if df is None or embeddings is None:
+ if deps.df is None or deps.embeddings is None:
raise HTTPException(status_code=503, detail="Data not loaded")
+ df = deps.df
+ embeddings = deps.embeddings
+
# Find both models - optimized with index lookup
if 'model_id' in df.index.names or df.index.name == 'model_id':
try:
@@ -976,7 +1183,7 @@ async def export_models(model_ids: List[str]):
Export selected models as JSON with full metadata.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
# Optimized export with index lookup
if 'model_id' in df.index.names or df.index.name == 'model_id':
@@ -991,7 +1198,6 @@ async def export_models(model_ids: List[str]):
if len(exported) == 0:
return {"models": []}
- # Use list comprehension for faster building
models = [
{
"model_id": str(row.get('model_id', '')),
@@ -1029,12 +1235,10 @@ async def get_cooccurrence_network(
Returns network graph data suitable for visualization.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
-
- # Get top models by field
top_models = network_builder.get_top_models_by_field(
library=library,
pipeline_tag=pipeline_tag,
@@ -1051,14 +1255,11 @@ async def get_cooccurrence_network(
}
model_ids = [mid for mid, _ in top_models]
-
- # Build co-occurrence network
graph = network_builder.build_cooccurrence_network(
model_ids=model_ids,
cooccurrence_method=cooccurrence_method
)
- # Convert to JSON-serializable format
nodes = []
for node_id, attrs in graph.nodes(data=True):
nodes.append({
@@ -1086,45 +1287,70 @@ async def get_cooccurrence_network(
"links": links,
"statistics": stats
}
-
- except Exception as e:
+ except (ValueError, KeyError, AttributeError) as e:
+ logger.error(f"Error building network: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}")
@app.get("/api/network/family/{model_id}")
async def get_family_network(
model_id: str,
- max_depth: int = Query(5, ge=1, le=10)
+ max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
+ edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
+ include_edge_attributes: bool = Query(True, description="Whether to include edge attributes (change in likes, downloads, etc.)")
):
"""
Build family tree network for a model (directed graph).
- Returns network graph data showing parent-child relationships.
+ Returns network graph data showing parent-child relationships with multiple relationship types.
+ Supports filtering by edge type (finetune, quantized, adapter, merge, parent).
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
+ filter_types = None
+ if edge_types:
+ filter_types = [t.strip() for t in edge_types.split(',') if t.strip()]
+
network_builder = ModelNetworkBuilder(df)
graph = network_builder.build_family_tree_network(
root_model_id=model_id,
- max_depth=max_depth
+ max_depth=max_depth,
+ include_edge_attributes=include_edge_attributes,
+ filter_edge_types=filter_types
)
- # Convert to JSON-serializable format
nodes = []
for node_id, attrs in graph.nodes(data=True):
nodes.append({
"id": node_id,
"title": attrs.get('title', node_id),
- "freq": attrs.get('freq', 0)
+ "freq": attrs.get('freq', 0),
+ "likes": attrs.get('likes', 0),
+ "downloads": attrs.get('downloads', 0),
+ "library": attrs.get('library', ''),
+ "pipeline": attrs.get('pipeline', '')
})
links = []
- for source, target in graph.edges():
- links.append({
+ for source, target, edge_attrs in graph.edges(data=True):
+ link_data = {
"source": source,
- "target": target
- })
+ "target": target,
+ "edge_type": edge_attrs.get('edge_type'),
+ "edge_types": edge_attrs.get('edge_types', [])
+ }
+
+ if include_edge_attributes:
+ link_data.update({
+ "change_in_likes": edge_attrs.get('change_in_likes'),
+ "percentage_change_in_likes": edge_attrs.get('percentage_change_in_likes'),
+ "change_in_downloads": edge_attrs.get('change_in_downloads'),
+ "percentage_change_in_downloads": edge_attrs.get('percentage_change_in_downloads'),
+ "change_in_createdAt_days": edge_attrs.get('change_in_createdAt_days')
+ })
+
+ links.append(link_data)
stats = network_builder.get_network_statistics(graph)
@@ -1134,8 +1360,8 @@ async def get_family_network(
"statistics": stats,
"root_model": model_id
}
-
- except Exception as e:
+ except (ValueError, KeyError, AttributeError) as e:
+ logger.error(f"Error building family network: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}")
@@ -1150,11 +1376,10 @@ async def get_model_neighbors(
Similar to graph database queries for finding connected nodes.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
- # Build network for top models (for performance)
top_models = network_builder.get_top_models_by_field(n=1000)
model_ids = [mid for mid, _ in top_models]
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
@@ -1171,8 +1396,8 @@ async def get_model_neighbors(
"neighbors": neighbors,
"count": len(neighbors)
}
-
- except Exception as e:
+ except (ValueError, KeyError, AttributeError) as e:
+ logger.error(f"Error finding neighbors: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}")
@@ -1187,7 +1412,7 @@ async def find_path_between_models(
Similar to graph database path queries.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
@@ -1235,7 +1460,7 @@ async def search_by_cooccurrence(
Similar to graph database queries for co-assignment patterns.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
@@ -1272,7 +1497,7 @@ async def get_model_relationships(
Similar to graph database relationship queries.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
@@ -1297,32 +1522,57 @@ async def get_model_relationships(
async def get_current_model_count(
use_cache: bool = Query(True),
force_refresh: bool = Query(False),
- use_dataset_snapshot: bool = Query(False)
+ use_dataset_snapshot: bool = Query(False),
+ use_models_page: bool = Query(True)
):
"""
Get the current number of models on Hugging Face Hub.
- Fetches live data from the Hub API or uses dataset snapshot (faster but may be outdated).
+ Uses multiple strategies: models page scraping (fastest), dataset snapshot, or API.
Query Parameters:
use_cache: Use cached results if available (default: True)
force_refresh: Force refresh even if cache is valid (default: False)
- use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
+ use_dataset_snapshot: Use dataset snapshot for breakdowns (default: False)
+ use_models_page: Try to get count from HF models page first (default: True)
"""
try:
+ tracker = get_tracker()
+
if use_dataset_snapshot:
- # Use improved tracker with dataset snapshot (like ai-ecosystem repo)
- tracker = get_improved_tracker()
- count_data = tracker.get_count_from_dataset_snapshot()
+ count_data = tracker.get_count_from_models_page()
if count_data is None:
- # Fallback to API if dataset unavailable
- count_data = tracker.get_current_model_count(use_cache=use_cache, force_refresh=force_refresh)
+ count_data = tracker.get_current_model_count(use_models_page=False)
+ else:
+ try:
+ from utils.data_loader import ModelDataLoader
+ data_loader = ModelDataLoader()
+ df = data_loader.load_data(sample_size=10000)
+ library_counts = {}
+ pipeline_counts = {}
+
+ for _, row in df.iterrows():
+ if pd.notna(row.get('library_name')):
+ lib = str(row.get('library_name'))
+ library_counts[lib] = library_counts.get(lib, 0) + 1
+ if pd.notna(row.get('pipeline_tag')):
+ pipeline = str(row.get('pipeline_tag'))
+ pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
+
+ if len(df) > 0 and count_data["total_models"] > len(df):
+ scale_factor = count_data["total_models"] / len(df)
+ library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
+ pipeline_counts = {k: int(v * scale_factor) for k, v in pipeline_counts.items()}
+
+ count_data["models_by_library"] = library_counts
+ count_data["models_by_pipeline"] = pipeline_counts
+ except Exception as e:
+ logger.warning(f"Could not get breakdowns from dataset: {e}")
else:
- # Use improved tracker with API (has caching)
- tracker = get_improved_tracker()
- count_data = tracker.get_current_model_count(use_cache=use_cache, force_refresh=force_refresh)
+ count_data = tracker.get_current_model_count(use_models_page=use_models_page)
return count_data
except Exception as e:
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}")
@@ -1343,7 +1593,7 @@ async def get_historical_model_counts(
try:
from datetime import datetime
- tracker = get_improved_tracker()
+ tracker = get_tracker()
start = None
end = None
@@ -1373,7 +1623,7 @@ async def get_historical_model_counts(
async def get_latest_model_count():
"""Get the most recently recorded model count from database."""
try:
- tracker = get_improved_tracker()
+ tracker = get_tracker()
latest = tracker.get_latest_count()
if latest is None:
raise HTTPException(status_code=404, detail="No model counts recorded yet")
@@ -1397,16 +1647,14 @@ async def record_model_count(
use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
"""
try:
- tracker = get_improved_tracker()
+ tracker = get_tracker()
- # Fetch and record in background to avoid blocking
def record():
if use_dataset_snapshot:
count_data = tracker.get_count_from_dataset_snapshot()
if count_data:
tracker.record_count(count_data, source="dataset_snapshot")
else:
- # Fallback to API
count_data = tracker.get_current_model_count(use_cache=False)
tracker.record_count(count_data, source="api")
else:
@@ -1433,7 +1681,7 @@ async def get_growth_stats(days: int = Query(7, ge=1, le=365)):
days: Number of days to analyze
"""
try:
- tracker = get_improved_tracker()
+ tracker = get_tracker()
stats = tracker.get_growth_stats(days)
return stats
except Exception as e:
@@ -1455,12 +1703,11 @@ async def export_network_graphml(
Similar to Open Syllabus graph export functionality.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
try:
network_builder = ModelNetworkBuilder(df)
- # Get top models by field
top_models = network_builder.get_top_models_by_field(
library=library,
pipeline_tag=pipeline_tag,
@@ -1473,29 +1720,24 @@ async def export_network_graphml(
raise HTTPException(status_code=404, detail="No models found matching criteria")
model_ids = [mid for mid, _ in top_models]
-
- # Build co-occurrence network
graph = network_builder.build_cooccurrence_network(
model_ids=model_ids,
cooccurrence_method=cooccurrence_method
)
- # Create temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file:
tmp_path = tmp_file.name
network_builder.export_graphml(graph, tmp_path)
- # Schedule cleanup after response is sent
background_tasks.add_task(os.unlink, tmp_path)
- # Return file for download
return FileResponse(
tmp_path,
media_type='application/xml',
filename=f'network_{cooccurrence_method}_{n}_models.graphml'
)
-
- except Exception as e:
+ except (ValueError, KeyError, AttributeError, IOError) as e:
+ logger.error(f"Error exporting network: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}")
@@ -1506,7 +1748,7 @@ async def get_model_papers(model_id: str):
Extracts arXiv IDs from model tags and fetches paper information.
"""
if df is None:
- raise HTTPException(status_code=503, detail="Data not loaded")
+ raise DataNotLoadedError()
model = df[df.get('model_id', '') == model_id]
if len(model) == 0:
@@ -1535,36 +1777,131 @@ async def get_model_papers(model_id: str):
}
+@app.get("/api/models/minimal.bin")
+async def get_minimal_binary():
+ """
+ Serve the binary minimal dataset file.
+ This is optimized for fast client-side loading.
+ """
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ root_dir = os.path.dirname(backend_dir)
+ binary_path = os.path.join(root_dir, "cache", "binary", "embeddings.bin")
+
+ if not os.path.exists(binary_path):
+ raise HTTPException(status_code=404, detail="Binary dataset not found. Run export_binary.py first.")
+
+ return FileResponse(
+ binary_path,
+ media_type="application/octet-stream",
+ headers={
+ "Content-Disposition": "attachment; filename=embeddings.bin",
+ "Cache-Control": "public, max-age=3600"
+ }
+ )
+
+
+@app.get("/api/models/model_ids.json")
+async def get_model_ids_json():
+ """Serve the model IDs JSON file."""
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ root_dir = os.path.dirname(backend_dir)
+ json_path = os.path.join(root_dir, "cache", "binary", "model_ids.json")
+
+ if not os.path.exists(json_path):
+ raise HTTPException(status_code=404, detail="Model IDs file not found.")
+
+ return FileResponse(
+ json_path,
+ media_type="application/json",
+ headers={"Cache-Control": "public, max-age=3600"}
+ )
+
+
+@app.get("/api/models/metadata.json")
+async def get_metadata_json():
+ """Serve the metadata JSON file with lookup tables."""
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ root_dir = os.path.dirname(backend_dir)
+ json_path = os.path.join(root_dir, "cache", "binary", "metadata.json")
+
+ if not os.path.exists(json_path):
+ raise HTTPException(status_code=404, detail="Metadata file not found.")
+
+ return FileResponse(
+ json_path,
+ media_type="application/json",
+ headers={"Cache-Control": "public, max-age=3600"}
+ )
+
+
@app.get("/api/model/{model_id}/files")
async def get_model_files(model_id: str, branch: str = Query("main")):
"""
Get file tree for a model from Hugging Face.
Proxies the request to avoid CORS issues.
+ Returns a flat list of files with path and size information.
"""
+ if not model_id or not model_id.strip():
+ raise HTTPException(status_code=400, detail="Invalid model ID")
+
+ branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
+
try:
- # Try main branch first, then master
- branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
-
- async with httpx.AsyncClient(timeout=10.0) as client:
+ async with httpx.AsyncClient(timeout=15.0) as client:
for branch_name in branches_to_try:
try:
url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}"
response = await client.get(url)
+
if response.status_code == 200:
- return response.json()
- except Exception:
+ data = response.json()
+ # Ensure we return an array
+ if isinstance(data, list):
+ return data
+ elif isinstance(data, dict) and 'tree' in data:
+ return data['tree']
+ else:
+ return []
+
+ elif response.status_code == 404:
+ # Try next branch
+ continue
+ else:
+ logger.warning(f"Unexpected status {response.status_code} for {url}")
+ continue
+
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 404:
+ continue # Try next branch
+ logger.warning(f"HTTP error for branch {branch_name}: {e}")
+ continue
+ except httpx.HTTPError as e:
+ logger.warning(f"HTTP error for branch {branch_name}: {e}")
continue
- raise HTTPException(status_code=404, detail="File tree not found for this model")
+ # All branches failed
+ raise HTTPException(
+ status_code=404,
+ detail=f"File tree not found for model '{model_id}'. The model may not exist or may not have any files."
+ )
+
except httpx.TimeoutException:
- raise HTTPException(status_code=504, detail="Request to Hugging Face timed out")
+ raise HTTPException(
+ status_code=504,
+ detail="Request to Hugging Face timed out. Please try again later."
+ )
+ except HTTPException:
+ raise # Re-raise HTTP exceptions
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Error fetching file tree: {str(e)}")
+ logger.error(f"Error fetching file tree: {e}", exc_info=True)
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error fetching file tree: {str(e)}"
+ )
if __name__ == "__main__":
import uvicorn
- # Use PORT environment variable for cloud platforms (Railway, Render, Heroku)
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
diff --git a/backend/api/routes/__init__.py b/backend/api/routes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b07a869e15406215660fc24e61e49a1cb246ff
--- /dev/null
+++ b/backend/api/routes/__init__.py
@@ -0,0 +1,6 @@
+"""
+API route modules.
+"""
+from . import models, stats, clusters
+
+__all__ = ['models', 'stats', 'clusters']
diff --git a/backend/api/routes/clusters.py b/backend/api/routes/clusters.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54280b8f2bb8aa7e80758d4d1a39db0e81e9a4d
--- /dev/null
+++ b/backend/api/routes/clusters.py
@@ -0,0 +1,102 @@
+"""
+API routes for cluster endpoints.
+"""
+from fastapi import APIRouter
+import numpy as np
+import pandas as pd
+from core.exceptions import DataNotLoadedError
+import api.dependencies as deps
+
+router = APIRouter(prefix="/api", tags=["clusters"])
+
+
+@router.get("/clusters")
+async def get_clusters():
+ """Get all clusters with metadata and hierarchical labels."""
+ if deps.df is None:
+ raise DataNotLoadedError()
+
+ # Import cluster_labels from models route
+ from api.routes.models import cluster_labels
+
+ # If clusters haven't been computed yet, return empty list instead of error
+ # This allows the frontend to work while data is still loading
+ if cluster_labels is None:
+ return {"clusters": []}
+
+ df = deps.df
+
+ # Generate hierarchical labels for clusters
+ clusters = []
+ unique_clusters = np.unique(cluster_labels)
+
+ for cluster_id in unique_clusters:
+ cluster_mask = cluster_labels == cluster_id
+ cluster_models = df[cluster_mask]
+
+ if len(cluster_models) == 0:
+ continue
+
+ # Generate hierarchical label
+ library_counts = cluster_models['library_name'].value_counts()
+ pipeline_counts = cluster_models['pipeline_tag'].value_counts()
+
+ # Determine primary domain/library
+ if len(library_counts) > 0:
+ primary_lib = library_counts.index[0]
+ if primary_lib and pd.notna(primary_lib):
+ if 'transformers' in str(primary_lib).lower():
+ domain = "NLP"
+ elif 'diffusers' in str(primary_lib).lower():
+ domain = "Multimodal"
+ elif 'timm' in str(primary_lib).lower():
+ domain = "Computer Vision"
+ else:
+ domain = str(primary_lib).replace('_', ' ').title()
+ else:
+ domain = "Other"
+ else:
+ domain = "Other"
+
+ # Determine subdomain from pipeline
+ if len(pipeline_counts) > 0:
+ primary_pipeline = pipeline_counts.index[0]
+ if primary_pipeline and pd.notna(primary_pipeline):
+ subdomain = str(primary_pipeline).replace('-', ' ').replace('_', ' ').title()
+ else:
+ subdomain = "General"
+ else:
+ subdomain = "General"
+
+ # Determine characteristics
+ characteristics = []
+ model_ids_lower = cluster_models['model_id'].astype(str).str.lower()
+ if model_ids_lower.str.contains('gpt', na=False).any():
+ characteristics.append("GPT-based")
+ if cluster_models['parent_model'].notna().any():
+ characteristics.append("Fine-tuned")
+ if not characteristics:
+ characteristics.append("Base Models")
+
+ char_str = "; ".join(characteristics)
+ label = f"{domain} — {subdomain} ({char_str})"
+
+ # Generate color (use consistent colors based on cluster_id)
+ colors = [
+ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
+ "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
+ ]
+ color = colors[cluster_id % len(colors)]
+
+ clusters.append({
+ "cluster_id": int(cluster_id),
+ "cluster_label": label,
+ "count": int(len(cluster_models)),
+ "color": color
+ })
+
+ # Sort by count descending
+ clusters.sort(key=lambda x: x["count"], reverse=True)
+
+ return {"clusters": clusters}
+
diff --git a/backend/api/routes/models.py b/backend/api/routes/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..652553fa8ceea430d2f3b2fdeb735197d61d326e
--- /dev/null
+++ b/backend/api/routes/models.py
@@ -0,0 +1,247 @@
+"""
+API routes for model data endpoints.
+"""
+from typing import Optional
+from fastapi import APIRouter, Query, HTTPException
+import numpy as np
+import pandas as pd
+import pickle
+import os
+import logging
+
+from umap import UMAP
+from models.schemas import ModelPoint
+from utils.family_tree import calculate_family_depths
+from utils.dimensionality_reduction import DimensionReducer
+from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
+import api.dependencies as deps
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api", tags=["models"])
+
+# Global cluster labels cache (shared across routes)
+cluster_labels = None
+
+
+def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
+ from sklearn.cluster import KMeans
+
+ n_samples = len(reduced_embeddings)
+ if n_samples < n_clusters:
+ n_clusters = max(1, n_samples // 10)
+
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
+ return kmeans.fit_predict(reduced_embeddings)
+
+
+@router.get("/models")
+async def get_models(
+ min_downloads: int = Query(0),
+ min_likes: int = Query(0),
+ search_query: Optional[str] = Query(None),
+ color_by: str = Query("library_name"),
+ size_by: str = Query("downloads"),
+ max_points: Optional[int] = Query(None),
+ projection_method: str = Query("umap"),
+ base_models_only: bool = Query(False),
+ max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
+ use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
+):
+ if deps.df is None:
+ raise DataNotLoadedError()
+
+ df = deps.df
+ data_loader = deps.data_loader
+
+ # Filter data
+ filtered_df = data_loader.filter_data(
+ df=df,
+ min_downloads=min_downloads,
+ min_likes=min_likes,
+ search_query=search_query,
+ libraries=None,
+ pipeline_tags=None
+ )
+
+ if base_models_only:
+ if 'parent_model' in filtered_df.columns:
+ filtered_df = filtered_df[
+ filtered_df['parent_model'].isna() |
+ (filtered_df['parent_model'].astype(str).str.strip() == '') |
+ (filtered_df['parent_model'].astype(str) == 'nan')
+ ]
+
+ if max_hierarchy_depth is not None:
+ family_depths = calculate_family_depths(df)
+ filtered_df = filtered_df[
+ filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
+ ]
+
+ filtered_count = len(filtered_df)
+
+ if len(filtered_df) == 0:
+ return {
+ "models": [],
+ "filtered_count": 0,
+ "returned_count": 0
+ }
+
+ if max_points is not None and len(filtered_df) > max_points:
+ if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
+ sampled_dfs = []
+ for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
+ n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
+ sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
+ filtered_df = pd.concat(sampled_dfs, ignore_index=True)
+ if len(filtered_df) > max_points:
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
+ else:
+ filtered_df = filtered_df.reset_index(drop=True)
+ else:
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
+
+ # Determine which embeddings to use
+ if use_graph_embeddings and deps.combined_embeddings is not None:
+ current_embeddings = deps.combined_embeddings
+ current_reduced = deps.reduced_embeddings_graph
+ embedding_type = "graph-aware"
+ else:
+ if deps.embeddings is None:
+ raise EmbeddingsNotReadyError()
+ current_embeddings = deps.embeddings
+ current_reduced = deps.reduced_embeddings
+ embedding_type = "text-only"
+
+ # Handle reduced embeddings loading/generation
+ reducer = deps.reducer
+ if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
+ backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ root_dir = os.path.dirname(backend_dir)
+ cache_dir = os.path.join(root_dir, "cache")
+ cache_suffix = "_graph" if use_graph_embeddings and deps.combined_embeddings is not None else ""
+ reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
+ reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
+
+ if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
+ try:
+ with open(reduced_cache, 'rb') as f:
+ current_reduced = pickle.load(f)
+ if reducer is None or reducer.method != projection_method.lower():
+ reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
+ reducer.load_reducer(reducer_cache)
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
+ current_reduced = None
+
+ if current_reduced is None:
+ if reducer is None or reducer.method != projection_method.lower():
+ reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
+ if projection_method.lower() == "umap":
+ reducer.reducer = UMAP(
+ n_components=3,
+ n_neighbors=30,
+ min_dist=0.3,
+ metric='cosine',
+ random_state=42,
+ n_jobs=-1,
+ low_memory=True,
+ spread=1.5
+ )
+ current_reduced = reducer.fit_transform(current_embeddings)
+ with open(reduced_cache, 'wb') as f:
+ pickle.dump(current_reduced, f)
+ reducer.save_reducer(reducer_cache)
+
+ # Update global variable
+ if use_graph_embeddings and deps.combined_embeddings is not None:
+ deps.reduced_embeddings_graph = current_reduced
+ else:
+ deps.reduced_embeddings = current_reduced
+
+ # Get indices for filtered data
+ filtered_model_ids = filtered_df['model_id'].astype(str).values
+
+ if df.index.name == 'model_id' or 'model_id' in df.index.names:
+ filtered_indices = []
+ for model_id in filtered_model_ids:
+ try:
+ pos = df.index.get_loc(model_id)
+ if isinstance(pos, (int, np.integer)):
+ filtered_indices.append(int(pos))
+ elif isinstance(pos, (slice, np.ndarray)):
+ if isinstance(pos, slice):
+ filtered_indices.append(int(pos.start))
+ else:
+ filtered_indices.append(int(pos[0]))
+ except (KeyError, TypeError):
+ continue
+ filtered_indices = np.array(filtered_indices, dtype=np.int32)
+ else:
+ df_model_ids = df['model_id'].astype(str).values
+ model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
+ filtered_indices = np.array([
+ model_id_to_pos[mid] for mid in filtered_model_ids
+ if mid in model_id_to_pos
+ ], dtype=np.int32)
+
+ if len(filtered_indices) == 0:
+ return {
+ "models": [],
+ "embedding_type": embedding_type,
+ "filtered_count": filtered_count,
+ "returned_count": 0
+ }
+
+ filtered_reduced = current_reduced[filtered_indices]
+ family_depths = calculate_family_depths(df)
+
+ global cluster_labels
+ clustering_embeddings = current_reduced
+ if cluster_labels is None or len(cluster_labels) != len(clustering_embeddings):
+ cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
+
+ filtered_clusters = cluster_labels[filtered_indices]
+
+ model_ids = filtered_df['model_id'].astype(str).values
+ library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
+ pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
+ downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
+ likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
+ trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
+ tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
+ parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
+ licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
+ created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
+
+ x_coords = filtered_reduced[:, 0].astype(float)
+ y_coords = filtered_reduced[:, 1].astype(float)
+ z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
+ models = [
+ ModelPoint(
+ model_id=model_ids[idx],
+ x=float(x_coords[idx]),
+ y=float(y_coords[idx]),
+ z=float(z_coords[idx]),
+ library_name=library_names[idx] if pd.notna(library_names[idx]) else None,
+ pipeline_tag=pipeline_tags[idx] if pd.notna(pipeline_tags[idx]) else None,
+ downloads=int(downloads_arr[idx]),
+ likes=int(likes_arr[idx]),
+ trending_score=float(trending_scores[idx]) if idx < len(trending_scores) and pd.notna(trending_scores[idx]) else None,
+ tags=tags_arr[idx] if idx < len(tags_arr) and pd.notna(tags_arr[idx]) else None,
+ parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
+ licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
+ family_depth=family_depths.get(model_ids[idx], None),
+ cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
+ created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
+ )
+ for idx in range(len(filtered_df))
+ ]
+
+ return {
+ "models": models,
+ "embedding_type": embedding_type,
+ "filtered_count": filtered_count,
+ "returned_count": len(models)
+ }
+
diff --git a/backend/api/routes/stats.py b/backend/api/routes/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..860cf3ce4bff0f6f09a507de89bf4286bbdfc353
--- /dev/null
+++ b/backend/api/routes/stats.py
@@ -0,0 +1,37 @@
+"""
+API routes for statistics endpoints.
+"""
+from fastapi import APIRouter
+from core.exceptions import DataNotLoadedError
+import api.dependencies as deps
+
+router = APIRouter(prefix="/api", tags=["stats"])
+
+
+@router.get("/stats")
+async def get_stats():
+ """Get dataset statistics."""
+ if deps.df is None:
+ raise DataNotLoadedError()
+
+ df = deps.df
+ total_models = len(df.index) if hasattr(df, 'index') else len(df)
+
+ # Get unique licenses with counts
+ licenses = {}
+ if 'license' in df.columns:
+ import pandas as pd
+ license_counts = df['license'].value_counts().to_dict()
+ licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
+
+ return {
+ "total_models": total_models,
+ "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
+ "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
+ "unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
+ "unique_licenses": len(licenses),
+ "licenses": licenses,
+ "avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
+ "avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
+ }
+
diff --git a/backend/config/requirements.txt b/backend/config/requirements.txt
index e3fc460ac083eeb1f04d3a75dd2197b054ceba11..8eaf5280322ea510e16ac3585888adff4098f9c3 100644
--- a/backend/config/requirements.txt
+++ b/backend/config/requirements.txt
@@ -11,5 +11,6 @@ huggingface-hub>=0.17.0
schedule>=1.2.0
tqdm>=4.66.0
networkx>=3.0
+node2vec>=0.4.6
httpx>=0.24.0
diff --git a/backend/core/__init__.py b/backend/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e62adb7dc738fd84ba2f2b6464adff697c3fd40
--- /dev/null
+++ b/backend/core/__init__.py
@@ -0,0 +1,2 @@
+"""Core configuration and utilities."""
+
diff --git a/backend/core/config.py b/backend/core/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6be5ca4796f84263b138e672be392078791bac
--- /dev/null
+++ b/backend/core/config.py
@@ -0,0 +1,23 @@
+"""Configuration management."""
+import os
+from typing import Optional
+
+class Settings:
+ """Application settings."""
+ FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
+ ALLOW_ALL_ORIGINS: bool = os.getenv("ALLOW_ALL_ORIGINS", "True").lower() in ("true", "1", "yes")
+ SAMPLE_SIZE: Optional[int] = None
+ USE_GRAPH_EMBEDDINGS: bool = os.getenv("USE_GRAPH_EMBEDDINGS", "false").lower() == "true"
+ PORT: int = int(os.getenv("PORT", 8000))
+
+ @classmethod
+ def get_sample_size(cls) -> Optional[int]:
+ """Get sample size from environment."""
+ sample_size_env = os.getenv("SAMPLE_SIZE")
+ if sample_size_env:
+ sample_size_val = int(sample_size_env)
+ return sample_size_val if sample_size_val > 0 else None
+ return None
+
+settings = Settings()
+
diff --git a/backend/core/exceptions.py b/backend/core/exceptions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27af9dc314e746f2046d8a86054be76a7858fab
--- /dev/null
+++ b/backend/core/exceptions.py
@@ -0,0 +1,18 @@
+"""Custom exceptions."""
+from fastapi import HTTPException
+
+class ModelNotFoundError(HTTPException):
+ """Model not found exception."""
+ def __init__(self, model_id: str):
+ super().__init__(status_code=404, detail=f"Model not found: {model_id}")
+
+class DataNotLoadedError(HTTPException):
+ """Data not loaded exception."""
+ def __init__(self):
+ super().__init__(status_code=503, detail="Data not loaded")
+
+class EmbeddingsNotReadyError(HTTPException):
+ """Embeddings not ready exception."""
+ def __init__(self):
+ super().__init__(status_code=503, detail="Embeddings not ready")
+
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef80f4c2c943e8b1049df1a8e5bc7092d1e34117
--- /dev/null
+++ b/backend/models/__init__.py
@@ -0,0 +1,2 @@
+"""Data models and schemas."""
+
diff --git a/backend/models/schemas.py b/backend/models/schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..60dec7be36259a9a470ab3a73d6e540c6c69d740
--- /dev/null
+++ b/backend/models/schemas.py
@@ -0,0 +1,22 @@
+"""Pydantic models for API."""
+from pydantic import BaseModel
+from typing import Optional
+
+class ModelPoint(BaseModel):
+ """Model point in 3D space."""
+ model_id: str
+ x: float
+ y: float
+ z: float
+ library_name: Optional[str]
+ pipeline_tag: Optional[str]
+ downloads: int
+ likes: int
+ trending_score: Optional[float]
+ tags: Optional[str]
+ parent_model: Optional[str] = None
+ licenses: Optional[str] = None
+ family_depth: Optional[int] = None
+ cluster_id: Optional[int] = None
+ created_at: Optional[str] = None # ISO format date string
+
diff --git a/backend/scripts/export_binary.py b/backend/scripts/export_binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff710fa9b8b3456f1f27fb9812680c14a8504fcc
--- /dev/null
+++ b/backend/scripts/export_binary.py
@@ -0,0 +1,263 @@
+"""
+Export minimal dataset to binary format for fast client-side loading.
+This creates a compact binary representation optimized for WebGL rendering.
+"""
+import struct
+import json
+import numpy as np
+import pandas as pd
+from pathlib import Path
+import sys
+import os
+
+# Add parent directory to path
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from utils.data_loader import ModelDataLoader
+from utils.dimensionality_reduction import DimensionReducer
+from utils.embeddings import ModelEmbedder
+
+
+def calculate_family_depths(df: pd.DataFrame) -> dict:
+ """Calculate depth of each model in its family tree."""
+ depths = {}
+
+ def get_depth(model_id: str, visited: set = None) -> int:
+ if visited is None:
+ visited = set()
+ if model_id in visited:
+ return 0 # Cycle detected
+ visited.add(model_id)
+
+ if model_id in depths:
+ return depths[model_id]
+
+ parent_col = df.get('parent_model', pd.Series([None] * len(df), index=df.index))
+ model_row = df[df['model_id'] == model_id]
+
+ if model_row.empty:
+ depths[model_id] = 0
+ return 0
+
+ parent = model_row.iloc[0].get('parent_model')
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
+ depths[model_id] = 0
+ return 0
+
+ parent_depth = get_depth(str(parent), visited.copy())
+ depth = parent_depth + 1
+ depths[model_id] = depth
+ return depth
+
+ for model_id in df['model_id'].unique():
+ if model_id not in depths:
+ get_depth(str(model_id))
+
+ return depths
+
+
+def export_binary_dataset(df: pd.DataFrame, reduced_embeddings: np.ndarray, output_dir: Path):
+ """
+ Export minimal dataset to binary format for fast client-side loading.
+
+ Binary format:
+ - Header (64 bytes): magic, version, counts, lookup table sizes
+ - Domain lookup table (32 bytes per domain)
+ - License lookup table (32 bytes per license)
+ - Family lookup table (32 bytes per family)
+ - Model records (16 bytes each): x, y, z, domain_id, license_id, family_id, flags
+ """
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"Exporting {len(df)} models to binary format...")
+
+ # Ensure we have coordinates
+ if 'x' not in df.columns or 'y' not in df.columns:
+ if reduced_embeddings is None or len(reduced_embeddings) != len(df):
+ raise ValueError("Need reduced embeddings to generate coordinates")
+
+ df['x'] = reduced_embeddings[:, 0] if reduced_embeddings.shape[1] > 0 else 0.0
+ df['y'] = reduced_embeddings[:, 1] if reduced_embeddings.shape[1] > 1 else 0.0
+ df['z'] = reduced_embeddings[:, 2] if reduced_embeddings.shape[1] > 2 else 0.0
+
+ # Create lookup tables
+ # Domain = library_name
+ domains = sorted(df['library_name'].dropna().astype(str).unique())
+ domains = [d for d in domains if d and d != 'nan'][:255] # Limit to 255
+
+ # License
+ licenses = sorted(df['license'].dropna().astype(str).unique())
+ licenses = [l for l in licenses if l and l != 'nan'][:255] # Limit to 255
+
+ # Family ID mapping (use parent_model to create family groups)
+ family_depths = calculate_family_depths(df)
+
+ # Create family mapping: group models by root parent
+ def get_root_parent(model_id: str) -> str:
+ visited = set()
+ current = str(model_id)
+ while current in visited == False:
+ visited.add(current)
+ model_row = df[df['model_id'] == current]
+ if model_row.empty:
+ return current
+ parent = model_row.iloc[0].get('parent_model')
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
+ return current
+ current = str(parent)
+ return current
+
+ root_parents = {}
+ family_counter = 0
+ for model_id in df['model_id'].unique():
+ root = get_root_parent(str(model_id))
+ if root not in root_parents:
+ root_parents[root] = family_counter
+ family_counter += 1
+
+ # Map each model to its family
+ model_to_family = {}
+ for model_id in df['model_id'].unique():
+ root = get_root_parent(str(model_id))
+ model_to_family[str(model_id)] = root_parents.get(root, 65535)
+
+ # Limit families to 65535 (u16 max)
+ if len(root_parents) > 65535:
+ # Use hash-based family IDs
+ import hashlib
+ for model_id in df['model_id'].unique():
+ root = get_root_parent(str(model_id))
+ family_hash = int(hashlib.md5(root.encode()).hexdigest()[:4], 16) % 65535
+ model_to_family[str(model_id)] = family_hash
+
+ # Prepare model records
+ records = []
+ model_ids = []
+
+ for idx, row in df.iterrows():
+ model_id = str(row['model_id'])
+ model_ids.append(model_id)
+
+ # Get coordinates
+ x = float(row.get('x', 0.0))
+ y = float(row.get('y', 0.0))
+ z = float(row.get('z', 0.0))
+
+ # Encode domain (library_name)
+ domain_str = str(row.get('library_name', ''))
+ domain_id = domains.index(domain_str) if domain_str in domains else 255
+
+ # Encode license
+ license_str = str(row.get('license', ''))
+ license_id = licenses.index(license_str) if license_str in licenses else 255
+
+ # Encode family
+ family_id = model_to_family.get(model_id, 65535)
+
+ # Encode flags
+ flags = 0
+ parent = row.get('parent_model')
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
+ flags |= 0x01 # is_base_model
+
+ # Check if has children (simple check - could be improved)
+ children = df[df['parent_model'] == model_id]
+ if len(children) > 0:
+ flags |= 0x04 # has_children
+ elif not pd.isna(parent) and parent != '' and str(parent) != 'nan':
+ flags |= 0x02 # has_parent
+
+ # Pack record: f32 x, f32 y, f32 z, u8 domain, u8 license, u16 family, u8 flags
+ records.append(struct.pack('fffBBBH', x, y, z, domain_id, license_id, family_id, flags))
+
+ num_models = len(records)
+
+ # Write binary file
+ with open(output_dir / 'embeddings.bin', 'wb') as f:
+ # Header (64 bytes)
+ header = struct.pack('5sBIIIBBH50s',
+ b'HFVIZ', # magic (5 bytes)
+ 1, # version (1 byte)
+ num_models, # num_models (4 bytes)
+ len(domains), # num_domains (4 bytes)
+ len(licenses), # num_licenses (4 bytes)
+ len(set(model_to_family.values())), # num_families (4 bytes)
+ 0, # reserved (1 byte)
+ 0, # reserved (1 byte)
+ 0, # reserved (2 bytes)
+ b'\x00' * 50 # padding (50 bytes)
+ )
+ f.write(header)
+
+ # Domain lookup table (32 bytes per domain, null-terminated)
+ for domain in domains:
+ domain_bytes = domain.encode('utf-8')[:31]
+ f.write(domain_bytes.ljust(32, b'\x00'))
+
+ # License lookup table (32 bytes per license)
+ for license in licenses:
+ license_bytes = license.encode('utf-8')[:31]
+ f.write(license_bytes.ljust(32, b'\x00'))
+
+ # Model records
+ f.write(b''.join(records))
+
+ # Write model IDs JSON (separate file for string table)
+ with open(output_dir / 'model_ids.json', 'w') as f:
+ json.dump(model_ids, f)
+
+ # Write metadata JSON
+ metadata = {
+ 'domains': domains,
+ 'licenses': licenses,
+ 'num_models': num_models,
+ 'num_families': len(set(model_to_family.values())),
+ 'version': 1
+ }
+ with open(output_dir / 'metadata.json', 'w') as f:
+ json.dump(metadata, f, indent=2)
+
+ binary_size = (output_dir / 'embeddings.bin').stat().st_size
+ json_size = (output_dir / 'model_ids.json').stat().st_size
+
+ print(f"✓ Exported {num_models} models")
+ print(f"✓ Binary size: {binary_size / 1024 / 1024:.2f} MB")
+ print(f"✓ Model IDs JSON: {json_size / 1024 / 1024:.2f} MB")
+ print(f"✓ Total: {(binary_size + json_size) / 1024 / 1024:.2f} MB")
+ print(f"✓ Domains: {len(domains)}")
+ print(f"✓ Licenses: {len(licenses)}")
+ print(f"✓ Families: {len(set(model_to_family.values()))}")
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Export dataset to binary format')
+ parser.add_argument('--output', type=str, default='backend/cache/binary', help='Output directory')
+ parser.add_argument('--sample-size', type=int, default=None, help='Sample size (for testing)')
+ args = parser.parse_args()
+
+ output_dir = Path(args.output)
+
+ # Load data
+ print("Loading dataset...")
+ data_loader = ModelDataLoader()
+ df = data_loader.load_data(sample_size=args.sample_size)
+ df = data_loader.preprocess_for_embedding(df)
+
+ # Generate embeddings and reduce dimensions if needed
+ if 'x' not in df.columns or 'y' not in df.columns:
+ print("Generating embeddings...")
+ embedder = ModelEmbedder()
+ embeddings = embedder.generate_embeddings(df['combined_text'].tolist())
+
+ print("Reducing dimensions...")
+ reducer = DimensionReducer()
+ reduced_embeddings = reducer.reduce_dimensions(embeddings, n_components=3, method='umap')
+ else:
+ reduced_embeddings = None
+
+ # Export
+ export_binary_dataset(df, reduced_embeddings, output_dir)
+ print("Done!")
+
diff --git a/backend/services/model_tracker.py b/backend/services/model_tracker.py
index eb93d3f419492b53d9129237598f0aac88307858..092e59eaaa3d46594ed55085a229c593ec9826b4 100644
--- a/backend/services/model_tracker.py
+++ b/backend/services/model_tracker.py
@@ -5,11 +5,16 @@ Tracks the number of models over time and provides historical data.
import os
import json
import sqlite3
+import logging
+import re
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from huggingface_hub import HfApi
import pandas as pd
from pathlib import Path
+import httpx
+
+logger = logging.getLogger(__name__)
class ModelCountTracker:
@@ -34,7 +39,6 @@ class ModelCountTracker:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- # Create table for model counts
cursor.execute("""
CREATE TABLE IF NOT EXISTS model_counts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -47,7 +51,6 @@ class ModelCountTracker:
)
""")
- # Create index for faster queries
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_timestamp
ON model_counts(timestamp)
@@ -56,27 +59,90 @@ class ModelCountTracker:
conn.commit()
conn.close()
- def get_current_model_count(self) -> Dict:
+ def get_count_from_models_page(self) -> Optional[Dict]:
"""
- Fetch current model count from Hugging Face Hub API.
- Uses efficient pagination to get accurate count.
+ Get model count by scraping the Hugging Face models page.
+ Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
+ or from window.__hf_deferred["numTotalItems"] in the page script.
Returns:
- Dictionary with total count and breakdowns
+ Dictionary with total_models count, or None if extraction fails
"""
try:
- # Use pagination to efficiently count models
- # The API returns paginated results, so we iterate through pages
- # For large counts, we sample and extrapolate for speed
+ url = "https://huggingface.co/models"
+ response = httpx.get(url, timeout=10.0, follow_redirects=True)
+ response.raise_for_status()
+
+ html_content = response.text
+
+ deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
+ deferred_matches = re.findall(deferred_pattern, html_content)
+
+ if deferred_matches:
+ total_models = int(deferred_matches[0])
+ logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
+
+ return {
+ "total_models": total_models,
+ "timestamp": datetime.utcnow().isoformat(),
+ "source": "hf_models_page",
+ "models_by_library": {},
+ "models_by_pipeline": {},
+ "models_by_author": {}
+ }
+ pattern = r'
]*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)
'
+ matches = re.findall(pattern, html_content)
+
+ if matches:
+ count_str = matches[0].replace(',', '')
+ total_models = int(count_str)
+
+ logger.info(f"Extracted model count from div: {total_models}")
+
+ return {
+ "total_models": total_models,
+ "timestamp": datetime.utcnow().isoformat(),
+ "source": "hf_models_page",
+ "models_by_library": {},
+ "models_by_pipeline": {},
+ "models_by_author": {}
+ }
+
+ logger.warning("Could not find model count in HF models page HTML")
+ return None
+
+ except httpx.HTTPError as e:
+ logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
+ return None
+ except Exception as e:
+ logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
+ return None
+
+ def get_current_model_count(self, use_models_page: bool = True) -> Dict:
+ """
+ Fetch current model count from Hugging Face Hub.
+ Uses multiple strategies: models page scraping (fastest), then API enumeration.
+
+ Args:
+ use_models_page: Try to get count from HF models page first (default: True)
+
+ Returns:
+ Dictionary with total count and breakdowns
+ """
+ if use_models_page:
+ page_count = self.get_count_from_models_page()
+ if page_count:
+ return page_count
+
+ try:
total_count = 0
library_counts = {}
pipeline_counts = {}
- page_size = 1000 # Process in batches
- max_pages = 100 # Limit to prevent timeout (can adjust)
- sample_size = 10000 # Sample size for breakdowns
+ page_size = 1000
+ max_pages = 100
+ sample_size = 10000
- # Count total models efficiently
models_iter = self.api.list_models(full=False)
sampled_models = []
@@ -87,25 +153,18 @@ class ModelCountTracker:
if i < sample_size:
sampled_models.append(model)
- # Safety limit to prevent infinite loops
if i >= max_pages * page_size:
- # If we hit the limit, estimate total from sample
- # This is a rough estimate - for exact count, increase max_pages
break
- # Calculate breakdowns from sample (extrapolate if needed)
for model in sampled_models:
- # Count by library
if hasattr(model, 'library_name') and model.library_name:
lib = model.library_name
library_counts[lib] = library_counts.get(lib, 0) + 1
- # Count by pipeline
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
pipeline = model.pipeline_tag
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
- # If we sampled, scale up the breakdowns proportionally
if len(sampled_models) < total_count and len(sampled_models) > 0:
scale_factor = total_count / len(sampled_models)
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -118,7 +177,7 @@ class ModelCountTracker:
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
- print(f"Error fetching model count: {e}")
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
return {
"total_models": 0,
"models_by_library": {},
@@ -162,7 +221,7 @@ class ModelCountTracker:
conn.close()
return True
except Exception as e:
- print(f"Error recording count: {e}")
+ logger.error(f"Error recording count: {e}", exc_info=True)
return False
def get_historical_counts(
@@ -211,7 +270,7 @@ class ModelCountTracker:
conn.close()
return results
except Exception as e:
- print(f"Error fetching historical counts: {e}")
+ logger.error(f"Error fetching historical counts: {e}", exc_info=True)
return []
def get_latest_count(self) -> Optional[Dict]:
@@ -239,7 +298,7 @@ class ModelCountTracker:
}
return None
except Exception as e:
- print(f"Error fetching latest count: {e}")
+ logger.error(f"Error fetching latest count: {e}", exc_info=True)
return None
def get_growth_stats(self, days: int = 7) -> Dict:
diff --git a/backend/services/model_tracker_improved.py b/backend/services/model_tracker_improved.py
index 3264597131f2039cddffd74e89346e9310f09959..685504c1abf7ae2b27f336488862659cfdf19f24 100644
--- a/backend/services/model_tracker_improved.py
+++ b/backend/services/model_tracker_improved.py
@@ -11,12 +11,17 @@ Key improvements:
import os
import json
import sqlite3
+import logging
+import re
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from huggingface_hub import HfApi
import pandas as pd
from pathlib import Path
import time
+import httpx
+
+logger = logging.getLogger(__name__)
class ImprovedModelCountTracker:
@@ -78,72 +83,73 @@ class ImprovedModelCountTracker:
elapsed = (datetime.utcnow() - self._cache_timestamp).total_seconds()
return elapsed < self.cache_ttl
- def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False) -> Dict:
+ def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False, use_models_page: bool = True) -> Dict:
"""
- Fetch current model count from Hugging Face Hub API.
- Uses caching and efficient sampling strategies.
+ Fetch current model count from Hugging Face Hub.
+ Uses multiple strategies: models page scraping (fastest), API, or dataset snapshot.
Args:
use_cache: Whether to use cached results if available
force_refresh: Force refresh even if cache is valid
+ use_models_page: Try to get count from HF models page first (default: True)
Returns:
Dictionary with total count and breakdowns
"""
- # Check cache first
if use_cache and not force_refresh and self._is_cache_valid():
return self._cache
+ if use_models_page:
+ page_count = self.get_count_from_models_page()
+ if page_count:
+ dataset_count = self.get_count_from_dataset_snapshot()
+ if dataset_count and dataset_count.get("models_by_library"):
+ page_count["models_by_library"] = dataset_count.get("models_by_library", {})
+ page_count["models_by_pipeline"] = dataset_count.get("models_by_pipeline", {})
+ page_count["models_by_author"] = dataset_count.get("models_by_author", {})
+
+ self._cache = page_count
+ self._cache_timestamp = datetime.utcnow()
+ return page_count
+
try:
- # Strategy 1: Try to get count efficiently using pagination
- # The HfApi.list_models() returns an iterator, so we can count efficiently
total_count = 0
library_counts = {}
pipeline_counts = {}
author_counts = {}
- # For breakdowns, we sample a subset for efficiency
- sample_size = 20000 # Sample 20K models for breakdowns
- max_count_for_full_breakdown = 50000 # If less than this, do full breakdown
+ sample_size = 20000
+ max_count_for_full_breakdown = 50000
models_iter = self.api.list_models(full=False, sort="created", direction=-1)
sampled_models = []
start_time = time.time()
- timeout_seconds = 30 # Don't spend more than 30 seconds
+ timeout_seconds = 30
for i, model in enumerate(models_iter):
- # Check timeout
if time.time() - start_time > timeout_seconds:
- # If we hit timeout, use sampling strategy
break
total_count += 1
- # Sample models for breakdowns
if i < sample_size:
sampled_models.append(model)
- # For smaller datasets, we can do full breakdown
if total_count < max_count_for_full_breakdown:
- # Count by library
if hasattr(model, 'library_name') and model.library_name:
lib = model.library_name
library_counts[lib] = library_counts.get(lib, 0) + 1
- # Count by pipeline
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
pipeline = model.pipeline_tag
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
- # Count by author (extract from model_id)
if hasattr(model, 'id') and model.id:
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
author_counts[author] = author_counts.get(author, 0) + 1
- # If we sampled, calculate breakdowns from sample and extrapolate
if total_count > len(sampled_models) and len(sampled_models) > 0:
- # Calculate breakdowns from sample
for model in sampled_models:
if hasattr(model, 'library_name') and model.library_name:
lib = model.library_name
@@ -157,7 +163,6 @@ class ImprovedModelCountTracker:
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
author_counts[author] = author_counts.get(author, 0) + 1
- # Scale up breakdowns proportionally
if len(sampled_models) > 0:
scale_factor = total_count / len(sampled_models)
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -168,20 +173,19 @@ class ImprovedModelCountTracker:
"total_models": total_count,
"models_by_library": library_counts,
"models_by_pipeline": pipeline_counts,
- "models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]), # Top 20 authors
+ "models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]),
"timestamp": datetime.utcnow().isoformat(),
"sampling_used": total_count > len(sampled_models) if sampled_models else False,
"sample_size": len(sampled_models) if sampled_models else total_count
}
- # Update cache
self._cache = result
self._cache_timestamp = datetime.utcnow()
return result
except Exception as e:
- print(f"Error fetching model count: {e}")
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
return {
"total_models": 0,
"models_by_library": {},
@@ -191,6 +195,70 @@ class ImprovedModelCountTracker:
"error": str(e)
}
+ def get_count_from_models_page(self) -> Optional[Dict]:
+ """
+ Get model count by scraping the Hugging Face models page.
+ Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
+
+ Returns:
+ Dictionary with total_models count, or None if extraction fails
+ """
+ try:
+ url = "https://huggingface.co/models"
+ response = httpx.get(url, timeout=10.0, follow_redirects=True)
+ response.raise_for_status()
+
+ html_content = response.text
+
+ # Look for the pattern: 2,249,310
+ # The number is in the format with commas
+ pattern = r']*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)
'
+ matches = re.findall(pattern, html_content)
+
+ if matches:
+ # Take the first match and remove commas
+ count_str = matches[0].replace(',', '')
+ total_models = int(count_str)
+
+ logger.info(f"Extracted model count from HF models page: {total_models}")
+
+ return {
+ "total_models": total_models,
+ "timestamp": datetime.utcnow().isoformat(),
+ "source": "hf_models_page",
+ "models_by_library": {},
+ "models_by_pipeline": {},
+ "models_by_author": {}
+ }
+ else:
+ # Fallback: try to find the number in the window.__hf_deferred object
+ # The page has: window.__hf_deferred["numTotalItems"] = 2249312;
+ deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
+ deferred_matches = re.findall(deferred_pattern, html_content)
+
+ if deferred_matches:
+ total_models = int(deferred_matches[0])
+ logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
+
+ return {
+ "total_models": total_models,
+ "timestamp": datetime.utcnow().isoformat(),
+ "source": "hf_models_page_deferred",
+ "models_by_library": {},
+ "models_by_pipeline": {},
+ "models_by_author": {}
+ }
+
+ logger.warning("Could not find model count in HF models page HTML")
+ return None
+
+ except httpx.HTTPError as e:
+ logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
+ return None
+ except Exception as e:
+ logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
+ return None
+
def get_count_from_dataset_snapshot(self, dataset_name: str = "modelbiome/ai_ecosystem_withmodelcards") -> Optional[Dict]:
"""
Alternative method: Get count from dataset snapshot (like ai-ecosystem repo does).
@@ -205,11 +273,9 @@ class ImprovedModelCountTracker:
try:
from datasets import load_dataset
- # Load just metadata to get count quickly
dataset = load_dataset(dataset_name, split="train")
total_count = len(dataset)
- # Sample for breakdowns
sample_size = min(10000, total_count)
sample = dataset.shuffle(seed=42).select(range(sample_size))
@@ -225,7 +291,6 @@ class ImprovedModelCountTracker:
pipeline = item['pipeline_tag']
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
- # Scale up
if sample_size < total_count:
scale_factor = total_count / sample_size
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -239,7 +304,7 @@ class ImprovedModelCountTracker:
"source": "dataset_snapshot"
}
except Exception as e:
- print(f"Error loading from dataset snapshot: {e}")
+ logger.error(f"Error loading from dataset snapshot: {e}", exc_info=True)
return None
def record_count(self, count_data: Optional[Dict] = None, source: str = "api") -> bool:
@@ -279,7 +344,7 @@ class ImprovedModelCountTracker:
conn.close()
return True
except Exception as e:
- print(f"Error recording count: {e}")
+ logger.error(f"Error recording count: {e}", exc_info=True)
return False
def get_historical_counts(
@@ -329,7 +394,7 @@ class ImprovedModelCountTracker:
conn.close()
return results
except Exception as e:
- print(f"Error fetching historical counts: {e}")
+ logger.error(f"Error fetching historical counts: {e}", exc_info=True)
return []
def get_latest_count(self) -> Optional[Dict]:
@@ -358,7 +423,7 @@ class ImprovedModelCountTracker:
}
return None
except Exception as e:
- print(f"Error fetching latest count: {e}")
+ logger.error(f"Error fetching latest count: {e}", exc_info=True)
return None
def get_growth_stats(self, days: int = 7) -> Dict:
diff --git a/backend/utils/data_loader.py b/backend/utils/data_loader.py
index 72e0a3f0a2258b8959d977e026bd457a8b54458f..f454029940b9fd7ad273022a14034bad88a50e7f 100644
--- a/backend/utils/data_loader.py
+++ b/backend/utils/data_loader.py
@@ -50,18 +50,16 @@ class ModelDataLoader:
else:
df = df.copy()
- # Fill NaN values
text_fields = ['tags', 'pipeline_tag', 'library_name', 'modelCard']
for field in text_fields:
if field in df.columns:
df[field] = df[field].fillna('')
- # Combine text fields for embedding
df['combined_text'] = (
df.get('tags', '').astype(str) + ' ' +
df.get('pipeline_tag', '').astype(str) + ' ' +
df.get('library_name', '').astype(str) + ' ' +
- df['modelCard'].astype(str).str[:500] # Limit modelCard to first 500 chars
+ df['modelCard'].astype(str).str[:500]
)
return df
@@ -94,7 +92,6 @@ class ModelDataLoader:
else:
df = df.copy()
- # Optimized filtering with vectorized operations
if min_downloads is not None:
downloads_col = df.get('downloads', pd.Series([0] * len(df), index=df.index))
df = df[downloads_col >= min_downloads]
diff --git a/backend/utils/embeddings.py b/backend/utils/embeddings.py
index 61a30e0b737cbfe5957c4c67cc9680c4ad65b623..4e38f7d16e03cace63fa05796eeca55ef1670ef6 100644
--- a/backend/utils/embeddings.py
+++ b/backend/utils/embeddings.py
@@ -27,7 +27,7 @@ class ModelEmbedder:
def generate_embeddings(
self,
texts: List[str],
- batch_size: int = 128, # Increased default batch size for speed
+ batch_size: int = 128,
show_progress: bool = True
) -> np.ndarray:
"""
diff --git a/backend/utils/family_tree.py b/backend/utils/family_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43fc92a996b7f4e945862ec65540e9d7db52697
--- /dev/null
+++ b/backend/utils/family_tree.py
@@ -0,0 +1,66 @@
+"""Family tree utility functions."""
+import pandas as pd
+from typing import Dict
+
+def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
+ """Calculate family depth for each model."""
+ depths = {}
+ computing = set()
+
+ def get_depth(model_id: str) -> int:
+ if model_id in depths:
+ return depths[model_id]
+ if model_id in computing:
+ depths[model_id] = 0
+ return 0
+
+ computing.add(model_id)
+
+ try:
+ if df.index.name == 'model_id':
+ row = df.loc[model_id]
+ else:
+ rows = df[df.get('model_id', '') == model_id]
+ if len(rows) == 0:
+ depths[model_id] = 0
+ computing.remove(model_id)
+ return 0
+ row = rows.iloc[0]
+
+ parent_id = row.get('parent_model')
+ if parent_id and pd.notna(parent_id):
+ parent_str = str(parent_id)
+ if parent_str != 'nan' and parent_str != '':
+ if df.index.name == 'model_id' and parent_str in df.index:
+ depth = get_depth(parent_str) + 1
+ elif df.index.name != 'model_id':
+ parent_rows = df[df.get('model_id', '') == parent_str]
+ if len(parent_rows) > 0:
+ depth = get_depth(parent_str) + 1
+ else:
+ depth = 0
+ else:
+ depth = 0
+ else:
+ depth = 0
+ else:
+ depth = 0
+ except (KeyError, IndexError):
+ depth = 0
+
+ depths[model_id] = depth
+ computing.remove(model_id)
+ return depth
+
+ if df.index.name == 'model_id':
+ for model_id in df.index:
+ if model_id not in depths:
+ get_depth(str(model_id))
+ else:
+ for _, row in df.iterrows():
+ model_id = str(row.get('model_id', ''))
+ if model_id and model_id not in depths:
+ get_depth(model_id)
+
+ return depths
+
diff --git a/backend/utils/graph_embeddings.py b/backend/utils/graph_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14e6fc9480b2b970eaa2130ed48f1ecd76ffee5
--- /dev/null
+++ b/backend/utils/graph_embeddings.py
@@ -0,0 +1,177 @@
+"""
+Graph-aware embeddings for hierarchical model relationships.
+Uses Node2Vec to create embeddings that respect family tree structure.
+"""
+import numpy as np
+import pandas as pd
+from typing import Dict, List, Optional, Tuple
+import networkx as nx
+import pickle
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+try:
+ from node2vec import Node2Vec
+ NODE2VEC_AVAILABLE = True
+except ImportError:
+ NODE2VEC_AVAILABLE = False
+ logger.warning("node2vec not available. Install with: pip install node2vec")
+
+
+class GraphEmbedder:
+ """
+ Generate graph embeddings that respect hierarchical relationships.
+ Combines text embeddings with graph structure embeddings.
+ """
+
+ def __init__(self, dimensions: int = 128, walk_length: int = 30, num_walks: int = 200):
+ """
+ Initialize graph embedder.
+
+ Args:
+ dimensions: Embedding dimensions
+ walk_length: Length of random walks
+ num_walks: Number of walks per node
+ """
+ self.dimensions = dimensions
+ self.walk_length = walk_length
+ self.num_walks = num_walks
+ self.graph: Optional[nx.DiGraph] = None
+ self.embeddings: Optional[np.ndarray] = None
+ self.model: Optional[Node2Vec] = None
+
+ def build_family_graph(self, df: pd.DataFrame) -> nx.DiGraph:
+ """
+ Build directed graph from family relationships.
+
+ Args:
+ df: DataFrame with model_id and parent_model columns
+
+ Returns:
+ NetworkX DiGraph
+ """
+ graph = nx.DiGraph()
+
+ for idx, row in df.iterrows():
+ model_id = str(row.get('model_id', idx))
+ graph.add_node(model_id)
+
+ parent_id = row.get('parent_model')
+ if parent_id and pd.notna(parent_id):
+ parent_str = str(parent_id)
+ if parent_str != 'nan' and parent_str != '':
+ graph.add_edge(parent_str, model_id)
+
+ self.graph = graph
+ logger.info(f"Built graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
+ return graph
+
+ def generate_graph_embeddings(
+ self,
+ graph: Optional[nx.DiGraph] = None,
+ workers: int = 4
+ ) -> Dict[str, np.ndarray]:
+ """
+ Generate Node2Vec embeddings for graph nodes.
+
+ Args:
+ graph: NetworkX graph (uses self.graph if None)
+ workers: Number of parallel workers
+
+ Returns:
+ Dictionary mapping model_id to embedding vector
+ """
+ if not NODE2VEC_AVAILABLE:
+ logger.warning("Node2Vec not available, returning empty embeddings")
+ return {}
+
+ if graph is None:
+ graph = self.graph
+
+ if graph is None or graph.number_of_nodes() == 0:
+ logger.warning("No graph available for embedding generation")
+ return {}
+
+ try:
+ node2vec = Node2Vec(
+ graph,
+ dimensions=self.dimensions,
+ walk_length=self.walk_length,
+ num_walks=self.num_walks,
+ workers=workers
+ )
+
+ model = node2vec.fit(window=10, min_count=1, batch_words=4)
+ self.model = model
+
+ embeddings_dict = {}
+ for node in graph.nodes():
+ if node in model.wv:
+ embeddings_dict[node] = model.wv[node]
+
+ logger.info(f"Generated graph embeddings for {len(embeddings_dict)} nodes")
+ return embeddings_dict
+
+ except Exception as e:
+ logger.error(f"Error generating graph embeddings: {e}", exc_info=True)
+ return {}
+
+ def combine_embeddings(
+ self,
+ text_embeddings: np.ndarray,
+ graph_embeddings: Dict[str, np.ndarray],
+ model_ids: List[str],
+ text_weight: float = 0.7,
+ graph_weight: float = 0.3
+ ) -> np.ndarray:
+ """
+ Combine text and graph embeddings with weighted average.
+
+ Args:
+ text_embeddings: Text-based embeddings (n_samples, text_dim)
+ graph_embeddings: Graph embeddings dictionary
+ model_ids: List of model IDs corresponding to text_embeddings
+ text_weight: Weight for text embeddings
+ graph_weight: Weight for graph embeddings
+
+ Returns:
+ Combined embeddings (n_samples, combined_dim)
+ """
+ if not graph_embeddings:
+ return text_embeddings
+
+ text_dim = text_embeddings.shape[1]
+ graph_dim = next(iter(graph_embeddings.values())).shape[0]
+
+ combined = np.zeros((len(model_ids), text_dim + graph_dim))
+
+ for i, model_id in enumerate(model_ids):
+ model_id_str = str(model_id)
+
+ text_emb = text_embeddings[i]
+ graph_emb = graph_embeddings.get(model_id_str, np.zeros(graph_dim))
+
+ normalized_text = text_emb / (np.linalg.norm(text_emb) + 1e-8)
+ normalized_graph = graph_emb / (np.linalg.norm(graph_emb) + 1e-8)
+
+ combined[i] = np.concatenate([
+ normalized_text * text_weight,
+ normalized_graph * graph_weight
+ ])
+
+ return combined
+
+ def save_embeddings(self, embeddings: Dict[str, np.ndarray], filepath: str):
+ """Save graph embeddings to disk."""
+ os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
+ with open(filepath, 'wb') as f:
+ pickle.dump(embeddings, f)
+
+ def load_embeddings(self, filepath: str) -> Dict[str, np.ndarray]:
+ """Load graph embeddings from disk."""
+ with open(filepath, 'rb') as f:
+ return pickle.load(f)
+
+
diff --git a/backend/utils/network_analysis.py b/backend/utils/network_analysis.py
index a82801bc183cedeabd995fc870507be4d3c9b44c..7f4983ed201ce1a784b1924c8a141f88822636d6 100644
--- a/backend/utils/network_analysis.py
+++ b/backend/utils/network_analysis.py
@@ -1,6 +1,7 @@
"""
Network analysis module inspired by Open Syllabus Project.
Builds co-occurrence networks for models based on shared contexts.
+Supports multiple relationship types: finetune, quantized, adapter, merge.
"""
import pandas as pd
import numpy as np
@@ -8,12 +9,66 @@ from collections import Counter
from itertools import combinations
from typing import List, Dict, Tuple, Optional, Set
import networkx as nx
+import ast
+from datetime import datetime
+
+
+def _parse_parent_list(value) -> List[str]:
+ """
+ Parse parent model list from string/eval format.
+ Handles both string representations and actual lists.
+ """
+ if pd.isna(value) or value == '' or str(value) == 'nan':
+ return []
+
+ try:
+ if isinstance(value, str):
+ if value.startswith('[') or value.startswith('('):
+ parsed = ast.literal_eval(value)
+ else:
+ parsed = [value]
+ else:
+ parsed = value
+
+ if isinstance(parsed, list):
+ return [str(p) for p in parsed if p and str(p) != 'nan']
+ elif parsed:
+ return [str(parsed)]
+ else:
+ return []
+ except (ValueError, SyntaxError):
+ return []
+
+
+def _get_all_parents(row: pd.Series) -> Dict[str, List[str]]:
+ """
+ Extract all parent types from a row.
+ Returns dict mapping relationship type to list of parent IDs.
+ """
+ parents = {}
+
+ parent_columns = {
+ 'parent_model': 'parent',
+ 'finetune_parent': 'finetune',
+ 'quantized_parent': 'quantized',
+ 'adapter_parent': 'adapter',
+ 'merge_parent': 'merge'
+ }
+
+ for col, rel_type in parent_columns.items():
+ if col in row:
+ parent_list = _parse_parent_list(row.get(col))
+ if parent_list:
+ parents[rel_type] = parent_list
+
+ return parents
class ModelNetworkBuilder:
"""
Build network graphs for models based on co-occurrence patterns.
Similar to Open Syllabus approach of connecting texts that appear together.
+ Supports multiple relationship types: finetune, quantized, adapter, merge.
"""
def __init__(self, df: pd.DataFrame):
@@ -22,13 +77,13 @@ class ModelNetworkBuilder:
Args:
df: DataFrame with model data including model_id, library_name,
- pipeline_tag, tags, parent_model, downloads, likes
+ pipeline_tag, tags, parent_model, finetune_parent, quantized_parent,
+ adapter_parent, merge_parent, downloads, likes, createdAt
"""
self.df = df.copy()
if 'model_id' not in self.df.columns:
raise ValueError("DataFrame must contain 'model_id' column")
- # Ensure model_id is index for fast lookups
if self.df.index.name != 'model_id':
if 'model_id' in self.df.columns:
self.df.set_index('model_id', drop=False, inplace=True)
@@ -208,23 +263,41 @@ class ModelNetworkBuilder:
def build_family_tree_network(
self,
root_model_id: str,
- max_depth: int = 5
+ max_depth: Optional[int] = 5,
+ include_edge_attributes: bool = True,
+ filter_edge_types: Optional[List[str]] = None
) -> nx.DiGraph:
"""
- Build directed graph of model family tree.
+ Build directed graph of model family tree with multiple relationship types.
Args:
root_model_id: Root model to start from
- max_depth: Maximum depth to traverse
+ max_depth: Maximum depth to traverse. If None, traverses entire tree without limit.
+ include_edge_attributes: Whether to calculate edge attributes (change in likes, downloads, etc.)
+ filter_edge_types: List of edge types to include (e.g., ['finetune', 'quantized']).
+ If None, includes all types.
Returns:
- NetworkX DiGraph representing family tree
+ NetworkX DiGraph representing family tree with edge types and attributes
"""
graph = nx.DiGraph()
visited = set()
- def add_family(current_id: str, depth: int):
- if depth <= 0 or current_id in visited:
+ children_index: Dict[str, List[Tuple[str, str]]] = {}
+ for idx, row in self.df.iterrows():
+ model_id = str(row.get('model_id', idx))
+ all_parents = _get_all_parents(row)
+
+ for rel_type, parent_list in all_parents.items():
+ for parent_id in parent_list:
+ if parent_id not in children_index:
+ children_index[parent_id] = []
+ children_index[parent_id].append((model_id, rel_type))
+
+ def add_family(current_id: str, depth: Optional[int]):
+ if current_id in visited:
+ return
+ if depth is not None and depth <= 0:
return
visited.add(current_id)
@@ -233,28 +306,98 @@ class ModelNetworkBuilder:
row = self.df.loc[current_id]
- # Add node
graph.add_node(str(current_id))
graph.nodes[str(current_id)]['title'] = self._format_title(current_id)
graph.nodes[str(current_id)]['freq'] = int(row.get('downloads', 0))
+ graph.nodes[str(current_id)]['likes'] = int(row.get('likes', 0))
+ graph.nodes[str(current_id)]['downloads'] = int(row.get('downloads', 0))
+ graph.nodes[str(current_id)]['library'] = str(row.get('library_name', '')) if pd.notna(row.get('library_name')) else ''
+ graph.nodes[str(current_id)]['pipeline'] = str(row.get('pipeline_tag', '')) if pd.notna(row.get('pipeline_tag')) else ''
- # Add edge to parent
- parent_id = row.get('parent_model')
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan':
- parent_id_str = str(parent_id)
- graph.add_edge(parent_id_str, str(current_id))
- add_family(parent_id_str, depth - 1)
+ createdAt = row.get('createdAt')
+ if pd.notna(createdAt):
+ graph.nodes[str(current_id)]['createdAt'] = str(createdAt)
- # Add edges to children
- children = self.df[self.df.get('parent_model', '') == current_id]
- for child_id, child_row in children.iterrows():
+ all_parents = _get_all_parents(row)
+ for rel_type, parent_list in all_parents.items():
+ if filter_edge_types and rel_type not in filter_edge_types:
+ continue
+
+ for parent_id in parent_list:
+ if parent_id in self.df.index:
+ graph.add_edge(parent_id, str(current_id))
+ graph[parent_id][str(current_id)]['edge_types'] = [rel_type]
+ graph[parent_id][str(current_id)]['edge_type'] = rel_type
+
+ next_depth = depth - 1 if depth is not None else None
+ add_family(parent_id, next_depth)
+
+ children = children_index.get(current_id, [])
+ for child_id, rel_type in children:
+ if filter_edge_types and rel_type not in filter_edge_types:
+ continue
+
if str(child_id) not in visited:
- graph.add_edge(str(current_id), str(child_id))
- add_family(str(child_id), depth - 1)
+ if not graph.has_edge(str(current_id), child_id):
+ graph.add_edge(str(current_id), child_id)
+ graph[str(current_id)][child_id]['edge_types'] = [rel_type]
+ graph[str(current_id)][child_id]['edge_type'] = rel_type
+ else:
+ if rel_type not in graph[str(current_id)][child_id].get('edge_types', []):
+ graph[str(current_id)][child_id]['edge_types'].append(rel_type)
+
+ next_depth = depth - 1 if depth is not None else None
+ add_family(child_id, next_depth)
add_family(root_model_id, max_depth)
+
+ if include_edge_attributes:
+ self._add_edge_attributes(graph)
+
return graph
+ def _add_edge_attributes(self, graph: nx.DiGraph):
+ """
+ Add edge attributes like change in likes, downloads, time difference.
+ Similar to the notebook's edge attribute calculation.
+ """
+ for edge in graph.edges():
+ parent_model = edge[0]
+ model_id = edge[1]
+
+ if parent_model not in graph.nodes() or model_id not in graph.nodes():
+ continue
+
+ parent_likes = graph.nodes[parent_model].get('likes', 0)
+ model_likes = graph.nodes[model_id].get('likes', 0)
+ parent_downloads = graph.nodes[parent_model].get('downloads', 0)
+ model_downloads = graph.nodes[model_id].get('downloads', 0)
+
+ graph.edges[edge]['change_in_likes'] = model_likes - parent_likes
+ if parent_likes != 0:
+ graph.edges[edge]['percentage_change_in_likes'] = (model_likes - parent_likes) / parent_likes
+ else:
+ graph.edges[edge]['percentage_change_in_likes'] = np.nan
+
+ graph.edges[edge]['change_in_downloads'] = model_downloads - parent_downloads
+ if parent_downloads != 0:
+ graph.edges[edge]['percentage_change_in_downloads'] = (model_downloads - parent_downloads) / parent_downloads
+ else:
+ graph.edges[edge]['percentage_change_in_downloads'] = np.nan
+
+ parent_created = graph.nodes[parent_model].get('createdAt')
+ model_created = graph.nodes[model_id].get('createdAt')
+
+ if parent_created and model_created:
+ try:
+ parent_dt = datetime.strptime(str(parent_created), '%Y-%m-%dT%H:%M:%S.%fZ')
+ model_dt = datetime.strptime(str(model_created), '%Y-%m-%dT%H:%M:%S.%fZ')
+ graph.edges[edge]['change_in_createdAt_days'] = (model_dt - parent_dt).days
+ except (ValueError, TypeError):
+ graph.edges[edge]['change_in_createdAt_days'] = np.nan
+ else:
+ graph.edges[edge]['change_in_createdAt_days'] = np.nan
+
def export_graphml(self, graph: nx.Graph, filename: str):
"""Export graph to GraphML format (like Open Syllabus)."""
nx.write_graphml(graph, filename)
diff --git a/frontend/.npmrc b/frontend/.npmrc
index 8c1d73ae125936d71261229d76e34cba870dc990..7356968121dade60fef934ab95fda9fd10761283 100644
--- a/frontend/.npmrc
+++ b/frontend/.npmrc
@@ -1,2 +1,4 @@
legacy-peer-deps=true
+
+
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index 2140e6ab206bff4c53ebf3d29a263db0473beb05..74fbbb1ee5dfd4b61fe42f2cd46ed6ab481b3270 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -32,7 +32,8 @@
"react-dom": "^18.2.0",
"react-scripts": "5.0.1",
"three": "^0.160.1",
- "typescript": "^5.0.0"
+ "typescript": "^5.0.0",
+ "zustand": "^5.0.8"
}
},
"node_modules/@alloc/quick-lru": {
diff --git a/frontend/package.json b/frontend/package.json
index 6c5dd299ae4098925421aa759357f3c29aa07fdc..df73512b094c59e8854eded3c1053fd2a4750ba0 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -28,7 +28,8 @@
"react-dom": "^18.2.0",
"react-scripts": "5.0.1",
"three": "^0.160.1",
- "typescript": "^5.0.0"
+ "typescript": "^5.0.0",
+ "zustand": "^5.0.8"
},
"scripts": {
"start": "react-scripts start",
diff --git a/frontend/public/index.html b/frontend/public/index.html
index 0a04459565a0ba5a79038d5b8ae8aec4f41feb2a..023b865876569432bbc201528bd683184f41b8b3 100644
--- a/frontend/public/index.html
+++ b/frontend/public/index.html
@@ -10,7 +10,7 @@
/>
-
+
Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face
diff --git a/frontend/src/App.css b/frontend/src/App.css
index a8317e31f12ae417be65760b5d5b95387dde1e0b..76b15732029dc3fd1a9ef6975b18794b5cf287b1 100644
--- a/frontend/src/App.css
+++ b/frontend/src/App.css
@@ -7,86 +7,24 @@
}
.App-header {
- background: linear-gradient(135deg, #1a237e 0%, #283593 20%, #3949ab 40%, #5e35b1 60%, #7b1fa2 80%, #6a1b9a 100%);
- background-size: 200% 200%;
- animation: gradientShift 20s ease infinite;
+ background: #2d2d2d;
color: #ffffff;
- padding: 3rem 2.5rem;
+ padding: 2.5rem 2rem;
text-align: center;
- border-bottom: 2px solid rgba(100, 181, 246, 0.3);
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.25), 0 2px 10px rgba(123, 31, 162, 0.3);
+ border-bottom: 1px solid #404040;
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
position: relative;
- overflow: hidden;
}
-.App-header::before {
- content: '';
- position: absolute;
- top: 0;
- left: 0;
- right: 0;
- bottom: 0;
- background:
- radial-gradient(circle at 20% 50%, rgba(100, 181, 246, 0.15) 0%, transparent 50%),
- radial-gradient(circle at 80% 80%, rgba(156, 39, 176, 0.1) 0%, transparent 50%),
- radial-gradient(circle at 40% 20%, rgba(33, 150, 243, 0.1) 0%, transparent 50%);
- pointer-events: none;
- animation: pulse 8s ease-in-out infinite;
-}
-
-.App-header::after {
- content: '';
- position: absolute;
- top: 0;
- left: 0;
- right: 0;
- bottom: 0;
- background-image:
- repeating-linear-gradient(
- 0deg,
- transparent,
- transparent 2px,
- rgba(255, 255, 255, 0.03) 2px,
- rgba(255, 255, 255, 0.03) 4px
- );
- pointer-events: none;
- opacity: 0.5;
-}
-
-@keyframes gradientShift {
- 0% {
- background-position: 0% 50%;
- }
- 50% {
- background-position: 100% 50%;
- }
- 100% {
- background-position: 0% 50%;
- }
-}
-
-@keyframes pulse {
- 0%, 100% {
- opacity: 1;
- }
- 50% {
- opacity: 0.8;
- }
-}
.App-header h1 {
margin: 0 0 1rem 0;
- font-size: 2.25rem;
- font-weight: 700;
- letter-spacing: -0.02em;
- line-height: 1.2;
- position: relative;
- z-index: 1;
- text-shadow: 0 2px 8px rgba(0, 0, 0, 0.4), 0 4px 16px rgba(123, 31, 162, 0.3);
- background: linear-gradient(180deg, #ffffff 0%, #e1bee7 100%);
- -webkit-background-clip: text;
- -webkit-text-fill-color: transparent;
- background-clip: text;
+ font-size: 2rem;
+ font-weight: 600;
+ letter-spacing: -0.01em;
+ line-height: 1.3;
+ color: #ffffff;
+ text-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
}
.App-header p {
@@ -122,23 +60,17 @@
}
.stats span {
- padding: 0.75rem 1.5rem;
- background: rgba(255, 255, 255, 0.15);
- border-radius: 12px;
- backdrop-filter: blur(20px);
- -webkit-backdrop-filter: blur(20px);
- border: 2px solid rgba(255, 255, 255, 0.25);
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1), inset 0 1px 0 rgba(255, 255, 255, 0.3);
- font-weight: 600;
- letter-spacing: 0.02em;
+ padding: 0.625rem 1.25rem;
+ background: rgba(255, 255, 255, 0.1);
+ border-radius: 6px;
+ border: 1px solid rgba(255, 255, 255, 0.2);
+ transition: all 0.2s ease;
+ font-weight: 500;
}
.stats span:hover {
- background: rgba(255, 255, 255, 0.25);
- transform: translateY(-2px) scale(1.05);
- box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15), inset 0 1px 0 rgba(255, 255, 255, 0.4);
- border-color: rgba(255, 255, 255, 0.4);
+ background: rgba(255, 255, 255, 0.15);
+ transform: translateY(-1px);
}
.main-content {
@@ -149,10 +81,9 @@
.sidebar {
width: 340px;
padding: 1.5rem;
- background: linear-gradient(to bottom, #fafafa 0%, #ffffff 100%);
+ background: #fafafa;
overflow-y: auto;
- border-right: 2px solid #e0e0e0;
- box-shadow: 2px 0 8px rgba(0, 0, 0, 0.05);
+ border-right: 1px solid #e0e0e0;
}
.sidebar h2 {
@@ -164,12 +95,11 @@
}
.sidebar h3 {
- font-size: 0.95rem;
- font-weight: 700;
- color: #5e35b1;
- margin: 0 0 1rem 0;
+ font-size: 0.9rem;
+ font-weight: 600;
+ color: #2d2d2d;
+ margin: 0 0 0.875rem 0;
letter-spacing: -0.01em;
- text-transform: none;
}
.sidebar label {
@@ -202,9 +132,8 @@
.sidebar input[type="text"]:focus,
.sidebar select:focus {
outline: none;
- border-color: #5e35b1;
- box-shadow: 0 0 0 3px rgba(94, 53, 177, 0.12), 0 2px 6px rgba(0, 0, 0, 0.1);
- transform: translateY(-1px);
+ border-color: #4a4a4a;
+ box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.08);
}
.sidebar input[type="range"] {
@@ -227,20 +156,20 @@
.sidebar input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
- width: 20px;
- height: 20px;
+ width: 18px;
+ height: 18px;
border-radius: 50%;
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
+ background: #4a4a4a;
cursor: pointer;
- box-shadow: 0 2px 6px rgba(94, 53, 177, 0.3), 0 4px 12px rgba(94, 53, 177, 0.2);
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
- border: 3px solid #ffffff;
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
+ transition: all 0.2s ease;
+ border: 2px solid #ffffff;
}
.sidebar input[type="range"]::-webkit-slider-thumb:hover {
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
- transform: scale(1.2);
- box-shadow: 0 3px 8px rgba(94, 53, 177, 0.4), 0 6px 16px rgba(94, 53, 177, 0.3);
+ background: #2d2d2d;
+ transform: scale(1.1);
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
}
.sidebar input[type="range"]::-webkit-slider-thumb:active {
@@ -248,20 +177,20 @@
}
.sidebar input[type="range"]::-moz-range-thumb {
- width: 20px;
- height: 20px;
+ width: 18px;
+ height: 18px;
border-radius: 50%;
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
+ background: #4a4a4a;
cursor: pointer;
- border: 3px solid #ffffff;
- box-shadow: 0 2px 6px rgba(94, 53, 177, 0.3), 0 4px 12px rgba(94, 53, 177, 0.2);
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
+ border: 2px solid #ffffff;
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
+ transition: all 0.2s ease;
}
.sidebar input[type="range"]::-moz-range-thumb:hover {
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
- transform: scale(1.2);
- box-shadow: 0 3px 8px rgba(94, 53, 177, 0.4), 0 6px 16px rgba(94, 53, 177, 0.3);
+ background: #2d2d2d;
+ transform: scale(1.1);
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
}
.sidebar input[type="range"]::-moz-range-thumb:active {
@@ -288,17 +217,16 @@
.sidebar-section {
background: #ffffff;
- border-radius: 8px;
+ border-radius: 6px;
padding: 1.25rem;
- margin-bottom: 1.25rem;
+ margin-bottom: 1rem;
border: 1px solid #e0e0e0;
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
- transition: all 0.3s ease;
+ transition: all 0.2s ease;
}
.sidebar-section:hover {
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.12);
border-color: #d0d0d0;
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
}
.filter-chip {
@@ -380,22 +308,21 @@
}
.loading {
- color: #5e35b1;
+ color: #2d2d2d;
font-weight: 600;
- background: linear-gradient(135deg, #f5f3ff 0%, #ede7f6 100%);
- border: 2px solid #d1c4e9;
- box-shadow: 0 4px 12px rgba(94, 53, 177, 0.1);
+ background: #f5f5f5;
+ border: 1px solid #d0d0d0;
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
}
.loading::after {
content: '';
- width: 48px;
- height: 48px;
- border: 5px solid #e1bee7;
- border-top-color: #5e35b1;
- border-right-color: #7b1fa2;
+ width: 40px;
+ height: 40px;
+ border: 4px solid #e0e0e0;
+ border-top-color: #4a4a4a;
border-radius: 50%;
- animation: spin 0.8s cubic-bezier(0.68, -0.55, 0.265, 1.55) infinite;
+ animation: spin 0.8s linear infinite;
}
@keyframes spin {
@@ -403,101 +330,62 @@
}
.error {
- color: #c62828;
- background: linear-gradient(135deg, #ffebee 0%, #ffcdd2 100%);
- border-radius: 12px;
- border: 2px solid #ef5350;
+ color: #d32f2f;
+ background: #ffebee;
+ border-radius: 8px;
+ border: 1px solid #ffcdd2;
max-width: 550px;
margin: 0 auto;
- box-shadow: 0 4px 12px rgba(198, 40, 40, 0.15);
font-weight: 500;
}
-.error::before {
- content: '⚠️';
- font-size: 2.5rem;
- display: block;
- margin-bottom: 0.5rem;
-}
-
.empty {
- color: #616161;
- background: linear-gradient(135deg, #fafafa 0%, #f5f5f5 100%);
- border-radius: 12px;
- border: 2px solid #e0e0e0;
+ color: #6a6a6a;
+ background: #f5f5f5;
+ border-radius: 8px;
+ border: 1px solid #e0e0e0;
max-width: 550px;
margin: 0 auto;
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
font-weight: 500;
}
-.empty::before {
- content: '🔍';
- font-size: 2.5rem;
- display: block;
- margin-bottom: 0.5rem;
-}
-
.btn {
padding: 0.625rem 1.25rem;
- border-radius: 6px;
+ border-radius: 4px;
border: none;
font-size: 0.9rem;
font-weight: 600;
cursor: pointer;
- transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
+ transition: all 0.2s ease;
font-family: 'Instrument Sans', sans-serif;
display: inline-flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
- position: relative;
- overflow: hidden;
-}
-
-.btn::before {
- content: '';
- position: absolute;
- top: 50%;
- left: 50%;
- width: 0;
- height: 0;
- border-radius: 50%;
- background: rgba(255, 255, 255, 0.3);
- transform: translate(-50%, -50%);
- transition: width 0.6s, height 0.6s;
}
-.btn:hover::before {
- width: 300px;
- height: 300px;
-}
.btn-primary {
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
+ background: #2d2d2d;
color: white;
- box-shadow: 0 2px 4px rgba(94, 53, 177, 0.3);
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15);
}
.btn-primary:hover {
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
- transform: translateY(-2px);
- box-shadow: 0 4px 12px rgba(94, 53, 177, 0.4);
+ background: #1a1a1a;
+ transform: translateY(-1px);
+ box-shadow: 0 3px 8px rgba(0, 0, 0, 0.2);
}
.btn-secondary {
background: #f5f5f5;
- color: #1a1a1a;
- border: 2px solid #e0e0e0;
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
+ color: #2d2d2d;
+ border: 1px solid #d0d0d0;
}
.btn-secondary:hover {
- background: #ffffff;
- border-color: #5e35b1;
- color: #5e35b1;
- transform: translateY(-1px);
- box-shadow: 0 2px 6px rgba(0, 0, 0, 0.12);
+ background: #e8e8e8;
+ border-color: #b0b0b0;
}
.btn-small {
@@ -585,7 +473,7 @@
--text-primary: #ffffff;
--text-secondary: #cccccc;
--border-color: #444444;
- --accent-color: #64b5f6;
+ --accent-color: #4a4a4a;
}
[data-theme="light"] {
@@ -595,32 +483,31 @@
--text-primary: #1a1a1a;
--text-secondary: #666666;
--border-color: #d0d0d0;
- --accent-color: #1976d2;
+ --accent-color: #4a4a4a;
}
/* Random Model Button */
.random-model-btn {
display: flex;
align-items: center;
- gap: 0.5rem;
- padding: 0.5rem 1rem;
- background: var(--accent-color, #4a90e2);
+ justify-content: center;
+ padding: 0.625rem 1.25rem;
+ background: #2d2d2d;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 0.9rem;
font-family: 'Instrument Sans', sans-serif;
- font-weight: 500;
+ font-weight: 600;
transition: all 0.2s;
width: 100%;
- justify-content: center;
}
.random-model-btn:hover:not(:disabled) {
- background: var(--accent-color, #357abd);
+ background: #1a1a1a;
transform: translateY(-1px);
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2);
}
.random-model-btn:disabled {
@@ -628,10 +515,6 @@
cursor: not-allowed;
}
-.random-icon {
- font-size: 1.1rem;
-}
-
/* Zoom Slider */
.zoom-slider-container {
margin-bottom: 1rem;
@@ -859,7 +742,7 @@
width: 18px;
height: 18px;
cursor: pointer;
- accent-color: #5e35b1;
+ accent-color: #4a4a4a;
margin-right: 0.5rem;
}
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index fff1c9f7e7175148d2286e21e6d02709cf7dce24..b7028115b168e32fdba258e31fafaab691d453d3 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -506,28 +506,24 @@ function App() {
alignItems: 'center',
marginBottom: '1.5rem',
paddingBottom: '1rem',
- borderBottom: '2px solid #e8e8e8'
+ borderBottom: '1px solid #e0e0e0'
}}>
Filters & Controls
{activeFilterCount > 0 && (
{activeFilterCount} active
@@ -537,40 +533,40 @@ function App() {
{/* Filter Results Count */}
{!loading && data.length > 0 && (
-
+
{data.length.toLocaleString()}
-
+
{data.length === 1 ? 'model' : 'models'}
{embeddingType === 'graph-aware' && (
- 🌐 Graph
+ Graph
)}
{filteredCount !== null && filteredCount !== data.length && (
-
+
of {filteredCount.toLocaleString()} matching
)}
{stats && filteredCount !== null && filteredCount < stats.total_models && (
-
+
from {stats.total_models.toLocaleString()} total
)}
@@ -579,15 +575,7 @@ function App() {
{/* Search Section */}
-
- 🔍 Search Models
-
+
Search Models
-
- 📊 Popularity Filters
-
+
Popularity Filters