Apply clean grayscale design, remove all emojis
Browse filesRemove gradient backgrounds, purple theme, and emojis throughout UI
This view is limited to 50 files because it contains too many changes. Β
See raw diff
- backend/README.md +0 -24
- backend/api/dependencies.py +23 -0
- backend/api/main.py +764 -427
- backend/api/routes/__init__.py +6 -0
- backend/api/routes/clusters.py +102 -0
- backend/api/routes/models.py +247 -0
- backend/api/routes/stats.py +37 -0
- backend/config/requirements.txt +1 -0
- backend/core/__init__.py +2 -0
- backend/core/config.py +23 -0
- backend/core/exceptions.py +18 -0
- backend/models/__init__.py +2 -0
- backend/models/schemas.py +22 -0
- backend/scripts/export_binary.py +263 -0
- backend/services/model_tracker.py +83 -24
- backend/services/model_tracker_improved.py +95 -30
- backend/utils/data_loader.py +1 -4
- backend/utils/embeddings.py +1 -1
- backend/utils/family_tree.py +66 -0
- backend/utils/graph_embeddings.py +177 -0
- backend/utils/network_analysis.py +163 -20
- frontend/.npmrc +2 -0
- frontend/package-lock.json +2 -1
- frontend/package.json +2 -1
- frontend/public/index.html +1 -1
- frontend/src/App.css +85 -202
- frontend/src/App.tsx +49 -118
- frontend/src/components/PaperPlots.css +0 -92
- frontend/src/components/PaperPlots.tsx +0 -755
- frontend/src/components/ScatterPlot.tsx +0 -7
- frontend/src/components/controls/ClusterFilter.css +122 -0
- frontend/src/components/controls/ClusterFilter.tsx +142 -0
- frontend/src/components/controls/NodeDensitySlider.css +31 -0
- frontend/src/components/controls/NodeDensitySlider.tsx +39 -0
- frontend/src/components/controls/RandomModelButton.tsx +32 -0
- frontend/src/components/controls/RenderingStyleSelector.css +37 -0
- frontend/src/components/controls/RenderingStyleSelector.tsx +43 -0
- frontend/src/components/controls/ThemeToggle.tsx +22 -0
- frontend/src/components/controls/VisualizationModeButtons.css +65 -0
- frontend/src/components/controls/VisualizationModeButtons.tsx +46 -0
- frontend/src/components/controls/ZoomSlider.tsx +43 -0
- frontend/src/components/layout/SearchBar.css +181 -0
- frontend/src/components/layout/SearchBar.tsx +201 -0
- frontend/src/components/{FileTree.css β modals/FileTree.css} +171 -3
- frontend/src/components/{FileTree.tsx β modals/FileTree.tsx} +314 -26
- frontend/src/components/{ModelModal.css β modals/ModelModal.css} +43 -14
- frontend/src/components/{ModelModal.tsx β modals/ModelModal.tsx} +17 -9
- frontend/src/components/{ColorLegend.css β ui/ColorLegend.css} +0 -0
- frontend/src/components/{ColorLegend.tsx β ui/ColorLegend.tsx} +1 -1
- frontend/src/components/{ErrorBoundary.tsx β ui/ErrorBoundary.tsx} +0 -0
backend/README.md
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# Backend API
|
| 2 |
-
|
| 3 |
-
FastAPI backend for serving model data to the React frontend.
|
| 4 |
-
|
| 5 |
-
## Structure
|
| 6 |
-
|
| 7 |
-
- `api/` - API routes and main application
|
| 8 |
-
- `services/` - External service integrations (arXiv, model tracking, scheduling)
|
| 9 |
-
- `utils/` - Utility modules (data loading, embeddings, dimensionality reduction, clustering, network analysis)
|
| 10 |
-
- `config/` - Configuration files (requirements.txt, etc.)
|
| 11 |
-
- `cache/` - Cached data (embeddings, reduced dimensions)
|
| 12 |
-
|
| 13 |
-
## Running
|
| 14 |
-
|
| 15 |
-
```bash
|
| 16 |
-
cd backend
|
| 17 |
-
uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
|
| 18 |
-
```
|
| 19 |
-
|
| 20 |
-
## Environment Variables
|
| 21 |
-
|
| 22 |
-
- `SAMPLE_SIZE` - Limit number of models to load (for development). Set to 0 or leave unset to load all models.
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/api/dependencies.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared dependencies for API routes."""
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional, Dict
|
| 5 |
+
from utils.data_loader import ModelDataLoader
|
| 6 |
+
from utils.embeddings import ModelEmbedder
|
| 7 |
+
from utils.dimensionality_reduction import DimensionReducer
|
| 8 |
+
from utils.graph_embeddings import GraphEmbedder
|
| 9 |
+
|
| 10 |
+
# Global state (initialized in startup) - these are module-level variables
|
| 11 |
+
# that will be updated by main.py during startup
|
| 12 |
+
data_loader = ModelDataLoader()
|
| 13 |
+
embedder: Optional[ModelEmbedder] = None
|
| 14 |
+
graph_embedder: Optional[GraphEmbedder] = None
|
| 15 |
+
reducer: Optional[DimensionReducer] = None
|
| 16 |
+
df: Optional[pd.DataFrame] = None
|
| 17 |
+
embeddings: Optional[np.ndarray] = None
|
| 18 |
+
graph_embeddings_dict: Optional[Dict[str, np.ndarray]] = None
|
| 19 |
+
combined_embeddings: Optional[np.ndarray] = None
|
| 20 |
+
reduced_embeddings: Optional[np.ndarray] = None
|
| 21 |
+
reduced_embeddings_graph: Optional[np.ndarray] = None
|
| 22 |
+
cluster_labels: Optional[np.ndarray] = None
|
| 23 |
+
|
backend/api/main.py
CHANGED
|
@@ -1,202 +1,216 @@
|
|
| 1 |
-
"""
|
| 2 |
-
FastAPI backend for serving model data to React/Visx frontend.
|
| 3 |
-
"""
|
| 4 |
import sys
|
| 5 |
import os
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.middleware.gzip import GZipMiddleware
|
| 13 |
from fastapi.responses import FileResponse, JSONResponse
|
| 14 |
from fastapi.exceptions import RequestValidationError
|
| 15 |
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 16 |
-
from typing import Optional, List, Dict
|
| 17 |
-
import pandas as pd
|
| 18 |
-
import numpy as np
|
| 19 |
from pydantic import BaseModel
|
| 20 |
from umap import UMAP
|
| 21 |
-
import tempfile
|
| 22 |
-
import traceback
|
| 23 |
-
import httpx
|
| 24 |
|
| 25 |
from utils.data_loader import ModelDataLoader
|
| 26 |
from utils.embeddings import ModelEmbedder
|
| 27 |
from utils.dimensionality_reduction import DimensionReducer
|
| 28 |
from utils.network_analysis import ModelNetworkBuilder
|
|
|
|
| 29 |
from services.model_tracker import get_tracker
|
| 30 |
-
from services.model_tracker_improved import get_improved_tracker
|
| 31 |
from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
@app.exception_handler(Exception)
|
| 38 |
async def global_exception_handler(request: Request, exc: Exception):
|
| 39 |
-
"
|
| 40 |
-
import traceback
|
| 41 |
-
error_detail = str(exc)
|
| 42 |
-
traceback_str = traceback.format_exc()
|
| 43 |
-
import sys
|
| 44 |
-
sys.stderr.write(f"Unhandled exception: {error_detail}\n{traceback_str}\n")
|
| 45 |
return JSONResponse(
|
| 46 |
status_code=500,
|
| 47 |
-
content={"detail":
|
| 48 |
-
headers=
|
| 49 |
-
"Access-Control-Allow-Origin": "*",
|
| 50 |
-
"Access-Control-Allow-Methods": "*",
|
| 51 |
-
"Access-Control-Allow-Headers": "*",
|
| 52 |
-
}
|
| 53 |
)
|
| 54 |
|
| 55 |
@app.exception_handler(StarletteHTTPException)
|
| 56 |
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
| 57 |
-
"""HTTP exception handler with CORS headers."""
|
| 58 |
return JSONResponse(
|
| 59 |
status_code=exc.status_code,
|
| 60 |
content={"detail": exc.detail},
|
| 61 |
-
headers=
|
| 62 |
-
"Access-Control-Allow-Origin": "*",
|
| 63 |
-
"Access-Control-Allow-Methods": "*",
|
| 64 |
-
"Access-Control-Allow-Headers": "*",
|
| 65 |
-
}
|
| 66 |
)
|
| 67 |
|
| 68 |
@app.exception_handler(RequestValidationError)
|
| 69 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 70 |
-
"""Validation exception handler with CORS headers."""
|
| 71 |
return JSONResponse(
|
| 72 |
status_code=422,
|
| 73 |
content={"detail": exc.errors()},
|
| 74 |
-
headers=
|
| 75 |
-
"Access-Control-Allow-Origin": "*",
|
| 76 |
-
"Access-Control-Allow-Methods": "*",
|
| 77 |
-
"Access-Control-Allow-Headers": "*",
|
| 78 |
-
}
|
| 79 |
)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
# Update allow_origins with your Netlify URL in production
|
| 83 |
-
# Note: Add your specific Netlify URL after deployment
|
| 84 |
-
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
| 85 |
-
# Allow all origins for development (restrict in production)
|
| 86 |
-
ALLOW_ALL_ORIGINS = os.getenv("ALLOW_ALL_ORIGINS", "true").lower() == "true"
|
| 87 |
-
if ALLOW_ALL_ORIGINS:
|
| 88 |
app.add_middleware(
|
| 89 |
CORSMiddleware,
|
| 90 |
-
allow_origins=["*"],
|
| 91 |
-
allow_credentials=False,
|
| 92 |
allow_methods=["*"],
|
| 93 |
allow_headers=["*"],
|
| 94 |
)
|
| 95 |
else:
|
| 96 |
app.add_middleware(
|
| 97 |
CORSMiddleware,
|
| 98 |
-
allow_origins=[
|
| 99 |
-
"http://localhost:3000", # Local development
|
| 100 |
-
FRONTEND_URL, # Production frontend URL
|
| 101 |
-
# Add your Netlify URL here after deployment, e.g.:
|
| 102 |
-
# "https://your-app-name.netlify.app",
|
| 103 |
-
],
|
| 104 |
allow_credentials=True,
|
| 105 |
allow_methods=["*"],
|
| 106 |
allow_headers=["*"],
|
| 107 |
)
|
| 108 |
|
| 109 |
-
data_loader = ModelDataLoader()
|
| 110 |
-
embedder: Optional[ModelEmbedder] = None
|
| 111 |
-
reducer: Optional[DimensionReducer] = None
|
| 112 |
-
df: Optional[pd.DataFrame] = None
|
| 113 |
-
embeddings: Optional[np.ndarray] = None
|
| 114 |
-
reduced_embeddings: Optional[np.ndarray] = None
|
| 115 |
-
cluster_labels: Optional[np.ndarray] = None # Cached cluster assignments
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class FilterParams(BaseModel):
|
| 119 |
-
min_downloads: int = 0
|
| 120 |
-
min_likes: int = 0
|
| 121 |
-
search_query: Optional[str] = None
|
| 122 |
-
libraries: Optional[List[str]] = None
|
| 123 |
-
pipeline_tags: Optional[List[str]] = None
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class ModelPoint(BaseModel):
|
| 127 |
-
model_id: str
|
| 128 |
-
x: float
|
| 129 |
-
y: float
|
| 130 |
-
z: float # 3D coordinate
|
| 131 |
-
library_name: Optional[str]
|
| 132 |
-
pipeline_tag: Optional[str]
|
| 133 |
-
downloads: int
|
| 134 |
-
likes: int
|
| 135 |
-
trending_score: Optional[float]
|
| 136 |
-
tags: Optional[str]
|
| 137 |
-
parent_model: Optional[str] = None
|
| 138 |
-
licenses: Optional[str] = None
|
| 139 |
-
family_depth: Optional[int] = None # Generation depth in family tree (0 = root)
|
| 140 |
-
cluster_id: Optional[int] = None # Cluster assignment for visualization
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
@app.on_event("startup")
|
| 144 |
async def startup_event():
|
| 145 |
-
|
| 146 |
-
global df, embedder, reducer, embeddings, reduced_embeddings
|
| 147 |
|
| 148 |
-
import os
|
| 149 |
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 150 |
root_dir = os.path.dirname(backend_dir)
|
| 151 |
cache_dir = os.path.join(root_dir, "cache")
|
| 152 |
os.makedirs(cache_dir, exist_ok=True)
|
| 153 |
|
| 154 |
embeddings_cache = os.path.join(cache_dir, "embeddings.pkl")
|
|
|
|
|
|
|
| 155 |
reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl")
|
|
|
|
| 156 |
reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl")
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
if
|
| 160 |
-
sample_size =
|
| 161 |
else:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
df = data_loader.
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
df.set_index('model_id', drop=False, inplace=True)
|
| 170 |
for col in ['downloads', 'likes']:
|
| 171 |
-
if col in df.columns:
|
| 172 |
-
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
|
| 173 |
|
| 174 |
-
embedder = ModelEmbedder()
|
| 175 |
|
|
|
|
| 176 |
if os.path.exists(embeddings_cache):
|
| 177 |
try:
|
| 178 |
-
embeddings = embedder.load_embeddings(embeddings_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
-
embeddings
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
embeddings = embedder.generate_embeddings(texts, batch_size=128)
|
| 185 |
-
embedder.save_embeddings(embeddings, embeddings_cache)
|
| 186 |
|
| 187 |
-
reducer
|
|
|
|
| 188 |
|
| 189 |
if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap):
|
| 190 |
try:
|
| 191 |
-
import pickle
|
| 192 |
with open(reduced_cache_umap, 'rb') as f:
|
| 193 |
-
reduced_embeddings = pickle.load(f)
|
| 194 |
-
reducer.load_reducer(reducer_cache_umap)
|
| 195 |
-
except
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
| 200 |
n_components=3,
|
| 201 |
n_neighbors=30,
|
| 202 |
min_dist=0.3,
|
|
@@ -206,61 +220,57 @@ async def startup_event():
|
|
| 206 |
low_memory=True,
|
| 207 |
spread=1.5
|
| 208 |
)
|
| 209 |
-
reduced_embeddings = reducer.fit_transform(embeddings)
|
| 210 |
-
import pickle
|
| 211 |
with open(reduced_cache_umap, 'wb') as f:
|
| 212 |
-
pickle.dump(reduced_embeddings, f)
|
| 213 |
-
reducer.save_reducer(reducer_cache_umap)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
|
| 217 |
-
"""
|
| 218 |
-
Calculate family tree depth for each model.
|
| 219 |
-
Returns a dictionary mapping model_id to depth (0 = root, 1 = first generation, etc.)
|
| 220 |
-
"""
|
| 221 |
-
depths = {}
|
| 222 |
-
visited = set()
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
if model_id in visited:
|
| 228 |
-
# Circular reference, treat as root
|
| 229 |
-
depths[model_id] = 0
|
| 230 |
-
return 0
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
parent_id_str = str(parent_id)
|
| 241 |
-
if parent_id_str in df.index:
|
| 242 |
-
depth = get_depth(parent_id_str) + 1
|
| 243 |
-
else:
|
| 244 |
-
depth = 0 # Parent not in dataset, treat as root
|
| 245 |
-
else:
|
| 246 |
-
depth = 0 # No parent, this is a root
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
|
| 259 |
def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
|
| 260 |
-
"""
|
| 261 |
-
Compute clusters using KMeans on reduced embeddings.
|
| 262 |
-
Returns cluster labels for each point.
|
| 263 |
-
"""
|
| 264 |
from sklearn.cluster import KMeans
|
| 265 |
|
| 266 |
n_samples = len(reduced_embeddings)
|
|
@@ -268,8 +278,7 @@ def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np
|
|
| 268 |
n_clusters = max(1, n_samples // 10)
|
| 269 |
|
| 270 |
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
| 271 |
-
|
| 272 |
-
return cluster_labels
|
| 273 |
|
| 274 |
|
| 275 |
@app.get("/")
|
|
@@ -284,24 +293,16 @@ async def get_models(
|
|
| 284 |
search_query: Optional[str] = Query(None),
|
| 285 |
color_by: str = Query("library_name"),
|
| 286 |
size_by: str = Query("downloads"),
|
| 287 |
-
max_points: Optional[int] = Query(None),
|
| 288 |
-
projection_method: str = Query("umap"),
|
| 289 |
-
base_models_only: bool = Query(False)
|
|
|
|
|
|
|
| 290 |
):
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
Supports multiple projection methods: UMAP or t-SNE.
|
| 294 |
-
If base_models_only=True, only returns root models (models without a parent_model).
|
| 295 |
-
|
| 296 |
-
Returns a JSON object with:
|
| 297 |
-
- models: List of ModelPoint objects
|
| 298 |
-
- filtered_count: Number of models matching filters (before max_points sampling)
|
| 299 |
-
- returned_count: Number of models actually returned (after max_points sampling)
|
| 300 |
-
"""
|
| 301 |
-
global df, embedder, reducer, embeddings, reduced_embeddings
|
| 302 |
|
| 303 |
-
|
| 304 |
-
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 305 |
|
| 306 |
# Filter data
|
| 307 |
filtered_df = data_loader.filter_data(
|
|
@@ -321,7 +322,12 @@ async def get_models(
|
|
| 321 |
(filtered_df['parent_model'].astype(str) == 'nan')
|
| 322 |
]
|
| 323 |
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
filtered_count = len(filtered_df)
|
| 326 |
|
| 327 |
if len(filtered_df) == 0:
|
|
@@ -332,42 +338,53 @@ async def get_models(
|
|
| 332 |
}
|
| 333 |
|
| 334 |
if max_points is not None and len(filtered_df) > max_points:
|
| 335 |
-
# Use stratified sampling to preserve distribution of important attributes
|
| 336 |
-
# Sample proportionally from different libraries/pipelines for better representation
|
| 337 |
if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
|
| 338 |
-
#
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
| 343 |
if len(filtered_df) > max_points:
|
| 344 |
-
filtered_df = filtered_df.sample(n=max_points, random_state=42)
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
-
filtered_df = filtered_df.sample(n=max_points, random_state=42)
|
| 347 |
-
|
| 348 |
-
if embeddings is None:
|
| 349 |
-
raise HTTPException(status_code=503, detail="Embeddings not loaded")
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 354 |
root_dir = os.path.dirname(backend_dir)
|
| 355 |
cache_dir = os.path.join(root_dir, "cache")
|
| 356 |
-
|
| 357 |
-
|
|
|
|
| 358 |
|
| 359 |
if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
|
| 360 |
try:
|
| 361 |
-
import pickle
|
| 362 |
with open(reduced_cache, 'rb') as f:
|
| 363 |
-
|
| 364 |
if reducer is None or reducer.method != projection_method.lower():
|
| 365 |
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 366 |
reducer.load_reducer(reducer_cache)
|
| 367 |
-
except
|
| 368 |
-
|
|
|
|
| 369 |
|
| 370 |
-
if
|
| 371 |
if reducer is None or reducer.method != projection_method.lower():
|
| 372 |
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 373 |
if projection_method.lower() == "umap":
|
|
@@ -381,52 +398,91 @@ async def get_models(
|
|
| 381 |
low_memory=True,
|
| 382 |
spread=1.5
|
| 383 |
)
|
| 384 |
-
|
| 385 |
-
import pickle
|
| 386 |
with open(reduced_cache, 'wb') as f:
|
| 387 |
-
pickle.dump(
|
| 388 |
reducer.save_reducer(reducer_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
-
#
|
| 391 |
-
# Map filtered dataframe indices to original dataframe integer positions
|
| 392 |
-
# Since df is indexed by model_id, we need to get the integer positions
|
| 393 |
if df.index.name == 'model_id' or 'model_id' in df.index.names:
|
| 394 |
-
#
|
| 395 |
-
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
else:
|
| 398 |
-
#
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
|
|
|
| 404 |
family_depths = calculate_family_depths(df)
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
| 409 |
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
-
# Build response with optimized vectorized operations
|
| 413 |
-
# Pre-extract arrays for faster access
|
| 414 |
model_ids = filtered_df['model_id'].astype(str).values
|
| 415 |
-
library_names = filtered_df
|
| 416 |
-
pipeline_tags = filtered_df
|
| 417 |
-
downloads_arr = filtered_df
|
| 418 |
-
likes_arr = filtered_df
|
| 419 |
-
trending_scores = filtered_df.get('trendingScore', pd.Series()).values
|
| 420 |
-
tags_arr = filtered_df.get('tags', pd.Series()).values
|
| 421 |
-
parent_models = filtered_df.get('parent_model', pd.Series()).values
|
| 422 |
-
licenses_arr = filtered_df.get('licenses', pd.Series()).values
|
| 423 |
-
|
| 424 |
-
|
| 425 |
x_coords = filtered_reduced[:, 0].astype(float)
|
| 426 |
y_coords = filtered_reduced[:, 1].astype(float)
|
| 427 |
z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
|
| 428 |
-
|
| 429 |
-
# Build models list with optimized operations
|
| 430 |
models = [
|
| 431 |
ModelPoint(
|
| 432 |
model_id=model_ids[idx],
|
|
@@ -442,28 +498,42 @@ async def get_models(
|
|
| 442 |
parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
|
| 443 |
licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
|
| 444 |
family_depth=family_depths.get(model_ids[idx], None),
|
| 445 |
-
cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None
|
|
|
|
| 446 |
)
|
| 447 |
for idx in range(len(filtered_df))
|
| 448 |
]
|
| 449 |
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
|
| 453 |
@app.get("/api/stats")
|
| 454 |
async def get_stats():
|
| 455 |
"""Get dataset statistics."""
|
| 456 |
if df is None:
|
| 457 |
-
raise
|
| 458 |
|
| 459 |
-
# Use len(df.index) to handle both regular and indexed DataFrames correctly
|
| 460 |
total_models = len(df.index) if hasattr(df, 'index') else len(df)
|
| 461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
return {
|
| 463 |
"total_models": total_models,
|
| 464 |
"unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
|
| 465 |
"unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
|
| 466 |
"unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity
|
|
|
|
|
|
|
| 467 |
"avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
|
| 468 |
"avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
|
| 469 |
}
|
|
@@ -473,7 +543,7 @@ async def get_stats():
|
|
| 473 |
async def get_model_details(model_id: str):
|
| 474 |
"""Get detailed information about a specific model."""
|
| 475 |
if df is None:
|
| 476 |
-
raise
|
| 477 |
|
| 478 |
model = df[df.get('model_id', '') == model_id]
|
| 479 |
if len(model) == 0:
|
|
@@ -481,11 +551,9 @@ async def get_model_details(model_id: str):
|
|
| 481 |
|
| 482 |
model = model.iloc[0]
|
| 483 |
|
| 484 |
-
# Extract arXiv IDs from tags
|
| 485 |
tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else ''
|
| 486 |
arxiv_ids = extract_arxiv_ids(tags_str)
|
| 487 |
|
| 488 |
-
# Fetch arXiv papers if any IDs found
|
| 489 |
papers = []
|
| 490 |
if arxiv_ids:
|
| 491 |
papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers
|
|
@@ -505,6 +573,8 @@ async def get_model_details(model_id: str):
|
|
| 505 |
}
|
| 506 |
|
| 507 |
|
|
|
|
|
|
|
| 508 |
@app.get("/api/family/stats")
|
| 509 |
async def get_family_stats():
|
| 510 |
"""
|
|
@@ -512,9 +582,8 @@ async def get_family_stats():
|
|
| 512 |
Returns family size distribution, depth statistics, model card length by depth, etc.
|
| 513 |
"""
|
| 514 |
if df is None:
|
| 515 |
-
raise
|
| 516 |
|
| 517 |
-
# Calculate family sizes
|
| 518 |
family_sizes = {}
|
| 519 |
root_models = set()
|
| 520 |
|
|
@@ -528,14 +597,13 @@ async def get_family_stats():
|
|
| 528 |
family_sizes[model_id] = 0
|
| 529 |
else:
|
| 530 |
parent_id_str = str(parent_id)
|
| 531 |
-
# Find root of this family
|
| 532 |
root = parent_id_str
|
| 533 |
visited = set()
|
| 534 |
while root in df.index and pd.notna(df.loc[root].get('parent_model')):
|
| 535 |
parent = df.loc[root].get('parent_model')
|
| 536 |
if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '':
|
| 537 |
break
|
| 538 |
-
if str(parent) in visited:
|
| 539 |
break
|
| 540 |
visited.add(root)
|
| 541 |
root = str(parent)
|
|
@@ -544,18 +612,15 @@ async def get_family_stats():
|
|
| 544 |
family_sizes[root] = 0
|
| 545 |
family_sizes[root] += 1
|
| 546 |
|
| 547 |
-
# Count family sizes
|
| 548 |
size_distribution = {}
|
| 549 |
for root, size in family_sizes.items():
|
| 550 |
size_distribution[size] = size_distribution.get(size, 0) + 1
|
| 551 |
|
| 552 |
-
# Calculate depth statistics
|
| 553 |
depths = calculate_family_depths(df)
|
| 554 |
depth_counts = {}
|
| 555 |
for depth in depths.values():
|
| 556 |
depth_counts[depth] = depth_counts.get(depth, 0) + 1
|
| 557 |
|
| 558 |
-
# Calculate model card length by depth
|
| 559 |
model_card_lengths_by_depth = {}
|
| 560 |
if 'modelCard' in df.columns:
|
| 561 |
for idx, row in df.iterrows():
|
|
@@ -568,7 +633,6 @@ async def get_family_stats():
|
|
| 568 |
model_card_lengths_by_depth[depth] = []
|
| 569 |
model_card_lengths_by_depth[depth].append(card_length)
|
| 570 |
|
| 571 |
-
# Calculate statistics for each depth
|
| 572 |
model_card_stats = {}
|
| 573 |
for depth, lengths in model_card_lengths_by_depth.items():
|
| 574 |
if lengths:
|
|
@@ -593,99 +657,218 @@ async def get_family_stats():
|
|
| 593 |
}
|
| 594 |
|
| 595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
@app.get("/api/family/{model_id}")
|
| 597 |
-
async def get_family_tree(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
"""
|
| 599 |
Get family tree for a model (ancestors and descendants).
|
| 600 |
Returns the model, its parent chain, and all children.
|
|
|
|
|
|
|
| 601 |
"""
|
| 602 |
if df is None:
|
| 603 |
-
raise
|
| 604 |
-
|
| 605 |
-
# Find the model
|
| 606 |
-
model_row = df[df.get('model_id', '') == model_id]
|
| 607 |
-
if len(model_row) == 0:
|
| 608 |
-
raise HTTPException(status_code=404, detail="Model not found")
|
| 609 |
-
|
| 610 |
-
family_models = []
|
| 611 |
-
visited = set()
|
| 612 |
|
| 613 |
-
# Get coordinates for family members
|
| 614 |
if reduced_embeddings is None:
|
| 615 |
raise HTTPException(status_code=503, detail="Embeddings not ready")
|
| 616 |
|
| 617 |
-
|
| 618 |
-
if 'parent_model' not in df.index.names and 'parent_model' in df.columns:
|
| 619 |
-
# Create a reverse index for faster parent lookups
|
| 620 |
-
parent_index = df[df['parent_model'].notna()].set_index('parent_model', drop=False, append=True)
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
return
|
| 626 |
visited.add(current_id)
|
| 627 |
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
|
|
|
| 644 |
|
| 645 |
-
def get_descendants(current_id: str, depth: int):
|
| 646 |
-
|
| 647 |
-
|
|
|
|
| 648 |
return
|
| 649 |
visited.add(current_id)
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
# Get descendants (children)
|
| 664 |
-
visited = set() # Reset for descendants
|
| 665 |
-
get_descendants(model_id, max_depth)
|
| 666 |
-
|
| 667 |
-
# Add the root model
|
| 668 |
-
visited.add(model_id)
|
| 669 |
-
|
| 670 |
-
# Get all family members with coordinates - optimized
|
| 671 |
-
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 672 |
-
# Use index lookup if available
|
| 673 |
try:
|
| 674 |
family_df = df.loc[list(visited)]
|
| 675 |
except KeyError:
|
| 676 |
-
|
| 677 |
-
|
|
|
|
|
|
|
| 678 |
else:
|
| 679 |
family_df = df[df.get('model_id', '').isin(visited)]
|
| 680 |
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
family_reduced = reduced_embeddings[family_indices]
|
| 683 |
|
| 684 |
-
# Build family tree structure - optimized with vectorized operations
|
| 685 |
family_map = {}
|
| 686 |
for idx, (i, row) in enumerate(family_df.iterrows()):
|
| 687 |
-
model_id_val = str(row.get('model_id',
|
| 688 |
-
parent_id = row.get('parent_model')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
family_map[model_id_val] = {
|
| 691 |
"model_id": model_id_val,
|
|
@@ -696,12 +879,12 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
|
|
| 696 |
"pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None,
|
| 697 |
"downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0,
|
| 698 |
"likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0,
|
| 699 |
-
"parent_model":
|
| 700 |
"licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None,
|
|
|
|
| 701 |
"children": []
|
| 702 |
}
|
| 703 |
|
| 704 |
-
# Build tree structure
|
| 705 |
root_models = []
|
| 706 |
for model_id_val, model_data in family_map.items():
|
| 707 |
parent_id = model_data["parent_model"]
|
|
@@ -711,7 +894,7 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
|
|
| 711 |
root_models.append(model_id_val)
|
| 712 |
|
| 713 |
return {
|
| 714 |
-
"root_model":
|
| 715 |
"family": list(family_map.values()),
|
| 716 |
"family_map": family_map,
|
| 717 |
"root_models": root_models
|
|
@@ -720,7 +903,9 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
|
|
| 720 |
|
| 721 |
@app.get("/api/search")
|
| 722 |
async def search_models(
|
| 723 |
-
|
|
|
|
|
|
|
| 724 |
graph_aware: bool = Query(False),
|
| 725 |
include_neighbors: bool = Query(True)
|
| 726 |
):
|
|
@@ -729,47 +914,79 @@ async def search_models(
|
|
| 729 |
Enhanced with graph-aware search option that includes network relationships.
|
| 730 |
"""
|
| 731 |
if df is None:
|
| 732 |
-
raise
|
|
|
|
|
|
|
|
|
|
| 733 |
|
| 734 |
if graph_aware:
|
| 735 |
-
# Use graph-aware search
|
| 736 |
try:
|
| 737 |
network_builder = ModelNetworkBuilder(df)
|
| 738 |
-
# Build network for top models (for performance)
|
| 739 |
top_models = network_builder.get_top_models_by_field(n=1000)
|
| 740 |
model_ids = [mid for mid, _ in top_models]
|
| 741 |
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
|
| 742 |
|
| 743 |
results = network_builder.search_graph_aware(
|
| 744 |
-
query=
|
| 745 |
graph=graph,
|
| 746 |
-
max_results=
|
| 747 |
include_neighbors=include_neighbors
|
| 748 |
)
|
| 749 |
|
| 750 |
-
return {"results": results, "search_type": "graph_aware"}
|
| 751 |
-
except
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
|
| 754 |
-
|
| 755 |
-
matches = df[
|
| 756 |
-
df.get('model_id', '').astype(str).str.lower().str.contains(query_lower, na=False)
|
| 757 |
-
].head(20) # Limit to 20 results
|
| 758 |
|
| 759 |
results = []
|
| 760 |
for _, row in matches.iterrows():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
results.append({
|
| 762 |
-
"model_id":
|
| 763 |
-
"
|
| 764 |
-
"
|
| 765 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
"downloads": int(row.get('downloads', 0)),
|
| 767 |
"likes": int(row.get('likes', 0)),
|
| 768 |
"parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None,
|
| 769 |
"match_type": "direct"
|
| 770 |
})
|
| 771 |
|
| 772 |
-
return {"results": results, "search_type": "basic"}
|
| 773 |
|
| 774 |
|
| 775 |
@app.get("/api/similar/{model_id}")
|
|
@@ -778,12 +995,12 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
|
|
| 778 |
Get k-nearest neighbors of a model based on embedding similarity.
|
| 779 |
Returns similar models with distance scores.
|
| 780 |
"""
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
if df is None or embeddings is None:
|
| 784 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 785 |
|
| 786 |
-
|
|
|
|
|
|
|
| 787 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 788 |
try:
|
| 789 |
model_row = df.loc[[model_id]]
|
|
@@ -797,16 +1014,11 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
|
|
| 797 |
model_idx = model_row.index[0]
|
| 798 |
model_embedding = embeddings[model_idx]
|
| 799 |
|
| 800 |
-
# Calculate cosine similarity to all other models - optimized
|
| 801 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 802 |
-
# Use vectorized operations for better performance
|
| 803 |
model_embedding_2d = model_embedding.reshape(1, -1)
|
| 804 |
similarities = cosine_similarity(model_embedding_2d, embeddings)[0]
|
| 805 |
|
| 806 |
-
# Get top k similar models (excluding itself) - use argpartition for speed
|
| 807 |
-
# argpartition is faster than full sort for top-k
|
| 808 |
top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1]
|
| 809 |
-
# Sort only the top k (much faster than sorting all)
|
| 810 |
top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
|
| 811 |
|
| 812 |
similar_models = []
|
|
@@ -817,7 +1029,7 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
|
|
| 817 |
similar_models.append({
|
| 818 |
"model_id": row.get('model_id', 'Unknown'),
|
| 819 |
"similarity": float(similarities[idx]),
|
| 820 |
-
"distance": float(1 - similarities[idx]),
|
| 821 |
"library_name": row.get('library_name'),
|
| 822 |
"pipeline_tag": row.get('pipeline_tag'),
|
| 823 |
"downloads": int(row.get('downloads', 0)),
|
|
@@ -843,11 +1055,12 @@ async def get_models_by_semantic_similarity(
|
|
| 843 |
Returns models with their similarity scores and coordinates.
|
| 844 |
Useful for exploring the embedding space around a specific model.
|
| 845 |
"""
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
if df is None or embeddings is None:
|
| 849 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 850 |
|
|
|
|
|
|
|
|
|
|
| 851 |
# Find the query model
|
| 852 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 853 |
try:
|
|
@@ -863,7 +1076,6 @@ async def get_models_by_semantic_similarity(
|
|
| 863 |
|
| 864 |
query_embedding = embeddings[model_idx]
|
| 865 |
|
| 866 |
-
# Filter by downloads/likes first for performance
|
| 867 |
filtered_df = data_loader.filter_data(
|
| 868 |
df=df,
|
| 869 |
min_downloads=min_downloads,
|
|
@@ -873,32 +1085,26 @@ async def get_models_by_semantic_similarity(
|
|
| 873 |
pipeline_tags=None
|
| 874 |
)
|
| 875 |
|
| 876 |
-
# Get indices of filtered models
|
| 877 |
if df.index.name == 'model_id' or 'model_id' in df.index.names:
|
| 878 |
filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index]
|
| 879 |
filtered_indices = np.array(filtered_indices, dtype=int)
|
| 880 |
else:
|
| 881 |
filtered_indices = filtered_df.index.values.astype(int)
|
| 882 |
|
| 883 |
-
# Calculate similarities only for filtered models
|
| 884 |
filtered_embeddings = embeddings[filtered_indices]
|
| 885 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 886 |
query_embedding_2d = query_embedding.reshape(1, -1)
|
| 887 |
similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0]
|
| 888 |
|
| 889 |
-
# Get top k similar models
|
| 890 |
top_k_local_indices = np.argpartition(similarities, -k)[-k:]
|
| 891 |
top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1]
|
| 892 |
|
| 893 |
-
# Get reduced embeddings for visualization
|
| 894 |
if reduced_embeddings is None:
|
| 895 |
raise HTTPException(status_code=503, detail="Reduced embeddings not ready")
|
| 896 |
|
| 897 |
-
# Map back to original indices
|
| 898 |
top_k_original_indices = filtered_indices[top_k_local_indices]
|
| 899 |
top_k_reduced = reduced_embeddings[top_k_original_indices]
|
| 900 |
|
| 901 |
-
# Build response
|
| 902 |
similar_models = []
|
| 903 |
for i, orig_idx in enumerate(top_k_original_indices):
|
| 904 |
row = df.iloc[orig_idx]
|
|
@@ -935,11 +1141,12 @@ async def get_distance(
|
|
| 935 |
"""
|
| 936 |
Calculate distance/similarity between two models.
|
| 937 |
"""
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
if df is None or embeddings is None:
|
| 941 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 942 |
|
|
|
|
|
|
|
|
|
|
| 943 |
# Find both models - optimized with index lookup
|
| 944 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 945 |
try:
|
|
@@ -976,7 +1183,7 @@ async def export_models(model_ids: List[str]):
|
|
| 976 |
Export selected models as JSON with full metadata.
|
| 977 |
"""
|
| 978 |
if df is None:
|
| 979 |
-
raise
|
| 980 |
|
| 981 |
# Optimized export with index lookup
|
| 982 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
|
@@ -991,7 +1198,6 @@ async def export_models(model_ids: List[str]):
|
|
| 991 |
if len(exported) == 0:
|
| 992 |
return {"models": []}
|
| 993 |
|
| 994 |
-
# Use list comprehension for faster building
|
| 995 |
models = [
|
| 996 |
{
|
| 997 |
"model_id": str(row.get('model_id', '')),
|
|
@@ -1029,12 +1235,10 @@ async def get_cooccurrence_network(
|
|
| 1029 |
Returns network graph data suitable for visualization.
|
| 1030 |
"""
|
| 1031 |
if df is None:
|
| 1032 |
-
raise
|
| 1033 |
|
| 1034 |
try:
|
| 1035 |
network_builder = ModelNetworkBuilder(df)
|
| 1036 |
-
|
| 1037 |
-
# Get top models by field
|
| 1038 |
top_models = network_builder.get_top_models_by_field(
|
| 1039 |
library=library,
|
| 1040 |
pipeline_tag=pipeline_tag,
|
|
@@ -1051,14 +1255,11 @@ async def get_cooccurrence_network(
|
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
model_ids = [mid for mid, _ in top_models]
|
| 1054 |
-
|
| 1055 |
-
# Build co-occurrence network
|
| 1056 |
graph = network_builder.build_cooccurrence_network(
|
| 1057 |
model_ids=model_ids,
|
| 1058 |
cooccurrence_method=cooccurrence_method
|
| 1059 |
)
|
| 1060 |
|
| 1061 |
-
# Convert to JSON-serializable format
|
| 1062 |
nodes = []
|
| 1063 |
for node_id, attrs in graph.nodes(data=True):
|
| 1064 |
nodes.append({
|
|
@@ -1086,45 +1287,70 @@ async def get_cooccurrence_network(
|
|
| 1086 |
"links": links,
|
| 1087 |
"statistics": stats
|
| 1088 |
}
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}")
|
| 1092 |
|
| 1093 |
|
| 1094 |
@app.get("/api/network/family/{model_id}")
|
| 1095 |
async def get_family_network(
|
| 1096 |
model_id: str,
|
| 1097 |
-
max_depth: int = Query(
|
|
|
|
|
|
|
| 1098 |
):
|
| 1099 |
"""
|
| 1100 |
Build family tree network for a model (directed graph).
|
| 1101 |
-
Returns network graph data showing parent-child relationships.
|
|
|
|
| 1102 |
"""
|
| 1103 |
if df is None:
|
| 1104 |
-
raise
|
| 1105 |
|
| 1106 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
network_builder = ModelNetworkBuilder(df)
|
| 1108 |
graph = network_builder.build_family_tree_network(
|
| 1109 |
root_model_id=model_id,
|
| 1110 |
-
max_depth=max_depth
|
|
|
|
|
|
|
| 1111 |
)
|
| 1112 |
|
| 1113 |
-
# Convert to JSON-serializable format
|
| 1114 |
nodes = []
|
| 1115 |
for node_id, attrs in graph.nodes(data=True):
|
| 1116 |
nodes.append({
|
| 1117 |
"id": node_id,
|
| 1118 |
"title": attrs.get('title', node_id),
|
| 1119 |
-
"freq": attrs.get('freq', 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
})
|
| 1121 |
|
| 1122 |
links = []
|
| 1123 |
-
for source, target in graph.edges():
|
| 1124 |
-
|
| 1125 |
"source": source,
|
| 1126 |
-
"target": target
|
| 1127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1128 |
|
| 1129 |
stats = network_builder.get_network_statistics(graph)
|
| 1130 |
|
|
@@ -1134,8 +1360,8 @@ async def get_family_network(
|
|
| 1134 |
"statistics": stats,
|
| 1135 |
"root_model": model_id
|
| 1136 |
}
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}")
|
| 1140 |
|
| 1141 |
|
|
@@ -1150,11 +1376,10 @@ async def get_model_neighbors(
|
|
| 1150 |
Similar to graph database queries for finding connected nodes.
|
| 1151 |
"""
|
| 1152 |
if df is None:
|
| 1153 |
-
raise
|
| 1154 |
|
| 1155 |
try:
|
| 1156 |
network_builder = ModelNetworkBuilder(df)
|
| 1157 |
-
# Build network for top models (for performance)
|
| 1158 |
top_models = network_builder.get_top_models_by_field(n=1000)
|
| 1159 |
model_ids = [mid for mid, _ in top_models]
|
| 1160 |
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
|
|
@@ -1171,8 +1396,8 @@ async def get_model_neighbors(
|
|
| 1171 |
"neighbors": neighbors,
|
| 1172 |
"count": len(neighbors)
|
| 1173 |
}
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}")
|
| 1177 |
|
| 1178 |
|
|
@@ -1187,7 +1412,7 @@ async def find_path_between_models(
|
|
| 1187 |
Similar to graph database path queries.
|
| 1188 |
"""
|
| 1189 |
if df is None:
|
| 1190 |
-
raise
|
| 1191 |
|
| 1192 |
try:
|
| 1193 |
network_builder = ModelNetworkBuilder(df)
|
|
@@ -1235,7 +1460,7 @@ async def search_by_cooccurrence(
|
|
| 1235 |
Similar to graph database queries for co-assignment patterns.
|
| 1236 |
"""
|
| 1237 |
if df is None:
|
| 1238 |
-
raise
|
| 1239 |
|
| 1240 |
try:
|
| 1241 |
network_builder = ModelNetworkBuilder(df)
|
|
@@ -1272,7 +1497,7 @@ async def get_model_relationships(
|
|
| 1272 |
Similar to graph database relationship queries.
|
| 1273 |
"""
|
| 1274 |
if df is None:
|
| 1275 |
-
raise
|
| 1276 |
|
| 1277 |
try:
|
| 1278 |
network_builder = ModelNetworkBuilder(df)
|
|
@@ -1297,32 +1522,57 @@ async def get_model_relationships(
|
|
| 1297 |
async def get_current_model_count(
|
| 1298 |
use_cache: bool = Query(True),
|
| 1299 |
force_refresh: bool = Query(False),
|
| 1300 |
-
use_dataset_snapshot: bool = Query(False)
|
|
|
|
| 1301 |
):
|
| 1302 |
"""
|
| 1303 |
Get the current number of models on Hugging Face Hub.
|
| 1304 |
-
|
| 1305 |
|
| 1306 |
Query Parameters:
|
| 1307 |
use_cache: Use cached results if available (default: True)
|
| 1308 |
force_refresh: Force refresh even if cache is valid (default: False)
|
| 1309 |
-
use_dataset_snapshot: Use dataset snapshot
|
|
|
|
| 1310 |
"""
|
| 1311 |
try:
|
|
|
|
|
|
|
| 1312 |
if use_dataset_snapshot:
|
| 1313 |
-
|
| 1314 |
-
tracker = get_improved_tracker()
|
| 1315 |
-
count_data = tracker.get_count_from_dataset_snapshot()
|
| 1316 |
if count_data is None:
|
| 1317 |
-
|
| 1318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1319 |
else:
|
| 1320 |
-
|
| 1321 |
-
tracker = get_improved_tracker()
|
| 1322 |
-
count_data = tracker.get_current_model_count(use_cache=use_cache, force_refresh=force_refresh)
|
| 1323 |
|
| 1324 |
return count_data
|
| 1325 |
except Exception as e:
|
|
|
|
| 1326 |
raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}")
|
| 1327 |
|
| 1328 |
|
|
@@ -1343,7 +1593,7 @@ async def get_historical_model_counts(
|
|
| 1343 |
try:
|
| 1344 |
from datetime import datetime
|
| 1345 |
|
| 1346 |
-
tracker =
|
| 1347 |
|
| 1348 |
start = None
|
| 1349 |
end = None
|
|
@@ -1373,7 +1623,7 @@ async def get_historical_model_counts(
|
|
| 1373 |
async def get_latest_model_count():
|
| 1374 |
"""Get the most recently recorded model count from database."""
|
| 1375 |
try:
|
| 1376 |
-
tracker =
|
| 1377 |
latest = tracker.get_latest_count()
|
| 1378 |
if latest is None:
|
| 1379 |
raise HTTPException(status_code=404, detail="No model counts recorded yet")
|
|
@@ -1397,16 +1647,14 @@ async def record_model_count(
|
|
| 1397 |
use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
|
| 1398 |
"""
|
| 1399 |
try:
|
| 1400 |
-
tracker =
|
| 1401 |
|
| 1402 |
-
# Fetch and record in background to avoid blocking
|
| 1403 |
def record():
|
| 1404 |
if use_dataset_snapshot:
|
| 1405 |
count_data = tracker.get_count_from_dataset_snapshot()
|
| 1406 |
if count_data:
|
| 1407 |
tracker.record_count(count_data, source="dataset_snapshot")
|
| 1408 |
else:
|
| 1409 |
-
# Fallback to API
|
| 1410 |
count_data = tracker.get_current_model_count(use_cache=False)
|
| 1411 |
tracker.record_count(count_data, source="api")
|
| 1412 |
else:
|
|
@@ -1433,7 +1681,7 @@ async def get_growth_stats(days: int = Query(7, ge=1, le=365)):
|
|
| 1433 |
days: Number of days to analyze
|
| 1434 |
"""
|
| 1435 |
try:
|
| 1436 |
-
tracker =
|
| 1437 |
stats = tracker.get_growth_stats(days)
|
| 1438 |
return stats
|
| 1439 |
except Exception as e:
|
|
@@ -1455,12 +1703,11 @@ async def export_network_graphml(
|
|
| 1455 |
Similar to Open Syllabus graph export functionality.
|
| 1456 |
"""
|
| 1457 |
if df is None:
|
| 1458 |
-
raise
|
| 1459 |
|
| 1460 |
try:
|
| 1461 |
network_builder = ModelNetworkBuilder(df)
|
| 1462 |
|
| 1463 |
-
# Get top models by field
|
| 1464 |
top_models = network_builder.get_top_models_by_field(
|
| 1465 |
library=library,
|
| 1466 |
pipeline_tag=pipeline_tag,
|
|
@@ -1473,29 +1720,24 @@ async def export_network_graphml(
|
|
| 1473 |
raise HTTPException(status_code=404, detail="No models found matching criteria")
|
| 1474 |
|
| 1475 |
model_ids = [mid for mid, _ in top_models]
|
| 1476 |
-
|
| 1477 |
-
# Build co-occurrence network
|
| 1478 |
graph = network_builder.build_cooccurrence_network(
|
| 1479 |
model_ids=model_ids,
|
| 1480 |
cooccurrence_method=cooccurrence_method
|
| 1481 |
)
|
| 1482 |
|
| 1483 |
-
# Create temporary file
|
| 1484 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file:
|
| 1485 |
tmp_path = tmp_file.name
|
| 1486 |
network_builder.export_graphml(graph, tmp_path)
|
| 1487 |
|
| 1488 |
-
# Schedule cleanup after response is sent
|
| 1489 |
background_tasks.add_task(os.unlink, tmp_path)
|
| 1490 |
|
| 1491 |
-
# Return file for download
|
| 1492 |
return FileResponse(
|
| 1493 |
tmp_path,
|
| 1494 |
media_type='application/xml',
|
| 1495 |
filename=f'network_{cooccurrence_method}_{n}_models.graphml'
|
| 1496 |
)
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}")
|
| 1500 |
|
| 1501 |
|
|
@@ -1506,7 +1748,7 @@ async def get_model_papers(model_id: str):
|
|
| 1506 |
Extracts arXiv IDs from model tags and fetches paper information.
|
| 1507 |
"""
|
| 1508 |
if df is None:
|
| 1509 |
-
raise
|
| 1510 |
|
| 1511 |
model = df[df.get('model_id', '') == model_id]
|
| 1512 |
if len(model) == 0:
|
|
@@ -1535,36 +1777,131 @@ async def get_model_papers(model_id: str):
|
|
| 1535 |
}
|
| 1536 |
|
| 1537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1538 |
@app.get("/api/model/{model_id}/files")
|
| 1539 |
async def get_model_files(model_id: str, branch: str = Query("main")):
|
| 1540 |
"""
|
| 1541 |
Get file tree for a model from Hugging Face.
|
| 1542 |
Proxies the request to avoid CORS issues.
|
|
|
|
| 1543 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1544 |
try:
|
| 1545 |
-
|
| 1546 |
-
branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
|
| 1547 |
-
|
| 1548 |
-
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 1549 |
for branch_name in branches_to_try:
|
| 1550 |
try:
|
| 1551 |
url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}"
|
| 1552 |
response = await client.get(url)
|
|
|
|
| 1553 |
if response.status_code == 200:
|
| 1554 |
-
|
| 1555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1556 |
continue
|
| 1557 |
|
| 1558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1559 |
except httpx.TimeoutException:
|
| 1560 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1561 |
except Exception as e:
|
| 1562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1563 |
|
| 1564 |
|
| 1565 |
if __name__ == "__main__":
|
| 1566 |
import uvicorn
|
| 1567 |
-
# Use PORT environment variable for cloud platforms (Railway, Render, Heroku)
|
| 1568 |
port = int(os.getenv("PORT", 8000))
|
| 1569 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 1570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import tempfile
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional, List, Dict
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import httpx
|
| 12 |
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from fastapi.middleware.gzip import GZipMiddleware
|
| 15 |
from fastapi.responses import FileResponse, JSONResponse
|
| 16 |
from fastapi.exceptions import RequestValidationError
|
| 17 |
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
|
|
|
|
|
|
| 18 |
from pydantic import BaseModel
|
| 19 |
from umap import UMAP
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from utils.data_loader import ModelDataLoader
|
| 22 |
from utils.embeddings import ModelEmbedder
|
| 23 |
from utils.dimensionality_reduction import DimensionReducer
|
| 24 |
from utils.network_analysis import ModelNetworkBuilder
|
| 25 |
+
from utils.graph_embeddings import GraphEmbedder
|
| 26 |
from services.model_tracker import get_tracker
|
|
|
|
| 27 |
from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers
|
| 28 |
+
from core.config import settings
|
| 29 |
+
from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
|
| 30 |
+
from models.schemas import ModelPoint
|
| 31 |
+
from utils.family_tree import calculate_family_depths
|
| 32 |
+
import api.dependencies as deps
|
| 33 |
+
from api.routes import models, stats, clusters
|
| 34 |
+
|
| 35 |
+
# Create aliases for backward compatibility with existing routes
|
| 36 |
+
# Note: These are set at module load time and may be None initially
|
| 37 |
+
# Functions should access via deps.* to get current values
|
| 38 |
+
data_loader = deps.data_loader
|
| 39 |
+
|
| 40 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 41 |
+
if backend_dir not in sys.path:
|
| 42 |
+
sys.path.insert(0, backend_dir)
|
| 43 |
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
app = FastAPI(title="HF Model Ecosystem API", version="2.0.0")
|
| 47 |
|
| 48 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 49 |
|
| 50 |
+
CORS_HEADERS = {
|
| 51 |
+
"Access-Control-Allow-Origin": "*",
|
| 52 |
+
"Access-Control-Allow-Methods": "*",
|
| 53 |
+
"Access-Control-Allow-Headers": "*",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
@app.exception_handler(Exception)
|
| 57 |
async def global_exception_handler(request: Request, exc: Exception):
|
| 58 |
+
logger.exception("Unhandled exception", exc_info=exc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return JSONResponse(
|
| 60 |
status_code=500,
|
| 61 |
+
content={"detail": "Internal server error"},
|
| 62 |
+
headers=CORS_HEADERS,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
@app.exception_handler(StarletteHTTPException)
|
| 66 |
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
|
|
|
| 67 |
return JSONResponse(
|
| 68 |
status_code=exc.status_code,
|
| 69 |
content={"detail": exc.detail},
|
| 70 |
+
headers=CORS_HEADERS,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
)
|
| 72 |
|
| 73 |
@app.exception_handler(RequestValidationError)
|
| 74 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
|
|
| 75 |
return JSONResponse(
|
| 76 |
status_code=422,
|
| 77 |
content={"detail": exc.errors()},
|
| 78 |
+
headers=CORS_HEADERS,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
+
if settings.ALLOW_ALL_ORIGINS:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
app.add_middleware(
|
| 83 |
CORSMiddleware,
|
| 84 |
+
allow_origins=["*"],
|
| 85 |
+
allow_credentials=False,
|
| 86 |
allow_methods=["*"],
|
| 87 |
allow_headers=["*"],
|
| 88 |
)
|
| 89 |
else:
|
| 90 |
app.add_middleware(
|
| 91 |
CORSMiddleware,
|
| 92 |
+
allow_origins=["http://localhost:3000", settings.FRONTEND_URL],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
allow_credentials=True,
|
| 94 |
allow_methods=["*"],
|
| 95 |
allow_headers=["*"],
|
| 96 |
)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
# Include routers
|
| 100 |
+
app.include_router(models.router)
|
| 101 |
+
app.include_router(stats.router)
|
| 102 |
+
app.include_router(clusters.router)
|
| 103 |
|
| 104 |
@app.on_event("startup")
|
| 105 |
async def startup_event():
|
| 106 |
+
# All variables are accessed via deps module, no need for global declarations
|
|
|
|
| 107 |
|
|
|
|
| 108 |
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 109 |
root_dir = os.path.dirname(backend_dir)
|
| 110 |
cache_dir = os.path.join(root_dir, "cache")
|
| 111 |
os.makedirs(cache_dir, exist_ok=True)
|
| 112 |
|
| 113 |
embeddings_cache = os.path.join(cache_dir, "embeddings.pkl")
|
| 114 |
+
graph_embeddings_cache = os.path.join(cache_dir, "graph_embeddings.pkl")
|
| 115 |
+
combined_embeddings_cache = os.path.join(cache_dir, "combined_embeddings.pkl")
|
| 116 |
reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl")
|
| 117 |
+
reduced_cache_umap_graph = os.path.join(cache_dir, "reduced_umap_3d_graph.pkl")
|
| 118 |
reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl")
|
| 119 |
+
reducer_cache_umap_graph = os.path.join(cache_dir, "reducer_umap_3d_graph.pkl")
|
| 120 |
|
| 121 |
+
sample_size = settings.get_sample_size()
|
| 122 |
+
if sample_size:
|
| 123 |
+
logger.info(f"Loading limited dataset: {sample_size} models (SAMPLE_SIZE={sample_size})")
|
| 124 |
else:
|
| 125 |
+
logger.info("No SAMPLE_SIZE set, loading full dataset")
|
| 126 |
+
|
| 127 |
+
deps.df = deps.data_loader.load_data(sample_size=sample_size)
|
| 128 |
+
deps.df = deps.data_loader.preprocess_for_embedding(deps.df)
|
| 129 |
+
|
| 130 |
+
if 'model_id' in deps.df.columns:
|
| 131 |
+
deps.df.set_index('model_id', drop=False, inplace=True)
|
|
|
|
| 132 |
for col in ['downloads', 'likes']:
|
| 133 |
+
if col in deps.df.columns:
|
| 134 |
+
deps.df[col] = pd.to_numeric(deps.df[col], errors='coerce').fillna(0).astype(int)
|
| 135 |
|
| 136 |
+
deps.embedder = ModelEmbedder()
|
| 137 |
|
| 138 |
+
# Load or generate text embeddings
|
| 139 |
if os.path.exists(embeddings_cache):
|
| 140 |
try:
|
| 141 |
+
deps.embeddings = deps.embedder.load_embeddings(embeddings_cache)
|
| 142 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 143 |
+
logger.warning(f"Failed to load cached embeddings: {e}")
|
| 144 |
+
deps.embeddings = None
|
| 145 |
+
|
| 146 |
+
if deps.embeddings is None:
|
| 147 |
+
texts = deps.df['combined_text'].tolist()
|
| 148 |
+
deps.embeddings = deps.embedder.generate_embeddings(texts, batch_size=128)
|
| 149 |
+
deps.embedder.save_embeddings(deps.embeddings, embeddings_cache)
|
| 150 |
+
|
| 151 |
+
# Initialize graph embedder and generate graph embeddings (optional, lazy-loaded)
|
| 152 |
+
if settings.USE_GRAPH_EMBEDDINGS:
|
| 153 |
+
try:
|
| 154 |
+
deps.graph_embedder = GraphEmbedder()
|
| 155 |
+
logger.info("Building family graph for graph embeddings...")
|
| 156 |
+
graph = deps.graph_embedder.build_family_graph(deps.df)
|
| 157 |
+
|
| 158 |
+
if os.path.exists(graph_embeddings_cache):
|
| 159 |
+
try:
|
| 160 |
+
deps.graph_embeddings_dict = deps.graph_embedder.load_embeddings(graph_embeddings_cache)
|
| 161 |
+
logger.info(f"Loaded cached graph embeddings for {len(deps.graph_embeddings_dict)} models")
|
| 162 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 163 |
+
logger.warning(f"Failed to load cached graph embeddings: {e}")
|
| 164 |
+
deps.graph_embeddings_dict = None
|
| 165 |
+
|
| 166 |
+
if deps.graph_embeddings_dict is None or len(deps.graph_embeddings_dict) == 0:
|
| 167 |
+
logger.info("Generating graph embeddings (this may take a while)...")
|
| 168 |
+
deps.graph_embeddings_dict = deps.graph_embedder.generate_graph_embeddings(graph, workers=4)
|
| 169 |
+
if deps.graph_embeddings_dict:
|
| 170 |
+
deps.graph_embedder.save_embeddings(deps.graph_embeddings_dict, graph_embeddings_cache)
|
| 171 |
+
logger.info(f"Generated graph embeddings for {len(deps.graph_embeddings_dict)} models")
|
| 172 |
+
|
| 173 |
+
# Combine text and graph embeddings
|
| 174 |
+
if deps.graph_embeddings_dict and len(deps.graph_embeddings_dict) > 0:
|
| 175 |
+
model_ids = deps.df['model_id'].astype(str).tolist()
|
| 176 |
+
if os.path.exists(combined_embeddings_cache):
|
| 177 |
+
try:
|
| 178 |
+
with open(combined_embeddings_cache, 'rb') as f:
|
| 179 |
+
deps.combined_embeddings = pickle.load(f)
|
| 180 |
+
logger.info("Loaded cached combined embeddings")
|
| 181 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 182 |
+
logger.warning(f"Failed to load cached combined embeddings: {e}")
|
| 183 |
+
deps.combined_embeddings = None
|
| 184 |
+
|
| 185 |
+
if deps.combined_embeddings is None:
|
| 186 |
+
logger.info("Combining text and graph embeddings...")
|
| 187 |
+
deps.combined_embeddings = deps.graph_embedder.combine_embeddings(
|
| 188 |
+
deps.embeddings, deps.graph_embeddings_dict, model_ids,
|
| 189 |
+
text_weight=0.7, graph_weight=0.3
|
| 190 |
+
)
|
| 191 |
+
with open(combined_embeddings_cache, 'wb') as f:
|
| 192 |
+
pickle.dump(deps.combined_embeddings, f)
|
| 193 |
+
logger.info("Combined embeddings saved")
|
| 194 |
except Exception as e:
|
| 195 |
+
logger.warning(f"Graph embeddings not available: {e}. Continuing with text-only embeddings.")
|
| 196 |
+
deps.graph_embedder = None
|
| 197 |
+
deps.graph_embeddings_dict = None
|
| 198 |
+
deps.combined_embeddings = None
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
# Initialize reducer for text embeddings
|
| 201 |
+
deps.reducer = DimensionReducer(method="umap", n_components=3)
|
| 202 |
|
| 203 |
if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap):
|
| 204 |
try:
|
|
|
|
| 205 |
with open(reduced_cache_umap, 'rb') as f:
|
| 206 |
+
deps.reduced_embeddings = pickle.load(f)
|
| 207 |
+
deps.reducer.load_reducer(reducer_cache_umap)
|
| 208 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 209 |
+
logger.warning(f"Failed to load cached reduced embeddings: {e}")
|
| 210 |
+
deps.reduced_embeddings = None
|
| 211 |
+
|
| 212 |
+
if deps.reduced_embeddings is None:
|
| 213 |
+
deps.reducer.reducer = UMAP(
|
| 214 |
n_components=3,
|
| 215 |
n_neighbors=30,
|
| 216 |
min_dist=0.3,
|
|
|
|
| 220 |
low_memory=True,
|
| 221 |
spread=1.5
|
| 222 |
)
|
| 223 |
+
deps.reduced_embeddings = deps.reducer.fit_transform(deps.embeddings)
|
|
|
|
| 224 |
with open(reduced_cache_umap, 'wb') as f:
|
| 225 |
+
pickle.dump(deps.reduced_embeddings, f)
|
| 226 |
+
deps.reducer.save_reducer(reducer_cache_umap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
# Initialize reducer for graph-aware embeddings if available
|
| 229 |
+
if deps.combined_embeddings is not None:
|
| 230 |
+
reducer_graph = DimensionReducer(method="umap", n_components=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
if os.path.exists(reduced_cache_umap_graph) and os.path.exists(reducer_cache_umap_graph):
|
| 233 |
+
try:
|
| 234 |
+
with open(reduced_cache_umap_graph, 'rb') as f:
|
| 235 |
+
deps.reduced_embeddings_graph = pickle.load(f)
|
| 236 |
+
reducer_graph.load_reducer(reducer_cache_umap_graph)
|
| 237 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 238 |
+
logger.warning(f"Failed to load cached graph-aware reduced embeddings: {e}")
|
| 239 |
+
deps.reduced_embeddings_graph = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
if deps.reduced_embeddings_graph is None:
|
| 242 |
+
reducer_graph.reducer = UMAP(
|
| 243 |
+
n_components=3,
|
| 244 |
+
n_neighbors=30,
|
| 245 |
+
min_dist=0.3,
|
| 246 |
+
metric='cosine',
|
| 247 |
+
random_state=42,
|
| 248 |
+
n_jobs=-1,
|
| 249 |
+
low_memory=True,
|
| 250 |
+
spread=1.5
|
| 251 |
+
)
|
| 252 |
+
deps.reduced_embeddings_graph = reducer_graph.fit_transform(deps.combined_embeddings)
|
| 253 |
+
with open(reduced_cache_umap_graph, 'wb') as f:
|
| 254 |
+
pickle.dump(deps.reduced_embeddings_graph, f)
|
| 255 |
+
reducer_graph.save_reducer(reducer_cache_umap_graph)
|
| 256 |
+
logger.info("Graph-aware embeddings reduced and cached")
|
| 257 |
|
| 258 |
+
# Update module-level aliases
|
| 259 |
+
df = deps.df
|
| 260 |
+
embedder = deps.embedder
|
| 261 |
+
graph_embedder = deps.graph_embedder
|
| 262 |
+
reducer = deps.reducer
|
| 263 |
+
embeddings = deps.embeddings
|
| 264 |
+
graph_embeddings_dict = deps.graph_embeddings_dict
|
| 265 |
+
combined_embeddings = deps.combined_embeddings
|
| 266 |
+
reduced_embeddings = deps.reduced_embeddings
|
| 267 |
+
reduced_embeddings_graph = deps.reduced_embeddings_graph
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
from utils.family_tree import calculate_family_depths
|
| 271 |
|
| 272 |
|
| 273 |
def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
from sklearn.cluster import KMeans
|
| 275 |
|
| 276 |
n_samples = len(reduced_embeddings)
|
|
|
|
| 278 |
n_clusters = max(1, n_samples // 10)
|
| 279 |
|
| 280 |
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
| 281 |
+
return kmeans.fit_predict(reduced_embeddings)
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
@app.get("/")
|
|
|
|
| 293 |
search_query: Optional[str] = Query(None),
|
| 294 |
color_by: str = Query("library_name"),
|
| 295 |
size_by: str = Query("downloads"),
|
| 296 |
+
max_points: Optional[int] = Query(None),
|
| 297 |
+
projection_method: str = Query("umap"),
|
| 298 |
+
base_models_only: bool = Query(False),
|
| 299 |
+
max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
|
| 300 |
+
use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
|
| 301 |
):
|
| 302 |
+
if deps.df is None:
|
| 303 |
+
raise DataNotLoadedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
df = deps.df
|
|
|
|
| 306 |
|
| 307 |
# Filter data
|
| 308 |
filtered_df = data_loader.filter_data(
|
|
|
|
| 322 |
(filtered_df['parent_model'].astype(str) == 'nan')
|
| 323 |
]
|
| 324 |
|
| 325 |
+
if max_hierarchy_depth is not None:
|
| 326 |
+
family_depths = calculate_family_depths(df)
|
| 327 |
+
filtered_df = filtered_df[
|
| 328 |
+
filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
filtered_count = len(filtered_df)
|
| 332 |
|
| 333 |
if len(filtered_df) == 0:
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
if max_points is not None and len(filtered_df) > max_points:
|
|
|
|
|
|
|
| 341 |
if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
|
| 342 |
+
# Sample proportionally by library, preserving all columns
|
| 343 |
+
sampled_dfs = []
|
| 344 |
+
for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
|
| 345 |
+
n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
|
| 346 |
+
sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
|
| 347 |
+
filtered_df = pd.concat(sampled_dfs, ignore_index=True)
|
| 348 |
if len(filtered_df) > max_points:
|
| 349 |
+
filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
|
| 350 |
+
else:
|
| 351 |
+
filtered_df = filtered_df.reset_index(drop=True)
|
| 352 |
else:
|
| 353 |
+
filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
+
# Determine which embeddings to use
|
| 356 |
+
if use_graph_embeddings and combined_embeddings is not None:
|
| 357 |
+
current_embeddings = combined_embeddings
|
| 358 |
+
current_reduced = reduced_embeddings_graph
|
| 359 |
+
embedding_type = "graph-aware"
|
| 360 |
+
else:
|
| 361 |
+
if embeddings is None:
|
| 362 |
+
raise EmbeddingsNotReadyError()
|
| 363 |
+
current_embeddings = embeddings
|
| 364 |
+
current_reduced = reduced_embeddings
|
| 365 |
+
embedding_type = "text-only"
|
| 366 |
+
|
| 367 |
+
# Handle reduced embeddings loading/generation
|
| 368 |
+
if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
|
| 369 |
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 370 |
root_dir = os.path.dirname(backend_dir)
|
| 371 |
cache_dir = os.path.join(root_dir, "cache")
|
| 372 |
+
cache_suffix = "_graph" if use_graph_embeddings and combined_embeddings is not None else ""
|
| 373 |
+
reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
|
| 374 |
+
reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
|
| 375 |
|
| 376 |
if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
|
| 377 |
try:
|
|
|
|
| 378 |
with open(reduced_cache, 'rb') as f:
|
| 379 |
+
current_reduced = pickle.load(f)
|
| 380 |
if reducer is None or reducer.method != projection_method.lower():
|
| 381 |
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 382 |
reducer.load_reducer(reducer_cache)
|
| 383 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 384 |
+
logger.warning(f"Failed to load cached reduced embeddings: {e}")
|
| 385 |
+
current_reduced = None
|
| 386 |
|
| 387 |
+
if current_reduced is None:
|
| 388 |
if reducer is None or reducer.method != projection_method.lower():
|
| 389 |
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 390 |
if projection_method.lower() == "umap":
|
|
|
|
| 398 |
low_memory=True,
|
| 399 |
spread=1.5
|
| 400 |
)
|
| 401 |
+
current_reduced = reducer.fit_transform(current_embeddings)
|
|
|
|
| 402 |
with open(reduced_cache, 'wb') as f:
|
| 403 |
+
pickle.dump(current_reduced, f)
|
| 404 |
reducer.save_reducer(reducer_cache)
|
| 405 |
+
|
| 406 |
+
# Update global variable
|
| 407 |
+
if use_graph_embeddings and deps.combined_embeddings is not None:
|
| 408 |
+
deps.reduced_embeddings_graph = current_reduced
|
| 409 |
+
else:
|
| 410 |
+
deps.reduced_embeddings = current_reduced
|
| 411 |
+
|
| 412 |
+
# Get indices for filtered data
|
| 413 |
+
# Use model_id column to map between filtered_df and original df
|
| 414 |
+
# This is safer than using index positions which can change after filtering
|
| 415 |
+
filtered_model_ids = filtered_df['model_id'].astype(str).values
|
| 416 |
|
| 417 |
+
# Map model_ids to positions in original df
|
|
|
|
|
|
|
| 418 |
if df.index.name == 'model_id' or 'model_id' in df.index.names:
|
| 419 |
+
# When df is indexed by model_id, use get_loc directly
|
| 420 |
+
filtered_indices = []
|
| 421 |
+
for model_id in filtered_model_ids:
|
| 422 |
+
try:
|
| 423 |
+
pos = df.index.get_loc(model_id)
|
| 424 |
+
# Handle both single position and array of positions
|
| 425 |
+
if isinstance(pos, (int, np.integer)):
|
| 426 |
+
filtered_indices.append(int(pos))
|
| 427 |
+
elif isinstance(pos, (slice, np.ndarray)):
|
| 428 |
+
# If multiple matches, take first
|
| 429 |
+
if isinstance(pos, slice):
|
| 430 |
+
filtered_indices.append(int(pos.start))
|
| 431 |
+
else:
|
| 432 |
+
filtered_indices.append(int(pos[0]))
|
| 433 |
+
except (KeyError, TypeError):
|
| 434 |
+
continue
|
| 435 |
+
filtered_indices = np.array(filtered_indices, dtype=np.int32)
|
| 436 |
else:
|
| 437 |
+
# When df is not indexed by model_id, find positions by matching model_id column
|
| 438 |
+
df_model_ids = df['model_id'].astype(str).values
|
| 439 |
+
model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
|
| 440 |
+
filtered_indices = np.array([
|
| 441 |
+
model_id_to_pos[mid] for mid in filtered_model_ids
|
| 442 |
+
if mid in model_id_to_pos
|
| 443 |
+
], dtype=np.int32)
|
| 444 |
+
|
| 445 |
+
if len(filtered_indices) == 0:
|
| 446 |
+
return {
|
| 447 |
+
"models": [],
|
| 448 |
+
"embedding_type": embedding_type,
|
| 449 |
+
"filtered_count": filtered_count,
|
| 450 |
+
"returned_count": 0
|
| 451 |
+
}
|
| 452 |
|
| 453 |
+
filtered_reduced = current_reduced[filtered_indices]
|
| 454 |
family_depths = calculate_family_depths(df)
|
| 455 |
|
| 456 |
+
# Use appropriate embeddings for clustering
|
| 457 |
+
clustering_embeddings = current_reduced
|
| 458 |
+
# Compute clusters if not already computed or if size changed
|
| 459 |
+
if models.cluster_labels is None or len(models.cluster_labels) != len(clustering_embeddings):
|
| 460 |
+
models.cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
|
| 461 |
|
| 462 |
+
# Handle case where cluster_labels might not match filtered data yet
|
| 463 |
+
if models.cluster_labels is not None and len(models.cluster_labels) > 0:
|
| 464 |
+
if len(filtered_indices) <= len(models.cluster_labels):
|
| 465 |
+
filtered_clusters = models.cluster_labels[filtered_indices]
|
| 466 |
+
else:
|
| 467 |
+
# Fallback: use first cluster for all if indices don't match
|
| 468 |
+
filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
|
| 469 |
+
else:
|
| 470 |
+
filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
|
| 471 |
|
|
|
|
|
|
|
| 472 |
model_ids = filtered_df['model_id'].astype(str).values
|
| 473 |
+
library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
|
| 474 |
+
pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
|
| 475 |
+
downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
|
| 476 |
+
likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
|
| 477 |
+
trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
|
| 478 |
+
tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
|
| 479 |
+
parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
|
| 480 |
+
licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
|
| 481 |
+
created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
|
| 482 |
+
|
| 483 |
x_coords = filtered_reduced[:, 0].astype(float)
|
| 484 |
y_coords = filtered_reduced[:, 1].astype(float)
|
| 485 |
z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
|
|
|
|
|
|
|
| 486 |
models = [
|
| 487 |
ModelPoint(
|
| 488 |
model_id=model_ids[idx],
|
|
|
|
| 498 |
parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
|
| 499 |
licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
|
| 500 |
family_depth=family_depths.get(model_ids[idx], None),
|
| 501 |
+
cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
|
| 502 |
+
created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
|
| 503 |
)
|
| 504 |
for idx in range(len(filtered_df))
|
| 505 |
]
|
| 506 |
|
| 507 |
+
# Return models with metadata about embedding type
|
| 508 |
+
return {
|
| 509 |
+
"models": models,
|
| 510 |
+
"embedding_type": embedding_type,
|
| 511 |
+
"filtered_count": filtered_count,
|
| 512 |
+
"returned_count": len(models)
|
| 513 |
+
}
|
| 514 |
|
| 515 |
|
| 516 |
@app.get("/api/stats")
|
| 517 |
async def get_stats():
|
| 518 |
"""Get dataset statistics."""
|
| 519 |
if df is None:
|
| 520 |
+
raise DataNotLoadedError()
|
| 521 |
|
|
|
|
| 522 |
total_models = len(df.index) if hasattr(df, 'index') else len(df)
|
| 523 |
|
| 524 |
+
# Get unique licenses with counts
|
| 525 |
+
licenses = {}
|
| 526 |
+
if 'license' in df.columns:
|
| 527 |
+
license_counts = df['license'].value_counts().to_dict()
|
| 528 |
+
licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
|
| 529 |
+
|
| 530 |
return {
|
| 531 |
"total_models": total_models,
|
| 532 |
"unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
|
| 533 |
"unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
|
| 534 |
"unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity
|
| 535 |
+
"unique_licenses": len(licenses),
|
| 536 |
+
"licenses": licenses, # License name -> count mapping
|
| 537 |
"avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
|
| 538 |
"avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
|
| 539 |
}
|
|
|
|
| 543 |
async def get_model_details(model_id: str):
|
| 544 |
"""Get detailed information about a specific model."""
|
| 545 |
if df is None:
|
| 546 |
+
raise DataNotLoadedError()
|
| 547 |
|
| 548 |
model = df[df.get('model_id', '') == model_id]
|
| 549 |
if len(model) == 0:
|
|
|
|
| 551 |
|
| 552 |
model = model.iloc[0]
|
| 553 |
|
|
|
|
| 554 |
tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else ''
|
| 555 |
arxiv_ids = extract_arxiv_ids(tags_str)
|
| 556 |
|
|
|
|
| 557 |
papers = []
|
| 558 |
if arxiv_ids:
|
| 559 |
papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers
|
|
|
|
| 573 |
}
|
| 574 |
|
| 575 |
|
| 576 |
+
# Clusters endpoint is handled by routes/clusters.py router
|
| 577 |
+
|
| 578 |
@app.get("/api/family/stats")
|
| 579 |
async def get_family_stats():
|
| 580 |
"""
|
|
|
|
| 582 |
Returns family size distribution, depth statistics, model card length by depth, etc.
|
| 583 |
"""
|
| 584 |
if df is None:
|
| 585 |
+
raise DataNotLoadedError()
|
| 586 |
|
|
|
|
| 587 |
family_sizes = {}
|
| 588 |
root_models = set()
|
| 589 |
|
|
|
|
| 597 |
family_sizes[model_id] = 0
|
| 598 |
else:
|
| 599 |
parent_id_str = str(parent_id)
|
|
|
|
| 600 |
root = parent_id_str
|
| 601 |
visited = set()
|
| 602 |
while root in df.index and pd.notna(df.loc[root].get('parent_model')):
|
| 603 |
parent = df.loc[root].get('parent_model')
|
| 604 |
if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '':
|
| 605 |
break
|
| 606 |
+
if str(parent) in visited:
|
| 607 |
break
|
| 608 |
visited.add(root)
|
| 609 |
root = str(parent)
|
|
|
|
| 612 |
family_sizes[root] = 0
|
| 613 |
family_sizes[root] += 1
|
| 614 |
|
|
|
|
| 615 |
size_distribution = {}
|
| 616 |
for root, size in family_sizes.items():
|
| 617 |
size_distribution[size] = size_distribution.get(size, 0) + 1
|
| 618 |
|
|
|
|
| 619 |
depths = calculate_family_depths(df)
|
| 620 |
depth_counts = {}
|
| 621 |
for depth in depths.values():
|
| 622 |
depth_counts[depth] = depth_counts.get(depth, 0) + 1
|
| 623 |
|
|
|
|
| 624 |
model_card_lengths_by_depth = {}
|
| 625 |
if 'modelCard' in df.columns:
|
| 626 |
for idx, row in df.iterrows():
|
|
|
|
| 633 |
model_card_lengths_by_depth[depth] = []
|
| 634 |
model_card_lengths_by_depth[depth].append(card_length)
|
| 635 |
|
|
|
|
| 636 |
model_card_stats = {}
|
| 637 |
for depth, lengths in model_card_lengths_by_depth.items():
|
| 638 |
if lengths:
|
|
|
|
| 657 |
}
|
| 658 |
|
| 659 |
|
| 660 |
+
@app.get("/api/family/path/{model_id}")
|
| 661 |
+
async def get_family_path(
|
| 662 |
+
model_id: str,
|
| 663 |
+
target_id: Optional[str] = Query(None, description="Target model ID. If None, returns path to root.")
|
| 664 |
+
):
|
| 665 |
+
"""
|
| 666 |
+
Get path from model to root or to target model.
|
| 667 |
+
Returns list of model IDs representing the path.
|
| 668 |
+
"""
|
| 669 |
+
if df is None:
|
| 670 |
+
raise DataNotLoadedError()
|
| 671 |
+
|
| 672 |
+
model_id_str = str(model_id)
|
| 673 |
+
|
| 674 |
+
if df.index.name == 'model_id':
|
| 675 |
+
if model_id_str not in df.index:
|
| 676 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 677 |
+
else:
|
| 678 |
+
model_rows = df[df.get('model_id', '') == model_id_str]
|
| 679 |
+
if len(model_rows) == 0:
|
| 680 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 681 |
+
|
| 682 |
+
path = [model_id_str]
|
| 683 |
+
visited = set([model_id_str])
|
| 684 |
+
current = model_id_str
|
| 685 |
+
|
| 686 |
+
if target_id:
|
| 687 |
+
target_str = str(target_id)
|
| 688 |
+
if df.index.name == 'model_id':
|
| 689 |
+
if target_str not in df.index:
|
| 690 |
+
raise HTTPException(status_code=404, detail="Target model not found")
|
| 691 |
+
|
| 692 |
+
while current != target_str and current not in visited:
|
| 693 |
+
try:
|
| 694 |
+
if df.index.name == 'model_id':
|
| 695 |
+
row = df.loc[current]
|
| 696 |
+
else:
|
| 697 |
+
rows = df[df.get('model_id', '') == current]
|
| 698 |
+
if len(rows) == 0:
|
| 699 |
+
break
|
| 700 |
+
row = rows.iloc[0]
|
| 701 |
+
|
| 702 |
+
parent_id = row.get('parent_model')
|
| 703 |
+
if parent_id and pd.notna(parent_id):
|
| 704 |
+
parent_str = str(parent_id)
|
| 705 |
+
if parent_str == target_str:
|
| 706 |
+
path.append(parent_str)
|
| 707 |
+
break
|
| 708 |
+
if parent_str not in visited:
|
| 709 |
+
path.append(parent_str)
|
| 710 |
+
visited.add(parent_str)
|
| 711 |
+
current = parent_str
|
| 712 |
+
else:
|
| 713 |
+
break
|
| 714 |
+
else:
|
| 715 |
+
break
|
| 716 |
+
except (KeyError, IndexError):
|
| 717 |
+
break
|
| 718 |
+
else:
|
| 719 |
+
while True:
|
| 720 |
+
try:
|
| 721 |
+
if df.index.name == 'model_id':
|
| 722 |
+
row = df.loc[current]
|
| 723 |
+
else:
|
| 724 |
+
rows = df[df.get('model_id', '') == current]
|
| 725 |
+
if len(rows) == 0:
|
| 726 |
+
break
|
| 727 |
+
row = rows.iloc[0]
|
| 728 |
+
|
| 729 |
+
parent_id = row.get('parent_model')
|
| 730 |
+
if parent_id and pd.notna(parent_id):
|
| 731 |
+
parent_str = str(parent_id)
|
| 732 |
+
if parent_str not in visited:
|
| 733 |
+
path.append(parent_str)
|
| 734 |
+
visited.add(parent_str)
|
| 735 |
+
current = parent_str
|
| 736 |
+
else:
|
| 737 |
+
break
|
| 738 |
+
else:
|
| 739 |
+
break
|
| 740 |
+
except (KeyError, IndexError):
|
| 741 |
+
break
|
| 742 |
+
|
| 743 |
+
return {
|
| 744 |
+
"path": path,
|
| 745 |
+
"source": model_id_str,
|
| 746 |
+
"target": target_id if target_id else "root",
|
| 747 |
+
"path_length": len(path) - 1
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
|
| 751 |
@app.get("/api/family/{model_id}")
|
| 752 |
+
async def get_family_tree(
|
| 753 |
+
model_id: str,
|
| 754 |
+
max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
|
| 755 |
+
max_depth_filter: Optional[int] = Query(None, ge=0, description="Filter results to models at or below this hierarchy depth.")
|
| 756 |
+
):
|
| 757 |
"""
|
| 758 |
Get family tree for a model (ancestors and descendants).
|
| 759 |
Returns the model, its parent chain, and all children.
|
| 760 |
+
|
| 761 |
+
If max_depth is None, traverses the entire family tree without depth limits.
|
| 762 |
"""
|
| 763 |
if df is None:
|
| 764 |
+
raise DataNotLoadedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
|
|
|
|
| 766 |
if reduced_embeddings is None:
|
| 767 |
raise HTTPException(status_code=503, detail="Embeddings not ready")
|
| 768 |
|
| 769 |
+
model_id_str = str(model_id)
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
+
if df.index.name == 'model_id':
|
| 772 |
+
if model_id_str not in df.index:
|
| 773 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 774 |
+
model_lookup = df.loc
|
| 775 |
+
else:
|
| 776 |
+
model_rows = df[df.get('model_id', '') == model_id_str]
|
| 777 |
+
if len(model_rows) == 0:
|
| 778 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 779 |
+
model_lookup = lambda x: df[df.get('model_id', '') == x]
|
| 780 |
+
|
| 781 |
+
from utils.network_analysis import _get_all_parents, _parse_parent_list
|
| 782 |
+
|
| 783 |
+
children_index: Dict[str, List[str]] = {}
|
| 784 |
+
parent_columns = ['parent_model', 'finetune_parent', 'quantized_parent', 'adapter_parent', 'merge_parent']
|
| 785 |
+
|
| 786 |
+
for idx, row in df.iterrows():
|
| 787 |
+
model_id_from_row = str(row.get('model_id', idx))
|
| 788 |
+
all_parents = _get_all_parents(row)
|
| 789 |
+
|
| 790 |
+
for rel_type, parent_list in all_parents.items():
|
| 791 |
+
for parent_str in parent_list:
|
| 792 |
+
if parent_str not in children_index:
|
| 793 |
+
children_index[parent_str] = []
|
| 794 |
+
children_index[parent_str].append(model_id_from_row)
|
| 795 |
+
|
| 796 |
+
visited = set()
|
| 797 |
+
|
| 798 |
+
def get_ancestors(current_id: str, depth: Optional[int]):
|
| 799 |
+
if current_id in visited:
|
| 800 |
+
return
|
| 801 |
+
if depth is not None and depth <= 0:
|
| 802 |
return
|
| 803 |
visited.add(current_id)
|
| 804 |
|
| 805 |
+
try:
|
| 806 |
+
if df.index.name == 'model_id':
|
| 807 |
+
row = df.loc[current_id]
|
| 808 |
+
else:
|
| 809 |
+
rows = model_lookup(current_id)
|
| 810 |
+
if len(rows) == 0:
|
| 811 |
+
return
|
| 812 |
+
row = rows.iloc[0]
|
| 813 |
+
|
| 814 |
+
all_parents = _get_all_parents(row)
|
| 815 |
+
for rel_type, parent_list in all_parents.items():
|
| 816 |
+
for parent_str in parent_list:
|
| 817 |
+
if parent_str != 'nan' and parent_str != '':
|
| 818 |
+
next_depth = depth - 1 if depth is not None else None
|
| 819 |
+
get_ancestors(parent_str, next_depth)
|
| 820 |
+
except (KeyError, IndexError):
|
| 821 |
+
return
|
| 822 |
|
| 823 |
+
def get_descendants(current_id: str, depth: Optional[int]):
|
| 824 |
+
if current_id in visited:
|
| 825 |
+
return
|
| 826 |
+
if depth is not None and depth <= 0:
|
| 827 |
return
|
| 828 |
visited.add(current_id)
|
| 829 |
|
| 830 |
+
children = children_index.get(current_id, [])
|
| 831 |
+
for child_id in children:
|
| 832 |
+
if child_id not in visited:
|
| 833 |
+
next_depth = depth - 1 if depth is not None else None
|
| 834 |
+
get_descendants(child_id, next_depth)
|
| 835 |
+
|
| 836 |
+
get_ancestors(model_id_str, max_depth)
|
| 837 |
+
visited = set()
|
| 838 |
+
get_descendants(model_id_str, max_depth)
|
| 839 |
+
visited.add(model_id_str)
|
| 840 |
+
|
| 841 |
+
if df.index.name == 'model_id':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
try:
|
| 843 |
family_df = df.loc[list(visited)]
|
| 844 |
except KeyError:
|
| 845 |
+
missing = [v for v in visited if v not in df.index]
|
| 846 |
+
if missing:
|
| 847 |
+
logger.warning(f"Some family members not found in index: {missing}")
|
| 848 |
+
family_df = df.loc[[v for v in visited if v in df.index]]
|
| 849 |
else:
|
| 850 |
family_df = df[df.get('model_id', '').isin(visited)]
|
| 851 |
|
| 852 |
+
if len(family_df) == 0:
|
| 853 |
+
raise HTTPException(status_code=404, detail="Family tree data not available")
|
| 854 |
+
|
| 855 |
+
family_indices = family_df.index.values
|
| 856 |
+
if len(family_indices) > len(reduced_embeddings):
|
| 857 |
+
raise HTTPException(status_code=503, detail="Embedding indices mismatch")
|
| 858 |
+
|
| 859 |
family_reduced = reduced_embeddings[family_indices]
|
| 860 |
|
|
|
|
| 861 |
family_map = {}
|
| 862 |
for idx, (i, row) in enumerate(family_df.iterrows()):
|
| 863 |
+
model_id_val = str(row.get('model_id', i))
|
| 864 |
+
parent_id = row.get('parent_model')
|
| 865 |
+
parent_id_str = str(parent_id) if parent_id and pd.notna(parent_id) else None
|
| 866 |
+
|
| 867 |
+
depths = calculate_family_depths(df)
|
| 868 |
+
model_depth = depths.get(model_id_val, 0)
|
| 869 |
+
|
| 870 |
+
if max_depth_filter is not None and model_depth > max_depth_filter:
|
| 871 |
+
continue
|
| 872 |
|
| 873 |
family_map[model_id_val] = {
|
| 874 |
"model_id": model_id_val,
|
|
|
|
| 879 |
"pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None,
|
| 880 |
"downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0,
|
| 881 |
"likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0,
|
| 882 |
+
"parent_model": parent_id_str,
|
| 883 |
"licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None,
|
| 884 |
+
"family_depth": model_depth,
|
| 885 |
"children": []
|
| 886 |
}
|
| 887 |
|
|
|
|
| 888 |
root_models = []
|
| 889 |
for model_id_val, model_data in family_map.items():
|
| 890 |
parent_id = model_data["parent_model"]
|
|
|
|
| 894 |
root_models.append(model_id_val)
|
| 895 |
|
| 896 |
return {
|
| 897 |
+
"root_model": model_id_str,
|
| 898 |
"family": list(family_map.values()),
|
| 899 |
"family_map": family_map,
|
| 900 |
"root_models": root_models
|
|
|
|
| 903 |
|
| 904 |
@app.get("/api/search")
|
| 905 |
async def search_models(
|
| 906 |
+
q: str = Query(..., min_length=1, alias="query"),
|
| 907 |
+
query: str = Query(None, min_length=1),
|
| 908 |
+
limit: int = Query(20, ge=1, le=100),
|
| 909 |
graph_aware: bool = Query(False),
|
| 910 |
include_neighbors: bool = Query(True)
|
| 911 |
):
|
|
|
|
| 914 |
Enhanced with graph-aware search option that includes network relationships.
|
| 915 |
"""
|
| 916 |
if df is None:
|
| 917 |
+
raise DataNotLoadedError()
|
| 918 |
+
|
| 919 |
+
# Support both 'q' and 'query' parameters
|
| 920 |
+
search_query = query or q
|
| 921 |
|
| 922 |
if graph_aware:
|
|
|
|
| 923 |
try:
|
| 924 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
| 925 |
top_models = network_builder.get_top_models_by_field(n=1000)
|
| 926 |
model_ids = [mid for mid, _ in top_models]
|
| 927 |
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
|
| 928 |
|
| 929 |
results = network_builder.search_graph_aware(
|
| 930 |
+
query=search_query,
|
| 931 |
graph=graph,
|
| 932 |
+
max_results=limit,
|
| 933 |
include_neighbors=include_neighbors
|
| 934 |
)
|
| 935 |
|
| 936 |
+
return {"results": results, "search_type": "graph_aware", "query": search_query}
|
| 937 |
+
except (ValueError, KeyError, AttributeError) as e:
|
| 938 |
+
logger.warning(f"Graph-aware search failed, falling back to basic search: {e}")
|
| 939 |
+
|
| 940 |
+
query_lower = search_query.lower()
|
| 941 |
+
|
| 942 |
+
# Enhanced search: search model_id, org, tags, library, pipeline
|
| 943 |
+
model_id_col = df.get('model_id', '').astype(str).str.lower()
|
| 944 |
+
library_col = df.get('library_name', '').astype(str).str.lower()
|
| 945 |
+
pipeline_col = df.get('pipeline_tag', '').astype(str).str.lower()
|
| 946 |
+
tags_col = df.get('tags', '').astype(str).str.lower()
|
| 947 |
+
license_col = df.get('license', '').astype(str).str.lower()
|
| 948 |
+
|
| 949 |
+
# Extract org from model_id
|
| 950 |
+
org_col = model_id_col.str.split('/').str[0]
|
| 951 |
+
|
| 952 |
+
# Multi-field search
|
| 953 |
+
mask = (
|
| 954 |
+
model_id_col.str.contains(query_lower, na=False) |
|
| 955 |
+
org_col.str.contains(query_lower, na=False) |
|
| 956 |
+
library_col.str.contains(query_lower, na=False) |
|
| 957 |
+
pipeline_col.str.contains(query_lower, na=False) |
|
| 958 |
+
tags_col.str.contains(query_lower, na=False) |
|
| 959 |
+
license_col.str.contains(query_lower, na=False)
|
| 960 |
+
)
|
| 961 |
|
| 962 |
+
matches = df[mask].head(limit)
|
|
|
|
|
|
|
|
|
|
| 963 |
|
| 964 |
results = []
|
| 965 |
for _, row in matches.iterrows():
|
| 966 |
+
model_id = str(row.get('model_id', ''))
|
| 967 |
+
org = model_id.split('/')[0] if '/' in model_id else ''
|
| 968 |
+
|
| 969 |
+
# Get coordinates if available
|
| 970 |
+
x = float(row.get('x', 0.0)) if 'x' in row else None
|
| 971 |
+
y = float(row.get('y', 0.0)) if 'y' in row else None
|
| 972 |
+
z = float(row.get('z', 0.0)) if 'z' in row else None
|
| 973 |
+
|
| 974 |
results.append({
|
| 975 |
+
"model_id": model_id,
|
| 976 |
+
"x": x,
|
| 977 |
+
"y": y,
|
| 978 |
+
"z": z,
|
| 979 |
+
"org": org,
|
| 980 |
+
"library": row.get('library_name'),
|
| 981 |
+
"pipeline": row.get('pipeline_tag'),
|
| 982 |
+
"license": row.get('license') if pd.notna(row.get('license')) else None,
|
| 983 |
"downloads": int(row.get('downloads', 0)),
|
| 984 |
"likes": int(row.get('likes', 0)),
|
| 985 |
"parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None,
|
| 986 |
"match_type": "direct"
|
| 987 |
})
|
| 988 |
|
| 989 |
+
return {"results": results, "search_type": "basic", "query": search_query}
|
| 990 |
|
| 991 |
|
| 992 |
@app.get("/api/similar/{model_id}")
|
|
|
|
| 995 |
Get k-nearest neighbors of a model based on embedding similarity.
|
| 996 |
Returns similar models with distance scores.
|
| 997 |
"""
|
| 998 |
+
if deps.df is None or deps.embeddings is None:
|
|
|
|
|
|
|
| 999 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 1000 |
|
| 1001 |
+
df = deps.df
|
| 1002 |
+
embeddings = deps.embeddings
|
| 1003 |
+
|
| 1004 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 1005 |
try:
|
| 1006 |
model_row = df.loc[[model_id]]
|
|
|
|
| 1014 |
model_idx = model_row.index[0]
|
| 1015 |
model_embedding = embeddings[model_idx]
|
| 1016 |
|
|
|
|
| 1017 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 1018 |
model_embedding_2d = model_embedding.reshape(1, -1)
|
| 1019 |
similarities = cosine_similarity(model_embedding_2d, embeddings)[0]
|
| 1020 |
|
|
|
|
|
|
|
| 1021 |
top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1]
|
|
|
|
| 1022 |
top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
|
| 1023 |
|
| 1024 |
similar_models = []
|
|
|
|
| 1029 |
similar_models.append({
|
| 1030 |
"model_id": row.get('model_id', 'Unknown'),
|
| 1031 |
"similarity": float(similarities[idx]),
|
| 1032 |
+
"distance": float(1 - similarities[idx]),
|
| 1033 |
"library_name": row.get('library_name'),
|
| 1034 |
"pipeline_tag": row.get('pipeline_tag'),
|
| 1035 |
"downloads": int(row.get('downloads', 0)),
|
|
|
|
| 1055 |
Returns models with their similarity scores and coordinates.
|
| 1056 |
Useful for exploring the embedding space around a specific model.
|
| 1057 |
"""
|
| 1058 |
+
if deps.df is None or deps.embeddings is None:
|
|
|
|
|
|
|
| 1059 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 1060 |
|
| 1061 |
+
df = deps.df
|
| 1062 |
+
embeddings = deps.embeddings
|
| 1063 |
+
|
| 1064 |
# Find the query model
|
| 1065 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 1066 |
try:
|
|
|
|
| 1076 |
|
| 1077 |
query_embedding = embeddings[model_idx]
|
| 1078 |
|
|
|
|
| 1079 |
filtered_df = data_loader.filter_data(
|
| 1080 |
df=df,
|
| 1081 |
min_downloads=min_downloads,
|
|
|
|
| 1085 |
pipeline_tags=None
|
| 1086 |
)
|
| 1087 |
|
|
|
|
| 1088 |
if df.index.name == 'model_id' or 'model_id' in df.index.names:
|
| 1089 |
filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index]
|
| 1090 |
filtered_indices = np.array(filtered_indices, dtype=int)
|
| 1091 |
else:
|
| 1092 |
filtered_indices = filtered_df.index.values.astype(int)
|
| 1093 |
|
|
|
|
| 1094 |
filtered_embeddings = embeddings[filtered_indices]
|
| 1095 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 1096 |
query_embedding_2d = query_embedding.reshape(1, -1)
|
| 1097 |
similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0]
|
| 1098 |
|
|
|
|
| 1099 |
top_k_local_indices = np.argpartition(similarities, -k)[-k:]
|
| 1100 |
top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1]
|
| 1101 |
|
|
|
|
| 1102 |
if reduced_embeddings is None:
|
| 1103 |
raise HTTPException(status_code=503, detail="Reduced embeddings not ready")
|
| 1104 |
|
|
|
|
| 1105 |
top_k_original_indices = filtered_indices[top_k_local_indices]
|
| 1106 |
top_k_reduced = reduced_embeddings[top_k_original_indices]
|
| 1107 |
|
|
|
|
| 1108 |
similar_models = []
|
| 1109 |
for i, orig_idx in enumerate(top_k_original_indices):
|
| 1110 |
row = df.iloc[orig_idx]
|
|
|
|
| 1141 |
"""
|
| 1142 |
Calculate distance/similarity between two models.
|
| 1143 |
"""
|
| 1144 |
+
if deps.df is None or deps.embeddings is None:
|
|
|
|
|
|
|
| 1145 |
raise HTTPException(status_code=503, detail="Data not loaded")
|
| 1146 |
|
| 1147 |
+
df = deps.df
|
| 1148 |
+
embeddings = deps.embeddings
|
| 1149 |
+
|
| 1150 |
# Find both models - optimized with index lookup
|
| 1151 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
| 1152 |
try:
|
|
|
|
| 1183 |
Export selected models as JSON with full metadata.
|
| 1184 |
"""
|
| 1185 |
if df is None:
|
| 1186 |
+
raise DataNotLoadedError()
|
| 1187 |
|
| 1188 |
# Optimized export with index lookup
|
| 1189 |
if 'model_id' in df.index.names or df.index.name == 'model_id':
|
|
|
|
| 1198 |
if len(exported) == 0:
|
| 1199 |
return {"models": []}
|
| 1200 |
|
|
|
|
| 1201 |
models = [
|
| 1202 |
{
|
| 1203 |
"model_id": str(row.get('model_id', '')),
|
|
|
|
| 1235 |
Returns network graph data suitable for visualization.
|
| 1236 |
"""
|
| 1237 |
if df is None:
|
| 1238 |
+
raise DataNotLoadedError()
|
| 1239 |
|
| 1240 |
try:
|
| 1241 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
|
|
|
| 1242 |
top_models = network_builder.get_top_models_by_field(
|
| 1243 |
library=library,
|
| 1244 |
pipeline_tag=pipeline_tag,
|
|
|
|
| 1255 |
}
|
| 1256 |
|
| 1257 |
model_ids = [mid for mid, _ in top_models]
|
|
|
|
|
|
|
| 1258 |
graph = network_builder.build_cooccurrence_network(
|
| 1259 |
model_ids=model_ids,
|
| 1260 |
cooccurrence_method=cooccurrence_method
|
| 1261 |
)
|
| 1262 |
|
|
|
|
| 1263 |
nodes = []
|
| 1264 |
for node_id, attrs in graph.nodes(data=True):
|
| 1265 |
nodes.append({
|
|
|
|
| 1287 |
"links": links,
|
| 1288 |
"statistics": stats
|
| 1289 |
}
|
| 1290 |
+
except (ValueError, KeyError, AttributeError) as e:
|
| 1291 |
+
logger.error(f"Error building network: {e}", exc_info=True)
|
| 1292 |
raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}")
|
| 1293 |
|
| 1294 |
|
| 1295 |
@app.get("/api/network/family/{model_id}")
|
| 1296 |
async def get_family_network(
|
| 1297 |
model_id: str,
|
| 1298 |
+
max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
|
| 1299 |
+
edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
|
| 1300 |
+
include_edge_attributes: bool = Query(True, description="Whether to include edge attributes (change in likes, downloads, etc.)")
|
| 1301 |
):
|
| 1302 |
"""
|
| 1303 |
Build family tree network for a model (directed graph).
|
| 1304 |
+
Returns network graph data showing parent-child relationships with multiple relationship types.
|
| 1305 |
+
Supports filtering by edge type (finetune, quantized, adapter, merge, parent).
|
| 1306 |
"""
|
| 1307 |
if df is None:
|
| 1308 |
+
raise DataNotLoadedError()
|
| 1309 |
|
| 1310 |
try:
|
| 1311 |
+
filter_types = None
|
| 1312 |
+
if edge_types:
|
| 1313 |
+
filter_types = [t.strip() for t in edge_types.split(',') if t.strip()]
|
| 1314 |
+
|
| 1315 |
network_builder = ModelNetworkBuilder(df)
|
| 1316 |
graph = network_builder.build_family_tree_network(
|
| 1317 |
root_model_id=model_id,
|
| 1318 |
+
max_depth=max_depth,
|
| 1319 |
+
include_edge_attributes=include_edge_attributes,
|
| 1320 |
+
filter_edge_types=filter_types
|
| 1321 |
)
|
| 1322 |
|
|
|
|
| 1323 |
nodes = []
|
| 1324 |
for node_id, attrs in graph.nodes(data=True):
|
| 1325 |
nodes.append({
|
| 1326 |
"id": node_id,
|
| 1327 |
"title": attrs.get('title', node_id),
|
| 1328 |
+
"freq": attrs.get('freq', 0),
|
| 1329 |
+
"likes": attrs.get('likes', 0),
|
| 1330 |
+
"downloads": attrs.get('downloads', 0),
|
| 1331 |
+
"library": attrs.get('library', ''),
|
| 1332 |
+
"pipeline": attrs.get('pipeline', '')
|
| 1333 |
})
|
| 1334 |
|
| 1335 |
links = []
|
| 1336 |
+
for source, target, edge_attrs in graph.edges(data=True):
|
| 1337 |
+
link_data = {
|
| 1338 |
"source": source,
|
| 1339 |
+
"target": target,
|
| 1340 |
+
"edge_type": edge_attrs.get('edge_type'),
|
| 1341 |
+
"edge_types": edge_attrs.get('edge_types', [])
|
| 1342 |
+
}
|
| 1343 |
+
|
| 1344 |
+
if include_edge_attributes:
|
| 1345 |
+
link_data.update({
|
| 1346 |
+
"change_in_likes": edge_attrs.get('change_in_likes'),
|
| 1347 |
+
"percentage_change_in_likes": edge_attrs.get('percentage_change_in_likes'),
|
| 1348 |
+
"change_in_downloads": edge_attrs.get('change_in_downloads'),
|
| 1349 |
+
"percentage_change_in_downloads": edge_attrs.get('percentage_change_in_downloads'),
|
| 1350 |
+
"change_in_createdAt_days": edge_attrs.get('change_in_createdAt_days')
|
| 1351 |
+
})
|
| 1352 |
+
|
| 1353 |
+
links.append(link_data)
|
| 1354 |
|
| 1355 |
stats = network_builder.get_network_statistics(graph)
|
| 1356 |
|
|
|
|
| 1360 |
"statistics": stats,
|
| 1361 |
"root_model": model_id
|
| 1362 |
}
|
| 1363 |
+
except (ValueError, KeyError, AttributeError) as e:
|
| 1364 |
+
logger.error(f"Error building family network: {e}", exc_info=True)
|
| 1365 |
raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}")
|
| 1366 |
|
| 1367 |
|
|
|
|
| 1376 |
Similar to graph database queries for finding connected nodes.
|
| 1377 |
"""
|
| 1378 |
if df is None:
|
| 1379 |
+
raise DataNotLoadedError()
|
| 1380 |
|
| 1381 |
try:
|
| 1382 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
| 1383 |
top_models = network_builder.get_top_models_by_field(n=1000)
|
| 1384 |
model_ids = [mid for mid, _ in top_models]
|
| 1385 |
graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
|
|
|
|
| 1396 |
"neighbors": neighbors,
|
| 1397 |
"count": len(neighbors)
|
| 1398 |
}
|
| 1399 |
+
except (ValueError, KeyError, AttributeError) as e:
|
| 1400 |
+
logger.error(f"Error finding neighbors: {e}", exc_info=True)
|
| 1401 |
raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}")
|
| 1402 |
|
| 1403 |
|
|
|
|
| 1412 |
Similar to graph database path queries.
|
| 1413 |
"""
|
| 1414 |
if df is None:
|
| 1415 |
+
raise DataNotLoadedError()
|
| 1416 |
|
| 1417 |
try:
|
| 1418 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
| 1460 |
Similar to graph database queries for co-assignment patterns.
|
| 1461 |
"""
|
| 1462 |
if df is None:
|
| 1463 |
+
raise DataNotLoadedError()
|
| 1464 |
|
| 1465 |
try:
|
| 1466 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
| 1497 |
Similar to graph database relationship queries.
|
| 1498 |
"""
|
| 1499 |
if df is None:
|
| 1500 |
+
raise DataNotLoadedError()
|
| 1501 |
|
| 1502 |
try:
|
| 1503 |
network_builder = ModelNetworkBuilder(df)
|
|
|
|
| 1522 |
async def get_current_model_count(
|
| 1523 |
use_cache: bool = Query(True),
|
| 1524 |
force_refresh: bool = Query(False),
|
| 1525 |
+
use_dataset_snapshot: bool = Query(False),
|
| 1526 |
+
use_models_page: bool = Query(True)
|
| 1527 |
):
|
| 1528 |
"""
|
| 1529 |
Get the current number of models on Hugging Face Hub.
|
| 1530 |
+
Uses multiple strategies: models page scraping (fastest), dataset snapshot, or API.
|
| 1531 |
|
| 1532 |
Query Parameters:
|
| 1533 |
use_cache: Use cached results if available (default: True)
|
| 1534 |
force_refresh: Force refresh even if cache is valid (default: False)
|
| 1535 |
+
use_dataset_snapshot: Use dataset snapshot for breakdowns (default: False)
|
| 1536 |
+
use_models_page: Try to get count from HF models page first (default: True)
|
| 1537 |
"""
|
| 1538 |
try:
|
| 1539 |
+
tracker = get_tracker()
|
| 1540 |
+
|
| 1541 |
if use_dataset_snapshot:
|
| 1542 |
+
count_data = tracker.get_count_from_models_page()
|
|
|
|
|
|
|
| 1543 |
if count_data is None:
|
| 1544 |
+
count_data = tracker.get_current_model_count(use_models_page=False)
|
| 1545 |
+
else:
|
| 1546 |
+
try:
|
| 1547 |
+
from utils.data_loader import ModelDataLoader
|
| 1548 |
+
data_loader = ModelDataLoader()
|
| 1549 |
+
df = data_loader.load_data(sample_size=10000)
|
| 1550 |
+
library_counts = {}
|
| 1551 |
+
pipeline_counts = {}
|
| 1552 |
+
|
| 1553 |
+
for _, row in df.iterrows():
|
| 1554 |
+
if pd.notna(row.get('library_name')):
|
| 1555 |
+
lib = str(row.get('library_name'))
|
| 1556 |
+
library_counts[lib] = library_counts.get(lib, 0) + 1
|
| 1557 |
+
if pd.notna(row.get('pipeline_tag')):
|
| 1558 |
+
pipeline = str(row.get('pipeline_tag'))
|
| 1559 |
+
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 1560 |
+
|
| 1561 |
+
if len(df) > 0 and count_data["total_models"] > len(df):
|
| 1562 |
+
scale_factor = count_data["total_models"] / len(df)
|
| 1563 |
+
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
| 1564 |
+
pipeline_counts = {k: int(v * scale_factor) for k, v in pipeline_counts.items()}
|
| 1565 |
+
|
| 1566 |
+
count_data["models_by_library"] = library_counts
|
| 1567 |
+
count_data["models_by_pipeline"] = pipeline_counts
|
| 1568 |
+
except Exception as e:
|
| 1569 |
+
logger.warning(f"Could not get breakdowns from dataset: {e}")
|
| 1570 |
else:
|
| 1571 |
+
count_data = tracker.get_current_model_count(use_models_page=use_models_page)
|
|
|
|
|
|
|
| 1572 |
|
| 1573 |
return count_data
|
| 1574 |
except Exception as e:
|
| 1575 |
+
logger.error(f"Error fetching model count: {e}", exc_info=True)
|
| 1576 |
raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}")
|
| 1577 |
|
| 1578 |
|
|
|
|
| 1593 |
try:
|
| 1594 |
from datetime import datetime
|
| 1595 |
|
| 1596 |
+
tracker = get_tracker()
|
| 1597 |
|
| 1598 |
start = None
|
| 1599 |
end = None
|
|
|
|
| 1623 |
async def get_latest_model_count():
|
| 1624 |
"""Get the most recently recorded model count from database."""
|
| 1625 |
try:
|
| 1626 |
+
tracker = get_tracker()
|
| 1627 |
latest = tracker.get_latest_count()
|
| 1628 |
if latest is None:
|
| 1629 |
raise HTTPException(status_code=404, detail="No model counts recorded yet")
|
|
|
|
| 1647 |
use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
|
| 1648 |
"""
|
| 1649 |
try:
|
| 1650 |
+
tracker = get_tracker()
|
| 1651 |
|
|
|
|
| 1652 |
def record():
|
| 1653 |
if use_dataset_snapshot:
|
| 1654 |
count_data = tracker.get_count_from_dataset_snapshot()
|
| 1655 |
if count_data:
|
| 1656 |
tracker.record_count(count_data, source="dataset_snapshot")
|
| 1657 |
else:
|
|
|
|
| 1658 |
count_data = tracker.get_current_model_count(use_cache=False)
|
| 1659 |
tracker.record_count(count_data, source="api")
|
| 1660 |
else:
|
|
|
|
| 1681 |
days: Number of days to analyze
|
| 1682 |
"""
|
| 1683 |
try:
|
| 1684 |
+
tracker = get_tracker()
|
| 1685 |
stats = tracker.get_growth_stats(days)
|
| 1686 |
return stats
|
| 1687 |
except Exception as e:
|
|
|
|
| 1703 |
Similar to Open Syllabus graph export functionality.
|
| 1704 |
"""
|
| 1705 |
if df is None:
|
| 1706 |
+
raise DataNotLoadedError()
|
| 1707 |
|
| 1708 |
try:
|
| 1709 |
network_builder = ModelNetworkBuilder(df)
|
| 1710 |
|
|
|
|
| 1711 |
top_models = network_builder.get_top_models_by_field(
|
| 1712 |
library=library,
|
| 1713 |
pipeline_tag=pipeline_tag,
|
|
|
|
| 1720 |
raise HTTPException(status_code=404, detail="No models found matching criteria")
|
| 1721 |
|
| 1722 |
model_ids = [mid for mid, _ in top_models]
|
|
|
|
|
|
|
| 1723 |
graph = network_builder.build_cooccurrence_network(
|
| 1724 |
model_ids=model_ids,
|
| 1725 |
cooccurrence_method=cooccurrence_method
|
| 1726 |
)
|
| 1727 |
|
|
|
|
| 1728 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file:
|
| 1729 |
tmp_path = tmp_file.name
|
| 1730 |
network_builder.export_graphml(graph, tmp_path)
|
| 1731 |
|
|
|
|
| 1732 |
background_tasks.add_task(os.unlink, tmp_path)
|
| 1733 |
|
|
|
|
| 1734 |
return FileResponse(
|
| 1735 |
tmp_path,
|
| 1736 |
media_type='application/xml',
|
| 1737 |
filename=f'network_{cooccurrence_method}_{n}_models.graphml'
|
| 1738 |
)
|
| 1739 |
+
except (ValueError, KeyError, AttributeError, IOError) as e:
|
| 1740 |
+
logger.error(f"Error exporting network: {e}", exc_info=True)
|
| 1741 |
raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}")
|
| 1742 |
|
| 1743 |
|
|
|
|
| 1748 |
Extracts arXiv IDs from model tags and fetches paper information.
|
| 1749 |
"""
|
| 1750 |
if df is None:
|
| 1751 |
+
raise DataNotLoadedError()
|
| 1752 |
|
| 1753 |
model = df[df.get('model_id', '') == model_id]
|
| 1754 |
if len(model) == 0:
|
|
|
|
| 1777 |
}
|
| 1778 |
|
| 1779 |
|
| 1780 |
+
@app.get("/api/models/minimal.bin")
|
| 1781 |
+
async def get_minimal_binary():
|
| 1782 |
+
"""
|
| 1783 |
+
Serve the binary minimal dataset file.
|
| 1784 |
+
This is optimized for fast client-side loading.
|
| 1785 |
+
"""
|
| 1786 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 1787 |
+
root_dir = os.path.dirname(backend_dir)
|
| 1788 |
+
binary_path = os.path.join(root_dir, "cache", "binary", "embeddings.bin")
|
| 1789 |
+
|
| 1790 |
+
if not os.path.exists(binary_path):
|
| 1791 |
+
raise HTTPException(status_code=404, detail="Binary dataset not found. Run export_binary.py first.")
|
| 1792 |
+
|
| 1793 |
+
return FileResponse(
|
| 1794 |
+
binary_path,
|
| 1795 |
+
media_type="application/octet-stream",
|
| 1796 |
+
headers={
|
| 1797 |
+
"Content-Disposition": "attachment; filename=embeddings.bin",
|
| 1798 |
+
"Cache-Control": "public, max-age=3600"
|
| 1799 |
+
}
|
| 1800 |
+
)
|
| 1801 |
+
|
| 1802 |
+
|
| 1803 |
+
@app.get("/api/models/model_ids.json")
|
| 1804 |
+
async def get_model_ids_json():
|
| 1805 |
+
"""Serve the model IDs JSON file."""
|
| 1806 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 1807 |
+
root_dir = os.path.dirname(backend_dir)
|
| 1808 |
+
json_path = os.path.join(root_dir, "cache", "binary", "model_ids.json")
|
| 1809 |
+
|
| 1810 |
+
if not os.path.exists(json_path):
|
| 1811 |
+
raise HTTPException(status_code=404, detail="Model IDs file not found.")
|
| 1812 |
+
|
| 1813 |
+
return FileResponse(
|
| 1814 |
+
json_path,
|
| 1815 |
+
media_type="application/json",
|
| 1816 |
+
headers={"Cache-Control": "public, max-age=3600"}
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
@app.get("/api/models/metadata.json")
|
| 1821 |
+
async def get_metadata_json():
|
| 1822 |
+
"""Serve the metadata JSON file with lookup tables."""
|
| 1823 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 1824 |
+
root_dir = os.path.dirname(backend_dir)
|
| 1825 |
+
json_path = os.path.join(root_dir, "cache", "binary", "metadata.json")
|
| 1826 |
+
|
| 1827 |
+
if not os.path.exists(json_path):
|
| 1828 |
+
raise HTTPException(status_code=404, detail="Metadata file not found.")
|
| 1829 |
+
|
| 1830 |
+
return FileResponse(
|
| 1831 |
+
json_path,
|
| 1832 |
+
media_type="application/json",
|
| 1833 |
+
headers={"Cache-Control": "public, max-age=3600"}
|
| 1834 |
+
)
|
| 1835 |
+
|
| 1836 |
+
|
| 1837 |
@app.get("/api/model/{model_id}/files")
|
| 1838 |
async def get_model_files(model_id: str, branch: str = Query("main")):
|
| 1839 |
"""
|
| 1840 |
Get file tree for a model from Hugging Face.
|
| 1841 |
Proxies the request to avoid CORS issues.
|
| 1842 |
+
Returns a flat list of files with path and size information.
|
| 1843 |
"""
|
| 1844 |
+
if not model_id or not model_id.strip():
|
| 1845 |
+
raise HTTPException(status_code=400, detail="Invalid model ID")
|
| 1846 |
+
|
| 1847 |
+
branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
|
| 1848 |
+
|
| 1849 |
try:
|
| 1850 |
+
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
|
|
|
|
|
|
|
|
| 1851 |
for branch_name in branches_to_try:
|
| 1852 |
try:
|
| 1853 |
url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}"
|
| 1854 |
response = await client.get(url)
|
| 1855 |
+
|
| 1856 |
if response.status_code == 200:
|
| 1857 |
+
data = response.json()
|
| 1858 |
+
# Ensure we return an array
|
| 1859 |
+
if isinstance(data, list):
|
| 1860 |
+
return data
|
| 1861 |
+
elif isinstance(data, dict) and 'tree' in data:
|
| 1862 |
+
return data['tree']
|
| 1863 |
+
else:
|
| 1864 |
+
return []
|
| 1865 |
+
|
| 1866 |
+
elif response.status_code == 404:
|
| 1867 |
+
# Try next branch
|
| 1868 |
+
continue
|
| 1869 |
+
else:
|
| 1870 |
+
logger.warning(f"Unexpected status {response.status_code} for {url}")
|
| 1871 |
+
continue
|
| 1872 |
+
|
| 1873 |
+
except httpx.HTTPStatusError as e:
|
| 1874 |
+
if e.response.status_code == 404:
|
| 1875 |
+
continue # Try next branch
|
| 1876 |
+
logger.warning(f"HTTP error for branch {branch_name}: {e}")
|
| 1877 |
+
continue
|
| 1878 |
+
except httpx.HTTPError as e:
|
| 1879 |
+
logger.warning(f"HTTP error for branch {branch_name}: {e}")
|
| 1880 |
continue
|
| 1881 |
|
| 1882 |
+
# All branches failed
|
| 1883 |
+
raise HTTPException(
|
| 1884 |
+
status_code=404,
|
| 1885 |
+
detail=f"File tree not found for model '{model_id}'. The model may not exist or may not have any files."
|
| 1886 |
+
)
|
| 1887 |
+
|
| 1888 |
except httpx.TimeoutException:
|
| 1889 |
+
raise HTTPException(
|
| 1890 |
+
status_code=504,
|
| 1891 |
+
detail="Request to Hugging Face timed out. Please try again later."
|
| 1892 |
+
)
|
| 1893 |
+
except HTTPException:
|
| 1894 |
+
raise # Re-raise HTTP exceptions
|
| 1895 |
except Exception as e:
|
| 1896 |
+
logger.error(f"Error fetching file tree: {e}", exc_info=True)
|
| 1897 |
+
raise HTTPException(
|
| 1898 |
+
status_code=500,
|
| 1899 |
+
detail=f"Error fetching file tree: {str(e)}"
|
| 1900 |
+
)
|
| 1901 |
|
| 1902 |
|
| 1903 |
if __name__ == "__main__":
|
| 1904 |
import uvicorn
|
|
|
|
| 1905 |
port = int(os.getenv("PORT", 8000))
|
| 1906 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 1907 |
|
backend/api/routes/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API route modules.
|
| 3 |
+
"""
|
| 4 |
+
from . import models, stats, clusters
|
| 5 |
+
|
| 6 |
+
__all__ = ['models', 'stats', 'clusters']
|
backend/api/routes/clusters.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API routes for cluster endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from core.exceptions import DataNotLoadedError
|
| 8 |
+
import api.dependencies as deps
|
| 9 |
+
|
| 10 |
+
router = APIRouter(prefix="/api", tags=["clusters"])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/clusters")
|
| 14 |
+
async def get_clusters():
|
| 15 |
+
"""Get all clusters with metadata and hierarchical labels."""
|
| 16 |
+
if deps.df is None:
|
| 17 |
+
raise DataNotLoadedError()
|
| 18 |
+
|
| 19 |
+
# Import cluster_labels from models route
|
| 20 |
+
from api.routes.models import cluster_labels
|
| 21 |
+
|
| 22 |
+
# If clusters haven't been computed yet, return empty list instead of error
|
| 23 |
+
# This allows the frontend to work while data is still loading
|
| 24 |
+
if cluster_labels is None:
|
| 25 |
+
return {"clusters": []}
|
| 26 |
+
|
| 27 |
+
df = deps.df
|
| 28 |
+
|
| 29 |
+
# Generate hierarchical labels for clusters
|
| 30 |
+
clusters = []
|
| 31 |
+
unique_clusters = np.unique(cluster_labels)
|
| 32 |
+
|
| 33 |
+
for cluster_id in unique_clusters:
|
| 34 |
+
cluster_mask = cluster_labels == cluster_id
|
| 35 |
+
cluster_models = df[cluster_mask]
|
| 36 |
+
|
| 37 |
+
if len(cluster_models) == 0:
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
# Generate hierarchical label
|
| 41 |
+
library_counts = cluster_models['library_name'].value_counts()
|
| 42 |
+
pipeline_counts = cluster_models['pipeline_tag'].value_counts()
|
| 43 |
+
|
| 44 |
+
# Determine primary domain/library
|
| 45 |
+
if len(library_counts) > 0:
|
| 46 |
+
primary_lib = library_counts.index[0]
|
| 47 |
+
if primary_lib and pd.notna(primary_lib):
|
| 48 |
+
if 'transformers' in str(primary_lib).lower():
|
| 49 |
+
domain = "NLP"
|
| 50 |
+
elif 'diffusers' in str(primary_lib).lower():
|
| 51 |
+
domain = "Multimodal"
|
| 52 |
+
elif 'timm' in str(primary_lib).lower():
|
| 53 |
+
domain = "Computer Vision"
|
| 54 |
+
else:
|
| 55 |
+
domain = str(primary_lib).replace('_', ' ').title()
|
| 56 |
+
else:
|
| 57 |
+
domain = "Other"
|
| 58 |
+
else:
|
| 59 |
+
domain = "Other"
|
| 60 |
+
|
| 61 |
+
# Determine subdomain from pipeline
|
| 62 |
+
if len(pipeline_counts) > 0:
|
| 63 |
+
primary_pipeline = pipeline_counts.index[0]
|
| 64 |
+
if primary_pipeline and pd.notna(primary_pipeline):
|
| 65 |
+
subdomain = str(primary_pipeline).replace('-', ' ').replace('_', ' ').title()
|
| 66 |
+
else:
|
| 67 |
+
subdomain = "General"
|
| 68 |
+
else:
|
| 69 |
+
subdomain = "General"
|
| 70 |
+
|
| 71 |
+
# Determine characteristics
|
| 72 |
+
characteristics = []
|
| 73 |
+
model_ids_lower = cluster_models['model_id'].astype(str).str.lower()
|
| 74 |
+
if model_ids_lower.str.contains('gpt', na=False).any():
|
| 75 |
+
characteristics.append("GPT-based")
|
| 76 |
+
if cluster_models['parent_model'].notna().any():
|
| 77 |
+
characteristics.append("Fine-tuned")
|
| 78 |
+
if not characteristics:
|
| 79 |
+
characteristics.append("Base Models")
|
| 80 |
+
|
| 81 |
+
char_str = "; ".join(characteristics)
|
| 82 |
+
label = f"{domain} β {subdomain} ({char_str})"
|
| 83 |
+
|
| 84 |
+
# Generate color (use consistent colors based on cluster_id)
|
| 85 |
+
colors = [
|
| 86 |
+
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
|
| 87 |
+
"#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
|
| 88 |
+
]
|
| 89 |
+
color = colors[cluster_id % len(colors)]
|
| 90 |
+
|
| 91 |
+
clusters.append({
|
| 92 |
+
"cluster_id": int(cluster_id),
|
| 93 |
+
"cluster_label": label,
|
| 94 |
+
"count": int(len(cluster_models)),
|
| 95 |
+
"color": color
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
# Sort by count descending
|
| 99 |
+
clusters.sort(key=lambda x: x["count"], reverse=True)
|
| 100 |
+
|
| 101 |
+
return {"clusters": clusters}
|
| 102 |
+
|
backend/api/routes/models.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API routes for model data endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from fastapi import APIRouter, Query, HTTPException
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import pickle
|
| 9 |
+
import os
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from umap import UMAP
|
| 13 |
+
from models.schemas import ModelPoint
|
| 14 |
+
from utils.family_tree import calculate_family_depths
|
| 15 |
+
from utils.dimensionality_reduction import DimensionReducer
|
| 16 |
+
from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
|
| 17 |
+
import api.dependencies as deps
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
router = APIRouter(prefix="/api", tags=["models"])
|
| 22 |
+
|
| 23 |
+
# Global cluster labels cache (shared across routes)
|
| 24 |
+
cluster_labels = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
|
| 28 |
+
from sklearn.cluster import KMeans
|
| 29 |
+
|
| 30 |
+
n_samples = len(reduced_embeddings)
|
| 31 |
+
if n_samples < n_clusters:
|
| 32 |
+
n_clusters = max(1, n_samples // 10)
|
| 33 |
+
|
| 34 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
| 35 |
+
return kmeans.fit_predict(reduced_embeddings)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@router.get("/models")
|
| 39 |
+
async def get_models(
|
| 40 |
+
min_downloads: int = Query(0),
|
| 41 |
+
min_likes: int = Query(0),
|
| 42 |
+
search_query: Optional[str] = Query(None),
|
| 43 |
+
color_by: str = Query("library_name"),
|
| 44 |
+
size_by: str = Query("downloads"),
|
| 45 |
+
max_points: Optional[int] = Query(None),
|
| 46 |
+
projection_method: str = Query("umap"),
|
| 47 |
+
base_models_only: bool = Query(False),
|
| 48 |
+
max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
|
| 49 |
+
use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
|
| 50 |
+
):
|
| 51 |
+
if deps.df is None:
|
| 52 |
+
raise DataNotLoadedError()
|
| 53 |
+
|
| 54 |
+
df = deps.df
|
| 55 |
+
data_loader = deps.data_loader
|
| 56 |
+
|
| 57 |
+
# Filter data
|
| 58 |
+
filtered_df = data_loader.filter_data(
|
| 59 |
+
df=df,
|
| 60 |
+
min_downloads=min_downloads,
|
| 61 |
+
min_likes=min_likes,
|
| 62 |
+
search_query=search_query,
|
| 63 |
+
libraries=None,
|
| 64 |
+
pipeline_tags=None
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if base_models_only:
|
| 68 |
+
if 'parent_model' in filtered_df.columns:
|
| 69 |
+
filtered_df = filtered_df[
|
| 70 |
+
filtered_df['parent_model'].isna() |
|
| 71 |
+
(filtered_df['parent_model'].astype(str).str.strip() == '') |
|
| 72 |
+
(filtered_df['parent_model'].astype(str) == 'nan')
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
if max_hierarchy_depth is not None:
|
| 76 |
+
family_depths = calculate_family_depths(df)
|
| 77 |
+
filtered_df = filtered_df[
|
| 78 |
+
filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
filtered_count = len(filtered_df)
|
| 82 |
+
|
| 83 |
+
if len(filtered_df) == 0:
|
| 84 |
+
return {
|
| 85 |
+
"models": [],
|
| 86 |
+
"filtered_count": 0,
|
| 87 |
+
"returned_count": 0
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
if max_points is not None and len(filtered_df) > max_points:
|
| 91 |
+
if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
|
| 92 |
+
sampled_dfs = []
|
| 93 |
+
for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
|
| 94 |
+
n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
|
| 95 |
+
sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
|
| 96 |
+
filtered_df = pd.concat(sampled_dfs, ignore_index=True)
|
| 97 |
+
if len(filtered_df) > max_points:
|
| 98 |
+
filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
|
| 99 |
+
else:
|
| 100 |
+
filtered_df = filtered_df.reset_index(drop=True)
|
| 101 |
+
else:
|
| 102 |
+
filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
|
| 103 |
+
|
| 104 |
+
# Determine which embeddings to use
|
| 105 |
+
if use_graph_embeddings and deps.combined_embeddings is not None:
|
| 106 |
+
current_embeddings = deps.combined_embeddings
|
| 107 |
+
current_reduced = deps.reduced_embeddings_graph
|
| 108 |
+
embedding_type = "graph-aware"
|
| 109 |
+
else:
|
| 110 |
+
if deps.embeddings is None:
|
| 111 |
+
raise EmbeddingsNotReadyError()
|
| 112 |
+
current_embeddings = deps.embeddings
|
| 113 |
+
current_reduced = deps.reduced_embeddings
|
| 114 |
+
embedding_type = "text-only"
|
| 115 |
+
|
| 116 |
+
# Handle reduced embeddings loading/generation
|
| 117 |
+
reducer = deps.reducer
|
| 118 |
+
if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
|
| 119 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 120 |
+
root_dir = os.path.dirname(backend_dir)
|
| 121 |
+
cache_dir = os.path.join(root_dir, "cache")
|
| 122 |
+
cache_suffix = "_graph" if use_graph_embeddings and deps.combined_embeddings is not None else ""
|
| 123 |
+
reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
|
| 124 |
+
reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
|
| 125 |
+
|
| 126 |
+
if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
|
| 127 |
+
try:
|
| 128 |
+
with open(reduced_cache, 'rb') as f:
|
| 129 |
+
current_reduced = pickle.load(f)
|
| 130 |
+
if reducer is None or reducer.method != projection_method.lower():
|
| 131 |
+
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 132 |
+
reducer.load_reducer(reducer_cache)
|
| 133 |
+
except (IOError, pickle.UnpicklingError, EOFError) as e:
|
| 134 |
+
logger.warning(f"Failed to load cached reduced embeddings: {e}")
|
| 135 |
+
current_reduced = None
|
| 136 |
+
|
| 137 |
+
if current_reduced is None:
|
| 138 |
+
if reducer is None or reducer.method != projection_method.lower():
|
| 139 |
+
reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
|
| 140 |
+
if projection_method.lower() == "umap":
|
| 141 |
+
reducer.reducer = UMAP(
|
| 142 |
+
n_components=3,
|
| 143 |
+
n_neighbors=30,
|
| 144 |
+
min_dist=0.3,
|
| 145 |
+
metric='cosine',
|
| 146 |
+
random_state=42,
|
| 147 |
+
n_jobs=-1,
|
| 148 |
+
low_memory=True,
|
| 149 |
+
spread=1.5
|
| 150 |
+
)
|
| 151 |
+
current_reduced = reducer.fit_transform(current_embeddings)
|
| 152 |
+
with open(reduced_cache, 'wb') as f:
|
| 153 |
+
pickle.dump(current_reduced, f)
|
| 154 |
+
reducer.save_reducer(reducer_cache)
|
| 155 |
+
|
| 156 |
+
# Update global variable
|
| 157 |
+
if use_graph_embeddings and deps.combined_embeddings is not None:
|
| 158 |
+
deps.reduced_embeddings_graph = current_reduced
|
| 159 |
+
else:
|
| 160 |
+
deps.reduced_embeddings = current_reduced
|
| 161 |
+
|
| 162 |
+
# Get indices for filtered data
|
| 163 |
+
filtered_model_ids = filtered_df['model_id'].astype(str).values
|
| 164 |
+
|
| 165 |
+
if df.index.name == 'model_id' or 'model_id' in df.index.names:
|
| 166 |
+
filtered_indices = []
|
| 167 |
+
for model_id in filtered_model_ids:
|
| 168 |
+
try:
|
| 169 |
+
pos = df.index.get_loc(model_id)
|
| 170 |
+
if isinstance(pos, (int, np.integer)):
|
| 171 |
+
filtered_indices.append(int(pos))
|
| 172 |
+
elif isinstance(pos, (slice, np.ndarray)):
|
| 173 |
+
if isinstance(pos, slice):
|
| 174 |
+
filtered_indices.append(int(pos.start))
|
| 175 |
+
else:
|
| 176 |
+
filtered_indices.append(int(pos[0]))
|
| 177 |
+
except (KeyError, TypeError):
|
| 178 |
+
continue
|
| 179 |
+
filtered_indices = np.array(filtered_indices, dtype=np.int32)
|
| 180 |
+
else:
|
| 181 |
+
df_model_ids = df['model_id'].astype(str).values
|
| 182 |
+
model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
|
| 183 |
+
filtered_indices = np.array([
|
| 184 |
+
model_id_to_pos[mid] for mid in filtered_model_ids
|
| 185 |
+
if mid in model_id_to_pos
|
| 186 |
+
], dtype=np.int32)
|
| 187 |
+
|
| 188 |
+
if len(filtered_indices) == 0:
|
| 189 |
+
return {
|
| 190 |
+
"models": [],
|
| 191 |
+
"embedding_type": embedding_type,
|
| 192 |
+
"filtered_count": filtered_count,
|
| 193 |
+
"returned_count": 0
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
filtered_reduced = current_reduced[filtered_indices]
|
| 197 |
+
family_depths = calculate_family_depths(df)
|
| 198 |
+
|
| 199 |
+
global cluster_labels
|
| 200 |
+
clustering_embeddings = current_reduced
|
| 201 |
+
if cluster_labels is None or len(cluster_labels) != len(clustering_embeddings):
|
| 202 |
+
cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
|
| 203 |
+
|
| 204 |
+
filtered_clusters = cluster_labels[filtered_indices]
|
| 205 |
+
|
| 206 |
+
model_ids = filtered_df['model_id'].astype(str).values
|
| 207 |
+
library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
|
| 208 |
+
pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
|
| 209 |
+
downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
|
| 210 |
+
likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
|
| 211 |
+
trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
|
| 212 |
+
tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
|
| 213 |
+
parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
|
| 214 |
+
licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
|
| 215 |
+
created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
|
| 216 |
+
|
| 217 |
+
x_coords = filtered_reduced[:, 0].astype(float)
|
| 218 |
+
y_coords = filtered_reduced[:, 1].astype(float)
|
| 219 |
+
z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
|
| 220 |
+
models = [
|
| 221 |
+
ModelPoint(
|
| 222 |
+
model_id=model_ids[idx],
|
| 223 |
+
x=float(x_coords[idx]),
|
| 224 |
+
y=float(y_coords[idx]),
|
| 225 |
+
z=float(z_coords[idx]),
|
| 226 |
+
library_name=library_names[idx] if pd.notna(library_names[idx]) else None,
|
| 227 |
+
pipeline_tag=pipeline_tags[idx] if pd.notna(pipeline_tags[idx]) else None,
|
| 228 |
+
downloads=int(downloads_arr[idx]),
|
| 229 |
+
likes=int(likes_arr[idx]),
|
| 230 |
+
trending_score=float(trending_scores[idx]) if idx < len(trending_scores) and pd.notna(trending_scores[idx]) else None,
|
| 231 |
+
tags=tags_arr[idx] if idx < len(tags_arr) and pd.notna(tags_arr[idx]) else None,
|
| 232 |
+
parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
|
| 233 |
+
licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
|
| 234 |
+
family_depth=family_depths.get(model_ids[idx], None),
|
| 235 |
+
cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
|
| 236 |
+
created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
|
| 237 |
+
)
|
| 238 |
+
for idx in range(len(filtered_df))
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"models": models,
|
| 243 |
+
"embedding_type": embedding_type,
|
| 244 |
+
"filtered_count": filtered_count,
|
| 245 |
+
"returned_count": len(models)
|
| 246 |
+
}
|
| 247 |
+
|
backend/api/routes/stats.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API routes for statistics endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter
|
| 5 |
+
from core.exceptions import DataNotLoadedError
|
| 6 |
+
import api.dependencies as deps
|
| 7 |
+
|
| 8 |
+
router = APIRouter(prefix="/api", tags=["stats"])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.get("/stats")
|
| 12 |
+
async def get_stats():
|
| 13 |
+
"""Get dataset statistics."""
|
| 14 |
+
if deps.df is None:
|
| 15 |
+
raise DataNotLoadedError()
|
| 16 |
+
|
| 17 |
+
df = deps.df
|
| 18 |
+
total_models = len(df.index) if hasattr(df, 'index') else len(df)
|
| 19 |
+
|
| 20 |
+
# Get unique licenses with counts
|
| 21 |
+
licenses = {}
|
| 22 |
+
if 'license' in df.columns:
|
| 23 |
+
import pandas as pd
|
| 24 |
+
license_counts = df['license'].value_counts().to_dict()
|
| 25 |
+
licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
|
| 26 |
+
|
| 27 |
+
return {
|
| 28 |
+
"total_models": total_models,
|
| 29 |
+
"unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
|
| 30 |
+
"unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
|
| 31 |
+
"unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
|
| 32 |
+
"unique_licenses": len(licenses),
|
| 33 |
+
"licenses": licenses,
|
| 34 |
+
"avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
|
| 35 |
+
"avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
|
| 36 |
+
}
|
| 37 |
+
|
backend/config/requirements.txt
CHANGED
|
@@ -11,5 +11,6 @@ huggingface-hub>=0.17.0
|
|
| 11 |
schedule>=1.2.0
|
| 12 |
tqdm>=4.66.0
|
| 13 |
networkx>=3.0
|
|
|
|
| 14 |
httpx>=0.24.0
|
| 15 |
|
|
|
|
| 11 |
schedule>=1.2.0
|
| 12 |
tqdm>=4.66.0
|
| 13 |
networkx>=3.0
|
| 14 |
+
node2vec>=0.4.6
|
| 15 |
httpx>=0.24.0
|
| 16 |
|
backend/core/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core configuration and utilities."""
|
| 2 |
+
|
backend/core/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management."""
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
class Settings:
|
| 6 |
+
"""Application settings."""
|
| 7 |
+
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
| 8 |
+
ALLOW_ALL_ORIGINS: bool = os.getenv("ALLOW_ALL_ORIGINS", "True").lower() in ("true", "1", "yes")
|
| 9 |
+
SAMPLE_SIZE: Optional[int] = None
|
| 10 |
+
USE_GRAPH_EMBEDDINGS: bool = os.getenv("USE_GRAPH_EMBEDDINGS", "false").lower() == "true"
|
| 11 |
+
PORT: int = int(os.getenv("PORT", 8000))
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def get_sample_size(cls) -> Optional[int]:
|
| 15 |
+
"""Get sample size from environment."""
|
| 16 |
+
sample_size_env = os.getenv("SAMPLE_SIZE")
|
| 17 |
+
if sample_size_env:
|
| 18 |
+
sample_size_val = int(sample_size_env)
|
| 19 |
+
return sample_size_val if sample_size_val > 0 else None
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
settings = Settings()
|
| 23 |
+
|
backend/core/exceptions.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom exceptions."""
|
| 2 |
+
from fastapi import HTTPException
|
| 3 |
+
|
| 4 |
+
class ModelNotFoundError(HTTPException):
|
| 5 |
+
"""Model not found exception."""
|
| 6 |
+
def __init__(self, model_id: str):
|
| 7 |
+
super().__init__(status_code=404, detail=f"Model not found: {model_id}")
|
| 8 |
+
|
| 9 |
+
class DataNotLoadedError(HTTPException):
|
| 10 |
+
"""Data not loaded exception."""
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super().__init__(status_code=503, detail="Data not loaded")
|
| 13 |
+
|
| 14 |
+
class EmbeddingsNotReadyError(HTTPException):
|
| 15 |
+
"""Embeddings not ready exception."""
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__(status_code=503, detail="Embeddings not ready")
|
| 18 |
+
|
backend/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models and schemas."""
|
| 2 |
+
|
backend/models/schemas.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for API."""
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
class ModelPoint(BaseModel):
|
| 6 |
+
"""Model point in 3D space."""
|
| 7 |
+
model_id: str
|
| 8 |
+
x: float
|
| 9 |
+
y: float
|
| 10 |
+
z: float
|
| 11 |
+
library_name: Optional[str]
|
| 12 |
+
pipeline_tag: Optional[str]
|
| 13 |
+
downloads: int
|
| 14 |
+
likes: int
|
| 15 |
+
trending_score: Optional[float]
|
| 16 |
+
tags: Optional[str]
|
| 17 |
+
parent_model: Optional[str] = None
|
| 18 |
+
licenses: Optional[str] = None
|
| 19 |
+
family_depth: Optional[int] = None
|
| 20 |
+
cluster_id: Optional[int] = None
|
| 21 |
+
created_at: Optional[str] = None # ISO format date string
|
| 22 |
+
|
backend/scripts/export_binary.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Export minimal dataset to binary format for fast client-side loading.
|
| 3 |
+
This creates a compact binary representation optimized for WebGL rendering.
|
| 4 |
+
"""
|
| 5 |
+
import struct
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
# Add parent directory to path
|
| 14 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 15 |
+
|
| 16 |
+
from utils.data_loader import ModelDataLoader
|
| 17 |
+
from utils.dimensionality_reduction import DimensionReducer
|
| 18 |
+
from utils.embeddings import ModelEmbedder
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def calculate_family_depths(df: pd.DataFrame) -> dict:
|
| 22 |
+
"""Calculate depth of each model in its family tree."""
|
| 23 |
+
depths = {}
|
| 24 |
+
|
| 25 |
+
def get_depth(model_id: str, visited: set = None) -> int:
|
| 26 |
+
if visited is None:
|
| 27 |
+
visited = set()
|
| 28 |
+
if model_id in visited:
|
| 29 |
+
return 0 # Cycle detected
|
| 30 |
+
visited.add(model_id)
|
| 31 |
+
|
| 32 |
+
if model_id in depths:
|
| 33 |
+
return depths[model_id]
|
| 34 |
+
|
| 35 |
+
parent_col = df.get('parent_model', pd.Series([None] * len(df), index=df.index))
|
| 36 |
+
model_row = df[df['model_id'] == model_id]
|
| 37 |
+
|
| 38 |
+
if model_row.empty:
|
| 39 |
+
depths[model_id] = 0
|
| 40 |
+
return 0
|
| 41 |
+
|
| 42 |
+
parent = model_row.iloc[0].get('parent_model')
|
| 43 |
+
if pd.isna(parent) or parent == '' or str(parent) == 'nan':
|
| 44 |
+
depths[model_id] = 0
|
| 45 |
+
return 0
|
| 46 |
+
|
| 47 |
+
parent_depth = get_depth(str(parent), visited.copy())
|
| 48 |
+
depth = parent_depth + 1
|
| 49 |
+
depths[model_id] = depth
|
| 50 |
+
return depth
|
| 51 |
+
|
| 52 |
+
for model_id in df['model_id'].unique():
|
| 53 |
+
if model_id not in depths:
|
| 54 |
+
get_depth(str(model_id))
|
| 55 |
+
|
| 56 |
+
return depths
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def export_binary_dataset(df: pd.DataFrame, reduced_embeddings: np.ndarray, output_dir: Path):
|
| 60 |
+
"""
|
| 61 |
+
Export minimal dataset to binary format for fast client-side loading.
|
| 62 |
+
|
| 63 |
+
Binary format:
|
| 64 |
+
- Header (64 bytes): magic, version, counts, lookup table sizes
|
| 65 |
+
- Domain lookup table (32 bytes per domain)
|
| 66 |
+
- License lookup table (32 bytes per license)
|
| 67 |
+
- Family lookup table (32 bytes per family)
|
| 68 |
+
- Model records (16 bytes each): x, y, z, domain_id, license_id, family_id, flags
|
| 69 |
+
"""
|
| 70 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
print(f"Exporting {len(df)} models to binary format...")
|
| 73 |
+
|
| 74 |
+
# Ensure we have coordinates
|
| 75 |
+
if 'x' not in df.columns or 'y' not in df.columns:
|
| 76 |
+
if reduced_embeddings is None or len(reduced_embeddings) != len(df):
|
| 77 |
+
raise ValueError("Need reduced embeddings to generate coordinates")
|
| 78 |
+
|
| 79 |
+
df['x'] = reduced_embeddings[:, 0] if reduced_embeddings.shape[1] > 0 else 0.0
|
| 80 |
+
df['y'] = reduced_embeddings[:, 1] if reduced_embeddings.shape[1] > 1 else 0.0
|
| 81 |
+
df['z'] = reduced_embeddings[:, 2] if reduced_embeddings.shape[1] > 2 else 0.0
|
| 82 |
+
|
| 83 |
+
# Create lookup tables
|
| 84 |
+
# Domain = library_name
|
| 85 |
+
domains = sorted(df['library_name'].dropna().astype(str).unique())
|
| 86 |
+
domains = [d for d in domains if d and d != 'nan'][:255] # Limit to 255
|
| 87 |
+
|
| 88 |
+
# License
|
| 89 |
+
licenses = sorted(df['license'].dropna().astype(str).unique())
|
| 90 |
+
licenses = [l for l in licenses if l and l != 'nan'][:255] # Limit to 255
|
| 91 |
+
|
| 92 |
+
# Family ID mapping (use parent_model to create family groups)
|
| 93 |
+
family_depths = calculate_family_depths(df)
|
| 94 |
+
|
| 95 |
+
# Create family mapping: group models by root parent
|
| 96 |
+
def get_root_parent(model_id: str) -> str:
|
| 97 |
+
visited = set()
|
| 98 |
+
current = str(model_id)
|
| 99 |
+
while current in visited == False:
|
| 100 |
+
visited.add(current)
|
| 101 |
+
model_row = df[df['model_id'] == current]
|
| 102 |
+
if model_row.empty:
|
| 103 |
+
return current
|
| 104 |
+
parent = model_row.iloc[0].get('parent_model')
|
| 105 |
+
if pd.isna(parent) or parent == '' or str(parent) == 'nan':
|
| 106 |
+
return current
|
| 107 |
+
current = str(parent)
|
| 108 |
+
return current
|
| 109 |
+
|
| 110 |
+
root_parents = {}
|
| 111 |
+
family_counter = 0
|
| 112 |
+
for model_id in df['model_id'].unique():
|
| 113 |
+
root = get_root_parent(str(model_id))
|
| 114 |
+
if root not in root_parents:
|
| 115 |
+
root_parents[root] = family_counter
|
| 116 |
+
family_counter += 1
|
| 117 |
+
|
| 118 |
+
# Map each model to its family
|
| 119 |
+
model_to_family = {}
|
| 120 |
+
for model_id in df['model_id'].unique():
|
| 121 |
+
root = get_root_parent(str(model_id))
|
| 122 |
+
model_to_family[str(model_id)] = root_parents.get(root, 65535)
|
| 123 |
+
|
| 124 |
+
# Limit families to 65535 (u16 max)
|
| 125 |
+
if len(root_parents) > 65535:
|
| 126 |
+
# Use hash-based family IDs
|
| 127 |
+
import hashlib
|
| 128 |
+
for model_id in df['model_id'].unique():
|
| 129 |
+
root = get_root_parent(str(model_id))
|
| 130 |
+
family_hash = int(hashlib.md5(root.encode()).hexdigest()[:4], 16) % 65535
|
| 131 |
+
model_to_family[str(model_id)] = family_hash
|
| 132 |
+
|
| 133 |
+
# Prepare model records
|
| 134 |
+
records = []
|
| 135 |
+
model_ids = []
|
| 136 |
+
|
| 137 |
+
for idx, row in df.iterrows():
|
| 138 |
+
model_id = str(row['model_id'])
|
| 139 |
+
model_ids.append(model_id)
|
| 140 |
+
|
| 141 |
+
# Get coordinates
|
| 142 |
+
x = float(row.get('x', 0.0))
|
| 143 |
+
y = float(row.get('y', 0.0))
|
| 144 |
+
z = float(row.get('z', 0.0))
|
| 145 |
+
|
| 146 |
+
# Encode domain (library_name)
|
| 147 |
+
domain_str = str(row.get('library_name', ''))
|
| 148 |
+
domain_id = domains.index(domain_str) if domain_str in domains else 255
|
| 149 |
+
|
| 150 |
+
# Encode license
|
| 151 |
+
license_str = str(row.get('license', ''))
|
| 152 |
+
license_id = licenses.index(license_str) if license_str in licenses else 255
|
| 153 |
+
|
| 154 |
+
# Encode family
|
| 155 |
+
family_id = model_to_family.get(model_id, 65535)
|
| 156 |
+
|
| 157 |
+
# Encode flags
|
| 158 |
+
flags = 0
|
| 159 |
+
parent = row.get('parent_model')
|
| 160 |
+
if pd.isna(parent) or parent == '' or str(parent) == 'nan':
|
| 161 |
+
flags |= 0x01 # is_base_model
|
| 162 |
+
|
| 163 |
+
# Check if has children (simple check - could be improved)
|
| 164 |
+
children = df[df['parent_model'] == model_id]
|
| 165 |
+
if len(children) > 0:
|
| 166 |
+
flags |= 0x04 # has_children
|
| 167 |
+
elif not pd.isna(parent) and parent != '' and str(parent) != 'nan':
|
| 168 |
+
flags |= 0x02 # has_parent
|
| 169 |
+
|
| 170 |
+
# Pack record: f32 x, f32 y, f32 z, u8 domain, u8 license, u16 family, u8 flags
|
| 171 |
+
records.append(struct.pack('fffBBBH', x, y, z, domain_id, license_id, family_id, flags))
|
| 172 |
+
|
| 173 |
+
num_models = len(records)
|
| 174 |
+
|
| 175 |
+
# Write binary file
|
| 176 |
+
with open(output_dir / 'embeddings.bin', 'wb') as f:
|
| 177 |
+
# Header (64 bytes)
|
| 178 |
+
header = struct.pack('5sBIIIBBH50s',
|
| 179 |
+
b'HFVIZ', # magic (5 bytes)
|
| 180 |
+
1, # version (1 byte)
|
| 181 |
+
num_models, # num_models (4 bytes)
|
| 182 |
+
len(domains), # num_domains (4 bytes)
|
| 183 |
+
len(licenses), # num_licenses (4 bytes)
|
| 184 |
+
len(set(model_to_family.values())), # num_families (4 bytes)
|
| 185 |
+
0, # reserved (1 byte)
|
| 186 |
+
0, # reserved (1 byte)
|
| 187 |
+
0, # reserved (2 bytes)
|
| 188 |
+
b'\x00' * 50 # padding (50 bytes)
|
| 189 |
+
)
|
| 190 |
+
f.write(header)
|
| 191 |
+
|
| 192 |
+
# Domain lookup table (32 bytes per domain, null-terminated)
|
| 193 |
+
for domain in domains:
|
| 194 |
+
domain_bytes = domain.encode('utf-8')[:31]
|
| 195 |
+
f.write(domain_bytes.ljust(32, b'\x00'))
|
| 196 |
+
|
| 197 |
+
# License lookup table (32 bytes per license)
|
| 198 |
+
for license in licenses:
|
| 199 |
+
license_bytes = license.encode('utf-8')[:31]
|
| 200 |
+
f.write(license_bytes.ljust(32, b'\x00'))
|
| 201 |
+
|
| 202 |
+
# Model records
|
| 203 |
+
f.write(b''.join(records))
|
| 204 |
+
|
| 205 |
+
# Write model IDs JSON (separate file for string table)
|
| 206 |
+
with open(output_dir / 'model_ids.json', 'w') as f:
|
| 207 |
+
json.dump(model_ids, f)
|
| 208 |
+
|
| 209 |
+
# Write metadata JSON
|
| 210 |
+
metadata = {
|
| 211 |
+
'domains': domains,
|
| 212 |
+
'licenses': licenses,
|
| 213 |
+
'num_models': num_models,
|
| 214 |
+
'num_families': len(set(model_to_family.values())),
|
| 215 |
+
'version': 1
|
| 216 |
+
}
|
| 217 |
+
with open(output_dir / 'metadata.json', 'w') as f:
|
| 218 |
+
json.dump(metadata, f, indent=2)
|
| 219 |
+
|
| 220 |
+
binary_size = (output_dir / 'embeddings.bin').stat().st_size
|
| 221 |
+
json_size = (output_dir / 'model_ids.json').stat().st_size
|
| 222 |
+
|
| 223 |
+
print(f"β Exported {num_models} models")
|
| 224 |
+
print(f"β Binary size: {binary_size / 1024 / 1024:.2f} MB")
|
| 225 |
+
print(f"β Model IDs JSON: {json_size / 1024 / 1024:.2f} MB")
|
| 226 |
+
print(f"β Total: {(binary_size + json_size) / 1024 / 1024:.2f} MB")
|
| 227 |
+
print(f"β Domains: {len(domains)}")
|
| 228 |
+
print(f"β Licenses: {len(licenses)}")
|
| 229 |
+
print(f"β Families: {len(set(model_to_family.values()))}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == '__main__':
|
| 233 |
+
import argparse
|
| 234 |
+
|
| 235 |
+
parser = argparse.ArgumentParser(description='Export dataset to binary format')
|
| 236 |
+
parser.add_argument('--output', type=str, default='backend/cache/binary', help='Output directory')
|
| 237 |
+
parser.add_argument('--sample-size', type=int, default=None, help='Sample size (for testing)')
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
output_dir = Path(args.output)
|
| 241 |
+
|
| 242 |
+
# Load data
|
| 243 |
+
print("Loading dataset...")
|
| 244 |
+
data_loader = ModelDataLoader()
|
| 245 |
+
df = data_loader.load_data(sample_size=args.sample_size)
|
| 246 |
+
df = data_loader.preprocess_for_embedding(df)
|
| 247 |
+
|
| 248 |
+
# Generate embeddings and reduce dimensions if needed
|
| 249 |
+
if 'x' not in df.columns or 'y' not in df.columns:
|
| 250 |
+
print("Generating embeddings...")
|
| 251 |
+
embedder = ModelEmbedder()
|
| 252 |
+
embeddings = embedder.generate_embeddings(df['combined_text'].tolist())
|
| 253 |
+
|
| 254 |
+
print("Reducing dimensions...")
|
| 255 |
+
reducer = DimensionReducer()
|
| 256 |
+
reduced_embeddings = reducer.reduce_dimensions(embeddings, n_components=3, method='umap')
|
| 257 |
+
else:
|
| 258 |
+
reduced_embeddings = None
|
| 259 |
+
|
| 260 |
+
# Export
|
| 261 |
+
export_binary_dataset(df, reduced_embeddings, output_dir)
|
| 262 |
+
print("Done!")
|
| 263 |
+
|
backend/services/model_tracker.py
CHANGED
|
@@ -5,11 +5,16 @@ Tracks the number of models over time and provides historical data.
|
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
import sqlite3
|
|
|
|
|
|
|
| 8 |
from datetime import datetime, timedelta
|
| 9 |
from typing import Dict, List, Optional, Tuple
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
import pandas as pd
|
| 12 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class ModelCountTracker:
|
|
@@ -34,7 +39,6 @@ class ModelCountTracker:
|
|
| 34 |
conn = sqlite3.connect(self.db_path)
|
| 35 |
cursor = conn.cursor()
|
| 36 |
|
| 37 |
-
# Create table for model counts
|
| 38 |
cursor.execute("""
|
| 39 |
CREATE TABLE IF NOT EXISTS model_counts (
|
| 40 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
@@ -47,7 +51,6 @@ class ModelCountTracker:
|
|
| 47 |
)
|
| 48 |
""")
|
| 49 |
|
| 50 |
-
# Create index for faster queries
|
| 51 |
cursor.execute("""
|
| 52 |
CREATE INDEX IF NOT EXISTS idx_timestamp
|
| 53 |
ON model_counts(timestamp)
|
|
@@ -56,27 +59,90 @@ class ModelCountTracker:
|
|
| 56 |
conn.commit()
|
| 57 |
conn.close()
|
| 58 |
|
| 59 |
-
def
|
| 60 |
"""
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
Returns:
|
| 65 |
-
Dictionary with
|
| 66 |
"""
|
| 67 |
try:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
total_count = 0
|
| 73 |
library_counts = {}
|
| 74 |
pipeline_counts = {}
|
| 75 |
-
page_size = 1000
|
| 76 |
-
max_pages = 100
|
| 77 |
-
sample_size = 10000
|
| 78 |
|
| 79 |
-
# Count total models efficiently
|
| 80 |
models_iter = self.api.list_models(full=False)
|
| 81 |
sampled_models = []
|
| 82 |
|
|
@@ -87,25 +153,18 @@ class ModelCountTracker:
|
|
| 87 |
if i < sample_size:
|
| 88 |
sampled_models.append(model)
|
| 89 |
|
| 90 |
-
# Safety limit to prevent infinite loops
|
| 91 |
if i >= max_pages * page_size:
|
| 92 |
-
# If we hit the limit, estimate total from sample
|
| 93 |
-
# This is a rough estimate - for exact count, increase max_pages
|
| 94 |
break
|
| 95 |
|
| 96 |
-
# Calculate breakdowns from sample (extrapolate if needed)
|
| 97 |
for model in sampled_models:
|
| 98 |
-
# Count by library
|
| 99 |
if hasattr(model, 'library_name') and model.library_name:
|
| 100 |
lib = model.library_name
|
| 101 |
library_counts[lib] = library_counts.get(lib, 0) + 1
|
| 102 |
|
| 103 |
-
# Count by pipeline
|
| 104 |
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
|
| 105 |
pipeline = model.pipeline_tag
|
| 106 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 107 |
|
| 108 |
-
# If we sampled, scale up the breakdowns proportionally
|
| 109 |
if len(sampled_models) < total_count and len(sampled_models) > 0:
|
| 110 |
scale_factor = total_count / len(sampled_models)
|
| 111 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
@@ -118,7 +177,7 @@ class ModelCountTracker:
|
|
| 118 |
"timestamp": datetime.utcnow().isoformat()
|
| 119 |
}
|
| 120 |
except Exception as e:
|
| 121 |
-
|
| 122 |
return {
|
| 123 |
"total_models": 0,
|
| 124 |
"models_by_library": {},
|
|
@@ -162,7 +221,7 @@ class ModelCountTracker:
|
|
| 162 |
conn.close()
|
| 163 |
return True
|
| 164 |
except Exception as e:
|
| 165 |
-
|
| 166 |
return False
|
| 167 |
|
| 168 |
def get_historical_counts(
|
|
@@ -211,7 +270,7 @@ class ModelCountTracker:
|
|
| 211 |
conn.close()
|
| 212 |
return results
|
| 213 |
except Exception as e:
|
| 214 |
-
|
| 215 |
return []
|
| 216 |
|
| 217 |
def get_latest_count(self) -> Optional[Dict]:
|
|
@@ -239,7 +298,7 @@ class ModelCountTracker:
|
|
| 239 |
}
|
| 240 |
return None
|
| 241 |
except Exception as e:
|
| 242 |
-
|
| 243 |
return None
|
| 244 |
|
| 245 |
def get_growth_stats(self, days: int = 7) -> Dict:
|
|
|
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
import sqlite3
|
| 8 |
+
import logging
|
| 9 |
+
import re
|
| 10 |
from datetime import datetime, timedelta
|
| 11 |
from typing import Dict, List, Optional, Tuple
|
| 12 |
from huggingface_hub import HfApi
|
| 13 |
import pandas as pd
|
| 14 |
from pathlib import Path
|
| 15 |
+
import httpx
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
| 20 |
class ModelCountTracker:
|
|
|
|
| 39 |
conn = sqlite3.connect(self.db_path)
|
| 40 |
cursor = conn.cursor()
|
| 41 |
|
|
|
|
| 42 |
cursor.execute("""
|
| 43 |
CREATE TABLE IF NOT EXISTS model_counts (
|
| 44 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
| 51 |
)
|
| 52 |
""")
|
| 53 |
|
|
|
|
| 54 |
cursor.execute("""
|
| 55 |
CREATE INDEX IF NOT EXISTS idx_timestamp
|
| 56 |
ON model_counts(timestamp)
|
|
|
|
| 59 |
conn.commit()
|
| 60 |
conn.close()
|
| 61 |
|
| 62 |
+
def get_count_from_models_page(self) -> Optional[Dict]:
|
| 63 |
"""
|
| 64 |
+
Get model count by scraping the Hugging Face models page.
|
| 65 |
+
Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
|
| 66 |
+
or from window.__hf_deferred["numTotalItems"] in the page script.
|
| 67 |
|
| 68 |
Returns:
|
| 69 |
+
Dictionary with total_models count, or None if extraction fails
|
| 70 |
"""
|
| 71 |
try:
|
| 72 |
+
url = "https://huggingface.co/models"
|
| 73 |
+
response = httpx.get(url, timeout=10.0, follow_redirects=True)
|
| 74 |
+
response.raise_for_status()
|
| 75 |
+
|
| 76 |
+
html_content = response.text
|
| 77 |
+
|
| 78 |
+
deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
|
| 79 |
+
deferred_matches = re.findall(deferred_pattern, html_content)
|
| 80 |
+
|
| 81 |
+
if deferred_matches:
|
| 82 |
+
total_models = int(deferred_matches[0])
|
| 83 |
+
logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"total_models": total_models,
|
| 87 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 88 |
+
"source": "hf_models_page",
|
| 89 |
+
"models_by_library": {},
|
| 90 |
+
"models_by_pipeline": {},
|
| 91 |
+
"models_by_author": {}
|
| 92 |
+
}
|
| 93 |
|
| 94 |
+
pattern = r'<div[^>]*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)</div>'
|
| 95 |
+
matches = re.findall(pattern, html_content)
|
| 96 |
+
|
| 97 |
+
if matches:
|
| 98 |
+
count_str = matches[0].replace(',', '')
|
| 99 |
+
total_models = int(count_str)
|
| 100 |
+
|
| 101 |
+
logger.info(f"Extracted model count from div: {total_models}")
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"total_models": total_models,
|
| 105 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 106 |
+
"source": "hf_models_page",
|
| 107 |
+
"models_by_library": {},
|
| 108 |
+
"models_by_pipeline": {},
|
| 109 |
+
"models_by_author": {}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
logger.warning("Could not find model count in HF models page HTML")
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
except httpx.HTTPError as e:
|
| 116 |
+
logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
|
| 117 |
+
return None
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def get_current_model_count(self, use_models_page: bool = True) -> Dict:
|
| 123 |
+
"""
|
| 124 |
+
Fetch current model count from Hugging Face Hub.
|
| 125 |
+
Uses multiple strategies: models page scraping (fastest), then API enumeration.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
use_models_page: Try to get count from HF models page first (default: True)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Dictionary with total count and breakdowns
|
| 132 |
+
"""
|
| 133 |
+
if use_models_page:
|
| 134 |
+
page_count = self.get_count_from_models_page()
|
| 135 |
+
if page_count:
|
| 136 |
+
return page_count
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
total_count = 0
|
| 140 |
library_counts = {}
|
| 141 |
pipeline_counts = {}
|
| 142 |
+
page_size = 1000
|
| 143 |
+
max_pages = 100
|
| 144 |
+
sample_size = 10000
|
| 145 |
|
|
|
|
| 146 |
models_iter = self.api.list_models(full=False)
|
| 147 |
sampled_models = []
|
| 148 |
|
|
|
|
| 153 |
if i < sample_size:
|
| 154 |
sampled_models.append(model)
|
| 155 |
|
|
|
|
| 156 |
if i >= max_pages * page_size:
|
|
|
|
|
|
|
| 157 |
break
|
| 158 |
|
|
|
|
| 159 |
for model in sampled_models:
|
|
|
|
| 160 |
if hasattr(model, 'library_name') and model.library_name:
|
| 161 |
lib = model.library_name
|
| 162 |
library_counts[lib] = library_counts.get(lib, 0) + 1
|
| 163 |
|
|
|
|
| 164 |
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
|
| 165 |
pipeline = model.pipeline_tag
|
| 166 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 167 |
|
|
|
|
| 168 |
if len(sampled_models) < total_count and len(sampled_models) > 0:
|
| 169 |
scale_factor = total_count / len(sampled_models)
|
| 170 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
|
|
| 177 |
"timestamp": datetime.utcnow().isoformat()
|
| 178 |
}
|
| 179 |
except Exception as e:
|
| 180 |
+
logger.error(f"Error fetching model count: {e}", exc_info=True)
|
| 181 |
return {
|
| 182 |
"total_models": 0,
|
| 183 |
"models_by_library": {},
|
|
|
|
| 221 |
conn.close()
|
| 222 |
return True
|
| 223 |
except Exception as e:
|
| 224 |
+
logger.error(f"Error recording count: {e}", exc_info=True)
|
| 225 |
return False
|
| 226 |
|
| 227 |
def get_historical_counts(
|
|
|
|
| 270 |
conn.close()
|
| 271 |
return results
|
| 272 |
except Exception as e:
|
| 273 |
+
logger.error(f"Error fetching historical counts: {e}", exc_info=True)
|
| 274 |
return []
|
| 275 |
|
| 276 |
def get_latest_count(self) -> Optional[Dict]:
|
|
|
|
| 298 |
}
|
| 299 |
return None
|
| 300 |
except Exception as e:
|
| 301 |
+
logger.error(f"Error fetching latest count: {e}", exc_info=True)
|
| 302 |
return None
|
| 303 |
|
| 304 |
def get_growth_stats(self, days: int = 7) -> Dict:
|
backend/services/model_tracker_improved.py
CHANGED
|
@@ -11,12 +11,17 @@ Key improvements:
|
|
| 11 |
import os
|
| 12 |
import json
|
| 13 |
import sqlite3
|
|
|
|
|
|
|
| 14 |
from datetime import datetime, timedelta
|
| 15 |
from typing import Dict, List, Optional, Tuple
|
| 16 |
from huggingface_hub import HfApi
|
| 17 |
import pandas as pd
|
| 18 |
from pathlib import Path
|
| 19 |
import time
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class ImprovedModelCountTracker:
|
|
@@ -78,72 +83,73 @@ class ImprovedModelCountTracker:
|
|
| 78 |
elapsed = (datetime.utcnow() - self._cache_timestamp).total_seconds()
|
| 79 |
return elapsed < self.cache_ttl
|
| 80 |
|
| 81 |
-
def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False) -> Dict:
|
| 82 |
"""
|
| 83 |
-
Fetch current model count from Hugging Face Hub
|
| 84 |
-
Uses
|
| 85 |
|
| 86 |
Args:
|
| 87 |
use_cache: Whether to use cached results if available
|
| 88 |
force_refresh: Force refresh even if cache is valid
|
|
|
|
| 89 |
|
| 90 |
Returns:
|
| 91 |
Dictionary with total count and breakdowns
|
| 92 |
"""
|
| 93 |
-
# Check cache first
|
| 94 |
if use_cache and not force_refresh and self._is_cache_valid():
|
| 95 |
return self._cache
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
try:
|
| 98 |
-
# Strategy 1: Try to get count efficiently using pagination
|
| 99 |
-
# The HfApi.list_models() returns an iterator, so we can count efficiently
|
| 100 |
total_count = 0
|
| 101 |
library_counts = {}
|
| 102 |
pipeline_counts = {}
|
| 103 |
author_counts = {}
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
max_count_for_full_breakdown = 50000 # If less than this, do full breakdown
|
| 108 |
|
| 109 |
models_iter = self.api.list_models(full=False, sort="created", direction=-1)
|
| 110 |
sampled_models = []
|
| 111 |
|
| 112 |
start_time = time.time()
|
| 113 |
-
timeout_seconds = 30
|
| 114 |
|
| 115 |
for i, model in enumerate(models_iter):
|
| 116 |
-
# Check timeout
|
| 117 |
if time.time() - start_time > timeout_seconds:
|
| 118 |
-
# If we hit timeout, use sampling strategy
|
| 119 |
break
|
| 120 |
|
| 121 |
total_count += 1
|
| 122 |
|
| 123 |
-
# Sample models for breakdowns
|
| 124 |
if i < sample_size:
|
| 125 |
sampled_models.append(model)
|
| 126 |
|
| 127 |
-
# For smaller datasets, we can do full breakdown
|
| 128 |
if total_count < max_count_for_full_breakdown:
|
| 129 |
-
# Count by library
|
| 130 |
if hasattr(model, 'library_name') and model.library_name:
|
| 131 |
lib = model.library_name
|
| 132 |
library_counts[lib] = library_counts.get(lib, 0) + 1
|
| 133 |
|
| 134 |
-
# Count by pipeline
|
| 135 |
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
|
| 136 |
pipeline = model.pipeline_tag
|
| 137 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 138 |
|
| 139 |
-
# Count by author (extract from model_id)
|
| 140 |
if hasattr(model, 'id') and model.id:
|
| 141 |
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
|
| 142 |
author_counts[author] = author_counts.get(author, 0) + 1
|
| 143 |
|
| 144 |
-
# If we sampled, calculate breakdowns from sample and extrapolate
|
| 145 |
if total_count > len(sampled_models) and len(sampled_models) > 0:
|
| 146 |
-
# Calculate breakdowns from sample
|
| 147 |
for model in sampled_models:
|
| 148 |
if hasattr(model, 'library_name') and model.library_name:
|
| 149 |
lib = model.library_name
|
|
@@ -157,7 +163,6 @@ class ImprovedModelCountTracker:
|
|
| 157 |
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
|
| 158 |
author_counts[author] = author_counts.get(author, 0) + 1
|
| 159 |
|
| 160 |
-
# Scale up breakdowns proportionally
|
| 161 |
if len(sampled_models) > 0:
|
| 162 |
scale_factor = total_count / len(sampled_models)
|
| 163 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
@@ -168,20 +173,19 @@ class ImprovedModelCountTracker:
|
|
| 168 |
"total_models": total_count,
|
| 169 |
"models_by_library": library_counts,
|
| 170 |
"models_by_pipeline": pipeline_counts,
|
| 171 |
-
"models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]),
|
| 172 |
"timestamp": datetime.utcnow().isoformat(),
|
| 173 |
"sampling_used": total_count > len(sampled_models) if sampled_models else False,
|
| 174 |
"sample_size": len(sampled_models) if sampled_models else total_count
|
| 175 |
}
|
| 176 |
|
| 177 |
-
# Update cache
|
| 178 |
self._cache = result
|
| 179 |
self._cache_timestamp = datetime.utcnow()
|
| 180 |
|
| 181 |
return result
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
-
|
| 185 |
return {
|
| 186 |
"total_models": 0,
|
| 187 |
"models_by_library": {},
|
|
@@ -191,6 +195,70 @@ class ImprovedModelCountTracker:
|
|
| 191 |
"error": str(e)
|
| 192 |
}
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
def get_count_from_dataset_snapshot(self, dataset_name: str = "modelbiome/ai_ecosystem_withmodelcards") -> Optional[Dict]:
|
| 195 |
"""
|
| 196 |
Alternative method: Get count from dataset snapshot (like ai-ecosystem repo does).
|
|
@@ -205,11 +273,9 @@ class ImprovedModelCountTracker:
|
|
| 205 |
try:
|
| 206 |
from datasets import load_dataset
|
| 207 |
|
| 208 |
-
# Load just metadata to get count quickly
|
| 209 |
dataset = load_dataset(dataset_name, split="train")
|
| 210 |
total_count = len(dataset)
|
| 211 |
|
| 212 |
-
# Sample for breakdowns
|
| 213 |
sample_size = min(10000, total_count)
|
| 214 |
sample = dataset.shuffle(seed=42).select(range(sample_size))
|
| 215 |
|
|
@@ -225,7 +291,6 @@ class ImprovedModelCountTracker:
|
|
| 225 |
pipeline = item['pipeline_tag']
|
| 226 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 227 |
|
| 228 |
-
# Scale up
|
| 229 |
if sample_size < total_count:
|
| 230 |
scale_factor = total_count / sample_size
|
| 231 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
@@ -239,7 +304,7 @@ class ImprovedModelCountTracker:
|
|
| 239 |
"source": "dataset_snapshot"
|
| 240 |
}
|
| 241 |
except Exception as e:
|
| 242 |
-
|
| 243 |
return None
|
| 244 |
|
| 245 |
def record_count(self, count_data: Optional[Dict] = None, source: str = "api") -> bool:
|
|
@@ -279,7 +344,7 @@ class ImprovedModelCountTracker:
|
|
| 279 |
conn.close()
|
| 280 |
return True
|
| 281 |
except Exception as e:
|
| 282 |
-
|
| 283 |
return False
|
| 284 |
|
| 285 |
def get_historical_counts(
|
|
@@ -329,7 +394,7 @@ class ImprovedModelCountTracker:
|
|
| 329 |
conn.close()
|
| 330 |
return results
|
| 331 |
except Exception as e:
|
| 332 |
-
|
| 333 |
return []
|
| 334 |
|
| 335 |
def get_latest_count(self) -> Optional[Dict]:
|
|
@@ -358,7 +423,7 @@ class ImprovedModelCountTracker:
|
|
| 358 |
}
|
| 359 |
return None
|
| 360 |
except Exception as e:
|
| 361 |
-
|
| 362 |
return None
|
| 363 |
|
| 364 |
def get_growth_stats(self, days: int = 7) -> Dict:
|
|
|
|
| 11 |
import os
|
| 12 |
import json
|
| 13 |
import sqlite3
|
| 14 |
+
import logging
|
| 15 |
+
import re
|
| 16 |
from datetime import datetime, timedelta
|
| 17 |
from typing import Dict, List, Optional, Tuple
|
| 18 |
from huggingface_hub import HfApi
|
| 19 |
import pandas as pd
|
| 20 |
from pathlib import Path
|
| 21 |
import time
|
| 22 |
+
import httpx
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
| 27 |
class ImprovedModelCountTracker:
|
|
|
|
| 83 |
elapsed = (datetime.utcnow() - self._cache_timestamp).total_seconds()
|
| 84 |
return elapsed < self.cache_ttl
|
| 85 |
|
| 86 |
+
def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False, use_models_page: bool = True) -> Dict:
|
| 87 |
"""
|
| 88 |
+
Fetch current model count from Hugging Face Hub.
|
| 89 |
+
Uses multiple strategies: models page scraping (fastest), API, or dataset snapshot.
|
| 90 |
|
| 91 |
Args:
|
| 92 |
use_cache: Whether to use cached results if available
|
| 93 |
force_refresh: Force refresh even if cache is valid
|
| 94 |
+
use_models_page: Try to get count from HF models page first (default: True)
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
Dictionary with total count and breakdowns
|
| 98 |
"""
|
|
|
|
| 99 |
if use_cache and not force_refresh and self._is_cache_valid():
|
| 100 |
return self._cache
|
| 101 |
|
| 102 |
+
if use_models_page:
|
| 103 |
+
page_count = self.get_count_from_models_page()
|
| 104 |
+
if page_count:
|
| 105 |
+
dataset_count = self.get_count_from_dataset_snapshot()
|
| 106 |
+
if dataset_count and dataset_count.get("models_by_library"):
|
| 107 |
+
page_count["models_by_library"] = dataset_count.get("models_by_library", {})
|
| 108 |
+
page_count["models_by_pipeline"] = dataset_count.get("models_by_pipeline", {})
|
| 109 |
+
page_count["models_by_author"] = dataset_count.get("models_by_author", {})
|
| 110 |
+
|
| 111 |
+
self._cache = page_count
|
| 112 |
+
self._cache_timestamp = datetime.utcnow()
|
| 113 |
+
return page_count
|
| 114 |
+
|
| 115 |
try:
|
|
|
|
|
|
|
| 116 |
total_count = 0
|
| 117 |
library_counts = {}
|
| 118 |
pipeline_counts = {}
|
| 119 |
author_counts = {}
|
| 120 |
|
| 121 |
+
sample_size = 20000
|
| 122 |
+
max_count_for_full_breakdown = 50000
|
|
|
|
| 123 |
|
| 124 |
models_iter = self.api.list_models(full=False, sort="created", direction=-1)
|
| 125 |
sampled_models = []
|
| 126 |
|
| 127 |
start_time = time.time()
|
| 128 |
+
timeout_seconds = 30
|
| 129 |
|
| 130 |
for i, model in enumerate(models_iter):
|
|
|
|
| 131 |
if time.time() - start_time > timeout_seconds:
|
|
|
|
| 132 |
break
|
| 133 |
|
| 134 |
total_count += 1
|
| 135 |
|
|
|
|
| 136 |
if i < sample_size:
|
| 137 |
sampled_models.append(model)
|
| 138 |
|
|
|
|
| 139 |
if total_count < max_count_for_full_breakdown:
|
|
|
|
| 140 |
if hasattr(model, 'library_name') and model.library_name:
|
| 141 |
lib = model.library_name
|
| 142 |
library_counts[lib] = library_counts.get(lib, 0) + 1
|
| 143 |
|
|
|
|
| 144 |
if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
|
| 145 |
pipeline = model.pipeline_tag
|
| 146 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 147 |
|
|
|
|
| 148 |
if hasattr(model, 'id') and model.id:
|
| 149 |
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
|
| 150 |
author_counts[author] = author_counts.get(author, 0) + 1
|
| 151 |
|
|
|
|
| 152 |
if total_count > len(sampled_models) and len(sampled_models) > 0:
|
|
|
|
| 153 |
for model in sampled_models:
|
| 154 |
if hasattr(model, 'library_name') and model.library_name:
|
| 155 |
lib = model.library_name
|
|
|
|
| 163 |
author = model.id.split('/')[0] if '/' in model.id else 'unknown'
|
| 164 |
author_counts[author] = author_counts.get(author, 0) + 1
|
| 165 |
|
|
|
|
| 166 |
if len(sampled_models) > 0:
|
| 167 |
scale_factor = total_count / len(sampled_models)
|
| 168 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
|
|
| 173 |
"total_models": total_count,
|
| 174 |
"models_by_library": library_counts,
|
| 175 |
"models_by_pipeline": pipeline_counts,
|
| 176 |
+
"models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]),
|
| 177 |
"timestamp": datetime.utcnow().isoformat(),
|
| 178 |
"sampling_used": total_count > len(sampled_models) if sampled_models else False,
|
| 179 |
"sample_size": len(sampled_models) if sampled_models else total_count
|
| 180 |
}
|
| 181 |
|
|
|
|
| 182 |
self._cache = result
|
| 183 |
self._cache_timestamp = datetime.utcnow()
|
| 184 |
|
| 185 |
return result
|
| 186 |
|
| 187 |
except Exception as e:
|
| 188 |
+
logger.error(f"Error fetching model count: {e}", exc_info=True)
|
| 189 |
return {
|
| 190 |
"total_models": 0,
|
| 191 |
"models_by_library": {},
|
|
|
|
| 195 |
"error": str(e)
|
| 196 |
}
|
| 197 |
|
| 198 |
+
def get_count_from_models_page(self) -> Optional[Dict]:
|
| 199 |
+
"""
|
| 200 |
+
Get model count by scraping the Hugging Face models page.
|
| 201 |
+
Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Dictionary with total_models count, or None if extraction fails
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
url = "https://huggingface.co/models"
|
| 208 |
+
response = httpx.get(url, timeout=10.0, follow_redirects=True)
|
| 209 |
+
response.raise_for_status()
|
| 210 |
+
|
| 211 |
+
html_content = response.text
|
| 212 |
+
|
| 213 |
+
# Look for the pattern: <div class="font-normal text-gray-400">2,249,310</div>
|
| 214 |
+
# The number is in the format with commas
|
| 215 |
+
pattern = r'<div[^>]*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)</div>'
|
| 216 |
+
matches = re.findall(pattern, html_content)
|
| 217 |
+
|
| 218 |
+
if matches:
|
| 219 |
+
# Take the first match and remove commas
|
| 220 |
+
count_str = matches[0].replace(',', '')
|
| 221 |
+
total_models = int(count_str)
|
| 222 |
+
|
| 223 |
+
logger.info(f"Extracted model count from HF models page: {total_models}")
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"total_models": total_models,
|
| 227 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 228 |
+
"source": "hf_models_page",
|
| 229 |
+
"models_by_library": {},
|
| 230 |
+
"models_by_pipeline": {},
|
| 231 |
+
"models_by_author": {}
|
| 232 |
+
}
|
| 233 |
+
else:
|
| 234 |
+
# Fallback: try to find the number in the window.__hf_deferred object
|
| 235 |
+
# The page has: window.__hf_deferred["numTotalItems"] = 2249312;
|
| 236 |
+
deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
|
| 237 |
+
deferred_matches = re.findall(deferred_pattern, html_content)
|
| 238 |
+
|
| 239 |
+
if deferred_matches:
|
| 240 |
+
total_models = int(deferred_matches[0])
|
| 241 |
+
logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
|
| 242 |
+
|
| 243 |
+
return {
|
| 244 |
+
"total_models": total_models,
|
| 245 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 246 |
+
"source": "hf_models_page_deferred",
|
| 247 |
+
"models_by_library": {},
|
| 248 |
+
"models_by_pipeline": {},
|
| 249 |
+
"models_by_author": {}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
logger.warning("Could not find model count in HF models page HTML")
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
except httpx.HTTPError as e:
|
| 256 |
+
logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
|
| 257 |
+
return None
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
def get_count_from_dataset_snapshot(self, dataset_name: str = "modelbiome/ai_ecosystem_withmodelcards") -> Optional[Dict]:
|
| 263 |
"""
|
| 264 |
Alternative method: Get count from dataset snapshot (like ai-ecosystem repo does).
|
|
|
|
| 273 |
try:
|
| 274 |
from datasets import load_dataset
|
| 275 |
|
|
|
|
| 276 |
dataset = load_dataset(dataset_name, split="train")
|
| 277 |
total_count = len(dataset)
|
| 278 |
|
|
|
|
| 279 |
sample_size = min(10000, total_count)
|
| 280 |
sample = dataset.shuffle(seed=42).select(range(sample_size))
|
| 281 |
|
|
|
|
| 291 |
pipeline = item['pipeline_tag']
|
| 292 |
pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
|
| 293 |
|
|
|
|
| 294 |
if sample_size < total_count:
|
| 295 |
scale_factor = total_count / sample_size
|
| 296 |
library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
|
|
|
|
| 304 |
"source": "dataset_snapshot"
|
| 305 |
}
|
| 306 |
except Exception as e:
|
| 307 |
+
logger.error(f"Error loading from dataset snapshot: {e}", exc_info=True)
|
| 308 |
return None
|
| 309 |
|
| 310 |
def record_count(self, count_data: Optional[Dict] = None, source: str = "api") -> bool:
|
|
|
|
| 344 |
conn.close()
|
| 345 |
return True
|
| 346 |
except Exception as e:
|
| 347 |
+
logger.error(f"Error recording count: {e}", exc_info=True)
|
| 348 |
return False
|
| 349 |
|
| 350 |
def get_historical_counts(
|
|
|
|
| 394 |
conn.close()
|
| 395 |
return results
|
| 396 |
except Exception as e:
|
| 397 |
+
logger.error(f"Error fetching historical counts: {e}", exc_info=True)
|
| 398 |
return []
|
| 399 |
|
| 400 |
def get_latest_count(self) -> Optional[Dict]:
|
|
|
|
| 423 |
}
|
| 424 |
return None
|
| 425 |
except Exception as e:
|
| 426 |
+
logger.error(f"Error fetching latest count: {e}", exc_info=True)
|
| 427 |
return None
|
| 428 |
|
| 429 |
def get_growth_stats(self, days: int = 7) -> Dict:
|
backend/utils/data_loader.py
CHANGED
|
@@ -50,18 +50,16 @@ class ModelDataLoader:
|
|
| 50 |
else:
|
| 51 |
df = df.copy()
|
| 52 |
|
| 53 |
-
# Fill NaN values
|
| 54 |
text_fields = ['tags', 'pipeline_tag', 'library_name', 'modelCard']
|
| 55 |
for field in text_fields:
|
| 56 |
if field in df.columns:
|
| 57 |
df[field] = df[field].fillna('')
|
| 58 |
|
| 59 |
-
# Combine text fields for embedding
|
| 60 |
df['combined_text'] = (
|
| 61 |
df.get('tags', '').astype(str) + ' ' +
|
| 62 |
df.get('pipeline_tag', '').astype(str) + ' ' +
|
| 63 |
df.get('library_name', '').astype(str) + ' ' +
|
| 64 |
-
df['modelCard'].astype(str).str[:500]
|
| 65 |
)
|
| 66 |
|
| 67 |
return df
|
|
@@ -94,7 +92,6 @@ class ModelDataLoader:
|
|
| 94 |
else:
|
| 95 |
df = df.copy()
|
| 96 |
|
| 97 |
-
# Optimized filtering with vectorized operations
|
| 98 |
if min_downloads is not None:
|
| 99 |
downloads_col = df.get('downloads', pd.Series([0] * len(df), index=df.index))
|
| 100 |
df = df[downloads_col >= min_downloads]
|
|
|
|
| 50 |
else:
|
| 51 |
df = df.copy()
|
| 52 |
|
|
|
|
| 53 |
text_fields = ['tags', 'pipeline_tag', 'library_name', 'modelCard']
|
| 54 |
for field in text_fields:
|
| 55 |
if field in df.columns:
|
| 56 |
df[field] = df[field].fillna('')
|
| 57 |
|
|
|
|
| 58 |
df['combined_text'] = (
|
| 59 |
df.get('tags', '').astype(str) + ' ' +
|
| 60 |
df.get('pipeline_tag', '').astype(str) + ' ' +
|
| 61 |
df.get('library_name', '').astype(str) + ' ' +
|
| 62 |
+
df['modelCard'].astype(str).str[:500]
|
| 63 |
)
|
| 64 |
|
| 65 |
return df
|
|
|
|
| 92 |
else:
|
| 93 |
df = df.copy()
|
| 94 |
|
|
|
|
| 95 |
if min_downloads is not None:
|
| 96 |
downloads_col = df.get('downloads', pd.Series([0] * len(df), index=df.index))
|
| 97 |
df = df[downloads_col >= min_downloads]
|
backend/utils/embeddings.py
CHANGED
|
@@ -27,7 +27,7 @@ class ModelEmbedder:
|
|
| 27 |
def generate_embeddings(
|
| 28 |
self,
|
| 29 |
texts: List[str],
|
| 30 |
-
batch_size: int = 128,
|
| 31 |
show_progress: bool = True
|
| 32 |
) -> np.ndarray:
|
| 33 |
"""
|
|
|
|
| 27 |
def generate_embeddings(
|
| 28 |
self,
|
| 29 |
texts: List[str],
|
| 30 |
+
batch_size: int = 128,
|
| 31 |
show_progress: bool = True
|
| 32 |
) -> np.ndarray:
|
| 33 |
"""
|
backend/utils/family_tree.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Family tree utility functions."""
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
|
| 6 |
+
"""Calculate family depth for each model."""
|
| 7 |
+
depths = {}
|
| 8 |
+
computing = set()
|
| 9 |
+
|
| 10 |
+
def get_depth(model_id: str) -> int:
|
| 11 |
+
if model_id in depths:
|
| 12 |
+
return depths[model_id]
|
| 13 |
+
if model_id in computing:
|
| 14 |
+
depths[model_id] = 0
|
| 15 |
+
return 0
|
| 16 |
+
|
| 17 |
+
computing.add(model_id)
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
if df.index.name == 'model_id':
|
| 21 |
+
row = df.loc[model_id]
|
| 22 |
+
else:
|
| 23 |
+
rows = df[df.get('model_id', '') == model_id]
|
| 24 |
+
if len(rows) == 0:
|
| 25 |
+
depths[model_id] = 0
|
| 26 |
+
computing.remove(model_id)
|
| 27 |
+
return 0
|
| 28 |
+
row = rows.iloc[0]
|
| 29 |
+
|
| 30 |
+
parent_id = row.get('parent_model')
|
| 31 |
+
if parent_id and pd.notna(parent_id):
|
| 32 |
+
parent_str = str(parent_id)
|
| 33 |
+
if parent_str != 'nan' and parent_str != '':
|
| 34 |
+
if df.index.name == 'model_id' and parent_str in df.index:
|
| 35 |
+
depth = get_depth(parent_str) + 1
|
| 36 |
+
elif df.index.name != 'model_id':
|
| 37 |
+
parent_rows = df[df.get('model_id', '') == parent_str]
|
| 38 |
+
if len(parent_rows) > 0:
|
| 39 |
+
depth = get_depth(parent_str) + 1
|
| 40 |
+
else:
|
| 41 |
+
depth = 0
|
| 42 |
+
else:
|
| 43 |
+
depth = 0
|
| 44 |
+
else:
|
| 45 |
+
depth = 0
|
| 46 |
+
else:
|
| 47 |
+
depth = 0
|
| 48 |
+
except (KeyError, IndexError):
|
| 49 |
+
depth = 0
|
| 50 |
+
|
| 51 |
+
depths[model_id] = depth
|
| 52 |
+
computing.remove(model_id)
|
| 53 |
+
return depth
|
| 54 |
+
|
| 55 |
+
if df.index.name == 'model_id':
|
| 56 |
+
for model_id in df.index:
|
| 57 |
+
if model_id not in depths:
|
| 58 |
+
get_depth(str(model_id))
|
| 59 |
+
else:
|
| 60 |
+
for _, row in df.iterrows():
|
| 61 |
+
model_id = str(row.get('model_id', ''))
|
| 62 |
+
if model_id and model_id not in depths:
|
| 63 |
+
get_depth(model_id)
|
| 64 |
+
|
| 65 |
+
return depths
|
| 66 |
+
|
backend/utils/graph_embeddings.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph-aware embeddings for hierarchical model relationships.
|
| 3 |
+
Uses Node2Vec to create embeddings that respect family tree structure.
|
| 4 |
+
"""
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
import networkx as nx
|
| 9 |
+
import pickle
|
| 10 |
+
import os
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from node2vec import Node2Vec
|
| 17 |
+
NODE2VEC_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
NODE2VEC_AVAILABLE = False
|
| 20 |
+
logger.warning("node2vec not available. Install with: pip install node2vec")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GraphEmbedder:
|
| 24 |
+
"""
|
| 25 |
+
Generate graph embeddings that respect hierarchical relationships.
|
| 26 |
+
Combines text embeddings with graph structure embeddings.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, dimensions: int = 128, walk_length: int = 30, num_walks: int = 200):
|
| 30 |
+
"""
|
| 31 |
+
Initialize graph embedder.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dimensions: Embedding dimensions
|
| 35 |
+
walk_length: Length of random walks
|
| 36 |
+
num_walks: Number of walks per node
|
| 37 |
+
"""
|
| 38 |
+
self.dimensions = dimensions
|
| 39 |
+
self.walk_length = walk_length
|
| 40 |
+
self.num_walks = num_walks
|
| 41 |
+
self.graph: Optional[nx.DiGraph] = None
|
| 42 |
+
self.embeddings: Optional[np.ndarray] = None
|
| 43 |
+
self.model: Optional[Node2Vec] = None
|
| 44 |
+
|
| 45 |
+
def build_family_graph(self, df: pd.DataFrame) -> nx.DiGraph:
|
| 46 |
+
"""
|
| 47 |
+
Build directed graph from family relationships.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
df: DataFrame with model_id and parent_model columns
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
NetworkX DiGraph
|
| 54 |
+
"""
|
| 55 |
+
graph = nx.DiGraph()
|
| 56 |
+
|
| 57 |
+
for idx, row in df.iterrows():
|
| 58 |
+
model_id = str(row.get('model_id', idx))
|
| 59 |
+
graph.add_node(model_id)
|
| 60 |
+
|
| 61 |
+
parent_id = row.get('parent_model')
|
| 62 |
+
if parent_id and pd.notna(parent_id):
|
| 63 |
+
parent_str = str(parent_id)
|
| 64 |
+
if parent_str != 'nan' and parent_str != '':
|
| 65 |
+
graph.add_edge(parent_str, model_id)
|
| 66 |
+
|
| 67 |
+
self.graph = graph
|
| 68 |
+
logger.info(f"Built graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
|
| 69 |
+
return graph
|
| 70 |
+
|
| 71 |
+
def generate_graph_embeddings(
|
| 72 |
+
self,
|
| 73 |
+
graph: Optional[nx.DiGraph] = None,
|
| 74 |
+
workers: int = 4
|
| 75 |
+
) -> Dict[str, np.ndarray]:
|
| 76 |
+
"""
|
| 77 |
+
Generate Node2Vec embeddings for graph nodes.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
graph: NetworkX graph (uses self.graph if None)
|
| 81 |
+
workers: Number of parallel workers
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Dictionary mapping model_id to embedding vector
|
| 85 |
+
"""
|
| 86 |
+
if not NODE2VEC_AVAILABLE:
|
| 87 |
+
logger.warning("Node2Vec not available, returning empty embeddings")
|
| 88 |
+
return {}
|
| 89 |
+
|
| 90 |
+
if graph is None:
|
| 91 |
+
graph = self.graph
|
| 92 |
+
|
| 93 |
+
if graph is None or graph.number_of_nodes() == 0:
|
| 94 |
+
logger.warning("No graph available for embedding generation")
|
| 95 |
+
return {}
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
node2vec = Node2Vec(
|
| 99 |
+
graph,
|
| 100 |
+
dimensions=self.dimensions,
|
| 101 |
+
walk_length=self.walk_length,
|
| 102 |
+
num_walks=self.num_walks,
|
| 103 |
+
workers=workers
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
model = node2vec.fit(window=10, min_count=1, batch_words=4)
|
| 107 |
+
self.model = model
|
| 108 |
+
|
| 109 |
+
embeddings_dict = {}
|
| 110 |
+
for node in graph.nodes():
|
| 111 |
+
if node in model.wv:
|
| 112 |
+
embeddings_dict[node] = model.wv[node]
|
| 113 |
+
|
| 114 |
+
logger.info(f"Generated graph embeddings for {len(embeddings_dict)} nodes")
|
| 115 |
+
return embeddings_dict
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error generating graph embeddings: {e}", exc_info=True)
|
| 119 |
+
return {}
|
| 120 |
+
|
| 121 |
+
def combine_embeddings(
|
| 122 |
+
self,
|
| 123 |
+
text_embeddings: np.ndarray,
|
| 124 |
+
graph_embeddings: Dict[str, np.ndarray],
|
| 125 |
+
model_ids: List[str],
|
| 126 |
+
text_weight: float = 0.7,
|
| 127 |
+
graph_weight: float = 0.3
|
| 128 |
+
) -> np.ndarray:
|
| 129 |
+
"""
|
| 130 |
+
Combine text and graph embeddings with weighted average.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
text_embeddings: Text-based embeddings (n_samples, text_dim)
|
| 134 |
+
graph_embeddings: Graph embeddings dictionary
|
| 135 |
+
model_ids: List of model IDs corresponding to text_embeddings
|
| 136 |
+
text_weight: Weight for text embeddings
|
| 137 |
+
graph_weight: Weight for graph embeddings
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Combined embeddings (n_samples, combined_dim)
|
| 141 |
+
"""
|
| 142 |
+
if not graph_embeddings:
|
| 143 |
+
return text_embeddings
|
| 144 |
+
|
| 145 |
+
text_dim = text_embeddings.shape[1]
|
| 146 |
+
graph_dim = next(iter(graph_embeddings.values())).shape[0]
|
| 147 |
+
|
| 148 |
+
combined = np.zeros((len(model_ids), text_dim + graph_dim))
|
| 149 |
+
|
| 150 |
+
for i, model_id in enumerate(model_ids):
|
| 151 |
+
model_id_str = str(model_id)
|
| 152 |
+
|
| 153 |
+
text_emb = text_embeddings[i]
|
| 154 |
+
graph_emb = graph_embeddings.get(model_id_str, np.zeros(graph_dim))
|
| 155 |
+
|
| 156 |
+
normalized_text = text_emb / (np.linalg.norm(text_emb) + 1e-8)
|
| 157 |
+
normalized_graph = graph_emb / (np.linalg.norm(graph_emb) + 1e-8)
|
| 158 |
+
|
| 159 |
+
combined[i] = np.concatenate([
|
| 160 |
+
normalized_text * text_weight,
|
| 161 |
+
normalized_graph * graph_weight
|
| 162 |
+
])
|
| 163 |
+
|
| 164 |
+
return combined
|
| 165 |
+
|
| 166 |
+
def save_embeddings(self, embeddings: Dict[str, np.ndarray], filepath: str):
|
| 167 |
+
"""Save graph embeddings to disk."""
|
| 168 |
+
os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
|
| 169 |
+
with open(filepath, 'wb') as f:
|
| 170 |
+
pickle.dump(embeddings, f)
|
| 171 |
+
|
| 172 |
+
def load_embeddings(self, filepath: str) -> Dict[str, np.ndarray]:
|
| 173 |
+
"""Load graph embeddings from disk."""
|
| 174 |
+
with open(filepath, 'rb') as f:
|
| 175 |
+
return pickle.load(f)
|
| 176 |
+
|
| 177 |
+
|
backend/utils/network_analysis.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
Network analysis module inspired by Open Syllabus Project.
|
| 3 |
Builds co-occurrence networks for models based on shared contexts.
|
|
|
|
| 4 |
"""
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
|
@@ -8,12 +9,66 @@ from collections import Counter
|
|
| 8 |
from itertools import combinations
|
| 9 |
from typing import List, Dict, Tuple, Optional, Set
|
| 10 |
import networkx as nx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class ModelNetworkBuilder:
|
| 14 |
"""
|
| 15 |
Build network graphs for models based on co-occurrence patterns.
|
| 16 |
Similar to Open Syllabus approach of connecting texts that appear together.
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
def __init__(self, df: pd.DataFrame):
|
|
@@ -22,13 +77,13 @@ class ModelNetworkBuilder:
|
|
| 22 |
|
| 23 |
Args:
|
| 24 |
df: DataFrame with model data including model_id, library_name,
|
| 25 |
-
pipeline_tag, tags, parent_model,
|
|
|
|
| 26 |
"""
|
| 27 |
self.df = df.copy()
|
| 28 |
if 'model_id' not in self.df.columns:
|
| 29 |
raise ValueError("DataFrame must contain 'model_id' column")
|
| 30 |
|
| 31 |
-
# Ensure model_id is index for fast lookups
|
| 32 |
if self.df.index.name != 'model_id':
|
| 33 |
if 'model_id' in self.df.columns:
|
| 34 |
self.df.set_index('model_id', drop=False, inplace=True)
|
|
@@ -208,23 +263,41 @@ class ModelNetworkBuilder:
|
|
| 208 |
def build_family_tree_network(
|
| 209 |
self,
|
| 210 |
root_model_id: str,
|
| 211 |
-
max_depth: int = 5
|
|
|
|
|
|
|
| 212 |
) -> nx.DiGraph:
|
| 213 |
"""
|
| 214 |
-
Build directed graph of model family tree.
|
| 215 |
|
| 216 |
Args:
|
| 217 |
root_model_id: Root model to start from
|
| 218 |
-
max_depth: Maximum depth to traverse
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
Returns:
|
| 221 |
-
NetworkX DiGraph representing family tree
|
| 222 |
"""
|
| 223 |
graph = nx.DiGraph()
|
| 224 |
visited = set()
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return
|
| 229 |
visited.add(current_id)
|
| 230 |
|
|
@@ -233,28 +306,98 @@ class ModelNetworkBuilder:
|
|
| 233 |
|
| 234 |
row = self.df.loc[current_id]
|
| 235 |
|
| 236 |
-
# Add node
|
| 237 |
graph.add_node(str(current_id))
|
| 238 |
graph.nodes[str(current_id)]['title'] = self._format_title(current_id)
|
| 239 |
graph.nodes[str(current_id)]['freq'] = int(row.get('downloads', 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
parent_id_str = str(parent_id)
|
| 245 |
-
graph.add_edge(parent_id_str, str(current_id))
|
| 246 |
-
add_family(parent_id_str, depth - 1)
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
if str(child_id) not in visited:
|
| 252 |
-
graph.
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
add_family(root_model_id, max_depth)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
return graph
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
def export_graphml(self, graph: nx.Graph, filename: str):
|
| 259 |
"""Export graph to GraphML format (like Open Syllabus)."""
|
| 260 |
nx.write_graphml(graph, filename)
|
|
|
|
| 1 |
"""
|
| 2 |
Network analysis module inspired by Open Syllabus Project.
|
| 3 |
Builds co-occurrence networks for models based on shared contexts.
|
| 4 |
+
Supports multiple relationship types: finetune, quantized, adapter, merge.
|
| 5 |
"""
|
| 6 |
import pandas as pd
|
| 7 |
import numpy as np
|
|
|
|
| 9 |
from itertools import combinations
|
| 10 |
from typing import List, Dict, Tuple, Optional, Set
|
| 11 |
import networkx as nx
|
| 12 |
+
import ast
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _parse_parent_list(value) -> List[str]:
|
| 17 |
+
"""
|
| 18 |
+
Parse parent model list from string/eval format.
|
| 19 |
+
Handles both string representations and actual lists.
|
| 20 |
+
"""
|
| 21 |
+
if pd.isna(value) or value == '' or str(value) == 'nan':
|
| 22 |
+
return []
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
if isinstance(value, str):
|
| 26 |
+
if value.startswith('[') or value.startswith('('):
|
| 27 |
+
parsed = ast.literal_eval(value)
|
| 28 |
+
else:
|
| 29 |
+
parsed = [value]
|
| 30 |
+
else:
|
| 31 |
+
parsed = value
|
| 32 |
+
|
| 33 |
+
if isinstance(parsed, list):
|
| 34 |
+
return [str(p) for p in parsed if p and str(p) != 'nan']
|
| 35 |
+
elif parsed:
|
| 36 |
+
return [str(parsed)]
|
| 37 |
+
else:
|
| 38 |
+
return []
|
| 39 |
+
except (ValueError, SyntaxError):
|
| 40 |
+
return []
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _get_all_parents(row: pd.Series) -> Dict[str, List[str]]:
|
| 44 |
+
"""
|
| 45 |
+
Extract all parent types from a row.
|
| 46 |
+
Returns dict mapping relationship type to list of parent IDs.
|
| 47 |
+
"""
|
| 48 |
+
parents = {}
|
| 49 |
+
|
| 50 |
+
parent_columns = {
|
| 51 |
+
'parent_model': 'parent',
|
| 52 |
+
'finetune_parent': 'finetune',
|
| 53 |
+
'quantized_parent': 'quantized',
|
| 54 |
+
'adapter_parent': 'adapter',
|
| 55 |
+
'merge_parent': 'merge'
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
for col, rel_type in parent_columns.items():
|
| 59 |
+
if col in row:
|
| 60 |
+
parent_list = _parse_parent_list(row.get(col))
|
| 61 |
+
if parent_list:
|
| 62 |
+
parents[rel_type] = parent_list
|
| 63 |
+
|
| 64 |
+
return parents
|
| 65 |
|
| 66 |
|
| 67 |
class ModelNetworkBuilder:
|
| 68 |
"""
|
| 69 |
Build network graphs for models based on co-occurrence patterns.
|
| 70 |
Similar to Open Syllabus approach of connecting texts that appear together.
|
| 71 |
+
Supports multiple relationship types: finetune, quantized, adapter, merge.
|
| 72 |
"""
|
| 73 |
|
| 74 |
def __init__(self, df: pd.DataFrame):
|
|
|
|
| 77 |
|
| 78 |
Args:
|
| 79 |
df: DataFrame with model data including model_id, library_name,
|
| 80 |
+
pipeline_tag, tags, parent_model, finetune_parent, quantized_parent,
|
| 81 |
+
adapter_parent, merge_parent, downloads, likes, createdAt
|
| 82 |
"""
|
| 83 |
self.df = df.copy()
|
| 84 |
if 'model_id' not in self.df.columns:
|
| 85 |
raise ValueError("DataFrame must contain 'model_id' column")
|
| 86 |
|
|
|
|
| 87 |
if self.df.index.name != 'model_id':
|
| 88 |
if 'model_id' in self.df.columns:
|
| 89 |
self.df.set_index('model_id', drop=False, inplace=True)
|
|
|
|
| 263 |
def build_family_tree_network(
|
| 264 |
self,
|
| 265 |
root_model_id: str,
|
| 266 |
+
max_depth: Optional[int] = 5,
|
| 267 |
+
include_edge_attributes: bool = True,
|
| 268 |
+
filter_edge_types: Optional[List[str]] = None
|
| 269 |
) -> nx.DiGraph:
|
| 270 |
"""
|
| 271 |
+
Build directed graph of model family tree with multiple relationship types.
|
| 272 |
|
| 273 |
Args:
|
| 274 |
root_model_id: Root model to start from
|
| 275 |
+
max_depth: Maximum depth to traverse. If None, traverses entire tree without limit.
|
| 276 |
+
include_edge_attributes: Whether to calculate edge attributes (change in likes, downloads, etc.)
|
| 277 |
+
filter_edge_types: List of edge types to include (e.g., ['finetune', 'quantized']).
|
| 278 |
+
If None, includes all types.
|
| 279 |
|
| 280 |
Returns:
|
| 281 |
+
NetworkX DiGraph representing family tree with edge types and attributes
|
| 282 |
"""
|
| 283 |
graph = nx.DiGraph()
|
| 284 |
visited = set()
|
| 285 |
|
| 286 |
+
children_index: Dict[str, List[Tuple[str, str]]] = {}
|
| 287 |
+
for idx, row in self.df.iterrows():
|
| 288 |
+
model_id = str(row.get('model_id', idx))
|
| 289 |
+
all_parents = _get_all_parents(row)
|
| 290 |
+
|
| 291 |
+
for rel_type, parent_list in all_parents.items():
|
| 292 |
+
for parent_id in parent_list:
|
| 293 |
+
if parent_id not in children_index:
|
| 294 |
+
children_index[parent_id] = []
|
| 295 |
+
children_index[parent_id].append((model_id, rel_type))
|
| 296 |
+
|
| 297 |
+
def add_family(current_id: str, depth: Optional[int]):
|
| 298 |
+
if current_id in visited:
|
| 299 |
+
return
|
| 300 |
+
if depth is not None and depth <= 0:
|
| 301 |
return
|
| 302 |
visited.add(current_id)
|
| 303 |
|
|
|
|
| 306 |
|
| 307 |
row = self.df.loc[current_id]
|
| 308 |
|
|
|
|
| 309 |
graph.add_node(str(current_id))
|
| 310 |
graph.nodes[str(current_id)]['title'] = self._format_title(current_id)
|
| 311 |
graph.nodes[str(current_id)]['freq'] = int(row.get('downloads', 0))
|
| 312 |
+
graph.nodes[str(current_id)]['likes'] = int(row.get('likes', 0))
|
| 313 |
+
graph.nodes[str(current_id)]['downloads'] = int(row.get('downloads', 0))
|
| 314 |
+
graph.nodes[str(current_id)]['library'] = str(row.get('library_name', '')) if pd.notna(row.get('library_name')) else ''
|
| 315 |
+
graph.nodes[str(current_id)]['pipeline'] = str(row.get('pipeline_tag', '')) if pd.notna(row.get('pipeline_tag')) else ''
|
| 316 |
|
| 317 |
+
createdAt = row.get('createdAt')
|
| 318 |
+
if pd.notna(createdAt):
|
| 319 |
+
graph.nodes[str(current_id)]['createdAt'] = str(createdAt)
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
all_parents = _get_all_parents(row)
|
| 322 |
+
for rel_type, parent_list in all_parents.items():
|
| 323 |
+
if filter_edge_types and rel_type not in filter_edge_types:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
for parent_id in parent_list:
|
| 327 |
+
if parent_id in self.df.index:
|
| 328 |
+
graph.add_edge(parent_id, str(current_id))
|
| 329 |
+
graph[parent_id][str(current_id)]['edge_types'] = [rel_type]
|
| 330 |
+
graph[parent_id][str(current_id)]['edge_type'] = rel_type
|
| 331 |
+
|
| 332 |
+
next_depth = depth - 1 if depth is not None else None
|
| 333 |
+
add_family(parent_id, next_depth)
|
| 334 |
+
|
| 335 |
+
children = children_index.get(current_id, [])
|
| 336 |
+
for child_id, rel_type in children:
|
| 337 |
+
if filter_edge_types and rel_type not in filter_edge_types:
|
| 338 |
+
continue
|
| 339 |
+
|
| 340 |
if str(child_id) not in visited:
|
| 341 |
+
if not graph.has_edge(str(current_id), child_id):
|
| 342 |
+
graph.add_edge(str(current_id), child_id)
|
| 343 |
+
graph[str(current_id)][child_id]['edge_types'] = [rel_type]
|
| 344 |
+
graph[str(current_id)][child_id]['edge_type'] = rel_type
|
| 345 |
+
else:
|
| 346 |
+
if rel_type not in graph[str(current_id)][child_id].get('edge_types', []):
|
| 347 |
+
graph[str(current_id)][child_id]['edge_types'].append(rel_type)
|
| 348 |
+
|
| 349 |
+
next_depth = depth - 1 if depth is not None else None
|
| 350 |
+
add_family(child_id, next_depth)
|
| 351 |
|
| 352 |
add_family(root_model_id, max_depth)
|
| 353 |
+
|
| 354 |
+
if include_edge_attributes:
|
| 355 |
+
self._add_edge_attributes(graph)
|
| 356 |
+
|
| 357 |
return graph
|
| 358 |
|
| 359 |
+
def _add_edge_attributes(self, graph: nx.DiGraph):
|
| 360 |
+
"""
|
| 361 |
+
Add edge attributes like change in likes, downloads, time difference.
|
| 362 |
+
Similar to the notebook's edge attribute calculation.
|
| 363 |
+
"""
|
| 364 |
+
for edge in graph.edges():
|
| 365 |
+
parent_model = edge[0]
|
| 366 |
+
model_id = edge[1]
|
| 367 |
+
|
| 368 |
+
if parent_model not in graph.nodes() or model_id not in graph.nodes():
|
| 369 |
+
continue
|
| 370 |
+
|
| 371 |
+
parent_likes = graph.nodes[parent_model].get('likes', 0)
|
| 372 |
+
model_likes = graph.nodes[model_id].get('likes', 0)
|
| 373 |
+
parent_downloads = graph.nodes[parent_model].get('downloads', 0)
|
| 374 |
+
model_downloads = graph.nodes[model_id].get('downloads', 0)
|
| 375 |
+
|
| 376 |
+
graph.edges[edge]['change_in_likes'] = model_likes - parent_likes
|
| 377 |
+
if parent_likes != 0:
|
| 378 |
+
graph.edges[edge]['percentage_change_in_likes'] = (model_likes - parent_likes) / parent_likes
|
| 379 |
+
else:
|
| 380 |
+
graph.edges[edge]['percentage_change_in_likes'] = np.nan
|
| 381 |
+
|
| 382 |
+
graph.edges[edge]['change_in_downloads'] = model_downloads - parent_downloads
|
| 383 |
+
if parent_downloads != 0:
|
| 384 |
+
graph.edges[edge]['percentage_change_in_downloads'] = (model_downloads - parent_downloads) / parent_downloads
|
| 385 |
+
else:
|
| 386 |
+
graph.edges[edge]['percentage_change_in_downloads'] = np.nan
|
| 387 |
+
|
| 388 |
+
parent_created = graph.nodes[parent_model].get('createdAt')
|
| 389 |
+
model_created = graph.nodes[model_id].get('createdAt')
|
| 390 |
+
|
| 391 |
+
if parent_created and model_created:
|
| 392 |
+
try:
|
| 393 |
+
parent_dt = datetime.strptime(str(parent_created), '%Y-%m-%dT%H:%M:%S.%fZ')
|
| 394 |
+
model_dt = datetime.strptime(str(model_created), '%Y-%m-%dT%H:%M:%S.%fZ')
|
| 395 |
+
graph.edges[edge]['change_in_createdAt_days'] = (model_dt - parent_dt).days
|
| 396 |
+
except (ValueError, TypeError):
|
| 397 |
+
graph.edges[edge]['change_in_createdAt_days'] = np.nan
|
| 398 |
+
else:
|
| 399 |
+
graph.edges[edge]['change_in_createdAt_days'] = np.nan
|
| 400 |
+
|
| 401 |
def export_graphml(self, graph: nx.Graph, filename: str):
|
| 402 |
"""Export graph to GraphML format (like Open Syllabus)."""
|
| 403 |
nx.write_graphml(graph, filename)
|
frontend/.npmrc
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
legacy-peer-deps=true
|
| 2 |
|
|
|
|
|
|
|
|
|
| 1 |
legacy-peer-deps=true
|
| 2 |
|
| 3 |
+
|
| 4 |
+
|
frontend/package-lock.json
CHANGED
|
@@ -32,7 +32,8 @@
|
|
| 32 |
"react-dom": "^18.2.0",
|
| 33 |
"react-scripts": "5.0.1",
|
| 34 |
"three": "^0.160.1",
|
| 35 |
-
"typescript": "^5.0.0"
|
|
|
|
| 36 |
}
|
| 37 |
},
|
| 38 |
"node_modules/@alloc/quick-lru": {
|
|
|
|
| 32 |
"react-dom": "^18.2.0",
|
| 33 |
"react-scripts": "5.0.1",
|
| 34 |
"three": "^0.160.1",
|
| 35 |
+
"typescript": "^5.0.0",
|
| 36 |
+
"zustand": "^5.0.8"
|
| 37 |
}
|
| 38 |
},
|
| 39 |
"node_modules/@alloc/quick-lru": {
|
frontend/package.json
CHANGED
|
@@ -28,7 +28,8 @@
|
|
| 28 |
"react-dom": "^18.2.0",
|
| 29 |
"react-scripts": "5.0.1",
|
| 30 |
"three": "^0.160.1",
|
| 31 |
-
"typescript": "^5.0.0"
|
|
|
|
| 32 |
},
|
| 33 |
"scripts": {
|
| 34 |
"start": "react-scripts start",
|
|
|
|
| 28 |
"react-dom": "^18.2.0",
|
| 29 |
"react-scripts": "5.0.1",
|
| 30 |
"three": "^0.160.1",
|
| 31 |
+
"typescript": "^5.0.0",
|
| 32 |
+
"zustand": "^5.0.8"
|
| 33 |
},
|
| 34 |
"scripts": {
|
| 35 |
"start": "react-scripts start",
|
frontend/public/index.html
CHANGED
|
@@ -10,7 +10,7 @@
|
|
| 10 |
/>
|
| 11 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 12 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 13 |
-
<link href="https://fonts.googleapis.com/css2?family=
|
| 14 |
<title>Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face</title>
|
| 15 |
</head>
|
| 16 |
<body>
|
|
|
|
| 10 |
/>
|
| 11 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 12 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 13 |
+
<link href="https://fonts.googleapis.com/css2?family=Instrument+Sans:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 14 |
<title>Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face</title>
|
| 15 |
</head>
|
| 16 |
<body>
|
frontend/src/App.css
CHANGED
|
@@ -7,86 +7,24 @@
|
|
| 7 |
}
|
| 8 |
|
| 9 |
.App-header {
|
| 10 |
-
background:
|
| 11 |
-
background-size: 200% 200%;
|
| 12 |
-
animation: gradientShift 20s ease infinite;
|
| 13 |
color: #ffffff;
|
| 14 |
-
padding:
|
| 15 |
text-align: center;
|
| 16 |
-
border-bottom:
|
| 17 |
-
box-shadow: 0
|
| 18 |
position: relative;
|
| 19 |
-
overflow: hidden;
|
| 20 |
}
|
| 21 |
|
| 22 |
-
.App-header::before {
|
| 23 |
-
content: '';
|
| 24 |
-
position: absolute;
|
| 25 |
-
top: 0;
|
| 26 |
-
left: 0;
|
| 27 |
-
right: 0;
|
| 28 |
-
bottom: 0;
|
| 29 |
-
background:
|
| 30 |
-
radial-gradient(circle at 20% 50%, rgba(100, 181, 246, 0.15) 0%, transparent 50%),
|
| 31 |
-
radial-gradient(circle at 80% 80%, rgba(156, 39, 176, 0.1) 0%, transparent 50%),
|
| 32 |
-
radial-gradient(circle at 40% 20%, rgba(33, 150, 243, 0.1) 0%, transparent 50%);
|
| 33 |
-
pointer-events: none;
|
| 34 |
-
animation: pulse 8s ease-in-out infinite;
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
.App-header::after {
|
| 38 |
-
content: '';
|
| 39 |
-
position: absolute;
|
| 40 |
-
top: 0;
|
| 41 |
-
left: 0;
|
| 42 |
-
right: 0;
|
| 43 |
-
bottom: 0;
|
| 44 |
-
background-image:
|
| 45 |
-
repeating-linear-gradient(
|
| 46 |
-
0deg,
|
| 47 |
-
transparent,
|
| 48 |
-
transparent 2px,
|
| 49 |
-
rgba(255, 255, 255, 0.03) 2px,
|
| 50 |
-
rgba(255, 255, 255, 0.03) 4px
|
| 51 |
-
);
|
| 52 |
-
pointer-events: none;
|
| 53 |
-
opacity: 0.5;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
@keyframes gradientShift {
|
| 57 |
-
0% {
|
| 58 |
-
background-position: 0% 50%;
|
| 59 |
-
}
|
| 60 |
-
50% {
|
| 61 |
-
background-position: 100% 50%;
|
| 62 |
-
}
|
| 63 |
-
100% {
|
| 64 |
-
background-position: 0% 50%;
|
| 65 |
-
}
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
@keyframes pulse {
|
| 69 |
-
0%, 100% {
|
| 70 |
-
opacity: 1;
|
| 71 |
-
}
|
| 72 |
-
50% {
|
| 73 |
-
opacity: 0.8;
|
| 74 |
-
}
|
| 75 |
-
}
|
| 76 |
|
| 77 |
.App-header h1 {
|
| 78 |
margin: 0 0 1rem 0;
|
| 79 |
-
font-size:
|
| 80 |
-
font-weight:
|
| 81 |
-
letter-spacing: -0.
|
| 82 |
-
line-height: 1.
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
text-shadow: 0 2px 8px rgba(0, 0, 0, 0.4), 0 4px 16px rgba(123, 31, 162, 0.3);
|
| 86 |
-
background: linear-gradient(180deg, #ffffff 0%, #e1bee7 100%);
|
| 87 |
-
-webkit-background-clip: text;
|
| 88 |
-
-webkit-text-fill-color: transparent;
|
| 89 |
-
background-clip: text;
|
| 90 |
}
|
| 91 |
|
| 92 |
.App-header p {
|
|
@@ -122,23 +60,17 @@
|
|
| 122 |
}
|
| 123 |
|
| 124 |
.stats span {
|
| 125 |
-
padding: 0.
|
| 126 |
-
background: rgba(255, 255, 255, 0.
|
| 127 |
-
border-radius:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
| 132 |
-
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1), inset 0 1px 0 rgba(255, 255, 255, 0.3);
|
| 133 |
-
font-weight: 600;
|
| 134 |
-
letter-spacing: 0.02em;
|
| 135 |
}
|
| 136 |
|
| 137 |
.stats span:hover {
|
| 138 |
-
background: rgba(255, 255, 255, 0.
|
| 139 |
-
transform: translateY(-
|
| 140 |
-
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15), inset 0 1px 0 rgba(255, 255, 255, 0.4);
|
| 141 |
-
border-color: rgba(255, 255, 255, 0.4);
|
| 142 |
}
|
| 143 |
|
| 144 |
.main-content {
|
|
@@ -149,10 +81,9 @@
|
|
| 149 |
.sidebar {
|
| 150 |
width: 340px;
|
| 151 |
padding: 1.5rem;
|
| 152 |
-
background:
|
| 153 |
overflow-y: auto;
|
| 154 |
-
border-right:
|
| 155 |
-
box-shadow: 2px 0 8px rgba(0, 0, 0, 0.05);
|
| 156 |
}
|
| 157 |
|
| 158 |
.sidebar h2 {
|
|
@@ -164,12 +95,11 @@
|
|
| 164 |
}
|
| 165 |
|
| 166 |
.sidebar h3 {
|
| 167 |
-
font-size: 0.
|
| 168 |
-
font-weight:
|
| 169 |
-
color: #
|
| 170 |
-
margin: 0 0
|
| 171 |
letter-spacing: -0.01em;
|
| 172 |
-
text-transform: none;
|
| 173 |
}
|
| 174 |
|
| 175 |
.sidebar label {
|
|
@@ -202,9 +132,8 @@
|
|
| 202 |
.sidebar input[type="text"]:focus,
|
| 203 |
.sidebar select:focus {
|
| 204 |
outline: none;
|
| 205 |
-
border-color: #
|
| 206 |
-
box-shadow: 0 0 0
|
| 207 |
-
transform: translateY(-1px);
|
| 208 |
}
|
| 209 |
|
| 210 |
.sidebar input[type="range"] {
|
|
@@ -227,20 +156,20 @@
|
|
| 227 |
.sidebar input[type="range"]::-webkit-slider-thumb {
|
| 228 |
-webkit-appearance: none;
|
| 229 |
appearance: none;
|
| 230 |
-
width:
|
| 231 |
-
height:
|
| 232 |
border-radius: 50%;
|
| 233 |
-
background:
|
| 234 |
cursor: pointer;
|
| 235 |
-
box-shadow: 0 2px
|
| 236 |
-
transition: all 0.
|
| 237 |
-
border:
|
| 238 |
}
|
| 239 |
|
| 240 |
.sidebar input[type="range"]::-webkit-slider-thumb:hover {
|
| 241 |
-
background:
|
| 242 |
-
transform: scale(1.
|
| 243 |
-
box-shadow: 0
|
| 244 |
}
|
| 245 |
|
| 246 |
.sidebar input[type="range"]::-webkit-slider-thumb:active {
|
|
@@ -248,20 +177,20 @@
|
|
| 248 |
}
|
| 249 |
|
| 250 |
.sidebar input[type="range"]::-moz-range-thumb {
|
| 251 |
-
width:
|
| 252 |
-
height:
|
| 253 |
border-radius: 50%;
|
| 254 |
-
background:
|
| 255 |
cursor: pointer;
|
| 256 |
-
border:
|
| 257 |
-
box-shadow: 0 2px
|
| 258 |
-
transition: all 0.
|
| 259 |
}
|
| 260 |
|
| 261 |
.sidebar input[type="range"]::-moz-range-thumb:hover {
|
| 262 |
-
background:
|
| 263 |
-
transform: scale(1.
|
| 264 |
-
box-shadow: 0
|
| 265 |
}
|
| 266 |
|
| 267 |
.sidebar input[type="range"]::-moz-range-thumb:active {
|
|
@@ -288,17 +217,16 @@
|
|
| 288 |
|
| 289 |
.sidebar-section {
|
| 290 |
background: #ffffff;
|
| 291 |
-
border-radius:
|
| 292 |
padding: 1.25rem;
|
| 293 |
-
margin-bottom:
|
| 294 |
border: 1px solid #e0e0e0;
|
| 295 |
-
|
| 296 |
-
transition: all 0.3s ease;
|
| 297 |
}
|
| 298 |
|
| 299 |
.sidebar-section:hover {
|
| 300 |
-
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.12);
|
| 301 |
border-color: #d0d0d0;
|
|
|
|
| 302 |
}
|
| 303 |
|
| 304 |
.filter-chip {
|
|
@@ -380,22 +308,21 @@
|
|
| 380 |
}
|
| 381 |
|
| 382 |
.loading {
|
| 383 |
-
color: #
|
| 384 |
font-weight: 600;
|
| 385 |
-
background:
|
| 386 |
-
border:
|
| 387 |
-
box-shadow: 0
|
| 388 |
}
|
| 389 |
|
| 390 |
.loading::after {
|
| 391 |
content: '';
|
| 392 |
-
width:
|
| 393 |
-
height:
|
| 394 |
-
border:
|
| 395 |
-
border-top-color: #
|
| 396 |
-
border-right-color: #7b1fa2;
|
| 397 |
border-radius: 50%;
|
| 398 |
-
animation: spin 0.8s
|
| 399 |
}
|
| 400 |
|
| 401 |
@keyframes spin {
|
|
@@ -403,101 +330,62 @@
|
|
| 403 |
}
|
| 404 |
|
| 405 |
.error {
|
| 406 |
-
color: #
|
| 407 |
-
background:
|
| 408 |
-
border-radius:
|
| 409 |
-
border:
|
| 410 |
max-width: 550px;
|
| 411 |
margin: 0 auto;
|
| 412 |
-
box-shadow: 0 4px 12px rgba(198, 40, 40, 0.15);
|
| 413 |
font-weight: 500;
|
| 414 |
}
|
| 415 |
|
| 416 |
-
.error::before {
|
| 417 |
-
content: 'β οΈ';
|
| 418 |
-
font-size: 2.5rem;
|
| 419 |
-
display: block;
|
| 420 |
-
margin-bottom: 0.5rem;
|
| 421 |
-
}
|
| 422 |
-
|
| 423 |
.empty {
|
| 424 |
-
color: #
|
| 425 |
-
background:
|
| 426 |
-
border-radius:
|
| 427 |
-
border:
|
| 428 |
max-width: 550px;
|
| 429 |
margin: 0 auto;
|
| 430 |
-
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
|
| 431 |
font-weight: 500;
|
| 432 |
}
|
| 433 |
|
| 434 |
-
.empty::before {
|
| 435 |
-
content: 'π';
|
| 436 |
-
font-size: 2.5rem;
|
| 437 |
-
display: block;
|
| 438 |
-
margin-bottom: 0.5rem;
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
.btn {
|
| 442 |
padding: 0.625rem 1.25rem;
|
| 443 |
-
border-radius:
|
| 444 |
border: none;
|
| 445 |
font-size: 0.9rem;
|
| 446 |
font-weight: 600;
|
| 447 |
cursor: pointer;
|
| 448 |
-
transition: all 0.
|
| 449 |
font-family: 'Instrument Sans', sans-serif;
|
| 450 |
display: inline-flex;
|
| 451 |
align-items: center;
|
| 452 |
justify-content: center;
|
| 453 |
gap: 0.5rem;
|
| 454 |
-
position: relative;
|
| 455 |
-
overflow: hidden;
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
.btn::before {
|
| 459 |
-
content: '';
|
| 460 |
-
position: absolute;
|
| 461 |
-
top: 50%;
|
| 462 |
-
left: 50%;
|
| 463 |
-
width: 0;
|
| 464 |
-
height: 0;
|
| 465 |
-
border-radius: 50%;
|
| 466 |
-
background: rgba(255, 255, 255, 0.3);
|
| 467 |
-
transform: translate(-50%, -50%);
|
| 468 |
-
transition: width 0.6s, height 0.6s;
|
| 469 |
}
|
| 470 |
|
| 471 |
-
.btn:hover::before {
|
| 472 |
-
width: 300px;
|
| 473 |
-
height: 300px;
|
| 474 |
-
}
|
| 475 |
|
| 476 |
.btn-primary {
|
| 477 |
-
background:
|
| 478 |
color: white;
|
| 479 |
-
box-shadow: 0 2px 4px rgba(
|
| 480 |
}
|
| 481 |
|
| 482 |
.btn-primary:hover {
|
| 483 |
-
background:
|
| 484 |
-
transform: translateY(-
|
| 485 |
-
box-shadow: 0
|
| 486 |
}
|
| 487 |
|
| 488 |
.btn-secondary {
|
| 489 |
background: #f5f5f5;
|
| 490 |
-
color: #
|
| 491 |
-
border:
|
| 492 |
-
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
|
| 493 |
}
|
| 494 |
|
| 495 |
.btn-secondary:hover {
|
| 496 |
-
background: #
|
| 497 |
-
border-color: #
|
| 498 |
-
color: #5e35b1;
|
| 499 |
-
transform: translateY(-1px);
|
| 500 |
-
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.12);
|
| 501 |
}
|
| 502 |
|
| 503 |
.btn-small {
|
|
@@ -585,7 +473,7 @@
|
|
| 585 |
--text-primary: #ffffff;
|
| 586 |
--text-secondary: #cccccc;
|
| 587 |
--border-color: #444444;
|
| 588 |
-
--accent-color: #
|
| 589 |
}
|
| 590 |
|
| 591 |
[data-theme="light"] {
|
|
@@ -595,32 +483,31 @@
|
|
| 595 |
--text-primary: #1a1a1a;
|
| 596 |
--text-secondary: #666666;
|
| 597 |
--border-color: #d0d0d0;
|
| 598 |
-
--accent-color: #
|
| 599 |
}
|
| 600 |
|
| 601 |
/* Random Model Button */
|
| 602 |
.random-model-btn {
|
| 603 |
display: flex;
|
| 604 |
align-items: center;
|
| 605 |
-
|
| 606 |
-
padding: 0.
|
| 607 |
-
background:
|
| 608 |
color: white;
|
| 609 |
border: none;
|
| 610 |
border-radius: 4px;
|
| 611 |
cursor: pointer;
|
| 612 |
font-size: 0.9rem;
|
| 613 |
font-family: 'Instrument Sans', sans-serif;
|
| 614 |
-
font-weight:
|
| 615 |
transition: all 0.2s;
|
| 616 |
width: 100%;
|
| 617 |
-
justify-content: center;
|
| 618 |
}
|
| 619 |
|
| 620 |
.random-model-btn:hover:not(:disabled) {
|
| 621 |
-
background:
|
| 622 |
transform: translateY(-1px);
|
| 623 |
-
box-shadow: 0 2px
|
| 624 |
}
|
| 625 |
|
| 626 |
.random-model-btn:disabled {
|
|
@@ -628,10 +515,6 @@
|
|
| 628 |
cursor: not-allowed;
|
| 629 |
}
|
| 630 |
|
| 631 |
-
.random-icon {
|
| 632 |
-
font-size: 1.1rem;
|
| 633 |
-
}
|
| 634 |
-
|
| 635 |
/* Zoom Slider */
|
| 636 |
.zoom-slider-container {
|
| 637 |
margin-bottom: 1rem;
|
|
@@ -859,7 +742,7 @@
|
|
| 859 |
width: 18px;
|
| 860 |
height: 18px;
|
| 861 |
cursor: pointer;
|
| 862 |
-
accent-color: #
|
| 863 |
margin-right: 0.5rem;
|
| 864 |
}
|
| 865 |
|
|
|
|
| 7 |
}
|
| 8 |
|
| 9 |
.App-header {
|
| 10 |
+
background: #2d2d2d;
|
|
|
|
|
|
|
| 11 |
color: #ffffff;
|
| 12 |
+
padding: 2.5rem 2rem;
|
| 13 |
text-align: center;
|
| 14 |
+
border-bottom: 1px solid #404040;
|
| 15 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
|
| 16 |
position: relative;
|
|
|
|
| 17 |
}
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
.App-header h1 {
|
| 21 |
margin: 0 0 1rem 0;
|
| 22 |
+
font-size: 2rem;
|
| 23 |
+
font-weight: 600;
|
| 24 |
+
letter-spacing: -0.01em;
|
| 25 |
+
line-height: 1.3;
|
| 26 |
+
color: #ffffff;
|
| 27 |
+
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
}
|
| 29 |
|
| 30 |
.App-header p {
|
|
|
|
| 60 |
}
|
| 61 |
|
| 62 |
.stats span {
|
| 63 |
+
padding: 0.625rem 1.25rem;
|
| 64 |
+
background: rgba(255, 255, 255, 0.1);
|
| 65 |
+
border-radius: 6px;
|
| 66 |
+
border: 1px solid rgba(255, 255, 255, 0.2);
|
| 67 |
+
transition: all 0.2s ease;
|
| 68 |
+
font-weight: 500;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
.stats span:hover {
|
| 72 |
+
background: rgba(255, 255, 255, 0.15);
|
| 73 |
+
transform: translateY(-1px);
|
|
|
|
|
|
|
| 74 |
}
|
| 75 |
|
| 76 |
.main-content {
|
|
|
|
| 81 |
.sidebar {
|
| 82 |
width: 340px;
|
| 83 |
padding: 1.5rem;
|
| 84 |
+
background: #fafafa;
|
| 85 |
overflow-y: auto;
|
| 86 |
+
border-right: 1px solid #e0e0e0;
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
.sidebar h2 {
|
|
|
|
| 95 |
}
|
| 96 |
|
| 97 |
.sidebar h3 {
|
| 98 |
+
font-size: 0.9rem;
|
| 99 |
+
font-weight: 600;
|
| 100 |
+
color: #2d2d2d;
|
| 101 |
+
margin: 0 0 0.875rem 0;
|
| 102 |
letter-spacing: -0.01em;
|
|
|
|
| 103 |
}
|
| 104 |
|
| 105 |
.sidebar label {
|
|
|
|
| 132 |
.sidebar input[type="text"]:focus,
|
| 133 |
.sidebar select:focus {
|
| 134 |
outline: none;
|
| 135 |
+
border-color: #4a4a4a;
|
| 136 |
+
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.08);
|
|
|
|
| 137 |
}
|
| 138 |
|
| 139 |
.sidebar input[type="range"] {
|
|
|
|
| 156 |
.sidebar input[type="range"]::-webkit-slider-thumb {
|
| 157 |
-webkit-appearance: none;
|
| 158 |
appearance: none;
|
| 159 |
+
width: 18px;
|
| 160 |
+
height: 18px;
|
| 161 |
border-radius: 50%;
|
| 162 |
+
background: #4a4a4a;
|
| 163 |
cursor: pointer;
|
| 164 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
| 165 |
+
transition: all 0.2s ease;
|
| 166 |
+
border: 2px solid #ffffff;
|
| 167 |
}
|
| 168 |
|
| 169 |
.sidebar input[type="range"]::-webkit-slider-thumb:hover {
|
| 170 |
+
background: #2d2d2d;
|
| 171 |
+
transform: scale(1.1);
|
| 172 |
+
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
|
| 173 |
}
|
| 174 |
|
| 175 |
.sidebar input[type="range"]::-webkit-slider-thumb:active {
|
|
|
|
| 177 |
}
|
| 178 |
|
| 179 |
.sidebar input[type="range"]::-moz-range-thumb {
|
| 180 |
+
width: 18px;
|
| 181 |
+
height: 18px;
|
| 182 |
border-radius: 50%;
|
| 183 |
+
background: #4a4a4a;
|
| 184 |
cursor: pointer;
|
| 185 |
+
border: 2px solid #ffffff;
|
| 186 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
| 187 |
+
transition: all 0.2s ease;
|
| 188 |
}
|
| 189 |
|
| 190 |
.sidebar input[type="range"]::-moz-range-thumb:hover {
|
| 191 |
+
background: #2d2d2d;
|
| 192 |
+
transform: scale(1.1);
|
| 193 |
+
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
|
| 194 |
}
|
| 195 |
|
| 196 |
.sidebar input[type="range"]::-moz-range-thumb:active {
|
|
|
|
| 217 |
|
| 218 |
.sidebar-section {
|
| 219 |
background: #ffffff;
|
| 220 |
+
border-radius: 6px;
|
| 221 |
padding: 1.25rem;
|
| 222 |
+
margin-bottom: 1rem;
|
| 223 |
border: 1px solid #e0e0e0;
|
| 224 |
+
transition: all 0.2s ease;
|
|
|
|
| 225 |
}
|
| 226 |
|
| 227 |
.sidebar-section:hover {
|
|
|
|
| 228 |
border-color: #d0d0d0;
|
| 229 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
|
| 230 |
}
|
| 231 |
|
| 232 |
.filter-chip {
|
|
|
|
| 308 |
}
|
| 309 |
|
| 310 |
.loading {
|
| 311 |
+
color: #2d2d2d;
|
| 312 |
font-weight: 600;
|
| 313 |
+
background: #f5f5f5;
|
| 314 |
+
border: 1px solid #d0d0d0;
|
| 315 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
|
| 316 |
}
|
| 317 |
|
| 318 |
.loading::after {
|
| 319 |
content: '';
|
| 320 |
+
width: 40px;
|
| 321 |
+
height: 40px;
|
| 322 |
+
border: 4px solid #e0e0e0;
|
| 323 |
+
border-top-color: #4a4a4a;
|
|
|
|
| 324 |
border-radius: 50%;
|
| 325 |
+
animation: spin 0.8s linear infinite;
|
| 326 |
}
|
| 327 |
|
| 328 |
@keyframes spin {
|
|
|
|
| 330 |
}
|
| 331 |
|
| 332 |
.error {
|
| 333 |
+
color: #d32f2f;
|
| 334 |
+
background: #ffebee;
|
| 335 |
+
border-radius: 8px;
|
| 336 |
+
border: 1px solid #ffcdd2;
|
| 337 |
max-width: 550px;
|
| 338 |
margin: 0 auto;
|
|
|
|
| 339 |
font-weight: 500;
|
| 340 |
}
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
.empty {
|
| 343 |
+
color: #6a6a6a;
|
| 344 |
+
background: #f5f5f5;
|
| 345 |
+
border-radius: 8px;
|
| 346 |
+
border: 1px solid #e0e0e0;
|
| 347 |
max-width: 550px;
|
| 348 |
margin: 0 auto;
|
|
|
|
| 349 |
font-weight: 500;
|
| 350 |
}
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
.btn {
|
| 353 |
padding: 0.625rem 1.25rem;
|
| 354 |
+
border-radius: 4px;
|
| 355 |
border: none;
|
| 356 |
font-size: 0.9rem;
|
| 357 |
font-weight: 600;
|
| 358 |
cursor: pointer;
|
| 359 |
+
transition: all 0.2s ease;
|
| 360 |
font-family: 'Instrument Sans', sans-serif;
|
| 361 |
display: inline-flex;
|
| 362 |
align-items: center;
|
| 363 |
justify-content: center;
|
| 364 |
gap: 0.5rem;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
}
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
.btn-primary {
|
| 369 |
+
background: #2d2d2d;
|
| 370 |
color: white;
|
| 371 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15);
|
| 372 |
}
|
| 373 |
|
| 374 |
.btn-primary:hover {
|
| 375 |
+
background: #1a1a1a;
|
| 376 |
+
transform: translateY(-1px);
|
| 377 |
+
box-shadow: 0 3px 8px rgba(0, 0, 0, 0.2);
|
| 378 |
}
|
| 379 |
|
| 380 |
.btn-secondary {
|
| 381 |
background: #f5f5f5;
|
| 382 |
+
color: #2d2d2d;
|
| 383 |
+
border: 1px solid #d0d0d0;
|
|
|
|
| 384 |
}
|
| 385 |
|
| 386 |
.btn-secondary:hover {
|
| 387 |
+
background: #e8e8e8;
|
| 388 |
+
border-color: #b0b0b0;
|
|
|
|
|
|
|
|
|
|
| 389 |
}
|
| 390 |
|
| 391 |
.btn-small {
|
|
|
|
| 473 |
--text-primary: #ffffff;
|
| 474 |
--text-secondary: #cccccc;
|
| 475 |
--border-color: #444444;
|
| 476 |
+
--accent-color: #4a4a4a;
|
| 477 |
}
|
| 478 |
|
| 479 |
[data-theme="light"] {
|
|
|
|
| 483 |
--text-primary: #1a1a1a;
|
| 484 |
--text-secondary: #666666;
|
| 485 |
--border-color: #d0d0d0;
|
| 486 |
+
--accent-color: #4a4a4a;
|
| 487 |
}
|
| 488 |
|
| 489 |
/* Random Model Button */
|
| 490 |
.random-model-btn {
|
| 491 |
display: flex;
|
| 492 |
align-items: center;
|
| 493 |
+
justify-content: center;
|
| 494 |
+
padding: 0.625rem 1.25rem;
|
| 495 |
+
background: #2d2d2d;
|
| 496 |
color: white;
|
| 497 |
border: none;
|
| 498 |
border-radius: 4px;
|
| 499 |
cursor: pointer;
|
| 500 |
font-size: 0.9rem;
|
| 501 |
font-family: 'Instrument Sans', sans-serif;
|
| 502 |
+
font-weight: 600;
|
| 503 |
transition: all 0.2s;
|
| 504 |
width: 100%;
|
|
|
|
| 505 |
}
|
| 506 |
|
| 507 |
.random-model-btn:hover:not(:disabled) {
|
| 508 |
+
background: #1a1a1a;
|
| 509 |
transform: translateY(-1px);
|
| 510 |
+
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2);
|
| 511 |
}
|
| 512 |
|
| 513 |
.random-model-btn:disabled {
|
|
|
|
| 515 |
cursor: not-allowed;
|
| 516 |
}
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
/* Zoom Slider */
|
| 519 |
.zoom-slider-container {
|
| 520 |
margin-bottom: 1rem;
|
|
|
|
| 742 |
width: 18px;
|
| 743 |
height: 18px;
|
| 744 |
cursor: pointer;
|
| 745 |
+
accent-color: #4a4a4a;
|
| 746 |
margin-right: 0.5rem;
|
| 747 |
}
|
| 748 |
|
frontend/src/App.tsx
CHANGED
|
@@ -506,28 +506,24 @@ function App() {
|
|
| 506 |
alignItems: 'center',
|
| 507 |
marginBottom: '1.5rem',
|
| 508 |
paddingBottom: '1rem',
|
| 509 |
-
borderBottom: '
|
| 510 |
}}>
|
| 511 |
<h2 style={{
|
| 512 |
margin: 0,
|
| 513 |
fontSize: '1.5rem',
|
| 514 |
-
fontWeight: '
|
| 515 |
-
|
| 516 |
-
WebkitBackgroundClip: 'text',
|
| 517 |
-
WebkitTextFillColor: 'transparent',
|
| 518 |
-
backgroundClip: 'text'
|
| 519 |
}}>
|
| 520 |
Filters & Controls
|
| 521 |
</h2>
|
| 522 |
{activeFilterCount > 0 && (
|
| 523 |
<div style={{
|
| 524 |
fontSize: '0.75rem',
|
| 525 |
-
background: '
|
| 526 |
color: 'white',
|
| 527 |
-
padding: '0.
|
| 528 |
-
borderRadius: '
|
| 529 |
-
fontWeight: '600'
|
| 530 |
-
boxShadow: '0 2px 6px rgba(94, 53, 177, 0.3)'
|
| 531 |
}}>
|
| 532 |
{activeFilterCount} active
|
| 533 |
</div>
|
|
@@ -537,40 +533,40 @@ function App() {
|
|
| 537 |
{/* Filter Results Count */}
|
| 538 |
{!loading && data.length > 0 && (
|
| 539 |
<div className="sidebar-section" style={{
|
| 540 |
-
background: '
|
| 541 |
-
border: '
|
| 542 |
fontSize: '0.9rem',
|
| 543 |
marginBottom: '1.5rem'
|
| 544 |
}}>
|
| 545 |
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
| 546 |
<div>
|
| 547 |
-
<strong style={{ fontSize: '1.1rem', color: '#
|
| 548 |
{data.length.toLocaleString()}
|
| 549 |
</strong>
|
| 550 |
-
<span style={{ marginLeft: '0.4rem', color: '#
|
| 551 |
{data.length === 1 ? 'model' : 'models'}
|
| 552 |
</span>
|
| 553 |
</div>
|
| 554 |
{embeddingType === 'graph-aware' && (
|
| 555 |
<span style={{
|
| 556 |
fontSize: '0.7rem',
|
| 557 |
-
background: '#
|
| 558 |
color: 'white',
|
| 559 |
padding: '0.3rem 0.6rem',
|
| 560 |
borderRadius: '12px',
|
| 561 |
fontWeight: '600'
|
| 562 |
}}>
|
| 563 |
-
|
| 564 |
</span>
|
| 565 |
)}
|
| 566 |
</div>
|
| 567 |
{filteredCount !== null && filteredCount !== data.length && (
|
| 568 |
-
<div style={{ fontSize: '0.8rem', color: '#
|
| 569 |
of {filteredCount.toLocaleString()} matching
|
| 570 |
</div>
|
| 571 |
)}
|
| 572 |
{stats && filteredCount !== null && filteredCount < stats.total_models && (
|
| 573 |
-
<div style={{ fontSize: '0.75rem', color: '#
|
| 574 |
from {stats.total_models.toLocaleString()} total
|
| 575 |
</div>
|
| 576 |
)}
|
|
@@ -579,15 +575,7 @@ function App() {
|
|
| 579 |
|
| 580 |
{/* Search Section */}
|
| 581 |
<div className="sidebar-section">
|
| 582 |
-
<h3
|
| 583 |
-
display: 'flex',
|
| 584 |
-
alignItems: 'center',
|
| 585 |
-
gap: '0.5rem',
|
| 586 |
-
color: '#5e35b1',
|
| 587 |
-
marginBottom: '0.75rem'
|
| 588 |
-
}}>
|
| 589 |
-
π Search Models
|
| 590 |
-
</h3>
|
| 591 |
<input
|
| 592 |
type="text"
|
| 593 |
value={searchQuery}
|
|
@@ -602,14 +590,7 @@ function App() {
|
|
| 602 |
|
| 603 |
{/* Popularity Filters */}
|
| 604 |
<div className="sidebar-section">
|
| 605 |
-
<h3
|
| 606 |
-
display: 'flex',
|
| 607 |
-
alignItems: 'center',
|
| 608 |
-
gap: '0.5rem',
|
| 609 |
-
color: '#5e35b1'
|
| 610 |
-
}}>
|
| 611 |
-
π Popularity Filters
|
| 612 |
-
</h3>
|
| 613 |
|
| 614 |
<label style={{ marginBottom: '1rem', display: 'block' }}>
|
| 615 |
<div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: '0.5rem' }}>
|
|
@@ -706,14 +687,7 @@ function App() {
|
|
| 706 |
|
| 707 |
{/* Discovery */}
|
| 708 |
<div className="sidebar-section">
|
| 709 |
-
<h3
|
| 710 |
-
display: 'flex',
|
| 711 |
-
alignItems: 'center',
|
| 712 |
-
gap: '0.5rem',
|
| 713 |
-
color: '#5e35b1'
|
| 714 |
-
}}>
|
| 715 |
-
π² Discovery
|
| 716 |
-
</h3>
|
| 717 |
<RandomModelButton
|
| 718 |
data={data}
|
| 719 |
onSelect={(model: ModelPoint) => {
|
|
@@ -727,15 +701,7 @@ function App() {
|
|
| 727 |
{/* Visualization Options */}
|
| 728 |
<div className="sidebar-section">
|
| 729 |
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1rem' }}>
|
| 730 |
-
<h3 style={{
|
| 731 |
-
margin: 0,
|
| 732 |
-
display: 'flex',
|
| 733 |
-
alignItems: 'center',
|
| 734 |
-
gap: '0.5rem',
|
| 735 |
-
color: '#5e35b1'
|
| 736 |
-
}}>
|
| 737 |
-
π¨ Visualization
|
| 738 |
-
</h3>
|
| 739 |
<ThemeToggle />
|
| 740 |
</div>
|
| 741 |
|
|
@@ -862,10 +828,10 @@ function App() {
|
|
| 862 |
</select>
|
| 863 |
</label>
|
| 864 |
|
| 865 |
-
<div className="sidebar-section" style={{ background: '#
|
| 866 |
<label style={{ display: 'block', marginBottom: '0' }}>
|
| 867 |
-
<span style={{ fontWeight: '600', display: 'block', marginBottom: '0.5rem', color: '#
|
| 868 |
-
|
| 869 |
</span>
|
| 870 |
<select
|
| 871 |
value={projectionMethod}
|
|
@@ -875,7 +841,7 @@ function App() {
|
|
| 875 |
<option value="umap">UMAP (better global structure)</option>
|
| 876 |
<option value="tsne">t-SNE (better local clusters)</option>
|
| 877 |
</select>
|
| 878 |
-
<div style={{ fontSize: '0.75rem', color: '#
|
| 879 |
<strong>UMAP:</strong> Preserves global structure, better for exploring relationships<br/>
|
| 880 |
<strong>t-SNE:</strong> Emphasizes local clusters, better for finding groups
|
| 881 |
</div>
|
|
@@ -884,15 +850,8 @@ function App() {
|
|
| 884 |
</div>
|
| 885 |
|
| 886 |
{/* View Modes */}
|
| 887 |
-
<div className="sidebar-section"
|
| 888 |
-
<h3
|
| 889 |
-
display: 'flex',
|
| 890 |
-
alignItems: 'center',
|
| 891 |
-
gap: '0.5rem',
|
| 892 |
-
color: '#5e35b1'
|
| 893 |
-
}}>
|
| 894 |
-
β‘ View Modes
|
| 895 |
-
</h3>
|
| 896 |
|
| 897 |
<label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
|
| 898 |
<input
|
|
@@ -937,7 +896,7 @@ function App() {
|
|
| 937 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 938 |
/>
|
| 939 |
<div>
|
| 940 |
-
<span style={{ fontWeight: '500' }}
|
| 941 |
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 942 |
Use embeddings that respect family tree structure. Models in the same family will be closer together.
|
| 943 |
</div>
|
|
@@ -955,11 +914,11 @@ function App() {
|
|
| 955 |
color: '#666'
|
| 956 |
}}>
|
| 957 |
<div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem', marginBottom: '0.25rem' }}>
|
| 958 |
-
<strong style={{ color:
|
| 959 |
-
{embeddingType === 'graph-aware' ? '
|
| 960 |
</strong>
|
| 961 |
</div>
|
| 962 |
-
<div style={{ fontSize: '0.7rem', color: '#
|
| 963 |
{embeddingType === 'graph-aware'
|
| 964 |
? 'Models in the same family tree are positioned closer together, revealing hierarchical relationships.'
|
| 965 |
: 'Standard text-based embeddings showing semantic similarity from model descriptions and tags.'}
|
|
@@ -1006,15 +965,8 @@ function App() {
|
|
| 1006 |
|
| 1007 |
{/* Structural Visualization Options */}
|
| 1008 |
{viewMode === '3d' && (
|
| 1009 |
-
<div className="sidebar-section"
|
| 1010 |
-
<h3
|
| 1011 |
-
display: 'flex',
|
| 1012 |
-
alignItems: 'center',
|
| 1013 |
-
gap: '0.5rem',
|
| 1014 |
-
color: '#5e35b1'
|
| 1015 |
-
}}>
|
| 1016 |
-
π Network Structure
|
| 1017 |
-
</h3>
|
| 1018 |
<div style={{ fontSize: '0.75rem', color: '#666', marginBottom: '1rem', lineHeight: '1.4' }}>
|
| 1019 |
Explore relationships and structure in the model ecosystem
|
| 1020 |
</div>
|
|
@@ -1026,12 +978,12 @@ function App() {
|
|
| 1026 |
onChange={(e) => setOverviewMode(e.target.checked)}
|
| 1027 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 1028 |
/>
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
</div>
|
| 1034 |
</div>
|
|
|
|
| 1035 |
</label>
|
| 1036 |
|
| 1037 |
<label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
|
|
@@ -1041,12 +993,12 @@ function App() {
|
|
| 1041 |
onChange={(e) => setShowNetworkEdges(e.target.checked)}
|
| 1042 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 1043 |
/>
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
</div>
|
| 1049 |
</div>
|
|
|
|
| 1050 |
</label>
|
| 1051 |
|
| 1052 |
{showNetworkEdges && (
|
|
@@ -1073,26 +1025,19 @@ function App() {
|
|
| 1073 |
onChange={(e) => setShowStructuralGroups(e.target.checked)}
|
| 1074 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 1075 |
/>
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
</div>
|
| 1081 |
</div>
|
|
|
|
| 1082 |
</label>
|
| 1083 |
</div>
|
| 1084 |
)}
|
| 1085 |
|
| 1086 |
{/* Quick Filters */}
|
| 1087 |
<div className="sidebar-section">
|
| 1088 |
-
<h3
|
| 1089 |
-
display: 'flex',
|
| 1090 |
-
alignItems: 'center',
|
| 1091 |
-
gap: '0.5rem',
|
| 1092 |
-
color: '#5e35b1'
|
| 1093 |
-
}}>
|
| 1094 |
-
β‘ Quick Actions
|
| 1095 |
-
</h3>
|
| 1096 |
<div style={{ display: 'flex', flexWrap: 'wrap', gap: '0.5rem' }}>
|
| 1097 |
<button
|
| 1098 |
onClick={() => {
|
|
@@ -1137,14 +1082,7 @@ function App() {
|
|
| 1137 |
</div>
|
| 1138 |
|
| 1139 |
<div className="sidebar-section">
|
| 1140 |
-
<h3
|
| 1141 |
-
display: 'flex',
|
| 1142 |
-
alignItems: 'center',
|
| 1143 |
-
gap: '0.5rem',
|
| 1144 |
-
color: '#5e35b1'
|
| 1145 |
-
}}>
|
| 1146 |
-
π³ Hierarchy Navigation
|
| 1147 |
-
</h3>
|
| 1148 |
<label style={{ marginBottom: '1rem', display: 'block' }}>
|
| 1149 |
<span style={{ fontWeight: '500', display: 'block', marginBottom: '0.5rem' }}>
|
| 1150 |
Max Hierarchy Depth
|
|
@@ -1224,14 +1162,7 @@ function App() {
|
|
| 1224 |
</div>
|
| 1225 |
|
| 1226 |
<div className="sidebar-section">
|
| 1227 |
-
<h3
|
| 1228 |
-
display: 'flex',
|
| 1229 |
-
alignItems: 'center',
|
| 1230 |
-
gap: '0.5rem',
|
| 1231 |
-
color: '#5e35b1'
|
| 1232 |
-
}}>
|
| 1233 |
-
π₯ Family Tree Explorer
|
| 1234 |
-
</h3>
|
| 1235 |
<div style={{ position: 'relative' }}>
|
| 1236 |
<input
|
| 1237 |
type="text"
|
|
|
|
| 506 |
alignItems: 'center',
|
| 507 |
marginBottom: '1.5rem',
|
| 508 |
paddingBottom: '1rem',
|
| 509 |
+
borderBottom: '1px solid #e0e0e0'
|
| 510 |
}}>
|
| 511 |
<h2 style={{
|
| 512 |
margin: 0,
|
| 513 |
fontSize: '1.5rem',
|
| 514 |
+
fontWeight: '600',
|
| 515 |
+
color: '#2d2d2d'
|
|
|
|
|
|
|
|
|
|
| 516 |
}}>
|
| 517 |
Filters & Controls
|
| 518 |
</h2>
|
| 519 |
{activeFilterCount > 0 && (
|
| 520 |
<div style={{
|
| 521 |
fontSize: '0.75rem',
|
| 522 |
+
background: '#4a4a4a',
|
| 523 |
color: 'white',
|
| 524 |
+
padding: '0.35rem 0.7rem',
|
| 525 |
+
borderRadius: '12px',
|
| 526 |
+
fontWeight: '600'
|
|
|
|
| 527 |
}}>
|
| 528 |
{activeFilterCount} active
|
| 529 |
</div>
|
|
|
|
| 533 |
{/* Filter Results Count */}
|
| 534 |
{!loading && data.length > 0 && (
|
| 535 |
<div className="sidebar-section" style={{
|
| 536 |
+
background: '#f5f5f5',
|
| 537 |
+
border: '1px solid #d0d0d0',
|
| 538 |
fontSize: '0.9rem',
|
| 539 |
marginBottom: '1.5rem'
|
| 540 |
}}>
|
| 541 |
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
| 542 |
<div>
|
| 543 |
+
<strong style={{ fontSize: '1.1rem', color: '#2d2d2d' }}>
|
| 544 |
{data.length.toLocaleString()}
|
| 545 |
</strong>
|
| 546 |
+
<span style={{ marginLeft: '0.4rem', color: '#4a4a4a' }}>
|
| 547 |
{data.length === 1 ? 'model' : 'models'}
|
| 548 |
</span>
|
| 549 |
</div>
|
| 550 |
{embeddingType === 'graph-aware' && (
|
| 551 |
<span style={{
|
| 552 |
fontSize: '0.7rem',
|
| 553 |
+
background: '#4a4a4a',
|
| 554 |
color: 'white',
|
| 555 |
padding: '0.3rem 0.6rem',
|
| 556 |
borderRadius: '12px',
|
| 557 |
fontWeight: '600'
|
| 558 |
}}>
|
| 559 |
+
Graph
|
| 560 |
</span>
|
| 561 |
)}
|
| 562 |
</div>
|
| 563 |
{filteredCount !== null && filteredCount !== data.length && (
|
| 564 |
+
<div style={{ fontSize: '0.8rem', color: '#666', marginTop: '0.25rem' }}>
|
| 565 |
of {filteredCount.toLocaleString()} matching
|
| 566 |
</div>
|
| 567 |
)}
|
| 568 |
{stats && filteredCount !== null && filteredCount < stats.total_models && (
|
| 569 |
+
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 570 |
from {stats.total_models.toLocaleString()} total
|
| 571 |
</div>
|
| 572 |
)}
|
|
|
|
| 575 |
|
| 576 |
{/* Search Section */}
|
| 577 |
<div className="sidebar-section">
|
| 578 |
+
<h3>Search Models</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
<input
|
| 580 |
type="text"
|
| 581 |
value={searchQuery}
|
|
|
|
| 590 |
|
| 591 |
{/* Popularity Filters */}
|
| 592 |
<div className="sidebar-section">
|
| 593 |
+
<h3>Popularity Filters</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
<label style={{ marginBottom: '1rem', display: 'block' }}>
|
| 596 |
<div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: '0.5rem' }}>
|
|
|
|
| 687 |
|
| 688 |
{/* Discovery */}
|
| 689 |
<div className="sidebar-section">
|
| 690 |
+
<h3>Discovery</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
<RandomModelButton
|
| 692 |
data={data}
|
| 693 |
onSelect={(model: ModelPoint) => {
|
|
|
|
| 701 |
{/* Visualization Options */}
|
| 702 |
<div className="sidebar-section">
|
| 703 |
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1rem' }}>
|
| 704 |
+
<h3 style={{ margin: 0 }}>Visualization Options</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
<ThemeToggle />
|
| 706 |
</div>
|
| 707 |
|
|
|
|
| 828 |
</select>
|
| 829 |
</label>
|
| 830 |
|
| 831 |
+
<div className="sidebar-section" style={{ background: '#f5f5f5', borderColor: '#d0d0d0', marginBottom: '1rem', padding: '0.75rem', borderRadius: '4px', border: '1px solid' }}>
|
| 832 |
<label style={{ display: 'block', marginBottom: '0' }}>
|
| 833 |
+
<span style={{ fontWeight: '600', display: 'block', marginBottom: '0.5rem', color: '#2d2d2d' }}>
|
| 834 |
+
Projection Method
|
| 835 |
</span>
|
| 836 |
<select
|
| 837 |
value={projectionMethod}
|
|
|
|
| 841 |
<option value="umap">UMAP (better global structure)</option>
|
| 842 |
<option value="tsne">t-SNE (better local clusters)</option>
|
| 843 |
</select>
|
| 844 |
+
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.5rem', lineHeight: '1.4' }}>
|
| 845 |
<strong>UMAP:</strong> Preserves global structure, better for exploring relationships<br/>
|
| 846 |
<strong>t-SNE:</strong> Emphasizes local clusters, better for finding groups
|
| 847 |
</div>
|
|
|
|
| 850 |
</div>
|
| 851 |
|
| 852 |
{/* View Modes */}
|
| 853 |
+
<div className="sidebar-section">
|
| 854 |
+
<h3>View Modes</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
<label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
|
| 857 |
<input
|
|
|
|
| 896 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 897 |
/>
|
| 898 |
<div>
|
| 899 |
+
<span style={{ fontWeight: '500' }}>Graph-Aware Embeddings</span>
|
| 900 |
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 901 |
Use embeddings that respect family tree structure. Models in the same family will be closer together.
|
| 902 |
</div>
|
|
|
|
| 914 |
color: '#666'
|
| 915 |
}}>
|
| 916 |
<div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem', marginBottom: '0.25rem' }}>
|
| 917 |
+
<strong style={{ color: '#2d2d2d' }}>
|
| 918 |
+
{embeddingType === 'graph-aware' ? 'Graph-Aware' : 'Text-Only'} Embeddings
|
| 919 |
</strong>
|
| 920 |
</div>
|
| 921 |
+
<div style={{ fontSize: '0.7rem', color: '#666', lineHeight: '1.4' }}>
|
| 922 |
{embeddingType === 'graph-aware'
|
| 923 |
? 'Models in the same family tree are positioned closer together, revealing hierarchical relationships.'
|
| 924 |
: 'Standard text-based embeddings showing semantic similarity from model descriptions and tags.'}
|
|
|
|
| 965 |
|
| 966 |
{/* Structural Visualization Options */}
|
| 967 |
{viewMode === '3d' && (
|
| 968 |
+
<div className="sidebar-section">
|
| 969 |
+
<h3>Network Structure</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
<div style={{ fontSize: '0.75rem', color: '#666', marginBottom: '1rem', lineHeight: '1.4' }}>
|
| 971 |
Explore relationships and structure in the model ecosystem
|
| 972 |
</div>
|
|
|
|
| 978 |
onChange={(e) => setOverviewMode(e.target.checked)}
|
| 979 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 980 |
/>
|
| 981 |
+
<div>
|
| 982 |
+
<span style={{ fontWeight: '500' }}>Overview Mode</span>
|
| 983 |
+
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 984 |
+
Zoom out to see full ecosystem structure with all relationships visible. Camera will automatically adjust.
|
|
|
|
| 985 |
</div>
|
| 986 |
+
</div>
|
| 987 |
</label>
|
| 988 |
|
| 989 |
<label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
|
|
|
|
| 993 |
onChange={(e) => setShowNetworkEdges(e.target.checked)}
|
| 994 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 995 |
/>
|
| 996 |
+
<div>
|
| 997 |
+
<span style={{ fontWeight: '500' }}>Network Relationships</span>
|
| 998 |
+
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 999 |
+
Show connections between related models (same library, pipeline, or tags). Blue = library, Pink = pipeline.
|
|
|
|
| 1000 |
</div>
|
| 1001 |
+
</div>
|
| 1002 |
</label>
|
| 1003 |
|
| 1004 |
{showNetworkEdges && (
|
|
|
|
| 1025 |
onChange={(e) => setShowStructuralGroups(e.target.checked)}
|
| 1026 |
style={{ marginRight: '0.5rem', cursor: 'pointer' }}
|
| 1027 |
/>
|
| 1028 |
+
<div>
|
| 1029 |
+
<span style={{ fontWeight: '500' }}>Structural Groupings</span>
|
| 1030 |
+
<div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
|
| 1031 |
+
Highlight clusters and groups with wireframe boundaries. Shows top library and pipeline clusters.
|
|
|
|
| 1032 |
</div>
|
| 1033 |
+
</div>
|
| 1034 |
</label>
|
| 1035 |
</div>
|
| 1036 |
)}
|
| 1037 |
|
| 1038 |
{/* Quick Filters */}
|
| 1039 |
<div className="sidebar-section">
|
| 1040 |
+
<h3>Quick Actions</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
<div style={{ display: 'flex', flexWrap: 'wrap', gap: '0.5rem' }}>
|
| 1042 |
<button
|
| 1043 |
onClick={() => {
|
|
|
|
| 1082 |
</div>
|
| 1083 |
|
| 1084 |
<div className="sidebar-section">
|
| 1085 |
+
<h3>Hierarchy Navigation</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1086 |
<label style={{ marginBottom: '1rem', display: 'block' }}>
|
| 1087 |
<span style={{ fontWeight: '500', display: 'block', marginBottom: '0.5rem' }}>
|
| 1088 |
Max Hierarchy Depth
|
|
|
|
| 1162 |
</div>
|
| 1163 |
|
| 1164 |
<div className="sidebar-section">
|
| 1165 |
+
<h3>Family Tree Explorer</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1166 |
<div style={{ position: 'relative' }}>
|
| 1167 |
<input
|
| 1168 |
type="text"
|
frontend/src/components/PaperPlots.css
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
.paper-plots {
|
| 2 |
-
display: flex;
|
| 3 |
-
flex-direction: column;
|
| 4 |
-
gap: 1rem;
|
| 5 |
-
padding: 1rem;
|
| 6 |
-
background: white;
|
| 7 |
-
border-radius: 8px;
|
| 8 |
-
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
| 9 |
-
}
|
| 10 |
-
|
| 11 |
-
.plot-selector {
|
| 12 |
-
border-bottom: 1px solid #e0e0e0;
|
| 13 |
-
padding-bottom: 1rem;
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
.plot-selector h3 {
|
| 17 |
-
margin: 0 0 0.75rem 0;
|
| 18 |
-
font-size: 1.25rem;
|
| 19 |
-
color: #333;
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
.plot-buttons {
|
| 23 |
-
display: flex;
|
| 24 |
-
flex-wrap: wrap;
|
| 25 |
-
gap: 0.5rem;
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
.plot-button {
|
| 29 |
-
padding: 0.5rem 1rem;
|
| 30 |
-
border: 1px solid #ccc;
|
| 31 |
-
background: white;
|
| 32 |
-
border-radius: 4px;
|
| 33 |
-
cursor: pointer;
|
| 34 |
-
font-size: 0.875rem;
|
| 35 |
-
transition: all 0.2s;
|
| 36 |
-
color: #333;
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
.plot-button:hover {
|
| 40 |
-
background: #f5f5f5;
|
| 41 |
-
border-color: #999;
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
.plot-button.active {
|
| 45 |
-
background: #4a90e2;
|
| 46 |
-
color: white;
|
| 47 |
-
border-color: #4a90e2;
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
.plot-container {
|
| 51 |
-
position: relative;
|
| 52 |
-
min-height: 600px;
|
| 53 |
-
display: flex;
|
| 54 |
-
align-items: center;
|
| 55 |
-
justify-content: center;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
.plot-loading {
|
| 59 |
-
position: absolute;
|
| 60 |
-
top: 50%;
|
| 61 |
-
left: 50%;
|
| 62 |
-
transform: translate(-50%, -50%);
|
| 63 |
-
color: #666;
|
| 64 |
-
font-size: 1rem;
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
.plot-tooltip {
|
| 68 |
-
position: absolute;
|
| 69 |
-
padding: 0.5rem;
|
| 70 |
-
background: rgba(0, 0, 0, 0.8);
|
| 71 |
-
color: white;
|
| 72 |
-
border-radius: 4px;
|
| 73 |
-
pointer-events: none;
|
| 74 |
-
font-size: 0.875rem;
|
| 75 |
-
z-index: 1000;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
.plot-container svg {
|
| 79 |
-
display: block;
|
| 80 |
-
margin: 0 auto;
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
@media (max-width: 768px) {
|
| 84 |
-
.plot-buttons {
|
| 85 |
-
flex-direction: column;
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
.plot-button {
|
| 89 |
-
width: 100%;
|
| 90 |
-
}
|
| 91 |
-
}
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/components/PaperPlots.tsx
DELETED
|
@@ -1,755 +0,0 @@
|
|
| 1 |
-
/**
|
| 2 |
-
* Interactive D3.js visualizations based on plots from the research paper.
|
| 3 |
-
* "Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face"
|
| 4 |
-
*/
|
| 5 |
-
import React, { useRef, useEffect, useState, useMemo } from 'react';
|
| 6 |
-
import * as d3 from 'd3';
|
| 7 |
-
import { ModelPoint } from '../types';
|
| 8 |
-
import './PaperPlots.css';
|
| 9 |
-
|
| 10 |
-
const API_BASE = process.env.REACT_APP_API_URL || 'http://localhost:8000';
|
| 11 |
-
|
| 12 |
-
interface PaperPlotsProps {
|
| 13 |
-
data: ModelPoint[];
|
| 14 |
-
width?: number;
|
| 15 |
-
height?: number;
|
| 16 |
-
}
|
| 17 |
-
|
| 18 |
-
type PlotType = 'family-size' | 'similarity-comparison' | 'license-drift' | 'model-card-length' | 'growth-timeline';
|
| 19 |
-
|
| 20 |
-
export default function PaperPlots({ data, width = 800, height = 600 }: PaperPlotsProps) {
|
| 21 |
-
const [activePlot, setActivePlot] = useState<PlotType>('family-size');
|
| 22 |
-
const familySizeRef = useRef<SVGSVGElement>(null);
|
| 23 |
-
const similarityRef = useRef<SVGSVGElement>(null);
|
| 24 |
-
const licenseDriftRef = useRef<SVGSVGElement>(null);
|
| 25 |
-
const modelCardLengthRef = useRef<SVGSVGElement>(null);
|
| 26 |
-
const growthTimelineRef = useRef<SVGSVGElement>(null);
|
| 27 |
-
const [familyTreeData, setFamilyTreeData] = useState<any>(null);
|
| 28 |
-
const [loading, setLoading] = useState(false);
|
| 29 |
-
|
| 30 |
-
// Fetch family tree statistics
|
| 31 |
-
useEffect(() => {
|
| 32 |
-
const fetchFamilyStats = async () => {
|
| 33 |
-
setLoading(true);
|
| 34 |
-
try {
|
| 35 |
-
const response = await fetch(`${API_BASE}/api/family/stats`);
|
| 36 |
-
if (response.ok) {
|
| 37 |
-
const stats = await response.json();
|
| 38 |
-
setFamilyTreeData(stats);
|
| 39 |
-
}
|
| 40 |
-
} catch (err) {
|
| 41 |
-
console.error('Error fetching family stats:', err);
|
| 42 |
-
} finally {
|
| 43 |
-
setLoading(false);
|
| 44 |
-
}
|
| 45 |
-
};
|
| 46 |
-
fetchFamilyStats();
|
| 47 |
-
}, []);
|
| 48 |
-
|
| 49 |
-
// Plot 1: Family Size Distribution
|
| 50 |
-
useEffect(() => {
|
| 51 |
-
if (activePlot !== 'family-size' || !familySizeRef.current) return;
|
| 52 |
-
|
| 53 |
-
const svg = d3.select(familySizeRef.current);
|
| 54 |
-
svg.selectAll('*').remove();
|
| 55 |
-
|
| 56 |
-
const margin = { top: 40, right: 40, bottom: 60, left: 60 };
|
| 57 |
-
const innerWidth = width - margin.left - margin.right;
|
| 58 |
-
const innerHeight = height - margin.top - margin.bottom;
|
| 59 |
-
|
| 60 |
-
// Use API data if available, otherwise calculate from current data
|
| 61 |
-
let binData: Array<{ x0: number; x1: number; count: number }>;
|
| 62 |
-
|
| 63 |
-
if (familyTreeData && familyTreeData.family_size_distribution) {
|
| 64 |
-
const sizeDist = familyTreeData.family_size_distribution;
|
| 65 |
-
const sizes = Object.keys(sizeDist).map(Number);
|
| 66 |
-
const counts = Object.values(sizeDist) as number[];
|
| 67 |
-
|
| 68 |
-
// Create histogram bins from distribution
|
| 69 |
-
const maxSize = d3.max(sizes) || 1;
|
| 70 |
-
const bins = d3.bin().thresholds(20).domain([0, maxSize])(sizes);
|
| 71 |
-
|
| 72 |
-
binData = bins.map(bin => {
|
| 73 |
-
let count = 0;
|
| 74 |
-
sizes.forEach((size, i) => {
|
| 75 |
-
if (size >= (bin.x0 || 0) && size < (bin.x1 || maxSize)) {
|
| 76 |
-
count += counts[i];
|
| 77 |
-
}
|
| 78 |
-
});
|
| 79 |
-
return {
|
| 80 |
-
x0: bin.x0 || 0,
|
| 81 |
-
x1: bin.x1 || maxSize,
|
| 82 |
-
count: count
|
| 83 |
-
};
|
| 84 |
-
}).filter(d => d.count > 0);
|
| 85 |
-
} else {
|
| 86 |
-
// Fallback: Calculate from current data
|
| 87 |
-
const familySizes = new Map<string, number>();
|
| 88 |
-
data.forEach(model => {
|
| 89 |
-
const familyKey = model.parent_model || model.model_id;
|
| 90 |
-
familySizes.set(familyKey, (familySizes.get(familyKey) || 0) + 1);
|
| 91 |
-
});
|
| 92 |
-
|
| 93 |
-
const sizes = Array.from(familySizes.values());
|
| 94 |
-
const bins = d3.bin().thresholds(20)(sizes);
|
| 95 |
-
binData = bins.map(bin => ({
|
| 96 |
-
x0: bin.x0 || 0,
|
| 97 |
-
x1: bin.x1 || 0,
|
| 98 |
-
count: bin.length
|
| 99 |
-
}));
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
const g = svg.append('g')
|
| 103 |
-
.attr('transform', `translate(${margin.left},${margin.top})`);
|
| 104 |
-
|
| 105 |
-
// Scales
|
| 106 |
-
const xScale = d3.scaleLinear()
|
| 107 |
-
.domain([0, d3.max(binData, d => d.x1) || 1])
|
| 108 |
-
.range([0, innerWidth])
|
| 109 |
-
.nice();
|
| 110 |
-
|
| 111 |
-
const yScale = d3.scaleLinear()
|
| 112 |
-
.domain([0, d3.max(binData, d => d.count) || 1])
|
| 113 |
-
.range([innerHeight, 0])
|
| 114 |
-
.nice();
|
| 115 |
-
|
| 116 |
-
// Bars
|
| 117 |
-
g.selectAll('rect')
|
| 118 |
-
.data(binData)
|
| 119 |
-
.enter()
|
| 120 |
-
.append('rect')
|
| 121 |
-
.attr('x', d => xScale(d.x0))
|
| 122 |
-
.attr('width', d => Math.max(0, xScale(d.x1) - xScale(d.x0) - 1))
|
| 123 |
-
.attr('y', d => yScale(d.count))
|
| 124 |
-
.attr('height', d => innerHeight - yScale(d.count))
|
| 125 |
-
.attr('fill', '#4a90e2')
|
| 126 |
-
.attr('opacity', 0.7)
|
| 127 |
-
.on('mouseover', function(event, d) {
|
| 128 |
-
d3.select(this).attr('opacity', 1);
|
| 129 |
-
const tooltip = d3.select('body').append('div')
|
| 130 |
-
.attr('class', 'plot-tooltip')
|
| 131 |
-
.style('opacity', 0);
|
| 132 |
-
tooltip.transition().duration(200).style('opacity', 0.9);
|
| 133 |
-
tooltip.html(`Family Size: ${d.x0.toFixed(0)}-${d.x1.toFixed(0)}<br/>Count: ${d.count}`)
|
| 134 |
-
.style('left', (event.pageX + 10) + 'px')
|
| 135 |
-
.style('top', (event.pageY - 28) + 'px');
|
| 136 |
-
})
|
| 137 |
-
.on('mouseout', function() {
|
| 138 |
-
d3.select(this).attr('opacity', 0.7);
|
| 139 |
-
d3.selectAll('.plot-tooltip').remove();
|
| 140 |
-
});
|
| 141 |
-
|
| 142 |
-
// Axes
|
| 143 |
-
const xAxis = d3.axisBottom(xScale).tickFormat(d3.format('d'));
|
| 144 |
-
const yAxis = d3.axisLeft(yScale);
|
| 145 |
-
|
| 146 |
-
g.append('g')
|
| 147 |
-
.attr('transform', `translate(0,${innerHeight})`)
|
| 148 |
-
.call(xAxis)
|
| 149 |
-
.append('text')
|
| 150 |
-
.attr('x', innerWidth / 2)
|
| 151 |
-
.attr('y', 45)
|
| 152 |
-
.attr('fill', 'currentColor')
|
| 153 |
-
.style('text-anchor', 'middle')
|
| 154 |
-
.style('font-size', '14px')
|
| 155 |
-
.text('Family Size (number of models)');
|
| 156 |
-
|
| 157 |
-
g.append('g')
|
| 158 |
-
.call(yAxis)
|
| 159 |
-
.append('text')
|
| 160 |
-
.attr('transform', 'rotate(-90)')
|
| 161 |
-
.attr('y', -45)
|
| 162 |
-
.attr('x', -innerHeight / 2)
|
| 163 |
-
.attr('fill', 'currentColor')
|
| 164 |
-
.style('text-anchor', 'middle')
|
| 165 |
-
.style('font-size', '14px')
|
| 166 |
-
.text('Number of Families');
|
| 167 |
-
|
| 168 |
-
// Title
|
| 169 |
-
svg.append('text')
|
| 170 |
-
.attr('x', width / 2)
|
| 171 |
-
.attr('y', 20)
|
| 172 |
-
.attr('text-anchor', 'middle')
|
| 173 |
-
.style('font-size', '16px')
|
| 174 |
-
.style('font-weight', 'bold')
|
| 175 |
-
.text('Family Size Distribution');
|
| 176 |
-
|
| 177 |
-
}, [activePlot, data, width, height, familyTreeData]);
|
| 178 |
-
|
| 179 |
-
// Plot 2: Similarity Comparison (Sibling vs Parent-Child)
|
| 180 |
-
useEffect(() => {
|
| 181 |
-
if (activePlot !== 'similarity-comparison' || !similarityRef.current || !data.length) return;
|
| 182 |
-
|
| 183 |
-
const svg = d3.select(similarityRef.current);
|
| 184 |
-
svg.selectAll('*').remove();
|
| 185 |
-
|
| 186 |
-
const margin = { top: 40, right: 40, bottom: 60, left: 60 };
|
| 187 |
-
const innerWidth = width - margin.left - margin.right;
|
| 188 |
-
const innerHeight = height - margin.top - margin.bottom;
|
| 189 |
-
|
| 190 |
-
// This would require similarity data - for now, create a placeholder visualization
|
| 191 |
-
// In the paper, this shows that siblings are more similar than parent-child pairs
|
| 192 |
-
const g = svg.append('g')
|
| 193 |
-
.attr('transform', `translate(${margin.left},${margin.top})`);
|
| 194 |
-
|
| 195 |
-
// Placeholder: Box plot or violin plot showing similarity distributions
|
| 196 |
-
// Sibling similarity (higher)
|
| 197 |
-
const siblingData = Array.from({ length: 100 }, () => 0.6 + Math.random() * 0.3);
|
| 198 |
-
// Parent-child similarity (lower)
|
| 199 |
-
const parentChildData = Array.from({ length: 100 }, () => 0.3 + Math.random() * 0.3);
|
| 200 |
-
|
| 201 |
-
const xScale = d3.scaleBand()
|
| 202 |
-
.domain(['Sibling Pairs', 'Parent-Child Pairs'])
|
| 203 |
-
.range([0, innerWidth])
|
| 204 |
-
.padding(0.3);
|
| 205 |
-
|
| 206 |
-
const yScale = d3.scaleLinear()
|
| 207 |
-
.domain([0, 1])
|
| 208 |
-
.range([innerHeight, 0])
|
| 209 |
-
.nice();
|
| 210 |
-
|
| 211 |
-
// Box plot visualization
|
| 212 |
-
[siblingData, parentChildData].forEach((dataset, i) => {
|
| 213 |
-
const label = i === 0 ? 'Sibling Pairs' : 'Parent-Child Pairs';
|
| 214 |
-
const x = xScale(label);
|
| 215 |
-
const bandWidth = xScale.bandwidth();
|
| 216 |
-
|
| 217 |
-
if (x === undefined) return;
|
| 218 |
-
|
| 219 |
-
// Calculate quartiles
|
| 220 |
-
const sorted = dataset.sort((a, b) => a - b);
|
| 221 |
-
const q1 = d3.quantile(sorted, 0.25) || 0;
|
| 222 |
-
const q2 = d3.quantile(sorted, 0.5) || 0;
|
| 223 |
-
const q3 = d3.quantile(sorted, 0.75) || 0;
|
| 224 |
-
const min = sorted[0];
|
| 225 |
-
const max = sorted[sorted.length - 1];
|
| 226 |
-
|
| 227 |
-
// Box
|
| 228 |
-
g.append('rect')
|
| 229 |
-
.attr('x', x)
|
| 230 |
-
.attr('y', yScale(q3))
|
| 231 |
-
.attr('width', bandWidth)
|
| 232 |
-
.attr('height', yScale(q1) - yScale(q3))
|
| 233 |
-
.attr('fill', i === 0 ? '#4a90e2' : '#e24a4a')
|
| 234 |
-
.attr('opacity', 0.6)
|
| 235 |
-
.attr('stroke', '#333')
|
| 236 |
-
.attr('stroke-width', 1);
|
| 237 |
-
|
| 238 |
-
// Median line
|
| 239 |
-
g.append('line')
|
| 240 |
-
.attr('x1', x)
|
| 241 |
-
.attr('x2', x + bandWidth)
|
| 242 |
-
.attr('y1', yScale(q2))
|
| 243 |
-
.attr('y2', yScale(q2))
|
| 244 |
-
.attr('stroke', '#333')
|
| 245 |
-
.attr('stroke-width', 2);
|
| 246 |
-
|
| 247 |
-
// Whiskers
|
| 248 |
-
g.append('line')
|
| 249 |
-
.attr('x1', x + bandWidth / 2)
|
| 250 |
-
.attr('x2', x + bandWidth / 2)
|
| 251 |
-
.attr('y1', yScale(min))
|
| 252 |
-
.attr('y2', yScale(q1))
|
| 253 |
-
.attr('stroke', '#333')
|
| 254 |
-
.attr('stroke-width', 1);
|
| 255 |
-
|
| 256 |
-
g.append('line')
|
| 257 |
-
.attr('x1', x + bandWidth / 2)
|
| 258 |
-
.attr('x2', x + bandWidth / 2)
|
| 259 |
-
.attr('y1', yScale(q3))
|
| 260 |
-
.attr('y2', yScale(max))
|
| 261 |
-
.attr('stroke', '#333')
|
| 262 |
-
.attr('stroke-width', 1);
|
| 263 |
-
|
| 264 |
-
// Min/Max lines
|
| 265 |
-
g.append('line')
|
| 266 |
-
.attr('x1', x + bandWidth * 0.25)
|
| 267 |
-
.attr('x2', x + bandWidth * 0.75)
|
| 268 |
-
.attr('y1', yScale(min))
|
| 269 |
-
.attr('y2', yScale(min))
|
| 270 |
-
.attr('stroke', '#333')
|
| 271 |
-
.attr('stroke-width', 1);
|
| 272 |
-
|
| 273 |
-
g.append('line')
|
| 274 |
-
.attr('x1', x + bandWidth * 0.25)
|
| 275 |
-
.attr('x2', x + bandWidth * 0.75)
|
| 276 |
-
.attr('y1', yScale(max))
|
| 277 |
-
.attr('y2', yScale(max))
|
| 278 |
-
.attr('stroke', '#333')
|
| 279 |
-
.attr('stroke-width', 1);
|
| 280 |
-
});
|
| 281 |
-
|
| 282 |
-
// Axes
|
| 283 |
-
const yAxis = d3.axisLeft(yScale);
|
| 284 |
-
g.append('g').call(yAxis);
|
| 285 |
-
|
| 286 |
-
g.append('text')
|
| 287 |
-
.attr('transform', 'rotate(-90)')
|
| 288 |
-
.attr('y', -45)
|
| 289 |
-
.attr('x', -innerHeight / 2)
|
| 290 |
-
.attr('fill', 'currentColor')
|
| 291 |
-
.style('text-anchor', 'middle')
|
| 292 |
-
.style('font-size', '14px')
|
| 293 |
-
.text('Similarity Score');
|
| 294 |
-
|
| 295 |
-
// Title
|
| 296 |
-
svg.append('text')
|
| 297 |
-
.attr('x', width / 2)
|
| 298 |
-
.attr('y', 20)
|
| 299 |
-
.attr('text-anchor', 'middle')
|
| 300 |
-
.style('font-size', '16px')
|
| 301 |
-
.style('font-weight', 'bold')
|
| 302 |
-
.text('Similarity: Siblings vs Parent-Child Pairs');
|
| 303 |
-
|
| 304 |
-
}, [activePlot, data, width, height, familyTreeData]);
|
| 305 |
-
|
| 306 |
-
// Plot 3: License Drift (over family depth)
|
| 307 |
-
useEffect(() => {
|
| 308 |
-
if (activePlot !== 'license-drift' || !licenseDriftRef.current || !data.length) return;
|
| 309 |
-
|
| 310 |
-
const svg = d3.select(licenseDriftRef.current);
|
| 311 |
-
svg.selectAll('*').remove();
|
| 312 |
-
|
| 313 |
-
const margin = { top: 40, right: 40, bottom: 60, left: 80 };
|
| 314 |
-
const innerWidth = width - margin.left - margin.right;
|
| 315 |
-
const innerHeight = height - margin.top - margin.bottom;
|
| 316 |
-
|
| 317 |
-
// Group by family depth and license type
|
| 318 |
-
const depthGroups = new Map<number, Map<string, number>>();
|
| 319 |
-
data.forEach(model => {
|
| 320 |
-
const depth = model.family_depth || 0;
|
| 321 |
-
const license = model.licenses ? (model.licenses.split(',')[0].trim() || 'unknown') : 'unknown';
|
| 322 |
-
|
| 323 |
-
if (!depthGroups.has(depth)) {
|
| 324 |
-
depthGroups.set(depth, new Map());
|
| 325 |
-
}
|
| 326 |
-
const licenseMap = depthGroups.get(depth)!;
|
| 327 |
-
licenseMap.set(license, (licenseMap.get(license) || 0) + 1);
|
| 328 |
-
});
|
| 329 |
-
|
| 330 |
-
const depths = Array.from(depthGroups.keys()).sort((a, b) => a - b);
|
| 331 |
-
const allLicenses = new Set<string>();
|
| 332 |
-
depthGroups.forEach(licenseMap => {
|
| 333 |
-
licenseMap.forEach((_, license) => allLicenses.add(license));
|
| 334 |
-
});
|
| 335 |
-
|
| 336 |
-
const licenseTypes = Array.from(allLicenses).slice(0, 5); // Top 5 licenses
|
| 337 |
-
const colorScale = d3.scaleOrdinal(d3.schemeCategory10).domain(licenseTypes);
|
| 338 |
-
|
| 339 |
-
const g = svg.append('g')
|
| 340 |
-
.attr('transform', `translate(${margin.left},${margin.top})`);
|
| 341 |
-
|
| 342 |
-
const xScale = d3.scaleBand()
|
| 343 |
-
.domain(depths.map(d => d.toString()))
|
| 344 |
-
.range([0, innerWidth])
|
| 345 |
-
.padding(0.1);
|
| 346 |
-
|
| 347 |
-
const yScale = d3.scaleLinear()
|
| 348 |
-
.domain([0, 1])
|
| 349 |
-
.range([innerHeight, 0]);
|
| 350 |
-
|
| 351 |
-
// Stacked area or bars showing license distribution
|
| 352 |
-
licenseTypes.forEach((license, i) => {
|
| 353 |
-
const stack = depths.map(depth => {
|
| 354 |
-
const licenseMap = depthGroups.get(depth) || new Map();
|
| 355 |
-
const total = Array.from(licenseMap.values()).reduce((a, b) => a + b, 0);
|
| 356 |
-
const count = licenseMap.get(license) || 0;
|
| 357 |
-
return { depth, proportion: total > 0 ? count / total : 0 };
|
| 358 |
-
});
|
| 359 |
-
|
| 360 |
-
// Draw as line chart showing proportion over depth
|
| 361 |
-
const line = d3.line<{ depth: number; proportion: number }>()
|
| 362 |
-
.x(d => (xScale(d.depth.toString()) || 0) + xScale.bandwidth() / 2)
|
| 363 |
-
.y(d => yScale(d.proportion))
|
| 364 |
-
.curve(d3.curveMonotoneX);
|
| 365 |
-
|
| 366 |
-
g.append('path')
|
| 367 |
-
.datum(stack)
|
| 368 |
-
.attr('fill', 'none')
|
| 369 |
-
.attr('stroke', colorScale(license))
|
| 370 |
-
.attr('stroke-width', 2)
|
| 371 |
-
.attr('d', line);
|
| 372 |
-
|
| 373 |
-
// Add circles for data points
|
| 374 |
-
g.selectAll(`.dot-${i}`)
|
| 375 |
-
.data(stack)
|
| 376 |
-
.enter()
|
| 377 |
-
.append('circle')
|
| 378 |
-
.attr('cx', d => (xScale(d.depth.toString()) || 0) + xScale.bandwidth() / 2)
|
| 379 |
-
.attr('cy', d => yScale(d.proportion))
|
| 380 |
-
.attr('r', 4)
|
| 381 |
-
.attr('fill', colorScale(license));
|
| 382 |
-
});
|
| 383 |
-
|
| 384 |
-
// Axes
|
| 385 |
-
const xAxis = d3.axisBottom(xScale);
|
| 386 |
-
const yAxis = d3.axisLeft(yScale).tickFormat(d3.format('.0%'));
|
| 387 |
-
|
| 388 |
-
g.append('g')
|
| 389 |
-
.attr('transform', `translate(0,${innerHeight})`)
|
| 390 |
-
.call(xAxis)
|
| 391 |
-
.append('text')
|
| 392 |
-
.attr('x', innerWidth / 2)
|
| 393 |
-
.attr('y', 45)
|
| 394 |
-
.attr('fill', 'currentColor')
|
| 395 |
-
.style('text-anchor', 'middle')
|
| 396 |
-
.style('font-size', '14px')
|
| 397 |
-
.text('Family Depth (generation)');
|
| 398 |
-
|
| 399 |
-
g.append('g').call(yAxis)
|
| 400 |
-
.append('text')
|
| 401 |
-
.attr('transform', 'rotate(-90)')
|
| 402 |
-
.attr('y', -60)
|
| 403 |
-
.attr('x', -innerHeight / 2)
|
| 404 |
-
.attr('fill', 'currentColor')
|
| 405 |
-
.style('text-anchor', 'middle')
|
| 406 |
-
.style('font-size', '14px')
|
| 407 |
-
.text('Proportion of Models');
|
| 408 |
-
|
| 409 |
-
// Legend
|
| 410 |
-
const legend = g.append('g')
|
| 411 |
-
.attr('transform', `translate(${innerWidth - 150}, 20)`);
|
| 412 |
-
|
| 413 |
-
licenseTypes.forEach((license, i) => {
|
| 414 |
-
const legendRow = legend.append('g')
|
| 415 |
-
.attr('transform', `translate(0, ${i * 20})`);
|
| 416 |
-
|
| 417 |
-
legendRow.append('rect')
|
| 418 |
-
.attr('width', 15)
|
| 419 |
-
.attr('height', 15)
|
| 420 |
-
.attr('fill', colorScale(license));
|
| 421 |
-
|
| 422 |
-
legendRow.append('text')
|
| 423 |
-
.attr('x', 20)
|
| 424 |
-
.attr('y', 12)
|
| 425 |
-
.style('font-size', '12px')
|
| 426 |
-
.text(license.length > 15 ? license.substring(0, 15) + '...' : license);
|
| 427 |
-
});
|
| 428 |
-
|
| 429 |
-
// Title
|
| 430 |
-
svg.append('text')
|
| 431 |
-
.attr('x', width / 2)
|
| 432 |
-
.attr('y', 20)
|
| 433 |
-
.attr('text-anchor', 'middle')
|
| 434 |
-
.style('font-size', '16px')
|
| 435 |
-
.style('font-weight', 'bold')
|
| 436 |
-
.text('License Distribution Across Family Generations');
|
| 437 |
-
|
| 438 |
-
}, [activePlot, data, width, height, familyTreeData]);
|
| 439 |
-
|
| 440 |
-
// Plot 4: Model Card Length Distribution
|
| 441 |
-
useEffect(() => {
|
| 442 |
-
if (activePlot !== 'model-card-length' || !modelCardLengthRef.current || !data.length) return;
|
| 443 |
-
|
| 444 |
-
const svg = d3.select(modelCardLengthRef.current);
|
| 445 |
-
svg.selectAll('*').remove();
|
| 446 |
-
|
| 447 |
-
const margin = { top: 40, right: 40, bottom: 60, left: 60 };
|
| 448 |
-
const innerWidth = width - margin.left - margin.right;
|
| 449 |
-
const innerHeight = height - margin.top - margin.bottom;
|
| 450 |
-
|
| 451 |
-
// Placeholder: Would need model card length data
|
| 452 |
-
// In the paper, this shows model cards getting shorter and more standardized
|
| 453 |
-
const g = svg.append('g')
|
| 454 |
-
.attr('transform', `translate(${margin.left},${margin.top})`);
|
| 455 |
-
|
| 456 |
-
// Use real model card length data from API if available
|
| 457 |
-
let depthData = new Map<number, number[]>();
|
| 458 |
-
|
| 459 |
-
if (familyTreeData && familyTreeData.model_card_length_by_depth) {
|
| 460 |
-
// Use real data from API
|
| 461 |
-
const cardStats = familyTreeData.model_card_length_by_depth;
|
| 462 |
-
Object.keys(cardStats).forEach(depthStr => {
|
| 463 |
-
const depth = parseInt(depthStr);
|
| 464 |
-
const stats = cardStats[depthStr];
|
| 465 |
-
// Create synthetic distribution from stats (mean, q1, q3)
|
| 466 |
-
const lengths: number[] = [];
|
| 467 |
-
const count = Math.min(stats.count, 100); // Limit for performance
|
| 468 |
-
for (let i = 0; i < count; i++) {
|
| 469 |
-
// Generate values around the mean with spread based on quartiles
|
| 470 |
-
const spread = (stats.q3 - stats.q1) / 2;
|
| 471 |
-
const length = stats.mean + (Math.random() - 0.5) * spread * 2;
|
| 472 |
-
lengths.push(Math.max(0, length));
|
| 473 |
-
}
|
| 474 |
-
depthData.set(depth, lengths);
|
| 475 |
-
});
|
| 476 |
-
} else {
|
| 477 |
-
// Fallback: Calculate from current data
|
| 478 |
-
const depthGroups = new Map<number, number[]>();
|
| 479 |
-
data.forEach(model => {
|
| 480 |
-
const depth = model.family_depth || 0;
|
| 481 |
-
// We don't have model card length in ModelPoint, so use placeholder
|
| 482 |
-
// In a real implementation, this would come from the API
|
| 483 |
-
if (!depthGroups.has(depth)) {
|
| 484 |
-
depthGroups.set(depth, []);
|
| 485 |
-
}
|
| 486 |
-
});
|
| 487 |
-
depthData = depthGroups;
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
// If still no data, use simulated data
|
| 491 |
-
if (depthData.size === 0) {
|
| 492 |
-
for (let depth = 0; depth <= 5; depth++) {
|
| 493 |
-
const lengths = Array.from({ length: 50 }, () => {
|
| 494 |
-
const baseLength = 2000 - depth * 200;
|
| 495 |
-
return baseLength + (Math.random() - 0.5) * 500;
|
| 496 |
-
});
|
| 497 |
-
depthData.set(depth, lengths);
|
| 498 |
-
}
|
| 499 |
-
}
|
| 500 |
-
|
| 501 |
-
const depths = Array.from(depthData.keys()).sort((a, b) => a - b);
|
| 502 |
-
const maxDepth = d3.max(depths) || 5;
|
| 503 |
-
const allLengths = Array.from(depthData.values()).flat();
|
| 504 |
-
const maxLength = d3.max(allLengths) || 3000;
|
| 505 |
-
|
| 506 |
-
const xScale = d3.scaleBand()
|
| 507 |
-
.domain(depths.map(d => d.toString()))
|
| 508 |
-
.range([0, innerWidth])
|
| 509 |
-
.padding(0.2);
|
| 510 |
-
|
| 511 |
-
const yScale = d3.scaleLinear()
|
| 512 |
-
.domain([0, maxLength])
|
| 513 |
-
.range([innerHeight, 0])
|
| 514 |
-
.nice();
|
| 515 |
-
|
| 516 |
-
// Violin plot or box plot
|
| 517 |
-
depthData.forEach((lengths, depth) => {
|
| 518 |
-
const x = xScale(depth.toString());
|
| 519 |
-
const bandWidth = xScale.bandwidth();
|
| 520 |
-
|
| 521 |
-
if (x === undefined) return;
|
| 522 |
-
|
| 523 |
-
// Simple box plot
|
| 524 |
-
const sorted = lengths.sort((a, b) => a - b);
|
| 525 |
-
const q1 = d3.quantile(sorted, 0.25) || 0;
|
| 526 |
-
const q2 = d3.quantile(sorted, 0.5) || 0;
|
| 527 |
-
const q3 = d3.quantile(sorted, 0.75) || 0;
|
| 528 |
-
|
| 529 |
-
g.append('rect')
|
| 530 |
-
.attr('x', x)
|
| 531 |
-
.attr('y', yScale(q3))
|
| 532 |
-
.attr('width', bandWidth)
|
| 533 |
-
.attr('height', yScale(q1) - yScale(q3))
|
| 534 |
-
.attr('fill', '#4a90e2')
|
| 535 |
-
.attr('opacity', 0.6)
|
| 536 |
-
.attr('stroke', '#333');
|
| 537 |
-
|
| 538 |
-
g.append('line')
|
| 539 |
-
.attr('x1', x)
|
| 540 |
-
.attr('x2', x + bandWidth)
|
| 541 |
-
.attr('y1', yScale(q2))
|
| 542 |
-
.attr('y2', yScale(q2))
|
| 543 |
-
.attr('stroke', '#333')
|
| 544 |
-
.attr('stroke-width', 2);
|
| 545 |
-
});
|
| 546 |
-
|
| 547 |
-
const yAxis = d3.axisLeft(yScale);
|
| 548 |
-
g.append('g').call(yAxis)
|
| 549 |
-
.append('text')
|
| 550 |
-
.attr('transform', 'rotate(-90)')
|
| 551 |
-
.attr('y', -45)
|
| 552 |
-
.attr('x', -innerHeight / 2)
|
| 553 |
-
.attr('fill', 'currentColor')
|
| 554 |
-
.style('text-anchor', 'middle')
|
| 555 |
-
.style('font-size', '14px')
|
| 556 |
-
.text('Model Card Length (characters)');
|
| 557 |
-
|
| 558 |
-
// Title
|
| 559 |
-
svg.append('text')
|
| 560 |
-
.attr('x', width / 2)
|
| 561 |
-
.attr('y', 20)
|
| 562 |
-
.attr('text-anchor', 'middle')
|
| 563 |
-
.style('font-size', '16px')
|
| 564 |
-
.style('font-weight', 'bold')
|
| 565 |
-
.text('Model Card Length by Family Generation');
|
| 566 |
-
|
| 567 |
-
}, [activePlot, data, width, height, familyTreeData]);
|
| 568 |
-
|
| 569 |
-
// Plot 5: Growth Timeline
|
| 570 |
-
useEffect(() => {
|
| 571 |
-
if (activePlot !== 'growth-timeline' || !growthTimelineRef.current) return;
|
| 572 |
-
|
| 573 |
-
const svg = d3.select(growthTimelineRef.current);
|
| 574 |
-
svg.selectAll('*').remove();
|
| 575 |
-
|
| 576 |
-
const margin = { top: 40, right: 40, bottom: 60, left: 60 };
|
| 577 |
-
const innerWidth = width - margin.left - margin.right;
|
| 578 |
-
const innerHeight = height - margin.top - margin.bottom;
|
| 579 |
-
|
| 580 |
-
// Fetch growth data from model tracker API
|
| 581 |
-
fetch(`${API_BASE}/api/model-count/historical?days=365`)
|
| 582 |
-
.then(res => res.json())
|
| 583 |
-
.then(data => {
|
| 584 |
-
if (!data.counts || data.counts.length === 0) {
|
| 585 |
-
svg.append('text')
|
| 586 |
-
.attr('x', width / 2)
|
| 587 |
-
.attr('y', height / 2)
|
| 588 |
-
.attr('text-anchor', 'middle')
|
| 589 |
-
.text('No historical data available');
|
| 590 |
-
return;
|
| 591 |
-
}
|
| 592 |
-
|
| 593 |
-
const g = svg.append('g')
|
| 594 |
-
.attr('transform', `translate(${margin.left},${margin.top})`);
|
| 595 |
-
|
| 596 |
-
const counts = data.counts.map((d: any) => ({
|
| 597 |
-
date: new Date(d.timestamp),
|
| 598 |
-
count: d.total_models
|
| 599 |
-
})).sort((a: any, b: any) => a.date - b.date);
|
| 600 |
-
|
| 601 |
-
const extent = d3.extent(counts, (d: any) => d.date) as [Date | undefined, Date | undefined];
|
| 602 |
-
const minDate = extent[0];
|
| 603 |
-
const maxDate = extent[1];
|
| 604 |
-
if (!minDate || !maxDate) return;
|
| 605 |
-
|
| 606 |
-
const xScale = d3.scaleTime()
|
| 607 |
-
.domain([minDate, maxDate])
|
| 608 |
-
.range([0, innerWidth]);
|
| 609 |
-
|
| 610 |
-
const yScale = d3.scaleLinear()
|
| 611 |
-
.domain([0, d3.max(counts, (d: any) => d.count) || 0] as [number, number])
|
| 612 |
-
.range([innerHeight, 0])
|
| 613 |
-
.nice();
|
| 614 |
-
|
| 615 |
-
const line = d3.line<any>()
|
| 616 |
-
.x(d => xScale(d.date))
|
| 617 |
-
.y(d => yScale(d.count))
|
| 618 |
-
.curve(d3.curveMonotoneX);
|
| 619 |
-
|
| 620 |
-
g.append('path')
|
| 621 |
-
.datum(counts)
|
| 622 |
-
.attr('fill', 'none')
|
| 623 |
-
.attr('stroke', '#4a90e2')
|
| 624 |
-
.attr('stroke-width', 2)
|
| 625 |
-
.attr('d', line);
|
| 626 |
-
|
| 627 |
-
g.selectAll('circle')
|
| 628 |
-
.data(counts)
|
| 629 |
-
.enter()
|
| 630 |
-
.append('circle')
|
| 631 |
-
.attr('cx', (d: any) => xScale(d.date))
|
| 632 |
-
.attr('cy', (d: any) => yScale(d.count))
|
| 633 |
-
.attr('r', 3)
|
| 634 |
-
.attr('fill', '#4a90e2')
|
| 635 |
-
.on('mouseover', function(event, d: any) {
|
| 636 |
-
d3.select(this).attr('r', 5);
|
| 637 |
-
const tooltip = d3.select('body').append('div')
|
| 638 |
-
.attr('class', 'plot-tooltip')
|
| 639 |
-
.style('opacity', 0);
|
| 640 |
-
tooltip.transition().duration(200).style('opacity', 0.9);
|
| 641 |
-
tooltip.html(`${d.date.toLocaleDateString()}<br/>Models: ${d.count.toLocaleString()}`)
|
| 642 |
-
.style('left', (event.pageX + 10) + 'px')
|
| 643 |
-
.style('top', (event.pageY - 28) + 'px');
|
| 644 |
-
})
|
| 645 |
-
.on('mouseout', function() {
|
| 646 |
-
d3.select(this).attr('r', 3);
|
| 647 |
-
d3.selectAll('.plot-tooltip').remove();
|
| 648 |
-
});
|
| 649 |
-
|
| 650 |
-
const xAxis = d3.axisBottom(xScale).ticks(6);
|
| 651 |
-
const yAxis = d3.axisLeft(yScale).tickFormat(d3.format('.2s'));
|
| 652 |
-
|
| 653 |
-
g.append('g')
|
| 654 |
-
.attr('transform', `translate(0,${innerHeight})`)
|
| 655 |
-
.call(xAxis)
|
| 656 |
-
.append('text')
|
| 657 |
-
.attr('x', innerWidth / 2)
|
| 658 |
-
.attr('y', 45)
|
| 659 |
-
.attr('fill', 'currentColor')
|
| 660 |
-
.style('text-anchor', 'middle')
|
| 661 |
-
.style('font-size', '14px')
|
| 662 |
-
.text('Date');
|
| 663 |
-
|
| 664 |
-
g.append('g').call(yAxis)
|
| 665 |
-
.append('text')
|
| 666 |
-
.attr('transform', 'rotate(-90)')
|
| 667 |
-
.attr('y', -45)
|
| 668 |
-
.attr('x', -innerHeight / 2)
|
| 669 |
-
.attr('fill', 'currentColor')
|
| 670 |
-
.style('text-anchor', 'middle')
|
| 671 |
-
.style('font-size', '14px')
|
| 672 |
-
.text('Total Models');
|
| 673 |
-
|
| 674 |
-
svg.append('text')
|
| 675 |
-
.attr('x', width / 2)
|
| 676 |
-
.attr('y', 20)
|
| 677 |
-
.attr('text-anchor', 'middle')
|
| 678 |
-
.style('font-size', '16px')
|
| 679 |
-
.style('font-weight', 'bold')
|
| 680 |
-
.text('Model Count Growth Over Time');
|
| 681 |
-
})
|
| 682 |
-
.catch(err => {
|
| 683 |
-
console.error('Error fetching growth data:', err);
|
| 684 |
-
svg.append('text')
|
| 685 |
-
.attr('x', width / 2)
|
| 686 |
-
.attr('y', height / 2)
|
| 687 |
-
.attr('text-anchor', 'middle')
|
| 688 |
-
.text('Error loading growth data');
|
| 689 |
-
});
|
| 690 |
-
|
| 691 |
-
}, [activePlot, width, height]);
|
| 692 |
-
|
| 693 |
-
const plotOptions: { value: PlotType; label: string; description: string }[] = [
|
| 694 |
-
{ value: 'family-size', label: 'Family Size Distribution', description: 'Distribution of family tree sizes' },
|
| 695 |
-
{ value: 'similarity-comparison', label: 'Similarity Comparison', description: 'Sibling vs parent-child similarity' },
|
| 696 |
-
{ value: 'license-drift', label: 'License Drift', description: 'License changes across generations' },
|
| 697 |
-
{ value: 'model-card-length', label: 'Model Card Length', description: 'Model card length by generation' },
|
| 698 |
-
{ value: 'growth-timeline', label: 'Growth Timeline', description: 'Model count over time' },
|
| 699 |
-
];
|
| 700 |
-
|
| 701 |
-
return (
|
| 702 |
-
<div className="paper-plots">
|
| 703 |
-
<div className="plot-selector">
|
| 704 |
-
<h3>Paper Visualizations</h3>
|
| 705 |
-
<div className="plot-buttons">
|
| 706 |
-
{plotOptions.map(option => (
|
| 707 |
-
<button
|
| 708 |
-
key={option.value}
|
| 709 |
-
className={`plot-button ${activePlot === option.value ? 'active' : ''}`}
|
| 710 |
-
onClick={() => setActivePlot(option.value)}
|
| 711 |
-
title={option.description}
|
| 712 |
-
>
|
| 713 |
-
{option.label}
|
| 714 |
-
</button>
|
| 715 |
-
))}
|
| 716 |
-
</div>
|
| 717 |
-
</div>
|
| 718 |
-
|
| 719 |
-
<div className="plot-container">
|
| 720 |
-
{loading && <div className="plot-loading">Loading data...</div>}
|
| 721 |
-
<svg
|
| 722 |
-
ref={familySizeRef}
|
| 723 |
-
width={width}
|
| 724 |
-
height={height}
|
| 725 |
-
style={{ display: activePlot === 'family-size' ? 'block' : 'none' }}
|
| 726 |
-
/>
|
| 727 |
-
<svg
|
| 728 |
-
ref={similarityRef}
|
| 729 |
-
width={width}
|
| 730 |
-
height={height}
|
| 731 |
-
style={{ display: activePlot === 'similarity-comparison' ? 'block' : 'none' }}
|
| 732 |
-
/>
|
| 733 |
-
<svg
|
| 734 |
-
ref={licenseDriftRef}
|
| 735 |
-
width={width}
|
| 736 |
-
height={height}
|
| 737 |
-
style={{ display: activePlot === 'license-drift' ? 'block' : 'none' }}
|
| 738 |
-
/>
|
| 739 |
-
<svg
|
| 740 |
-
ref={modelCardLengthRef}
|
| 741 |
-
width={width}
|
| 742 |
-
height={height}
|
| 743 |
-
style={{ display: activePlot === 'model-card-length' ? 'block' : 'none' }}
|
| 744 |
-
/>
|
| 745 |
-
<svg
|
| 746 |
-
ref={growthTimelineRef}
|
| 747 |
-
width={width}
|
| 748 |
-
height={height}
|
| 749 |
-
style={{ display: activePlot === 'growth-timeline' ? 'block' : 'none' }}
|
| 750 |
-
/>
|
| 751 |
-
</div>
|
| 752 |
-
</div>
|
| 753 |
-
);
|
| 754 |
-
}
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/components/ScatterPlot.tsx
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
/**
|
| 2 |
-
* Legacy Visx scatter plot - kept for reference.
|
| 3 |
-
* Use EnhancedScatterPlot.tsx for D3.js implementation.
|
| 4 |
-
*/
|
| 5 |
-
// This file is kept for compatibility but EnhancedScatterPlot is preferred
|
| 6 |
-
|
| 7 |
-
export { default } from './EnhancedScatterPlot';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/components/controls/ClusterFilter.css
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.cluster-filter {
|
| 2 |
+
margin-bottom: 1.5rem;
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
.cluster-filter-header {
|
| 6 |
+
margin-bottom: 0.75rem;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
.cluster-filter-header h3 {
|
| 10 |
+
margin: 0;
|
| 11 |
+
font-size: 1rem;
|
| 12 |
+
font-weight: 600;
|
| 13 |
+
color: var(--text-primary, #1a1a1a);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.cluster-filter-search {
|
| 17 |
+
margin-bottom: 0.75rem;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.cluster-search-input {
|
| 21 |
+
width: 100%;
|
| 22 |
+
padding: 0.5rem;
|
| 23 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 24 |
+
border-radius: 4px;
|
| 25 |
+
font-size: 0.9rem;
|
| 26 |
+
background: var(--bg-primary, #ffffff);
|
| 27 |
+
color: var(--text-primary, #1a1a1a);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
.cluster-search-input:focus {
|
| 31 |
+
outline: none;
|
| 32 |
+
border-color: var(--accent-color, #4a90e2);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
.cluster-filter-actions {
|
| 36 |
+
display: flex;
|
| 37 |
+
gap: 0.5rem;
|
| 38 |
+
margin-bottom: 0.75rem;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.cluster-action-btn {
|
| 42 |
+
flex: 1;
|
| 43 |
+
padding: 0.4rem 0.6rem;
|
| 44 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 45 |
+
border-radius: 4px;
|
| 46 |
+
background: var(--bg-primary, #ffffff);
|
| 47 |
+
color: var(--text-primary, #1a1a1a);
|
| 48 |
+
font-size: 0.85rem;
|
| 49 |
+
cursor: pointer;
|
| 50 |
+
transition: all 0.2s;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.cluster-action-btn:hover:not(:disabled) {
|
| 54 |
+
background: var(--bg-secondary, #f5f5f5);
|
| 55 |
+
border-color: var(--accent-color, #4a90e2);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.cluster-action-btn:disabled {
|
| 59 |
+
opacity: 0.5;
|
| 60 |
+
cursor: not-allowed;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
.cluster-list {
|
| 64 |
+
max-height: 300px;
|
| 65 |
+
overflow-y: auto;
|
| 66 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 67 |
+
border-radius: 4px;
|
| 68 |
+
padding: 0.5rem;
|
| 69 |
+
background: var(--bg-primary, #ffffff);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.cluster-item {
|
| 73 |
+
display: flex;
|
| 74 |
+
align-items: center;
|
| 75 |
+
gap: 0.5rem;
|
| 76 |
+
padding: 0.5rem;
|
| 77 |
+
cursor: pointer;
|
| 78 |
+
border-radius: 3px;
|
| 79 |
+
transition: background 0.15s;
|
| 80 |
+
font-size: 0.85rem;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.cluster-item:hover {
|
| 84 |
+
background: var(--bg-secondary, #f5f5f5);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.cluster-item.selected {
|
| 88 |
+
background: var(--bg-secondary, #f5f5f5);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.cluster-checkbox {
|
| 92 |
+
margin: 0;
|
| 93 |
+
cursor: pointer;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
.cluster-color-indicator {
|
| 97 |
+
width: 12px;
|
| 98 |
+
height: 12px;
|
| 99 |
+
border-radius: 2px;
|
| 100 |
+
flex-shrink: 0;
|
| 101 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.cluster-label {
|
| 105 |
+
flex: 1;
|
| 106 |
+
color: var(--text-primary, #1a1a1a);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.cluster-count {
|
| 110 |
+
font-size: 0.75rem;
|
| 111 |
+
color: var(--text-secondary, #666);
|
| 112 |
+
margin-left: auto;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.cluster-filter-loading,
|
| 116 |
+
.cluster-filter-empty {
|
| 117 |
+
padding: 1rem;
|
| 118 |
+
text-align: center;
|
| 119 |
+
color: var(--text-secondary, #666);
|
| 120 |
+
font-size: 0.9rem;
|
| 121 |
+
}
|
| 122 |
+
|
frontend/src/components/controls/ClusterFilter.tsx
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Enhanced cluster filter component with search, Select All/Clear All/Random buttons.
|
| 3 |
+
* Inspired by LAION's cluster filtering UI.
|
| 4 |
+
*/
|
| 5 |
+
import React, { useState, useMemo } from 'react';
|
| 6 |
+
import { useFilterStore } from '../../stores/filterStore';
|
| 7 |
+
import './ClusterFilter.css';
|
| 8 |
+
|
| 9 |
+
export interface Cluster {
|
| 10 |
+
cluster_id: number;
|
| 11 |
+
cluster_label: string;
|
| 12 |
+
count: number;
|
| 13 |
+
color?: string;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
interface ClusterFilterProps {
|
| 17 |
+
clusters: Cluster[];
|
| 18 |
+
loading?: boolean;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
export default function ClusterFilter({ clusters, loading = false }: ClusterFilterProps) {
|
| 22 |
+
const { selectedClusters, setSelectedClusters } = useFilterStore();
|
| 23 |
+
const [searchTerm, setSearchTerm] = useState('');
|
| 24 |
+
|
| 25 |
+
const filteredClusters = useMemo(() => {
|
| 26 |
+
if (!searchTerm) return clusters;
|
| 27 |
+
const lowerSearch = searchTerm.toLowerCase();
|
| 28 |
+
return clusters.filter(c =>
|
| 29 |
+
c.cluster_label.toLowerCase().includes(lowerSearch) ||
|
| 30 |
+
c.cluster_id.toString().includes(lowerSearch)
|
| 31 |
+
);
|
| 32 |
+
}, [clusters, searchTerm]);
|
| 33 |
+
|
| 34 |
+
const handleSelectAll = () => {
|
| 35 |
+
setSelectedClusters(clusters.map(c => c.cluster_id));
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
const handleClearAll = () => {
|
| 39 |
+
setSelectedClusters([]);
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
const handleRandom = () => {
|
| 43 |
+
if (clusters.length === 0) return;
|
| 44 |
+
const randomCluster = clusters[Math.floor(Math.random() * clusters.length)];
|
| 45 |
+
setSelectedClusters([randomCluster.cluster_id]);
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
const handleToggleCluster = (clusterId: number) => {
|
| 49 |
+
if (selectedClusters.includes(clusterId)) {
|
| 50 |
+
setSelectedClusters(selectedClusters.filter(id => id !== clusterId));
|
| 51 |
+
} else {
|
| 52 |
+
setSelectedClusters([...selectedClusters, clusterId]);
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
if (loading) {
|
| 57 |
+
return (
|
| 58 |
+
<div className="cluster-filter">
|
| 59 |
+
<div className="cluster-filter-loading">Loading clusters...</div>
|
| 60 |
+
</div>
|
| 61 |
+
);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if (clusters.length === 0) {
|
| 65 |
+
return (
|
| 66 |
+
<div className="cluster-filter">
|
| 67 |
+
<div className="cluster-filter-empty">No clusters available</div>
|
| 68 |
+
</div>
|
| 69 |
+
);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return (
|
| 73 |
+
<div className="cluster-filter">
|
| 74 |
+
<div className="cluster-filter-header">
|
| 75 |
+
<h3>Dataset Clusters</h3>
|
| 76 |
+
</div>
|
| 77 |
+
|
| 78 |
+
<div className="cluster-filter-search">
|
| 79 |
+
<input
|
| 80 |
+
type="text"
|
| 81 |
+
placeholder={`Search ${clusters.length} clusters...`}
|
| 82 |
+
value={searchTerm}
|
| 83 |
+
onChange={(e) => setSearchTerm(e.target.value)}
|
| 84 |
+
className="cluster-search-input"
|
| 85 |
+
/>
|
| 86 |
+
</div>
|
| 87 |
+
|
| 88 |
+
<div className="cluster-filter-actions">
|
| 89 |
+
<button
|
| 90 |
+
onClick={handleSelectAll}
|
| 91 |
+
className="cluster-action-btn"
|
| 92 |
+
disabled={clusters.length === 0}
|
| 93 |
+
>
|
| 94 |
+
Select All
|
| 95 |
+
</button>
|
| 96 |
+
<button
|
| 97 |
+
onClick={handleClearAll}
|
| 98 |
+
className="cluster-action-btn"
|
| 99 |
+
disabled={selectedClusters.length === 0}
|
| 100 |
+
>
|
| 101 |
+
Clear All
|
| 102 |
+
</button>
|
| 103 |
+
<button
|
| 104 |
+
onClick={handleRandom}
|
| 105 |
+
className="cluster-action-btn"
|
| 106 |
+
disabled={clusters.length === 0}
|
| 107 |
+
>
|
| 108 |
+
Random
|
| 109 |
+
</button>
|
| 110 |
+
</div>
|
| 111 |
+
|
| 112 |
+
<div className="cluster-list">
|
| 113 |
+
{filteredClusters.length === 0 ? (
|
| 114 |
+
<div className="cluster-filter-empty">No clusters match your search</div>
|
| 115 |
+
) : (
|
| 116 |
+
filteredClusters.map(cluster => (
|
| 117 |
+
<label
|
| 118 |
+
key={cluster.cluster_id}
|
| 119 |
+
className={`cluster-item ${selectedClusters.includes(cluster.cluster_id) ? 'selected' : ''}`}
|
| 120 |
+
>
|
| 121 |
+
<input
|
| 122 |
+
type="checkbox"
|
| 123 |
+
checked={selectedClusters.includes(cluster.cluster_id)}
|
| 124 |
+
onChange={() => handleToggleCluster(cluster.cluster_id)}
|
| 125 |
+
className="cluster-checkbox"
|
| 126 |
+
/>
|
| 127 |
+
{cluster.color && (
|
| 128 |
+
<span
|
| 129 |
+
className="cluster-color-indicator"
|
| 130 |
+
style={{ backgroundColor: cluster.color }}
|
| 131 |
+
/>
|
| 132 |
+
)}
|
| 133 |
+
<span className="cluster-label">{cluster.cluster_label}</span>
|
| 134 |
+
<span className="cluster-count">({cluster.count.toLocaleString()})</span>
|
| 135 |
+
</label>
|
| 136 |
+
))
|
| 137 |
+
)}
|
| 138 |
+
</div>
|
| 139 |
+
</div>
|
| 140 |
+
);
|
| 141 |
+
}
|
| 142 |
+
|
frontend/src/components/controls/NodeDensitySlider.css
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.node-density-slider {
|
| 2 |
+
margin-bottom: 1rem;
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
.node-density-label {
|
| 6 |
+
display: block;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
.node-density-title {
|
| 10 |
+
font-weight: 500;
|
| 11 |
+
display: block;
|
| 12 |
+
margin-bottom: 0.5rem;
|
| 13 |
+
color: var(--text-primary, #1a1a1a);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.node-density-input {
|
| 17 |
+
width: 100%;
|
| 18 |
+
cursor: pointer;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
.node-density-input:disabled {
|
| 22 |
+
opacity: 0.5;
|
| 23 |
+
cursor: not-allowed;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
.node-density-hint {
|
| 27 |
+
font-size: 0.75rem;
|
| 28 |
+
color: var(--text-secondary, #666);
|
| 29 |
+
margin-top: 0.25rem;
|
| 30 |
+
}
|
| 31 |
+
|
frontend/src/components/controls/NodeDensitySlider.tsx
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Node density slider for controlling rendering performance.
|
| 3 |
+
* Lower density improves performance for large datasets.
|
| 4 |
+
*/
|
| 5 |
+
import React from 'react';
|
| 6 |
+
import { useFilterStore } from '../../stores/filterStore';
|
| 7 |
+
import './NodeDensitySlider.css';
|
| 8 |
+
|
| 9 |
+
interface NodeDensitySliderProps {
|
| 10 |
+
disabled?: boolean;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
export default function NodeDensitySlider({ disabled = false }: NodeDensitySliderProps) {
|
| 14 |
+
const { nodeDensity, setNodeDensity } = useFilterStore();
|
| 15 |
+
|
| 16 |
+
return (
|
| 17 |
+
<div className="node-density-slider">
|
| 18 |
+
<label className="node-density-label">
|
| 19 |
+
<span className="node-density-title">
|
| 20 |
+
Node Density ({nodeDensity}%)
|
| 21 |
+
</span>
|
| 22 |
+
<input
|
| 23 |
+
type="range"
|
| 24 |
+
min="10"
|
| 25 |
+
max="100"
|
| 26 |
+
step="10"
|
| 27 |
+
value={nodeDensity}
|
| 28 |
+
onChange={(e) => setNodeDensity(parseInt(e.target.value))}
|
| 29 |
+
disabled={disabled}
|
| 30 |
+
className="node-density-input"
|
| 31 |
+
/>
|
| 32 |
+
<div className="node-density-hint">
|
| 33 |
+
Lower density improves performance for large datasets
|
| 34 |
+
</div>
|
| 35 |
+
</label>
|
| 36 |
+
</div>
|
| 37 |
+
);
|
| 38 |
+
}
|
| 39 |
+
|
frontend/src/components/controls/RandomModelButton.tsx
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Button to select a random model from the dataset for discovery.
|
| 3 |
+
*/
|
| 4 |
+
import React from 'react';
|
| 5 |
+
import { ModelPoint } from '../../types';
|
| 6 |
+
|
| 7 |
+
interface RandomModelButtonProps {
|
| 8 |
+
data: ModelPoint[];
|
| 9 |
+
onSelect: (model: ModelPoint) => void;
|
| 10 |
+
disabled?: boolean;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
export default function RandomModelButton({ data, onSelect, disabled }: RandomModelButtonProps) {
|
| 14 |
+
const handleRandomSelect = () => {
|
| 15 |
+
if (data.length === 0) return;
|
| 16 |
+
const randomIndex = Math.floor(Math.random() * data.length);
|
| 17 |
+
onSelect(data[randomIndex]);
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
return (
|
| 21 |
+
<button
|
| 22 |
+
onClick={handleRandomSelect}
|
| 23 |
+
disabled={disabled || data.length === 0}
|
| 24 |
+
className="random-model-btn"
|
| 25 |
+
title="Select a random model"
|
| 26 |
+
aria-label="Select random model"
|
| 27 |
+
>
|
| 28 |
+
<span>Select Random Model</span>
|
| 29 |
+
</button>
|
| 30 |
+
);
|
| 31 |
+
}
|
| 32 |
+
|
frontend/src/components/controls/RenderingStyleSelector.css
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.rendering-style-selector {
|
| 2 |
+
margin-bottom: 1rem;
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
.rendering-style-label {
|
| 6 |
+
display: block;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
.rendering-style-title {
|
| 10 |
+
font-weight: 500;
|
| 11 |
+
display: block;
|
| 12 |
+
margin-bottom: 0.5rem;
|
| 13 |
+
color: var(--text-primary, #1a1a1a);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.rendering-style-select {
|
| 17 |
+
width: 100%;
|
| 18 |
+
padding: 0.5rem;
|
| 19 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 20 |
+
border-radius: 4px;
|
| 21 |
+
background: var(--bg-primary, #ffffff);
|
| 22 |
+
color: var(--text-primary, #1a1a1a);
|
| 23 |
+
font-size: 0.9rem;
|
| 24 |
+
cursor: pointer;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.rendering-style-select:focus {
|
| 28 |
+
outline: none;
|
| 29 |
+
border-color: var(--accent-color, #4a90e2);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.rendering-style-hint {
|
| 33 |
+
font-size: 0.75rem;
|
| 34 |
+
color: var(--text-secondary, #666);
|
| 35 |
+
margin-top: 0.25rem;
|
| 36 |
+
}
|
| 37 |
+
|
frontend/src/components/controls/RenderingStyleSelector.tsx
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Rendering style selector dropdown.
|
| 3 |
+
* Allows users to choose different 3D layout/geometry styles.
|
| 4 |
+
*/
|
| 5 |
+
import React from 'react';
|
| 6 |
+
import { useFilterStore, RenderingStyle } from '../../stores/filterStore';
|
| 7 |
+
import './RenderingStyleSelector.css';
|
| 8 |
+
|
| 9 |
+
const STYLES: { value: RenderingStyle; label: string; description: string }[] = [
|
| 10 |
+
{ value: 'embeddings', label: 'Embeddings', description: 'Standard embedding-based layout' },
|
| 11 |
+
{ value: 'sphere', label: 'Sphere', description: 'Spherical arrangement of points' },
|
| 12 |
+
{ value: 'galaxy', label: 'Galaxy', description: 'Spiral galaxy-like layout' },
|
| 13 |
+
{ value: 'wave', label: 'Wave', description: 'Wave pattern arrangement' },
|
| 14 |
+
{ value: 'helix', label: 'Helix', description: 'Helical/spiral arrangement' },
|
| 15 |
+
{ value: 'torus', label: 'Torus', description: 'Torus/donut-shaped layout' },
|
| 16 |
+
];
|
| 17 |
+
|
| 18 |
+
export default function RenderingStyleSelector() {
|
| 19 |
+
const { renderingStyle, setRenderingStyle } = useFilterStore();
|
| 20 |
+
|
| 21 |
+
return (
|
| 22 |
+
<div className="rendering-style-selector">
|
| 23 |
+
<label className="rendering-style-label">
|
| 24 |
+
<span className="rendering-style-title">Rendering Style</span>
|
| 25 |
+
<select
|
| 26 |
+
value={renderingStyle}
|
| 27 |
+
onChange={(e) => setRenderingStyle(e.target.value as RenderingStyle)}
|
| 28 |
+
className="rendering-style-select"
|
| 29 |
+
>
|
| 30 |
+
{STYLES.map(style => (
|
| 31 |
+
<option key={style.value} value={style.value}>
|
| 32 |
+
{style.label}
|
| 33 |
+
</option>
|
| 34 |
+
))}
|
| 35 |
+
</select>
|
| 36 |
+
<div className="rendering-style-hint">
|
| 37 |
+
{STYLES.find(s => s.value === renderingStyle)?.description}
|
| 38 |
+
</div>
|
| 39 |
+
</label>
|
| 40 |
+
</div>
|
| 41 |
+
);
|
| 42 |
+
}
|
| 43 |
+
|
frontend/src/components/controls/ThemeToggle.tsx
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Toggle button for switching between light and dark themes.
|
| 3 |
+
*/
|
| 4 |
+
import React from 'react';
|
| 5 |
+
import { useFilterStore } from '../../stores/filterStore';
|
| 6 |
+
|
| 7 |
+
export default function ThemeToggle() {
|
| 8 |
+
const theme = useFilterStore((state) => state.theme);
|
| 9 |
+
const toggleTheme = useFilterStore((state) => state.toggleTheme);
|
| 10 |
+
|
| 11 |
+
return (
|
| 12 |
+
<button
|
| 13 |
+
onClick={toggleTheme}
|
| 14 |
+
className="theme-toggle"
|
| 15 |
+
title={`Switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
|
| 16 |
+
aria-label={`Current theme: ${theme}. Click to switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
|
| 17 |
+
>
|
| 18 |
+
{theme === 'light' ? 'π' : 'βοΈ'}
|
| 19 |
+
</button>
|
| 20 |
+
);
|
| 21 |
+
}
|
| 22 |
+
|
frontend/src/components/controls/VisualizationModeButtons.css
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.visualization-mode-buttons {
|
| 2 |
+
position: sticky;
|
| 3 |
+
top: 0;
|
| 4 |
+
z-index: 100;
|
| 5 |
+
background: var(--bg-primary, #ffffff);
|
| 6 |
+
border-bottom: 1px solid var(--border-color, #e0e0e0);
|
| 7 |
+
padding: 0.75rem 1rem;
|
| 8 |
+
margin-bottom: 1rem;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
.mode-buttons-container {
|
| 12 |
+
display: flex;
|
| 13 |
+
gap: 0.5rem;
|
| 14 |
+
flex-wrap: wrap;
|
| 15 |
+
justify-content: center;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.mode-button {
|
| 19 |
+
display: flex;
|
| 20 |
+
align-items: center;
|
| 21 |
+
gap: 0.5rem;
|
| 22 |
+
padding: 0.5rem 1rem;
|
| 23 |
+
border: 1px solid var(--border-color, #e0e0e0);
|
| 24 |
+
border-radius: 6px;
|
| 25 |
+
background: var(--bg-primary, #ffffff);
|
| 26 |
+
color: var(--text-primary, #1a1a1a);
|
| 27 |
+
font-size: 0.9rem;
|
| 28 |
+
cursor: pointer;
|
| 29 |
+
transition: all 0.2s;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.mode-button:hover {
|
| 33 |
+
background: var(--bg-secondary, #f5f5f5);
|
| 34 |
+
border-color: var(--accent-color, #4a90e2);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.mode-button.active {
|
| 38 |
+
background: var(--accent-color, #4a90e2);
|
| 39 |
+
color: white;
|
| 40 |
+
border-color: var(--accent-color, #4a90e2);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.mode-icon {
|
| 44 |
+
font-size: 1.1rem;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.mode-label {
|
| 48 |
+
font-weight: 500;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
@media (max-width: 768px) {
|
| 52 |
+
.mode-buttons-container {
|
| 53 |
+
gap: 0.25rem;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.mode-button {
|
| 57 |
+
padding: 0.4rem 0.6rem;
|
| 58 |
+
font-size: 0.8rem;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.mode-label {
|
| 62 |
+
display: none;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
frontend/src/components/controls/VisualizationModeButtons.tsx
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Visualization mode buttons with sticky header.
|
| 3 |
+
* Inspired by LAION's mode selection UI.
|
| 4 |
+
*/
|
| 5 |
+
import React from 'react';
|
| 6 |
+
import { useFilterStore, ViewMode } from '../../stores/filterStore';
|
| 7 |
+
import './VisualizationModeButtons.css';
|
| 8 |
+
|
| 9 |
+
interface ModeOption {
|
| 10 |
+
value: ViewMode;
|
| 11 |
+
label: string;
|
| 12 |
+
icon: string;
|
| 13 |
+
description: string;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
const MODES: ModeOption[] = [
|
| 17 |
+
{ value: '3d', label: '3D Embedding', icon: 'π―', description: 'Interactive 3D exploration' },
|
| 18 |
+
{ value: 'scatter', label: '2D Scatter', icon: 'π', description: '2D projection view' },
|
| 19 |
+
{ value: 'network', label: 'Network', icon: 'πΈοΈ', description: 'Network graph view' },
|
| 20 |
+
{ value: 'distribution', label: 'Distribution', icon: 'π', description: 'Statistical distributions' },
|
| 21 |
+
{ value: 'stacked', label: 'Stacked', icon: 'π', description: 'Hierarchical view' },
|
| 22 |
+
{ value: 'heatmap', label: 'Heatmap', icon: 'π₯', description: 'Density heatmap' },
|
| 23 |
+
];
|
| 24 |
+
|
| 25 |
+
export default function VisualizationModeButtons() {
|
| 26 |
+
const { viewMode, setViewMode } = useFilterStore();
|
| 27 |
+
|
| 28 |
+
return (
|
| 29 |
+
<div className="visualization-mode-buttons">
|
| 30 |
+
<div className="mode-buttons-container">
|
| 31 |
+
{MODES.map(mode => (
|
| 32 |
+
<button
|
| 33 |
+
key={mode.value}
|
| 34 |
+
className={`mode-button ${viewMode === mode.value ? 'active' : ''}`}
|
| 35 |
+
onClick={() => setViewMode(mode.value)}
|
| 36 |
+
title={mode.description}
|
| 37 |
+
>
|
| 38 |
+
<span className="mode-icon">{mode.icon}</span>
|
| 39 |
+
<span className="mode-label">{mode.label}</span>
|
| 40 |
+
</button>
|
| 41 |
+
))}
|
| 42 |
+
</div>
|
| 43 |
+
</div>
|
| 44 |
+
);
|
| 45 |
+
}
|
| 46 |
+
|
frontend/src/components/controls/ZoomSlider.tsx
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Slider control for zoom level in 3D visualization.
|
| 3 |
+
*/
|
| 4 |
+
import React from 'react';
|
| 5 |
+
|
| 6 |
+
interface ZoomSliderProps {
|
| 7 |
+
value: number;
|
| 8 |
+
onChange: (value: number) => void;
|
| 9 |
+
min?: number;
|
| 10 |
+
max?: number;
|
| 11 |
+
step?: number;
|
| 12 |
+
disabled?: boolean;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export default function ZoomSlider({
|
| 16 |
+
value,
|
| 17 |
+
onChange,
|
| 18 |
+
min = 0.1,
|
| 19 |
+
max = 5,
|
| 20 |
+
step = 0.1,
|
| 21 |
+
disabled = false,
|
| 22 |
+
}: ZoomSliderProps) {
|
| 23 |
+
return (
|
| 24 |
+
<div className="zoom-slider-container">
|
| 25 |
+
<label className="zoom-slider-label">
|
| 26 |
+
<span>Zoom Level</span>
|
| 27 |
+
<span className="zoom-value">{value.toFixed(1)}x</span>
|
| 28 |
+
</label>
|
| 29 |
+
<input
|
| 30 |
+
type="range"
|
| 31 |
+
min={min}
|
| 32 |
+
max={max}
|
| 33 |
+
step={step}
|
| 34 |
+
value={value}
|
| 35 |
+
onChange={(e) => onChange(parseFloat(e.target.value))}
|
| 36 |
+
disabled={disabled}
|
| 37 |
+
className="zoom-slider"
|
| 38 |
+
aria-label="Zoom level"
|
| 39 |
+
/>
|
| 40 |
+
</div>
|
| 41 |
+
);
|
| 42 |
+
}
|
| 43 |
+
|
frontend/src/components/layout/SearchBar.css
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.search-bar-container {
|
| 2 |
+
position: relative;
|
| 3 |
+
width: 100%;
|
| 4 |
+
max-width: 600px;
|
| 5 |
+
z-index: 1000;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
.search-bar {
|
| 9 |
+
position: relative;
|
| 10 |
+
display: flex;
|
| 11 |
+
align-items: center;
|
| 12 |
+
background: white;
|
| 13 |
+
border: 2px solid #e0e0e0;
|
| 14 |
+
border-radius: 8px;
|
| 15 |
+
padding: 8px 12px;
|
| 16 |
+
transition: border-color 0.2s;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
.search-bar:focus-within {
|
| 20 |
+
border-color: #4a90e2;
|
| 21 |
+
box-shadow: 0 0 0 3px rgba(74, 144, 226, 0.1);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
.search-input {
|
| 25 |
+
flex: 1;
|
| 26 |
+
border: none;
|
| 27 |
+
outline: none;
|
| 28 |
+
font-size: 14px;
|
| 29 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 30 |
+
color: #333;
|
| 31 |
+
background: transparent;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
.search-input::placeholder {
|
| 35 |
+
color: #999;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.search-loading {
|
| 39 |
+
margin-left: 8px;
|
| 40 |
+
color: #4a90e2;
|
| 41 |
+
animation: spin 1s linear infinite;
|
| 42 |
+
font-size: 16px;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.search-clear {
|
| 46 |
+
margin-left: 8px;
|
| 47 |
+
background: none;
|
| 48 |
+
border: none;
|
| 49 |
+
color: #999;
|
| 50 |
+
cursor: pointer;
|
| 51 |
+
font-size: 20px;
|
| 52 |
+
line-height: 1;
|
| 53 |
+
padding: 0;
|
| 54 |
+
width: 20px;
|
| 55 |
+
height: 20px;
|
| 56 |
+
display: flex;
|
| 57 |
+
align-items: center;
|
| 58 |
+
justify-content: center;
|
| 59 |
+
transition: color 0.2s;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
.search-clear:hover {
|
| 63 |
+
color: #333;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.search-results {
|
| 67 |
+
position: absolute;
|
| 68 |
+
top: 100%;
|
| 69 |
+
left: 0;
|
| 70 |
+
right: 0;
|
| 71 |
+
margin-top: 4px;
|
| 72 |
+
background: white;
|
| 73 |
+
border: 1px solid #e0e0e0;
|
| 74 |
+
border-radius: 8px;
|
| 75 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
| 76 |
+
max-height: 400px;
|
| 77 |
+
overflow-y: auto;
|
| 78 |
+
z-index: 1001;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.search-result {
|
| 82 |
+
padding: 12px;
|
| 83 |
+
cursor: pointer;
|
| 84 |
+
border-bottom: 1px solid #f0f0f0;
|
| 85 |
+
transition: background-color 0.15s;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.search-result:last-child {
|
| 89 |
+
border-bottom: none;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
.search-result:hover,
|
| 93 |
+
.search-result.selected {
|
| 94 |
+
background-color: #f5f5f5;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.result-header {
|
| 98 |
+
display: flex;
|
| 99 |
+
align-items: center;
|
| 100 |
+
gap: 8px;
|
| 101 |
+
margin-bottom: 6px;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.result-model-id {
|
| 105 |
+
font-size: 14px;
|
| 106 |
+
font-weight: 600;
|
| 107 |
+
color: #333;
|
| 108 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.result-org {
|
| 112 |
+
font-size: 12px;
|
| 113 |
+
color: #666;
|
| 114 |
+
background: #f0f0f0;
|
| 115 |
+
padding: 2px 6px;
|
| 116 |
+
border-radius: 4px;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
.result-meta {
|
| 120 |
+
display: flex;
|
| 121 |
+
flex-wrap: wrap;
|
| 122 |
+
gap: 6px;
|
| 123 |
+
margin-bottom: 4px;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.result-tag {
|
| 127 |
+
font-size: 11px;
|
| 128 |
+
color: #666;
|
| 129 |
+
background: #e8e8e8;
|
| 130 |
+
padding: 2px 6px;
|
| 131 |
+
border-radius: 3px;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
.result-snippet {
|
| 135 |
+
font-size: 12px;
|
| 136 |
+
color: #666;
|
| 137 |
+
margin-top: 4px;
|
| 138 |
+
line-height: 1.4;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.result-snippet mark {
|
| 142 |
+
background: #fff3cd;
|
| 143 |
+
padding: 1px 2px;
|
| 144 |
+
border-radius: 2px;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.search-no-results {
|
| 148 |
+
padding: 16px;
|
| 149 |
+
text-align: center;
|
| 150 |
+
color: #999;
|
| 151 |
+
font-size: 14px;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
@keyframes spin {
|
| 155 |
+
from {
|
| 156 |
+
transform: rotate(0deg);
|
| 157 |
+
}
|
| 158 |
+
to {
|
| 159 |
+
transform: rotate(360deg);
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/* Scrollbar styling */
|
| 164 |
+
.search-results::-webkit-scrollbar {
|
| 165 |
+
width: 8px;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.search-results::-webkit-scrollbar-track {
|
| 169 |
+
background: #f1f1f1;
|
| 170 |
+
border-radius: 4px;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
.search-results::-webkit-scrollbar-thumb {
|
| 174 |
+
background: #c1c1c1;
|
| 175 |
+
border-radius: 4px;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.search-results::-webkit-scrollbar-thumb:hover {
|
| 179 |
+
background: #a8a8a8;
|
| 180 |
+
}
|
| 181 |
+
|
frontend/src/components/layout/SearchBar.tsx
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Enhanced search bar with autocomplete and keyboard navigation.
|
| 3 |
+
* Integrates with filter store and triggers map zoom/modal open.
|
| 4 |
+
*/
|
| 5 |
+
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
| 6 |
+
import { useFilterStore } from '../../stores/filterStore';
|
| 7 |
+
import './SearchBar.css';
|
| 8 |
+
|
| 9 |
+
import { API_BASE } from '../../config/api';
|
| 10 |
+
|
| 11 |
+
interface SearchResult {
|
| 12 |
+
model_id: string;
|
| 13 |
+
x: number;
|
| 14 |
+
y: number;
|
| 15 |
+
z: number;
|
| 16 |
+
org: string;
|
| 17 |
+
library?: string;
|
| 18 |
+
pipeline?: string;
|
| 19 |
+
license?: string;
|
| 20 |
+
snippet?: string;
|
| 21 |
+
match_score?: number;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
interface SearchBarProps {
|
| 25 |
+
onSelect?: (result: SearchResult) => void;
|
| 26 |
+
onZoomTo?: (x: number, y: number, z: number) => void;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
export default function SearchBar({ onSelect, onZoomTo }: SearchBarProps) {
|
| 30 |
+
const [query, setQuery] = useState('');
|
| 31 |
+
const [results, setResults] = useState<SearchResult[]>([]);
|
| 32 |
+
const [selectedIndex, setSelectedIndex] = useState(-1);
|
| 33 |
+
const [isOpen, setIsOpen] = useState(false);
|
| 34 |
+
const [isLoading, setIsLoading] = useState(false);
|
| 35 |
+
const inputRef = useRef<HTMLInputElement>(null);
|
| 36 |
+
const resultsRef = useRef<HTMLDivElement>(null);
|
| 37 |
+
|
| 38 |
+
const setSearchQuery = useFilterStore((state) => state.setSearchQuery);
|
| 39 |
+
|
| 40 |
+
// Debounced search
|
| 41 |
+
useEffect(() => {
|
| 42 |
+
if (query.length < 2) {
|
| 43 |
+
setResults([]);
|
| 44 |
+
setIsOpen(false);
|
| 45 |
+
return;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
setIsLoading(true);
|
| 49 |
+
const timer = setTimeout(async () => {
|
| 50 |
+
try {
|
| 51 |
+
const response = await fetch(
|
| 52 |
+
`${API_BASE}/api/search?q=${encodeURIComponent(query)}&limit=20`
|
| 53 |
+
);
|
| 54 |
+
if (!response.ok) throw new Error('Search failed');
|
| 55 |
+
const data = await response.json();
|
| 56 |
+
setResults(data.results || []);
|
| 57 |
+
setIsOpen(true);
|
| 58 |
+
setSelectedIndex(-1);
|
| 59 |
+
} catch (err) {
|
| 60 |
+
console.error('Search error:', err);
|
| 61 |
+
setResults([]);
|
| 62 |
+
} finally {
|
| 63 |
+
setIsLoading(false);
|
| 64 |
+
}
|
| 65 |
+
}, 150);
|
| 66 |
+
|
| 67 |
+
return () => clearTimeout(timer);
|
| 68 |
+
}, [query]);
|
| 69 |
+
|
| 70 |
+
const handleSelect = useCallback((result: SearchResult) => {
|
| 71 |
+
setSearchQuery(result.model_id);
|
| 72 |
+
|
| 73 |
+
// Trigger zoom if coordinates available
|
| 74 |
+
if (onZoomTo && result.x !== undefined && result.y !== undefined) {
|
| 75 |
+
onZoomTo(result.x, result.y, result.z || 0);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// Trigger select callback
|
| 79 |
+
if (onSelect) {
|
| 80 |
+
onSelect(result);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
setIsOpen(false);
|
| 84 |
+
setQuery('');
|
| 85 |
+
inputRef.current?.blur();
|
| 86 |
+
}, [onSelect, onZoomTo, setSearchQuery]);
|
| 87 |
+
|
| 88 |
+
const handleKeyDown = (e: React.KeyboardEvent) => {
|
| 89 |
+
if (!isOpen || results.length === 0) return;
|
| 90 |
+
|
| 91 |
+
if (e.key === 'ArrowDown') {
|
| 92 |
+
e.preventDefault();
|
| 93 |
+
setSelectedIndex(prev =>
|
| 94 |
+
prev < results.length - 1 ? prev + 1 : prev
|
| 95 |
+
);
|
| 96 |
+
// Scroll into view
|
| 97 |
+
if (resultsRef.current && selectedIndex >= 0) {
|
| 98 |
+
const selectedElement = resultsRef.current.children[selectedIndex + 1] as HTMLElement;
|
| 99 |
+
selectedElement?.scrollIntoView({ block: 'nearest' });
|
| 100 |
+
}
|
| 101 |
+
} else if (e.key === 'ArrowUp') {
|
| 102 |
+
e.preventDefault();
|
| 103 |
+
setSelectedIndex(prev => prev > 0 ? prev - 1 : -1);
|
| 104 |
+
} else if (e.key === 'Enter') {
|
| 105 |
+
e.preventDefault();
|
| 106 |
+
if (selectedIndex >= 0 && results[selectedIndex]) {
|
| 107 |
+
handleSelect(results[selectedIndex]);
|
| 108 |
+
} else if (results.length > 0) {
|
| 109 |
+
handleSelect(results[0]);
|
| 110 |
+
}
|
| 111 |
+
} else if (e.key === 'Escape') {
|
| 112 |
+
setIsOpen(false);
|
| 113 |
+
inputRef.current?.blur();
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
const handleFocus = () => {
|
| 118 |
+
if (results.length > 0) {
|
| 119 |
+
setIsOpen(true);
|
| 120 |
+
}
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
const handleBlur = (e: React.FocusEvent) => {
|
| 124 |
+
// Delay to allow click events on results
|
| 125 |
+
setTimeout(() => {
|
| 126 |
+
if (!resultsRef.current?.contains(document.activeElement)) {
|
| 127 |
+
setIsOpen(false);
|
| 128 |
+
}
|
| 129 |
+
}, 200);
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
return (
|
| 133 |
+
<div className="search-bar-container">
|
| 134 |
+
<div className="search-bar">
|
| 135 |
+
<input
|
| 136 |
+
ref={inputRef}
|
| 137 |
+
type="text"
|
| 138 |
+
value={query}
|
| 139 |
+
onChange={(e) => setQuery(e.target.value)}
|
| 140 |
+
onKeyDown={handleKeyDown}
|
| 141 |
+
onFocus={handleFocus}
|
| 142 |
+
onBlur={handleBlur}
|
| 143 |
+
placeholder="Search models, orgs, tasks, licenses..."
|
| 144 |
+
className="search-input"
|
| 145 |
+
aria-label="Search models"
|
| 146 |
+
aria-expanded={isOpen}
|
| 147 |
+
aria-haspopup="listbox"
|
| 148 |
+
/>
|
| 149 |
+
{isLoading && <div className="search-loading">β³</div>}
|
| 150 |
+
{query.length > 0 && !isLoading && (
|
| 151 |
+
<button
|
| 152 |
+
className="search-clear"
|
| 153 |
+
onClick={() => {
|
| 154 |
+
setQuery('');
|
| 155 |
+
setResults([]);
|
| 156 |
+
setIsOpen(false);
|
| 157 |
+
}}
|
| 158 |
+
aria-label="Clear search"
|
| 159 |
+
>
|
| 160 |
+
Γ
|
| 161 |
+
</button>
|
| 162 |
+
)}
|
| 163 |
+
</div>
|
| 164 |
+
{isOpen && results.length > 0 && (
|
| 165 |
+
<div ref={resultsRef} className="search-results" role="listbox">
|
| 166 |
+
{results.map((result, idx) => (
|
| 167 |
+
<div
|
| 168 |
+
key={result.model_id}
|
| 169 |
+
className={`search-result ${idx === selectedIndex ? 'selected' : ''}`}
|
| 170 |
+
onClick={() => handleSelect(result)}
|
| 171 |
+
role="option"
|
| 172 |
+
aria-selected={idx === selectedIndex}
|
| 173 |
+
>
|
| 174 |
+
<div className="result-header">
|
| 175 |
+
<strong className="result-model-id">{result.model_id}</strong>
|
| 176 |
+
{result.org && <span className="result-org">{result.org}</span>}
|
| 177 |
+
</div>
|
| 178 |
+
<div className="result-meta">
|
| 179 |
+
{result.library && <span className="result-tag">{result.library}</span>}
|
| 180 |
+
{result.pipeline && <span className="result-tag">{result.pipeline}</span>}
|
| 181 |
+
{result.license && <span className="result-tag">{result.license}</span>}
|
| 182 |
+
</div>
|
| 183 |
+
{result.snippet && (
|
| 184 |
+
<div
|
| 185 |
+
className="result-snippet"
|
| 186 |
+
dangerouslySetInnerHTML={{ __html: result.snippet }}
|
| 187 |
+
/>
|
| 188 |
+
)}
|
| 189 |
+
</div>
|
| 190 |
+
))}
|
| 191 |
+
</div>
|
| 192 |
+
)}
|
| 193 |
+
{isOpen && query.length >= 2 && results.length === 0 && !isLoading && (
|
| 194 |
+
<div className="search-results">
|
| 195 |
+
<div className="search-no-results">No results found</div>
|
| 196 |
+
</div>
|
| 197 |
+
)}
|
| 198 |
+
</div>
|
| 199 |
+
);
|
| 200 |
+
}
|
| 201 |
+
|
frontend/src/components/{FileTree.css β modals/FileTree.css}
RENAMED
|
@@ -3,8 +3,11 @@
|
|
| 3 |
border: 1px solid #e0e0e0;
|
| 4 |
border-radius: 4px;
|
| 5 |
background: #fafafa;
|
| 6 |
-
max-height:
|
| 7 |
overflow-y: auto;
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
| 9 |
|
| 10 |
.file-tree-header {
|
|
@@ -16,6 +19,19 @@
|
|
| 16 |
border-bottom: 1px solid #e0e0e0;
|
| 17 |
font-size: 0.9rem;
|
| 18 |
font-weight: 600;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
.file-tree-link {
|
|
@@ -23,16 +39,110 @@
|
|
| 23 |
text-decoration: none;
|
| 24 |
font-size: 0.85rem;
|
| 25 |
font-weight: 400;
|
|
|
|
| 26 |
}
|
| 27 |
|
| 28 |
.file-tree-link:hover {
|
| 29 |
text-decoration: underline;
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
.file-tree {
|
| 33 |
padding: 0.5rem;
|
| 34 |
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
|
| 35 |
font-size: 0.85rem;
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
|
| 38 |
.file-tree-node {
|
|
@@ -43,9 +153,14 @@
|
|
| 43 |
display: flex;
|
| 44 |
align-items: center;
|
| 45 |
gap: 0.5rem;
|
| 46 |
-
padding: 0.
|
| 47 |
-
border-radius:
|
| 48 |
transition: background 0.15s;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
}
|
| 50 |
|
| 51 |
.file-tree-item.directory:hover {
|
|
@@ -56,6 +171,37 @@
|
|
| 56 |
background: #f0f0f0;
|
| 57 |
}
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
.file-icon {
|
| 60 |
font-size: 1rem;
|
| 61 |
width: 1.25rem;
|
|
@@ -83,6 +229,9 @@
|
|
| 83 |
|
| 84 |
.file-tree-children {
|
| 85 |
margin-left: 0.5rem;
|
|
|
|
|
|
|
|
|
|
| 86 |
}
|
| 87 |
|
| 88 |
.file-tree-loading,
|
|
@@ -98,3 +247,22 @@
|
|
| 98 |
color: #d32f2f;
|
| 99 |
}
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
border: 1px solid #e0e0e0;
|
| 4 |
border-radius: 4px;
|
| 5 |
background: #fafafa;
|
| 6 |
+
max-height: 600px;
|
| 7 |
overflow-y: auto;
|
| 8 |
+
overflow-x: hidden;
|
| 9 |
+
display: flex;
|
| 10 |
+
flex-direction: column;
|
| 11 |
}
|
| 12 |
|
| 13 |
.file-tree-header {
|
|
|
|
| 19 |
border-bottom: 1px solid #e0e0e0;
|
| 20 |
font-size: 0.9rem;
|
| 21 |
font-weight: 600;
|
| 22 |
+
flex-shrink: 0;
|
| 23 |
+
position: sticky;
|
| 24 |
+
top: 0;
|
| 25 |
+
z-index: 10;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.file-count-badge {
|
| 29 |
+
background: #e3f2fd;
|
| 30 |
+
color: #1976d2;
|
| 31 |
+
padding: 0.2rem 0.5rem;
|
| 32 |
+
border-radius: 12px;
|
| 33 |
+
font-size: 0.75rem;
|
| 34 |
+
font-weight: 500;
|
| 35 |
}
|
| 36 |
|
| 37 |
.file-tree-link {
|
|
|
|
| 39 |
text-decoration: none;
|
| 40 |
font-size: 0.85rem;
|
| 41 |
font-weight: 400;
|
| 42 |
+
white-space: nowrap;
|
| 43 |
}
|
| 44 |
|
| 45 |
.file-tree-link:hover {
|
| 46 |
text-decoration: underline;
|
| 47 |
}
|
| 48 |
|
| 49 |
+
.file-tree-button {
|
| 50 |
+
background: #f0f0f0;
|
| 51 |
+
border: 1px solid #d0d0d0;
|
| 52 |
+
border-radius: 3px;
|
| 53 |
+
padding: 0.25rem 0.5rem;
|
| 54 |
+
font-size: 0.75rem;
|
| 55 |
+
cursor: pointer;
|
| 56 |
+
color: #333;
|
| 57 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 58 |
+
transition: background 0.15s;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.file-tree-button:hover {
|
| 62 |
+
background: #e0e0e0;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.file-tree-button:active {
|
| 66 |
+
background: #d0d0d0;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.file-tree-filters {
|
| 70 |
+
padding: 0.75rem 1rem;
|
| 71 |
+
background: #ffffff;
|
| 72 |
+
border-bottom: 1px solid #e0e0e0;
|
| 73 |
+
display: flex;
|
| 74 |
+
gap: 0.75rem;
|
| 75 |
+
flex-shrink: 0;
|
| 76 |
+
position: sticky;
|
| 77 |
+
top: 48px;
|
| 78 |
+
z-index: 9;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.file-tree-search {
|
| 82 |
+
flex: 1;
|
| 83 |
+
position: relative;
|
| 84 |
+
display: flex;
|
| 85 |
+
align-items: center;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.file-tree-search-input {
|
| 89 |
+
width: 100%;
|
| 90 |
+
padding: 0.5rem 2rem 0.5rem 0.75rem;
|
| 91 |
+
border: 1px solid #d0d0d0;
|
| 92 |
+
border-radius: 4px;
|
| 93 |
+
font-size: 0.85rem;
|
| 94 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.file-tree-search-input:focus {
|
| 98 |
+
outline: none;
|
| 99 |
+
border-color: #4a90e2;
|
| 100 |
+
box-shadow: 0 0 0 2px rgba(74, 144, 226, 0.1);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.file-tree-clear {
|
| 104 |
+
position: absolute;
|
| 105 |
+
right: 0.5rem;
|
| 106 |
+
background: none;
|
| 107 |
+
border: none;
|
| 108 |
+
cursor: pointer;
|
| 109 |
+
color: #666;
|
| 110 |
+
font-size: 1rem;
|
| 111 |
+
padding: 0.25rem;
|
| 112 |
+
display: flex;
|
| 113 |
+
align-items: center;
|
| 114 |
+
justify-content: center;
|
| 115 |
+
border-radius: 2px;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.file-tree-clear:hover {
|
| 119 |
+
background: #f0f0f0;
|
| 120 |
+
color: #1a1a1a;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.file-tree-type-filter {
|
| 124 |
+
padding: 0.5rem 0.75rem;
|
| 125 |
+
border: 1px solid #d0d0d0;
|
| 126 |
+
border-radius: 4px;
|
| 127 |
+
font-size: 0.85rem;
|
| 128 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 129 |
+
background: white;
|
| 130 |
+
cursor: pointer;
|
| 131 |
+
min-width: 150px;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
.file-tree-type-filter:focus {
|
| 135 |
+
outline: none;
|
| 136 |
+
border-color: #4a90e2;
|
| 137 |
+
box-shadow: 0 0 0 2px rgba(74, 144, 226, 0.1);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
.file-tree {
|
| 141 |
padding: 0.5rem;
|
| 142 |
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
|
| 143 |
font-size: 0.85rem;
|
| 144 |
+
flex: 1;
|
| 145 |
+
overflow-y: auto;
|
| 146 |
}
|
| 147 |
|
| 148 |
.file-tree-node {
|
|
|
|
| 153 |
display: flex;
|
| 154 |
align-items: center;
|
| 155 |
gap: 0.5rem;
|
| 156 |
+
padding: 0.375rem 0.5rem;
|
| 157 |
+
border-radius: 3px;
|
| 158 |
transition: background 0.15s;
|
| 159 |
+
user-select: none;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.file-tree-item.directory {
|
| 163 |
+
cursor: pointer;
|
| 164 |
}
|
| 165 |
|
| 166 |
.file-tree-item.directory:hover {
|
|
|
|
| 171 |
background: #f0f0f0;
|
| 172 |
}
|
| 173 |
|
| 174 |
+
.file-actions {
|
| 175 |
+
display: flex;
|
| 176 |
+
gap: 0.25rem;
|
| 177 |
+
margin-left: auto;
|
| 178 |
+
opacity: 0;
|
| 179 |
+
transition: opacity 0.2s;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.file-tree-item:hover .file-actions {
|
| 183 |
+
opacity: 1;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
.file-action-btn {
|
| 187 |
+
background: none;
|
| 188 |
+
border: none;
|
| 189 |
+
cursor: pointer;
|
| 190 |
+
font-size: 0.9rem;
|
| 191 |
+
padding: 0.25rem;
|
| 192 |
+
border-radius: 2px;
|
| 193 |
+
display: flex;
|
| 194 |
+
align-items: center;
|
| 195 |
+
justify-content: center;
|
| 196 |
+
transition: background 0.15s;
|
| 197 |
+
text-decoration: none;
|
| 198 |
+
color: inherit;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
.file-action-btn:hover {
|
| 202 |
+
background: rgba(0, 0, 0, 0.1);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
.file-icon {
|
| 206 |
font-size: 1rem;
|
| 207 |
width: 1.25rem;
|
|
|
|
| 229 |
|
| 230 |
.file-tree-children {
|
| 231 |
margin-left: 0.5rem;
|
| 232 |
+
border-left: 1px solid #e8e8e8;
|
| 233 |
+
padding-left: 0.5rem;
|
| 234 |
+
margin-top: 0.125rem;
|
| 235 |
}
|
| 236 |
|
| 237 |
.file-tree-loading,
|
|
|
|
| 247 |
color: #d32f2f;
|
| 248 |
}
|
| 249 |
|
| 250 |
+
/* Scrollbar styling */
|
| 251 |
+
.file-tree-container::-webkit-scrollbar {
|
| 252 |
+
width: 8px;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
.file-tree-container::-webkit-scrollbar-track {
|
| 256 |
+
background: #f1f1f1;
|
| 257 |
+
border-radius: 4px;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
.file-tree-container::-webkit-scrollbar-thumb {
|
| 261 |
+
background: #c1c1c1;
|
| 262 |
+
border-radius: 4px;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
.file-tree-container::-webkit-scrollbar-thumb:hover {
|
| 266 |
+
background: #a8a8a8;
|
| 267 |
+
}
|
| 268 |
+
|
frontend/src/components/{FileTree.tsx β modals/FileTree.tsx}
RENAMED
|
@@ -2,9 +2,12 @@
|
|
| 2 |
* File tree component for displaying model file structure.
|
| 3 |
* Fetches and displays files from Hugging Face model repository.
|
| 4 |
*/
|
| 5 |
-
import React, { useState, useEffect } from 'react';
|
|
|
|
| 6 |
import './FileTree.css';
|
| 7 |
|
|
|
|
|
|
|
| 8 |
interface FileNode {
|
| 9 |
path: string;
|
| 10 |
type: 'file' | 'directory';
|
|
@@ -21,50 +24,84 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 21 |
const [loading, setLoading] = useState(true);
|
| 22 |
const [error, setError] = useState<string | null>(null);
|
| 23 |
const [expandedPaths, setExpandedPaths] = useState<Set<string>>(new Set());
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
useEffect(() => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
const fetchFiles = async () => {
|
| 27 |
setLoading(true);
|
| 28 |
setError(null);
|
| 29 |
try {
|
| 30 |
-
// Fetch file tree through our backend API (avoids CORS issues)
|
| 31 |
-
// Use same API base as main app
|
| 32 |
-
const apiBase = (window as any).__API_BASE__ || process.env.REACT_APP_API_URL || 'http://localhost:8000';
|
| 33 |
const response = await fetch(
|
| 34 |
-
`${
|
| 35 |
);
|
| 36 |
|
| 37 |
-
if (
|
| 38 |
throw new Error('File tree not available for this model');
|
| 39 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
const data = await response.json();
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
// Convert flat list to tree structure
|
| 44 |
const tree = buildFileTree(data);
|
| 45 |
setFiles(tree);
|
| 46 |
} catch (err: any) {
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
} finally {
|
| 50 |
setLoading(false);
|
| 51 |
}
|
| 52 |
};
|
| 53 |
|
| 54 |
-
|
| 55 |
-
fetchFiles();
|
| 56 |
-
}
|
| 57 |
}, [modelId]);
|
| 58 |
|
| 59 |
const buildFileTree = (fileList: any[]): FileNode[] => {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
const tree: FileNode[] = [];
|
| 61 |
const pathMap = new Map<string, FileNode>();
|
| 62 |
|
| 63 |
-
// Sort files by path
|
| 64 |
-
const sortedFiles = [...fileList].sort((a, b) =>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
for (const file of sortedFiles) {
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
let currentPath = '';
|
| 69 |
let parent: FileNode | null = null;
|
| 70 |
|
|
@@ -77,7 +114,7 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 77 |
const node: FileNode = {
|
| 78 |
path: currentPath,
|
| 79 |
type: isDirectory ? 'directory' : 'file',
|
| 80 |
-
size: file.size,
|
| 81 |
children: isDirectory ? [] : undefined,
|
| 82 |
};
|
| 83 |
|
|
@@ -111,6 +148,26 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 111 |
});
|
| 112 |
};
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
const formatFileSize = (bytes?: number): string => {
|
| 115 |
if (!bytes) return '';
|
| 116 |
if (bytes < 1024) return `${bytes} B`;
|
|
@@ -119,6 +176,109 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 119 |
return `${(bytes / (1024 * 1024 * 1024)).toFixed(1)} GB`;
|
| 120 |
};
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
const getFileIcon = (node: FileNode): string => {
|
| 123 |
if (node.type === 'directory') {
|
| 124 |
return expandedPaths.has(node.path) ? 'π' : 'π';
|
|
@@ -141,9 +301,28 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 141 |
return iconMap[ext || ''] || 'π';
|
| 142 |
};
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
const renderNode = (node: FileNode, depth: number = 0): React.ReactNode => {
|
| 145 |
const isExpanded = expandedPaths.has(node.path);
|
| 146 |
const hasChildren = node.children && node.children.length > 0;
|
|
|
|
| 147 |
|
| 148 |
return (
|
| 149 |
<div key={node.path} className="file-tree-node" style={{ paddingLeft: `${depth * 1.5}rem` }}>
|
|
@@ -153,13 +332,37 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 153 |
style={{ cursor: node.type === 'directory' ? 'pointer' : 'default' }}
|
| 154 |
>
|
| 155 |
<span className="file-icon">{getFileIcon(node)}</span>
|
| 156 |
-
<span className="file-name"
|
| 157 |
{node.type === 'file' && node.size && (
|
| 158 |
<span className="file-size">{formatFileSize(node.size)}</span>
|
| 159 |
)}
|
| 160 |
{node.type === 'directory' && (
|
| 161 |
<span className="file-expand">{isExpanded ? 'βΌ' : 'βΆ'}</span>
|
| 162 |
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
</div>
|
| 164 |
{isExpanded && hasChildren && (
|
| 165 |
<div className="file-tree-children">
|
|
@@ -199,21 +402,106 @@ export default function FileTree({ modelId }: FileTreeProps) {
|
|
| 199 |
);
|
| 200 |
}
|
| 201 |
|
|
|
|
|
|
|
| 202 |
return (
|
| 203 |
<div className="file-tree-container">
|
| 204 |
<div className="file-tree-header">
|
| 205 |
-
<
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
<div className="file-tree">
|
| 216 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
</div>
|
| 218 |
</div>
|
| 219 |
);
|
|
|
|
| 2 |
* File tree component for displaying model file structure.
|
| 3 |
* Fetches and displays files from Hugging Face model repository.
|
| 4 |
*/
|
| 5 |
+
import React, { useState, useEffect, useMemo } from 'react';
|
| 6 |
+
import { getHuggingFaceFileTreeUrl } from '../../utils/api/hfUrl';
|
| 7 |
import './FileTree.css';
|
| 8 |
|
| 9 |
+
import { API_BASE } from '../../config/api';
|
| 10 |
+
|
| 11 |
interface FileNode {
|
| 12 |
path: string;
|
| 13 |
type: 'file' | 'directory';
|
|
|
|
| 24 |
const [loading, setLoading] = useState(true);
|
| 25 |
const [error, setError] = useState<string | null>(null);
|
| 26 |
const [expandedPaths, setExpandedPaths] = useState<Set<string>>(new Set());
|
| 27 |
+
const [searchQuery, setSearchQuery] = useState('');
|
| 28 |
+
const [fileTypeFilter, setFileTypeFilter] = useState<string>('all');
|
| 29 |
+
const [showSearch, setShowSearch] = useState(false);
|
| 30 |
+
const searchInputRef = React.useRef<HTMLInputElement>(null);
|
| 31 |
|
| 32 |
useEffect(() => {
|
| 33 |
+
if (!modelId) {
|
| 34 |
+
setLoading(false);
|
| 35 |
+
setError('No model ID provided');
|
| 36 |
+
return;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
const fetchFiles = async () => {
|
| 40 |
setLoading(true);
|
| 41 |
setError(null);
|
| 42 |
try {
|
|
|
|
|
|
|
|
|
|
| 43 |
const response = await fetch(
|
| 44 |
+
`${API_BASE}/api/model/${encodeURIComponent(modelId)}/files?branch=main`
|
| 45 |
);
|
| 46 |
|
| 47 |
+
if (response.status === 404) {
|
| 48 |
throw new Error('File tree not available for this model');
|
| 49 |
}
|
| 50 |
+
|
| 51 |
+
if (response.status === 503) {
|
| 52 |
+
throw new Error('Backend service unavailable');
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if (!response.ok) {
|
| 56 |
+
const errorText = await response.text();
|
| 57 |
+
throw new Error(`Failed to load file tree: ${response.status} ${errorText}`);
|
| 58 |
+
}
|
| 59 |
|
| 60 |
const data = await response.json();
|
| 61 |
|
| 62 |
+
if (!Array.isArray(data)) {
|
| 63 |
+
throw new Error('Invalid response format');
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
// Convert flat list to tree structure
|
| 67 |
const tree = buildFileTree(data);
|
| 68 |
setFiles(tree);
|
| 69 |
} catch (err: any) {
|
| 70 |
+
const errorMessage = err instanceof Error ? err.message : 'Failed to load files';
|
| 71 |
+
setError(errorMessage);
|
| 72 |
+
// Only log in development
|
| 73 |
+
if (process.env.NODE_ENV === 'development') {
|
| 74 |
+
console.error('Error fetching file tree:', err);
|
| 75 |
+
}
|
| 76 |
} finally {
|
| 77 |
setLoading(false);
|
| 78 |
}
|
| 79 |
};
|
| 80 |
|
| 81 |
+
fetchFiles();
|
|
|
|
|
|
|
| 82 |
}, [modelId]);
|
| 83 |
|
| 84 |
const buildFileTree = (fileList: any[]): FileNode[] => {
|
| 85 |
+
if (!Array.isArray(fileList) || fileList.length === 0) {
|
| 86 |
+
return [];
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
const tree: FileNode[] = [];
|
| 90 |
const pathMap = new Map<string, FileNode>();
|
| 91 |
|
| 92 |
+
// Sort files by path for consistent ordering
|
| 93 |
+
const sortedFiles = [...fileList].sort((a, b) => {
|
| 94 |
+
const pathA = a.path || '';
|
| 95 |
+
const pathB = b.path || '';
|
| 96 |
+
return pathA.localeCompare(pathB);
|
| 97 |
+
});
|
| 98 |
|
| 99 |
for (const file of sortedFiles) {
|
| 100 |
+
if (!file.path) continue;
|
| 101 |
+
|
| 102 |
+
const parts = file.path.split('/').filter((p: string) => p.length > 0);
|
| 103 |
+
if (parts.length === 0) continue;
|
| 104 |
+
|
| 105 |
let currentPath = '';
|
| 106 |
let parent: FileNode | null = null;
|
| 107 |
|
|
|
|
| 114 |
const node: FileNode = {
|
| 115 |
path: currentPath,
|
| 116 |
type: isDirectory ? 'directory' : 'file',
|
| 117 |
+
size: isDirectory ? undefined : (file.size || undefined), // Only set size for files
|
| 118 |
children: isDirectory ? [] : undefined,
|
| 119 |
};
|
| 120 |
|
|
|
|
| 148 |
});
|
| 149 |
};
|
| 150 |
|
| 151 |
+
const expandAll = () => {
|
| 152 |
+
const allPaths = new Set<string>();
|
| 153 |
+
const collectPaths = (nodes: FileNode[]) => {
|
| 154 |
+
nodes.forEach(node => {
|
| 155 |
+
if (node.type === 'directory' && node.children) {
|
| 156 |
+
allPaths.add(node.path);
|
| 157 |
+
if (node.children.length > 0) {
|
| 158 |
+
collectPaths(node.children);
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
});
|
| 162 |
+
};
|
| 163 |
+
collectPaths(files);
|
| 164 |
+
setExpandedPaths(allPaths);
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
const collapseAll = () => {
|
| 168 |
+
setExpandedPaths(new Set());
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
const formatFileSize = (bytes?: number): string => {
|
| 172 |
if (!bytes) return '';
|
| 173 |
if (bytes < 1024) return `${bytes} B`;
|
|
|
|
| 176 |
return `${(bytes / (1024 * 1024 * 1024)).toFixed(1)} GB`;
|
| 177 |
};
|
| 178 |
|
| 179 |
+
// Get all file extensions from the tree
|
| 180 |
+
const getAllFileExtensions = useMemo(() => {
|
| 181 |
+
const extensions = new Set<string>();
|
| 182 |
+
const collectExtensions = (nodes: FileNode[]) => {
|
| 183 |
+
nodes.forEach(node => {
|
| 184 |
+
if (node.type === 'file') {
|
| 185 |
+
const ext = node.path.split('.').pop()?.toLowerCase();
|
| 186 |
+
if (ext) extensions.add(ext);
|
| 187 |
+
}
|
| 188 |
+
if (node.children) {
|
| 189 |
+
collectExtensions(node.children);
|
| 190 |
+
}
|
| 191 |
+
});
|
| 192 |
+
};
|
| 193 |
+
collectExtensions(files);
|
| 194 |
+
return Array.from(extensions).sort();
|
| 195 |
+
}, [files]);
|
| 196 |
+
|
| 197 |
+
// Auto-expand directories when searching
|
| 198 |
+
useEffect(() => {
|
| 199 |
+
if (searchQuery) {
|
| 200 |
+
const pathsToExpand = new Set<string>();
|
| 201 |
+
const findMatchingPaths = (nodes: FileNode[], query: string) => {
|
| 202 |
+
nodes.forEach(node => {
|
| 203 |
+
if (node.path.toLowerCase().includes(query.toLowerCase())) {
|
| 204 |
+
// Expand all parent directories
|
| 205 |
+
const parts = node.path.split('/');
|
| 206 |
+
let currentPath = '';
|
| 207 |
+
for (let i = 0; i < parts.length - 1; i++) {
|
| 208 |
+
currentPath = currentPath ? `${currentPath}/${parts[i]}` : parts[i];
|
| 209 |
+
pathsToExpand.add(currentPath);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
if (node.children) {
|
| 213 |
+
findMatchingPaths(node.children, query);
|
| 214 |
+
}
|
| 215 |
+
});
|
| 216 |
+
};
|
| 217 |
+
findMatchingPaths(files, searchQuery);
|
| 218 |
+
setExpandedPaths(pathsToExpand);
|
| 219 |
+
}
|
| 220 |
+
}, [searchQuery, files]);
|
| 221 |
+
|
| 222 |
+
// Filter files based on search and file type
|
| 223 |
+
const filterNodes = (nodes: FileNode[]): FileNode[] => {
|
| 224 |
+
return nodes
|
| 225 |
+
.map(node => {
|
| 226 |
+
const matchesSearch = !searchQuery ||
|
| 227 |
+
node.path.toLowerCase().includes(searchQuery.toLowerCase());
|
| 228 |
+
const matchesType = fileTypeFilter === 'all' ||
|
| 229 |
+
(node.type === 'file' && node.path.toLowerCase().endsWith(`.${fileTypeFilter}`)) ||
|
| 230 |
+
(node.type === 'directory');
|
| 231 |
+
|
| 232 |
+
if (!matchesSearch || !matchesType) {
|
| 233 |
+
return null;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
const filteredChildren = node.children ? filterNodes(node.children) : undefined;
|
| 237 |
+
const result: FileNode | null = filteredChildren && filteredChildren.length > 0
|
| 238 |
+
? { ...node, children: filteredChildren }
|
| 239 |
+
: filteredChildren === undefined && matchesSearch && matchesType
|
| 240 |
+
? { ...node }
|
| 241 |
+
: null;
|
| 242 |
+
return result;
|
| 243 |
+
})
|
| 244 |
+
.filter((node): node is FileNode => node !== null);
|
| 245 |
+
};
|
| 246 |
+
|
| 247 |
+
const filteredFiles = useMemo(() => {
|
| 248 |
+
if (!searchQuery && fileTypeFilter === 'all') return files;
|
| 249 |
+
return filterNodes(files);
|
| 250 |
+
}, [files, searchQuery, fileTypeFilter]);
|
| 251 |
+
|
| 252 |
+
// Count total files
|
| 253 |
+
const countFiles = (nodes: FileNode[]): number => {
|
| 254 |
+
let count = 0;
|
| 255 |
+
nodes.forEach(node => {
|
| 256 |
+
if (node.type === 'file') count++;
|
| 257 |
+
if (node.children) count += countFiles(node.children);
|
| 258 |
+
});
|
| 259 |
+
return count;
|
| 260 |
+
};
|
| 261 |
+
|
| 262 |
+
const totalFileCount = useMemo(() => countFiles(files), [files]);
|
| 263 |
+
const visibleFileCount = useMemo(() => countFiles(filteredFiles), [filteredFiles]);
|
| 264 |
+
|
| 265 |
+
// Keyboard shortcut for search (Cmd+K / Ctrl+K)
|
| 266 |
+
useEffect(() => {
|
| 267 |
+
const handleKeyDown = (e: KeyboardEvent) => {
|
| 268 |
+
if ((e.metaKey || e.ctrlKey) && e.key === 'k') {
|
| 269 |
+
e.preventDefault();
|
| 270 |
+
setShowSearch(true);
|
| 271 |
+
setTimeout(() => searchInputRef.current?.focus(), 0);
|
| 272 |
+
}
|
| 273 |
+
if (e.key === 'Escape' && showSearch) {
|
| 274 |
+
setShowSearch(false);
|
| 275 |
+
setSearchQuery('');
|
| 276 |
+
}
|
| 277 |
+
};
|
| 278 |
+
window.addEventListener('keydown', handleKeyDown);
|
| 279 |
+
return () => window.removeEventListener('keydown', handleKeyDown);
|
| 280 |
+
}, [showSearch]);
|
| 281 |
+
|
| 282 |
const getFileIcon = (node: FileNode): string => {
|
| 283 |
if (node.type === 'directory') {
|
| 284 |
return expandedPaths.has(node.path) ? 'π' : 'π';
|
|
|
|
| 301 |
return iconMap[ext || ''] || 'π';
|
| 302 |
};
|
| 303 |
|
| 304 |
+
const copyFilePath = (path: string) => {
|
| 305 |
+
navigator.clipboard.writeText(path).then(() => {
|
| 306 |
+
// Show temporary feedback
|
| 307 |
+
const button = document.querySelector(`[data-file-path="${path}"]`) as HTMLElement;
|
| 308 |
+
if (button) {
|
| 309 |
+
const originalText = button.textContent;
|
| 310 |
+
button.textContent = 'Copied!';
|
| 311 |
+
setTimeout(() => {
|
| 312 |
+
if (button) button.textContent = originalText;
|
| 313 |
+
}, 1000);
|
| 314 |
+
}
|
| 315 |
+
});
|
| 316 |
+
};
|
| 317 |
+
|
| 318 |
+
const getFileUrl = (path: string) => {
|
| 319 |
+
return `https://huggingface.co/${modelId}/resolve/main/${path}`;
|
| 320 |
+
};
|
| 321 |
+
|
| 322 |
const renderNode = (node: FileNode, depth: number = 0): React.ReactNode => {
|
| 323 |
const isExpanded = expandedPaths.has(node.path);
|
| 324 |
const hasChildren = node.children && node.children.length > 0;
|
| 325 |
+
const fileName = node.path.split('/').pop() || node.path;
|
| 326 |
|
| 327 |
return (
|
| 328 |
<div key={node.path} className="file-tree-node" style={{ paddingLeft: `${depth * 1.5}rem` }}>
|
|
|
|
| 332 |
style={{ cursor: node.type === 'directory' ? 'pointer' : 'default' }}
|
| 333 |
>
|
| 334 |
<span className="file-icon">{getFileIcon(node)}</span>
|
| 335 |
+
<span className="file-name" title={node.path}>{fileName}</span>
|
| 336 |
{node.type === 'file' && node.size && (
|
| 337 |
<span className="file-size">{formatFileSize(node.size)}</span>
|
| 338 |
)}
|
| 339 |
{node.type === 'directory' && (
|
| 340 |
<span className="file-expand">{isExpanded ? 'βΌ' : 'βΆ'}</span>
|
| 341 |
)}
|
| 342 |
+
{node.type === 'file' && (
|
| 343 |
+
<div className="file-actions" onClick={(e) => e.stopPropagation()}>
|
| 344 |
+
<button
|
| 345 |
+
className="file-action-btn"
|
| 346 |
+
onClick={() => copyFilePath(node.path)}
|
| 347 |
+
data-file-path={node.path}
|
| 348 |
+
title="Copy file path"
|
| 349 |
+
aria-label="Copy path"
|
| 350 |
+
>
|
| 351 |
+
π
|
| 352 |
+
</button>
|
| 353 |
+
<a
|
| 354 |
+
href={getFileUrl(node.path)}
|
| 355 |
+
target="_blank"
|
| 356 |
+
rel="noopener noreferrer"
|
| 357 |
+
className="file-action-btn"
|
| 358 |
+
title="Download file"
|
| 359 |
+
aria-label="Download"
|
| 360 |
+
onClick={(e) => e.stopPropagation()}
|
| 361 |
+
>
|
| 362 |
+
β¬οΈ
|
| 363 |
+
</a>
|
| 364 |
+
</div>
|
| 365 |
+
)}
|
| 366 |
</div>
|
| 367 |
{isExpanded && hasChildren && (
|
| 368 |
<div className="file-tree-children">
|
|
|
|
| 402 |
);
|
| 403 |
}
|
| 404 |
|
| 405 |
+
const hasDirectories = files.some(node => node.type === 'directory');
|
| 406 |
+
|
| 407 |
return (
|
| 408 |
<div className="file-tree-container">
|
| 409 |
<div className="file-tree-header">
|
| 410 |
+
<div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem' }}>
|
| 411 |
+
<strong>Repository Files</strong>
|
| 412 |
+
<span className="file-count-badge">
|
| 413 |
+
{visibleFileCount === totalFileCount
|
| 414 |
+
? `${totalFileCount} file${totalFileCount !== 1 ? 's' : ''}`
|
| 415 |
+
: `${visibleFileCount} of ${totalFileCount} files`}
|
| 416 |
+
</span>
|
| 417 |
+
</div>
|
| 418 |
+
<div style={{ display: 'flex', gap: '0.5rem', alignItems: 'center', flexWrap: 'wrap' }}>
|
| 419 |
+
<button
|
| 420 |
+
onClick={() => setShowSearch(!showSearch)}
|
| 421 |
+
className="file-tree-button"
|
| 422 |
+
title="Search files (Cmd+K)"
|
| 423 |
+
aria-label="Search"
|
| 424 |
+
>
|
| 425 |
+
π Search
|
| 426 |
+
</button>
|
| 427 |
+
{hasDirectories && (
|
| 428 |
+
<>
|
| 429 |
+
<button
|
| 430 |
+
onClick={expandAll}
|
| 431 |
+
className="file-tree-button"
|
| 432 |
+
title="Expand all directories"
|
| 433 |
+
aria-label="Expand all"
|
| 434 |
+
>
|
| 435 |
+
Expand All
|
| 436 |
+
</button>
|
| 437 |
+
<button
|
| 438 |
+
onClick={collapseAll}
|
| 439 |
+
className="file-tree-button"
|
| 440 |
+
title="Collapse all directories"
|
| 441 |
+
aria-label="Collapse all"
|
| 442 |
+
>
|
| 443 |
+
Collapse All
|
| 444 |
+
</button>
|
| 445 |
+
</>
|
| 446 |
+
)}
|
| 447 |
+
<a
|
| 448 |
+
href={getHuggingFaceFileTreeUrl(modelId, 'main')}
|
| 449 |
+
target="_blank"
|
| 450 |
+
rel="noopener noreferrer"
|
| 451 |
+
className="file-tree-link"
|
| 452 |
+
>
|
| 453 |
+
View on HF β
|
| 454 |
+
</a>
|
| 455 |
+
</div>
|
| 456 |
</div>
|
| 457 |
+
|
| 458 |
+
{/* Search and Filter Bar */}
|
| 459 |
+
{(showSearch || searchQuery || fileTypeFilter !== 'all') && (
|
| 460 |
+
<div className="file-tree-filters">
|
| 461 |
+
<div className="file-tree-search">
|
| 462 |
+
<input
|
| 463 |
+
ref={searchInputRef}
|
| 464 |
+
type="text"
|
| 465 |
+
placeholder="Search files... (Cmd+K)"
|
| 466 |
+
value={searchQuery}
|
| 467 |
+
onChange={(e) => setSearchQuery(e.target.value)}
|
| 468 |
+
className="file-tree-search-input"
|
| 469 |
+
/>
|
| 470 |
+
{searchQuery && (
|
| 471 |
+
<button
|
| 472 |
+
onClick={() => setSearchQuery('')}
|
| 473 |
+
className="file-tree-clear"
|
| 474 |
+
aria-label="Clear search"
|
| 475 |
+
>
|
| 476 |
+
β
|
| 477 |
+
</button>
|
| 478 |
+
)}
|
| 479 |
+
</div>
|
| 480 |
+
{getAllFileExtensions.length > 0 && (
|
| 481 |
+
<select
|
| 482 |
+
value={fileTypeFilter}
|
| 483 |
+
onChange={(e) => setFileTypeFilter(e.target.value)}
|
| 484 |
+
className="file-tree-type-filter"
|
| 485 |
+
>
|
| 486 |
+
<option value="all">All file types</option>
|
| 487 |
+
{getAllFileExtensions.map(ext => (
|
| 488 |
+
<option key={ext} value={ext}>.{ext}</option>
|
| 489 |
+
))}
|
| 490 |
+
</select>
|
| 491 |
+
)}
|
| 492 |
+
</div>
|
| 493 |
+
)}
|
| 494 |
+
|
| 495 |
<div className="file-tree">
|
| 496 |
+
{filteredFiles.length === 0 ? (
|
| 497 |
+
<div className="file-tree-empty">
|
| 498 |
+
{searchQuery || fileTypeFilter !== 'all'
|
| 499 |
+
? 'No files match your filters'
|
| 500 |
+
: 'No files found'}
|
| 501 |
+
</div>
|
| 502 |
+
) : (
|
| 503 |
+
filteredFiles.map((node) => renderNode(node))
|
| 504 |
+
)}
|
| 505 |
</div>
|
| 506 |
</div>
|
| 507 |
);
|
frontend/src/components/{ModelModal.css β modals/ModelModal.css}
RENAMED
|
@@ -25,7 +25,7 @@
|
|
| 25 |
.modal-content {
|
| 26 |
background: #ffffff;
|
| 27 |
border-radius: 8px;
|
| 28 |
-
max-width:
|
| 29 |
width: 100%;
|
| 30 |
max-height: 90vh;
|
| 31 |
overflow-y: auto;
|
|
@@ -34,11 +34,16 @@
|
|
| 34 |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
|
| 35 |
border: 1px solid #d0d0d0;
|
| 36 |
animation: slideUp 0.3s ease-out;
|
| 37 |
-
font-family: '
|
| 38 |
display: flex;
|
| 39 |
flex-direction: column;
|
| 40 |
}
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
@keyframes slideUp {
|
| 43 |
from {
|
| 44 |
transform: translateY(20px);
|
|
@@ -66,7 +71,7 @@
|
|
| 66 |
justify-content: center;
|
| 67 |
border-radius: 2px;
|
| 68 |
transition: all 0.2s;
|
| 69 |
-
font-family: '
|
| 70 |
}
|
| 71 |
|
| 72 |
.modal-close:hover {
|
|
@@ -89,7 +94,7 @@
|
|
| 89 |
font-size: 1.5rem;
|
| 90 |
color: #1a1a1a;
|
| 91 |
word-break: break-word;
|
| 92 |
-
font-family: '
|
| 93 |
font-weight: 600;
|
| 94 |
line-height: 1.3;
|
| 95 |
}
|
|
@@ -117,7 +122,7 @@
|
|
| 117 |
border-radius: 4px;
|
| 118 |
cursor: pointer;
|
| 119 |
font-size: 0.85rem;
|
| 120 |
-
font-family: '
|
| 121 |
transition: all 0.2s;
|
| 122 |
font-weight: 500;
|
| 123 |
}
|
|
@@ -136,6 +141,12 @@
|
|
| 136 |
gap: 0.5rem;
|
| 137 |
margin-bottom: 1.5rem;
|
| 138 |
border-bottom: 2px solid #e0e0e0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
}
|
| 140 |
|
| 141 |
.modal-tab {
|
|
@@ -145,11 +156,29 @@
|
|
| 145 |
border-bottom: 2px solid transparent;
|
| 146 |
cursor: pointer;
|
| 147 |
font-size: 0.9rem;
|
| 148 |
-
font-family: '
|
| 149 |
color: #666;
|
| 150 |
font-weight: 500;
|
| 151 |
margin-bottom: -2px;
|
| 152 |
transition: all 0.2s;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
}
|
| 154 |
|
| 155 |
.modal-tab:hover {
|
|
@@ -186,14 +215,14 @@
|
|
| 186 |
text-transform: uppercase;
|
| 187 |
letter-spacing: 0.5px;
|
| 188 |
font-weight: 600;
|
| 189 |
-
font-family: '
|
| 190 |
}
|
| 191 |
|
| 192 |
.info-value {
|
| 193 |
font-size: 1.1rem;
|
| 194 |
color: #1a1a1a;
|
| 195 |
font-weight: 500;
|
| 196 |
-
font-family: '
|
| 197 |
}
|
| 198 |
|
| 199 |
.info-value.highlight {
|
|
@@ -230,7 +259,7 @@
|
|
| 230 |
letter-spacing: 0.5px;
|
| 231 |
font-weight: 600;
|
| 232 |
margin-bottom: 0.75rem;
|
| 233 |
-
font-family: '
|
| 234 |
}
|
| 235 |
|
| 236 |
.section-content {
|
|
@@ -287,7 +316,7 @@
|
|
| 287 |
color: #4a4a4a;
|
| 288 |
text-transform: uppercase;
|
| 289 |
letter-spacing: 0.5px;
|
| 290 |
-
font-family: '
|
| 291 |
}
|
| 292 |
|
| 293 |
.modal-info-grid {
|
|
@@ -306,14 +335,14 @@
|
|
| 306 |
font-size: 0.875rem;
|
| 307 |
color: #6a6a6a;
|
| 308 |
font-weight: 500;
|
| 309 |
-
font-family: '
|
| 310 |
}
|
| 311 |
|
| 312 |
.modal-info-item span {
|
| 313 |
font-size: 1rem;
|
| 314 |
color: #1a1a1a;
|
| 315 |
font-weight: 500;
|
| 316 |
-
font-family: '
|
| 317 |
}
|
| 318 |
|
| 319 |
.modal-tags {
|
|
@@ -324,7 +353,7 @@
|
|
| 324 |
color: #1a1a1a;
|
| 325 |
font-size: 0.9rem;
|
| 326 |
line-height: 1.5;
|
| 327 |
-
font-family: '
|
| 328 |
}
|
| 329 |
|
| 330 |
.modal-footer {
|
|
@@ -345,7 +374,7 @@
|
|
| 345 |
text-decoration: none;
|
| 346 |
border-radius: 4px;
|
| 347 |
font-weight: 500;
|
| 348 |
-
font-family: '
|
| 349 |
transition: all 0.2s;
|
| 350 |
border: 1px solid #1a1a1a;
|
| 351 |
}
|
|
|
|
| 25 |
.modal-content {
|
| 26 |
background: #ffffff;
|
| 27 |
border-radius: 8px;
|
| 28 |
+
max-width: 900px;
|
| 29 |
width: 100%;
|
| 30 |
max-height: 90vh;
|
| 31 |
overflow-y: auto;
|
|
|
|
| 34 |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
|
| 35 |
border: 1px solid #d0d0d0;
|
| 36 |
animation: slideUp 0.3s ease-out;
|
| 37 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 38 |
display: flex;
|
| 39 |
flex-direction: column;
|
| 40 |
}
|
| 41 |
|
| 42 |
+
.modal-content[data-tab="files"] {
|
| 43 |
+
max-width: 1000px;
|
| 44 |
+
max-height: 95vh;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
@keyframes slideUp {
|
| 48 |
from {
|
| 49 |
transform: translateY(20px);
|
|
|
|
| 71 |
justify-content: center;
|
| 72 |
border-radius: 2px;
|
| 73 |
transition: all 0.2s;
|
| 74 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 75 |
}
|
| 76 |
|
| 77 |
.modal-close:hover {
|
|
|
|
| 94 |
font-size: 1.5rem;
|
| 95 |
color: #1a1a1a;
|
| 96 |
word-break: break-word;
|
| 97 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 98 |
font-weight: 600;
|
| 99 |
line-height: 1.3;
|
| 100 |
}
|
|
|
|
| 122 |
border-radius: 4px;
|
| 123 |
cursor: pointer;
|
| 124 |
font-size: 0.85rem;
|
| 125 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 126 |
transition: all 0.2s;
|
| 127 |
font-weight: 500;
|
| 128 |
}
|
|
|
|
| 141 |
gap: 0.5rem;
|
| 142 |
margin-bottom: 1.5rem;
|
| 143 |
border-bottom: 2px solid #e0e0e0;
|
| 144 |
+
position: sticky;
|
| 145 |
+
top: 0;
|
| 146 |
+
background: #ffffff;
|
| 147 |
+
z-index: 10;
|
| 148 |
+
padding-top: 0.5rem;
|
| 149 |
+
margin-top: -0.5rem;
|
| 150 |
}
|
| 151 |
|
| 152 |
.modal-tab {
|
|
|
|
| 156 |
border-bottom: 2px solid transparent;
|
| 157 |
cursor: pointer;
|
| 158 |
font-size: 0.9rem;
|
| 159 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 160 |
color: #666;
|
| 161 |
font-weight: 500;
|
| 162 |
margin-bottom: -2px;
|
| 163 |
transition: all 0.2s;
|
| 164 |
+
display: flex;
|
| 165 |
+
align-items: center;
|
| 166 |
+
gap: 0.5rem;
|
| 167 |
+
position: relative;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
.tab-icon {
|
| 171 |
+
font-size: 1rem;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.tab-badge {
|
| 175 |
+
background: #4a90e2;
|
| 176 |
+
color: white;
|
| 177 |
+
font-size: 0.7rem;
|
| 178 |
+
padding: 0.15rem 0.4rem;
|
| 179 |
+
border-radius: 10px;
|
| 180 |
+
font-weight: 600;
|
| 181 |
+
margin-left: 0.25rem;
|
| 182 |
}
|
| 183 |
|
| 184 |
.modal-tab:hover {
|
|
|
|
| 215 |
text-transform: uppercase;
|
| 216 |
letter-spacing: 0.5px;
|
| 217 |
font-weight: 600;
|
| 218 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 219 |
}
|
| 220 |
|
| 221 |
.info-value {
|
| 222 |
font-size: 1.1rem;
|
| 223 |
color: #1a1a1a;
|
| 224 |
font-weight: 500;
|
| 225 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 226 |
}
|
| 227 |
|
| 228 |
.info-value.highlight {
|
|
|
|
| 259 |
letter-spacing: 0.5px;
|
| 260 |
font-weight: 600;
|
| 261 |
margin-bottom: 0.75rem;
|
| 262 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 263 |
}
|
| 264 |
|
| 265 |
.section-content {
|
|
|
|
| 316 |
color: #4a4a4a;
|
| 317 |
text-transform: uppercase;
|
| 318 |
letter-spacing: 0.5px;
|
| 319 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 320 |
}
|
| 321 |
|
| 322 |
.modal-info-grid {
|
|
|
|
| 335 |
font-size: 0.875rem;
|
| 336 |
color: #6a6a6a;
|
| 337 |
font-weight: 500;
|
| 338 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 339 |
}
|
| 340 |
|
| 341 |
.modal-info-item span {
|
| 342 |
font-size: 1rem;
|
| 343 |
color: #1a1a1a;
|
| 344 |
font-weight: 500;
|
| 345 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 346 |
}
|
| 347 |
|
| 348 |
.modal-tags {
|
|
|
|
| 353 |
color: #1a1a1a;
|
| 354 |
font-size: 0.9rem;
|
| 355 |
line-height: 1.5;
|
| 356 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 357 |
}
|
| 358 |
|
| 359 |
.modal-footer {
|
|
|
|
| 374 |
text-decoration: none;
|
| 375 |
border-radius: 4px;
|
| 376 |
font-weight: 500;
|
| 377 |
+
font-family: 'Instrument Sans', sans-serif;
|
| 378 |
transition: all 0.2s;
|
| 379 |
border: 1px solid #1a1a1a;
|
| 380 |
}
|
frontend/src/components/{ModelModal.tsx β modals/ModelModal.tsx}
RENAMED
|
@@ -3,12 +3,12 @@
|
|
| 3 |
* Enhanced with bookmark, comparison, similar models, and file tree features.
|
| 4 |
*/
|
| 5 |
import React, { useState, useEffect } from 'react';
|
| 6 |
-
import { ModelPoint } from '
|
| 7 |
import FileTree from './FileTree';
|
|
|
|
|
|
|
| 8 |
import './ModelModal.css';
|
| 9 |
|
| 10 |
-
const API_BASE = process.env.REACT_APP_API_URL || 'http://localhost:8000';
|
| 11 |
-
|
| 12 |
interface ArxivPaper {
|
| 13 |
arxiv_id: string;
|
| 14 |
title: string;
|
|
@@ -71,7 +71,7 @@ export default function ModelModal({
|
|
| 71 |
|
| 72 |
if (!isOpen || !model) return null;
|
| 73 |
|
| 74 |
-
const hfUrl =
|
| 75 |
|
| 76 |
// Parse tags if it's a string representation of an array
|
| 77 |
const parseTags = (tags: string | null | undefined): string[] => {
|
|
@@ -156,7 +156,11 @@ export default function ModelModal({
|
|
| 156 |
|
| 157 |
return (
|
| 158 |
<div className="modal-overlay" onClick={onClose}>
|
| 159 |
-
<div
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
<div className="modal-header">
|
| 161 |
<h2>{model.model_id}</h2>
|
| 162 |
<button className="modal-close" onClick={onClose}>Close</button>
|
|
@@ -197,20 +201,24 @@ export default function ModelModal({
|
|
| 197 |
className={`modal-tab ${activeTab === 'details' ? 'active' : ''}`}
|
| 198 |
onClick={() => setActiveTab('details')}
|
| 199 |
>
|
| 200 |
-
|
|
|
|
| 201 |
</button>
|
| 202 |
<button
|
| 203 |
className={`modal-tab ${activeTab === 'files' ? 'active' : ''}`}
|
| 204 |
onClick={() => setActiveTab('files')}
|
| 205 |
>
|
| 206 |
-
|
|
|
|
| 207 |
</button>
|
| 208 |
{(papers.length > 0 || papersLoading) && (
|
| 209 |
<button
|
| 210 |
className={`modal-tab ${activeTab === 'papers' ? 'active' : ''}`}
|
| 211 |
onClick={() => setActiveTab('papers')}
|
| 212 |
>
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
</button>
|
| 215 |
)}
|
| 216 |
</div>
|
|
@@ -283,7 +291,7 @@ export default function ModelModal({
|
|
| 283 |
<div className="section-title">Parent Model</div>
|
| 284 |
<div className="section-content">
|
| 285 |
<a
|
| 286 |
-
href={
|
| 287 |
target="_blank"
|
| 288 |
rel="noopener noreferrer"
|
| 289 |
className="model-link"
|
|
|
|
| 3 |
* Enhanced with bookmark, comparison, similar models, and file tree features.
|
| 4 |
*/
|
| 5 |
import React, { useState, useEffect } from 'react';
|
| 6 |
+
import { ModelPoint } from '../../types';
|
| 7 |
import FileTree from './FileTree';
|
| 8 |
+
import { getHuggingFaceUrl } from '../../utils/api/hfUrl';
|
| 9 |
+
import { API_BASE } from '../../config/api';
|
| 10 |
import './ModelModal.css';
|
| 11 |
|
|
|
|
|
|
|
| 12 |
interface ArxivPaper {
|
| 13 |
arxiv_id: string;
|
| 14 |
title: string;
|
|
|
|
| 71 |
|
| 72 |
if (!isOpen || !model) return null;
|
| 73 |
|
| 74 |
+
const hfUrl = getHuggingFaceUrl(model.model_id);
|
| 75 |
|
| 76 |
// Parse tags if it's a string representation of an array
|
| 77 |
const parseTags = (tags: string | null | undefined): string[] => {
|
|
|
|
| 156 |
|
| 157 |
return (
|
| 158 |
<div className="modal-overlay" onClick={onClose}>
|
| 159 |
+
<div
|
| 160 |
+
className="modal-content"
|
| 161 |
+
onClick={(e) => e.stopPropagation()}
|
| 162 |
+
data-tab={activeTab}
|
| 163 |
+
>
|
| 164 |
<div className="modal-header">
|
| 165 |
<h2>{model.model_id}</h2>
|
| 166 |
<button className="modal-close" onClick={onClose}>Close</button>
|
|
|
|
| 201 |
className={`modal-tab ${activeTab === 'details' ? 'active' : ''}`}
|
| 202 |
onClick={() => setActiveTab('details')}
|
| 203 |
>
|
| 204 |
+
<span className="tab-icon">π</span>
|
| 205 |
+
<span>Details</span>
|
| 206 |
</button>
|
| 207 |
<button
|
| 208 |
className={`modal-tab ${activeTab === 'files' ? 'active' : ''}`}
|
| 209 |
onClick={() => setActiveTab('files')}
|
| 210 |
>
|
| 211 |
+
<span className="tab-icon">π</span>
|
| 212 |
+
<span>Files</span>
|
| 213 |
</button>
|
| 214 |
{(papers.length > 0 || papersLoading) && (
|
| 215 |
<button
|
| 216 |
className={`modal-tab ${activeTab === 'papers' ? 'active' : ''}`}
|
| 217 |
onClick={() => setActiveTab('papers')}
|
| 218 |
>
|
| 219 |
+
<span className="tab-icon">π</span>
|
| 220 |
+
<span>Papers</span>
|
| 221 |
+
{papers.length > 0 && <span className="tab-badge">{papers.length}</span>}
|
| 222 |
</button>
|
| 223 |
)}
|
| 224 |
</div>
|
|
|
|
| 291 |
<div className="section-title">Parent Model</div>
|
| 292 |
<div className="section-content">
|
| 293 |
<a
|
| 294 |
+
href={getHuggingFaceUrl(model.parent_model)}
|
| 295 |
target="_blank"
|
| 296 |
rel="noopener noreferrer"
|
| 297 |
className="model-link"
|
frontend/src/components/{ColorLegend.css β ui/ColorLegend.css}
RENAMED
|
File without changes
|
frontend/src/components/{ColorLegend.tsx β ui/ColorLegend.tsx}
RENAMED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
* Shows color mappings for categorical and continuous data.
|
| 4 |
*/
|
| 5 |
import React from 'react';
|
| 6 |
-
import { getCategoricalColorMap, getContinuousColorScale } from '
|
| 7 |
import './ColorLegend.css';
|
| 8 |
|
| 9 |
interface ColorLegendProps {
|
|
|
|
| 3 |
* Shows color mappings for categorical and continuous data.
|
| 4 |
*/
|
| 5 |
import React from 'react';
|
| 6 |
+
import { getCategoricalColorMap, getContinuousColorScale } from '../../utils/rendering/colors';
|
| 7 |
import './ColorLegend.css';
|
| 8 |
|
| 9 |
interface ColorLegendProps {
|
frontend/src/components/{ErrorBoundary.tsx β ui/ErrorBoundary.tsx}
RENAMED
|
File without changes
|